Skip to content

Add more MLX dispatches #5327

Add more MLX dispatches

Add more MLX dispatches #5327

Workflow file for this run

name: Tests
on:
push:
branches:
- main
pull_request:
branches:
- main
# Cancels all previous workflow runs for pull requests that have not completed.
concurrency:
# The concurrency group contains the workflow name and the branch name for pull requests
# or the commit hash for any other events.
group: ${{ github.workflow }}-${{ github.event_name == 'pull_request' && github.head_ref || github.sha }}
cancel-in-progress: true
jobs:
changes:
name: "Check for changes"
runs-on: ubuntu-latest
outputs:
changes: ${{ steps.changes.outputs.src }}
steps:
- uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
with:
fetch-depth: 0
persist-credentials: false
- uses: dorny/paths-filter@de90cc6fb38fc0963ad72b210f1f284cd68cea36 # v3.0.2
id: changes
with:
filters: |
python: &python
- 'pytensor/**/*.py'
- 'tests/**/*.py'
- 'pytensor/**/*.pyx'
- 'tests/**/*.pyx'
- '*.py'
src:
- *python
- 'pytensor/**/*.c'
- 'tests/**/*.c'
- 'pytensor/**/*.h'
- 'tests/**/*.h'
- '.github/workflows/*.yml'
- 'setup.cfg'
- 'requirements.txt'
- '.pre-commit-config.yaml'
style:
name: Check code style
needs: changes
runs-on: ubuntu-latest
if: ${{ needs.changes.outputs.changes == 'true' }}
strategy:
matrix:
python-version: ["3.11", "3.13"]
steps:
- uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
with:
persist-credentials: false
- uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0
with:
python-version: ${{ matrix.python-version }}
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
test:
name: "${{ matrix.os }} test py${{ matrix.python-version }} : fast-compile ${{ matrix.fast-compile }} : float32 ${{ matrix.float32 }} : ${{ matrix.part }}"
needs:
- changes
- style
runs-on: ${{ matrix.os }}
if: ${{ needs.changes.outputs.changes == 'true' && needs.style.result == 'success' }}
strategy:
fail-fast: false
matrix:
os: ["ubuntu-latest"]
python-version: ["3.11", "3.13"]
fast-compile: [0, 1]
float32: [0, 1]
install-numba: [0]
install-jax: [0]
install-torch: [0]
install-mlx: [0]
install-xarray: [0]
part:
- "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/xtensor"
- "tests/scan"
- "tests/tensor --ignore=tests/tensor/rewriting --ignore=tests/tensor/test_math.py --ignore=tests/tensor/test_basic.py --ignore=tests/tensor/test_inplace.py --ignore=tests/tensor/conv --ignore=tests/tensor/test_blas.py --ignore=tests/tensor/test_elemwise.py --ignore=tests/tensor/test_math_scipy.py"
- "tests/tensor/rewriting"
- "tests/tensor/test_math.py"
- "tests/tensor/test_basic.py tests/tensor/test_inplace.py tests/tensor/conv"
- "tests/tensor/test_blas.py tests/tensor/test_elemwise.py tests/tensor/test_math_scipy.py"
exclude:
- python-version: "3.11"
fast-compile: 1
- python-version: "3.11"
float32: 1
- fast-compile: 1
float32: 1
include:
- os: "ubuntu-latest"
part: "--doctest-modules pytensor --ignore=pytensor/misc/check_duplicate_key.py --ignore=pytensor/link --ignore=pytensor/ipython.py"
python-version: "3.12"
fast-compile: 0
float32: 0
install-numba: 0
install-jax: 0
install-torch: 0
install-mlx: 0
install-xarray: 0
- install-numba: 1
os: "ubuntu-latest"
python-version: "3.11"
fast-compile: 0
float32: 0
part: "tests/link/numba --ignore=tests/link/numba/test_slinalg.py"
- install-numba: 1
os: "ubuntu-latest"
python-version: "3.13"
fast-compile: 0
float32: 0
part: "tests/link/numba --ignore=tests/link/numba/test_slinalg.py"
- install-numba: 1
os: "ubuntu-latest"
python-version: "3.13"
fast-compile: 0
float32: 0
part: "tests/link/numba/test_slinalg.py"
- install-jax: 1
os: "ubuntu-latest"
python-version: "3.11"
fast-compile: 0
float32: 0
part: "tests/link/jax"
- install-jax: 1
os: "ubuntu-latest"
python-version: "3.13"
fast-compile: 0
float32: 0
part: "tests/link/jax"
- install-torch: 1
os: "ubuntu-latest"
python-version: "3.11"
fast-compile: 0
float32: 0
part: "tests/link/pytorch"
- install-xarray: 1
os: "ubuntu-latest"
python-version: "3.13"
fast-compile: 0
float32: 0
part: "tests/xtensor"
- os: "macos-15"
python-version: "3.11"
fast-compile: 0
float32: 0
install-mlx: 1
install-numba: 0
install-jax: 0
install-torch: 0
part: "tests/link/mlx"
- os: "macos-15"
python-version: "3.13"
fast-compile: 0
float32: 0
install-numba: 0
install-jax: 0
install-torch: 0
part: "tests/tensor/test_blas.py tests/tensor/test_elemwise.py tests/tensor/test_math_scipy.py"
steps:
- uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
with:
fetch-depth: 0
persist-credentials: false
- name: Set up Python ${{ matrix.python-version }}
uses: mamba-org/setup-micromamba@add3a49764cedee8ee24e82dfde87f5bc2914462 # v2.0.7
with:
environment-name: pytensor-test
micromamba-version: "1.5.10-0" # until https://github.com/mamba-org/setup-micromamba/issues/225 is resolved
init-shell: bash
post-cleanup: "all"
create-args: python=${{ matrix.python-version }}
- name: Create matrix id
id: matrix-id
env:
MATRIX_CONTEXT: ${{ toJson(matrix) }}
run: |
echo $MATRIX_CONTEXT
export MATRIX_ID=`echo $MATRIX_CONTEXT | sha256sum | cut -c 1-32`
echo $MATRIX_ID
echo "id=$MATRIX_ID" >> $GITHUB_OUTPUT
- name: Install dependencies
shell: micromamba-shell {0}
run: |
if [[ $OS == "macos-15" ]]; then
micromamba install --yes -q "python~=${PYTHON_VERSION}" "numpy${NUMPY_VERSION}" scipy pip graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock pytest-sphinx libblas=*=*accelerate;
else
micromamba install --yes -q "python~=${PYTHON_VERSION}" mkl "numpy${NUMPY_VERSION}" scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock pytest-sphinx;
fi
if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "numba>=0.57"; fi
if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" jax jaxlib numpyro equinox && pip install tfp-nightly; fi
if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi
if [[ $INSTALL_MLX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" mlx; fi
if [[ $INSTALL_XARRAY == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" xarray xarray-einstats; fi
pip install -e ./
micromamba list && pip freeze
python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))'
if [[ $OS == "macos-15" ]]; then
python -c 'import pytensor; assert pytensor.config.blas__ldflags.startswith("-framework Accelerate"), "Blas flags are not set to MacOS Accelerate"';
else
python -c 'import pytensor; assert pytensor.config.blas__ldflags != "", "Blas flags are empty"';
fi
env:
PYTHON_VERSION: ${{ matrix.python-version }}
INSTALL_NUMBA: ${{ matrix.install-numba }}
INSTALL_JAX: ${{ matrix.install-jax }}
INSTALL_TORCH: ${{ matrix.install-torch}}
INSTALL_XARRAY: ${{ matrix.install-xarray }}
INSTALL_MLX: ${{ matrix.install-mlx }}
OS: ${{ matrix.os}}
- name: Run tests
shell: micromamba-shell {0}
run: |
if [[ $FAST_COMPILE == "1" ]]; then export PYTENSOR_FLAGS=$PYTENSOR_FLAGS,mode=FAST_COMPILE; fi
if [[ $FLOAT32 == "1" ]]; then export PYTENSOR_FLAGS=$PYTENSOR_FLAGS,floatX=float32; fi
export PYTENSOR_FLAGS=$PYTENSOR_FLAGS,warn__ignore_bug_before=all,on_opt_error=raise,on_shape_error=raise,gcc__cxxflags=-pipe
python -m pytest -r A --verbose --runslow --durations=50 --cov=pytensor/ --cov-report=xml:coverage/coverage-${MATRIX_ID}.xml --no-cov-on-fail $PART --benchmark-skip
env:
MATRIX_ID: ${{ steps.matrix-id.outputs.id }}
MKL_THREADING_LAYER: GNU
MKL_NUM_THREADS: 1
OMP_NUM_THREADS: 1
PART: ${{ matrix.part }}
FAST_COMPILE: ${{ matrix.fast-compile }}
FLOAT32: ${{ matrix.float32 }}
- name: Upload coverage file
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with:
name: coverage-${{ steps.matrix-id.outputs.id }}
path: coverage/coverage-${{ steps.matrix-id.outputs.id }}.xml
benchmarks:
name: "Benchmarks"
needs:
- changes
- style
runs-on: ubuntu-latest
if: ${{ needs.changes.outputs.changes == 'true' && needs.style.result == 'success' }}
strategy:
fail-fast: false
steps:
- uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
with:
fetch-depth: 0
persist-credentials: false
- name: Set up Python 3.11
uses: mamba-org/setup-micromamba@add3a49764cedee8ee24e82dfde87f5bc2914462 # v2.0.7
with:
environment-name: pytensor-test
micromamba-version: "1.5.10-0" # until https://github.com/mamba-org/setup-micromamba/issues/225 is resolved
init-shell: bash
post-cleanup: "all"
- name: Install dependencies
shell: micromamba-shell {0}
run: |
micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" mkl numpy scipy pip mkl-service cython pytest "numba>=0.57" jax jaxlib pytest-benchmark
pip install -e ./
micromamba list && pip freeze
python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))'
python -c 'import pytensor; assert pytensor.config.blas__ldflags != "", "Blas flags are empty"'
env:
PYTHON_VERSION: 3.11
- name: Download previous benchmark data
uses: actions/cache@0057852bfaa89a56745cba8c7296529d2fc39830 # v4.3.0
with:
path: ./cache
key: ${{ runner.os }}-benchmark
- name: Run benchmarks
shell: micromamba-shell {0}
run: |
export PYTENSOR_FLAGS=mode=FAST_COMPILE,warn__ignore_bug_before=all,on_opt_error=raise,on_shape_error=raise,gcc__cxxflags=-pipe
python -m pytest --runslow --benchmark-only --benchmark-json output.json
- name: Store benchmark result
uses: benchmark-action/github-action-benchmark@4bdcce38c94cec68da58d012ac24b7b1155efe8b # v1.20.7
with:
name: Python Benchmark with pytest-benchmark
tool: "pytest"
output-file-path: output.json
external-data-json-path: ./cache/benchmark-data.json
alert-threshold: "200%"
github-token: ${{ secrets.GITHUB_TOKEN }}
comment-on-alert: false
fail-on-alert: false
auto-push: false
all-checks:
if: ${{ always() }}
runs-on: ubuntu-latest
name: "All tests"
needs: [changes, style, test]
steps:
- name: Check build matrix status
if: ${{ needs.changes.outputs.changes == 'true' && (needs.style.result != 'success' || needs.test.result != 'success') }}
run: exit 1
upload-coverage:
runs-on: ubuntu-latest
name: "Upload coverage"
needs: [changes, all-checks]
if: ${{ needs.changes.outputs.changes == 'true' && needs.all-checks.result == 'success' }}
steps:
- uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
with:
persist-credentials: false
- name: Set up Python
uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0
with:
python-version: "3.13"
- name: Install dependencies
run: |
python -m pip install -U coverage>=5.1 coveralls
- name: Download coverage file
uses: actions/download-artifact@634f93cb2916e3fdff6788551b99b062d0335ce0 # v5.0.0
with:
pattern: coverage-*
path: coverage
merge-multiple: true
- name: Upload coverage to Codecov
uses: codecov/codecov-action@5a1091511ad55cbe89839c7260b706298ca349f7 # v5.5.1
with:
directory: ./coverage/
fail_ci_if_error: true
token: ${{ secrets.CODECOV_TOKEN }}