diff --git a/.ci/docker/requirements-ci.txt b/.ci/docker/requirements-ci.txt index 8602d4d0fa..e0968f422c 100644 --- a/.ci/docker/requirements-ci.txt +++ b/.ci/docker/requirements-ci.txt @@ -1,7 +1,51 @@ -# Python dependencies required for unit tests +# Python dependencies required for NPU tests +# Based on upstream PyTorch .ci/docker/requirements-ci.txt -mypy==1.9.0 -# Pin MyPy version because new errors are likely to appear with each release -#Description: linter -#Pinned versions: 1.9.0 -#test that import: test_typing.py, test_type_hints.py +# pytest and plugins +pytest==7.3.2 +pytest-xdist==3.3.1 +pytest-flakefinder==1.1.0 +pytest-rerunfailures>=10.3 +pytest-subtests==0.13.1 +pytest-timeout>=2.3.1 +xdoctest==1.3.0 + +# test utilities +hypothesis==6.56.4 +expecttest==0.3.0 +parameterized==0.8.1 + +# numpy (version per Python version) +numpy==1.26.2; python_version >= "3.11" and python_version < "3.14" + +# scientific packages +scipy==1.14.1; python_version > "3.11" and python_version < "3.14" +scikit-image==0.22.0 +pillow==12.1.1 +pywavelets==1.7.0; python_version >= "3.12" + +# core utilities +networkx==2.8.8 +optree==0.13.0; python_version < "3.14" +opt-einsum==3.3 +filelock==3.20.3 +sympy==1.13.3 + +# build/serialization +pyyaml==6.0.3 +packaging==24.0 +typing-extensions==4.12.2; python_version < "3.14" +pyzstd +setuptools>=70.1.0,<82 +zstandard + +# ONNX support +onnx==1.20.0 +onnxscript==0.6.2 +protobuf==6.33.5 + +# misc +psutil +jinja2==3.1.6 +tqdm>=4.66.0 +click \ No newline at end of file diff --git a/.github/actions/setup-npu-test-env/action.yml b/.github/actions/setup-npu-test-env/action.yml new file mode 100644 index 0000000000..1c45976642 --- /dev/null +++ b/.github/actions/setup-npu-test-env/action.yml @@ -0,0 +1,146 @@ +name: 'Setup NPU Test Environment' +description: 'Common environment setup for NPU upstream tests - checkout, cache, install PyTorch/torch_npu/triton-ascend, test dependencies' + +inputs: + python_version: + required: true + type: string + description: Python version to use + torch_wheel_artifact: + required: true + type: string + description: Name of the torch wheel artifact + torch_npu_wheel_artifact: + required: true + type: string + description: Name of the torch_npu wheel artifact + pytorch_src_artifact: + required: true + type: string + description: Name of the PyTorch source artifact + +env: + # PyPI 缓存 URL(用于加速 pip 下载) + PYPI_CACHE_URL: 'http://cache-service.nginx-pypi-cache.svc.cluster.local/pypi/simple' + +runs: + using: 'composite' + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + repository: ${{ github.repository }} + ref: ${{ github.ref }} + fetch-depth: 1 + path: ascend_pytorch + + - name: Setup cache directories + run: | + mkdir -p /github/home/.cache/pip + chmod -R 777 /github/home/.cache + + - name: Cache pip + uses: actions/cache@v4 + with: + path: /github/home/.cache/pip + key: pip-arm-collect-py${{ inputs.python_version }} + restore-keys: | + pip-arm-collect-py${{ inputs.python_version }}- + pip-arm-collect- + + - name: Download built torch wheel + uses: actions/download-artifact@v4 + with: + name: ${{ inputs.torch_wheel_artifact }} + path: torch-wheel-artifact + + - name: Download built torch_npu wheel + uses: actions/download-artifact@v4 + with: + name: ${{ inputs.torch_npu_wheel_artifact }} + path: torch-npu-wheel-artifact + + - name: Download PyTorch source and test code + uses: actions/download-artifact@v4 + with: + name: ${{ inputs.pytorch_src_artifact }} + path: pytorch-src-artifact + + - name: Extract PyTorch source + run: | + tar -xzf pytorch-src-artifact/pytorch-src.tar.gz + + - name: Install built PyTorch and torch_npu + run: | + source /usr/local/Ascend/cann/set_env.sh 2>/dev/null || true + source /usr/local/Ascend/nnal/atb/set_env.sh 2>/dev/null || true + + PIP=pip${{ inputs.python_version }} + PYTHON=python${{ inputs.python_version }} + export PIP_CACHE_DIR=/github/home/.cache/pip + + # Configure pip to use PyPI cache for faster downloads + if [ -n "${{ env.PYPI_CACHE_URL }}" ]; then + $PIP config set global.index-url ${{ env.PYPI_CACHE_URL }} + $PIP config set global.trusted-host "cache-service.nginx-pypi-cache.svc.cluster.local" + echo "pip index-url configured: ${{ env.PYPI_CACHE_URL }}" + fi + + $PIP install --upgrade pip + + # Install built torch wheel + TORCH_WHL=$(ls torch-wheel-artifact/*.whl | head -1) + $PIP install "${TORCH_WHL}" + + # Install built torch_npu wheel + TORCH_NPU_WHL=$(ls torch-npu-wheel-artifact/*.whl | head -1) + $PIP install "${TORCH_NPU_WHL}" + + echo "Installed PyTorch and torch_npu from built wheels" + echo "torch: ${TORCH_WHL}" + echo "torch_npu: ${TORCH_NPU_WHL}" + + - name: Install test dependencies + run: | + PIP=pip${{ inputs.python_version }} + export PIP_CACHE_DIR=/github/home/.cache/pip + cd pytorch-src + + # Core test dependencies + $PIP install pytest pytest-timeout pytest-xdist hypothesis zstandard pyyaml + $PIP install pytest-rerunfailures pytest-flakefinder + $PIP install 'pytest-subtests==0.13.1' 'xdoctest==1.1.0' 'pulp>=2.9' + + # Optional dependencies for ONNX tests + # These are not in PyTorch requirements.txt but needed by specific tests + $PIP install onnxruntime onnxscript onnx-ir ml-dtypes || true + + # torchvision for ONNX model tests (install without deps to bypass torch version check) + # PyPI torchvision requires exact torch version (torch==2.11.0), but we have dev build + # Use --no-deps to skip torch dependency, we already have our compiled torch installed + $PIP install numpy pillow || true + $PIP install torchvision --no-deps || true + + # Other optional dependencies + $PIP install parameterized pandas || true + $PIP install opencv-python || true + + # PyTorch requirements (if exists) + if [ -f requirements.txt ]; then + $PIP install -r requirements.txt || true + fi + + - name: Verify NPU availability + run: | + source /usr/local/Ascend/cann/set_env.sh 2>/dev/null || true + source /usr/local/Ascend/nnal/atb/set_env.sh 2>/dev/null || true + + PYTHON=python${{ inputs.python_version }} + $PYTHON -c " + import torch + print(f'torch: {torch.__version__}') + import torch_npu + print(f'torch_npu: {torch_npu.__version__}') + print(f'NPU available: {torch.npu.is_available()}') + print(f'NPU count: {torch.npu.device_count()}') + " \ No newline at end of file diff --git a/.github/docker/pytorch-npu-builder.Dockerfile b/.github/docker/pytorch-npu-builder.Dockerfile new file mode 100644 index 0000000000..f8a443b402 --- /dev/null +++ b/.github/docker/pytorch-npu-builder.Dockerfile @@ -0,0 +1,152 @@ +# 基于 PyPA manylinux 2_28 aarch64 镜像 (与 PyTorch 主干一致) +FROM quay.io/pypa/manylinux_2_28_aarch64 + +ARG GCCTOOLSET_VERSION=13 + +# CANN 包下载 URL(通过 build-arg 传入) +ARG CANN_TOOLKIT_URL +ARG CANN_A3OPS_URL +ARG CANN_NNAL_URL +ARG CANN_VERSION + +# Language variables +ENV LC_ALL=en_US.UTF-8 +ENV LANG=en_US.UTF-8 +ENV LANGUAGE=en_US.UTF-8 + +# 安装必要的 OS 包 (与 PyTorch 官方 Dockerfile 一致) +RUN yum -y install epel-release && \ + yum -y update && \ + yum install -y \ + autoconf \ + automake \ + bison \ + bzip2 \ + curl \ + diffutils \ + file \ + git \ + less \ + libffi-devel \ + libgomp \ + make \ + openssl-devel \ + patch \ + perl \ + unzip \ + util-linux \ + wget \ + which \ + xz \ + yasm \ + zstd \ + sudo \ + gcc-toolset-${GCCTOOLSET_VERSION}-gcc \ + gcc-toolset-${GCCTOOLSET_VERSION}-gcc-c++ \ + gcc-toolset-${GCCTOOLSET_VERSION}-gcc-gfortran \ + gcc-toolset-${GCCTOOLSET_VERSION}-gdb && \ + yum install -y --enablerepo=powertools ninja-build && \ + rm -rf /var/cache/yum + +# 确保使用正确的 devtoolset +ENV PATH=/opt/rh/gcc-toolset-${GCCTOOLSET_VERSION}/root/usr/bin:$PATH +ENV LD_LIBRARY_PATH=/opt/rh/gcc-toolset-${GCCTOOLSET_VERSION}/root/usr/lib64:/opt/rh/gcc-toolset-${GCCTOOLSET_VERSION}/root/usr/lib:$LD_LIBRARY_PATH + +# git 2.36+ 需要配置 safe.directory +RUN git config --global --add safe.directory "*" + +# ============================================================ +# 预装所有 Python 版本(镜像支持多 Python 版本) +# ============================================================ +# manylinux 镜像已包含 cp310-cp310, cp311-cp311, cp312-cp312, cp313-cp313 +# 默认使用 Python 3.11(可通过环境变量切换) + +ENV DEFAULT_PYTHON_VERSION=3.11 +ENV PATH=/opt/python/cp311-cp311/bin:$PATH + +# 创建 Python 版本切换脚本 +RUN printf '#!/bin/bash\n\ +# Python 版本切换辅助脚本\n\ +# 使用方法: source /usr/local/bin/switch_python.sh 3.11\n\ +\n\ +PYTHON_VERSION="${1:-3.11}"\n\ +\n\ +case "$PYTHON_VERSION" in\n\ + 3.10) PYTHON_DIR="cp310-cp310" ;;\n\ + 3.11) PYTHON_DIR="cp311-cp311" ;;\n\ + 3.12) PYTHON_DIR="cp312-cp312" ;;\n\ + 3.13) PYTHON_DIR="cp313-cp313" ;;\n\ + *) echo "Unsupported Python version: $PYTHON_VERSION"; return 1 ;;\n\ +esac\n\ +\n\ +export PATH=/opt/python/$PYTHON_DIR/bin:$PATH\n\ +echo "Switched to Python $PYTHON_VERSION ($(python --version))"\n\ +' > /usr/local/bin/switch_python.sh && \ + chmod +x /usr/local/bin/switch_python.sh + +# 为每个 Python 版本安装常用包 +RUN for py_dir in cp310-cp310 cp311-cp311 cp312-cp312 cp313-cp313; do \ + /opt/python/$py_dir/bin/pip install --upgrade pip setuptools wheel; \ + done + +# ============================================================ +# 安装 CANN(使用传入的 URL) +# ============================================================ + +WORKDIR /root + +RUN mkdir -p cann && cd cann && \ + curl -O "${CANN_TOOLKIT_URL}" && \ + curl -O "${CANN_A3OPS_URL}" && \ + curl -O "${CANN_NNAL_URL}" && \ + chmod +x Ascend-cann*.run && \ + ./Ascend-cann-toolkit*.run --full --quiet --install-path=/usr/local/Ascend && \ + ./Ascend-cann-A3*.run --install --quiet --install-path=/usr/local/Ascend && \ + source /usr/local/Ascend/cann/set_env.sh && \ + ./Ascend-cann-nnal*.run --install --quiet --install-path=/usr/local/Ascend && \ + rm -rf cann + +# 设置环境变量 +ENV CANN_PATH=/usr/local/Ascend/cann +ENV NNAL_PATH=/usr/local/Ascend/nnal +ENV ASCEND_HOME=/usr/local/Ascend +ENV CANN_VERSION=${CANN_VERSION} + +# 添加 CANN 环境初始化脚本 +RUN printf '#!/bin/bash\n\ +source /usr/local/Ascend/cann/set_env.sh 2>/dev/null || true\n\ +source /usr/local/Ascend/nnal/atb/set_env.sh 2>/dev/null || true\n\ +' > /etc/profile.d/cann_env.sh && \ + chmod +x /etc/profile.d/cann_env.sh + +# ============================================================ +# 预安装 pytest 等测试依赖(为所有 Python 版本) +# ============================================================ + +RUN for py_dir in cp310-cp310 cp311-cp311 cp312-cp312 cp313-cp313; do \ + /opt/python/$py_dir/bin/pip install pytest pytest-timeout pytest-xdist hypothesis pyyaml zstandard cmake ninja; \ + done + +# ============================================================ +# 设置工作目录和默认命令 +# ============================================================ + +WORKDIR /workspace + +# 创建 welcome 消息 +RUN printf '\n\ +========================================\n\ +PyTorch NPU Builder Image\n\ +========================================\n\ +CANN Version: %s\n\ +Python Versions: 3.10, 3.11, 3.12, 3.13 (default: 3.11)\n\ +\n\ +To switch Python version:\n\ + source /usr/local/bin/switch_python.sh 3.12\n\ +\n\ +To setup CANN environment:\n\ + source /etc/profile.d/cann_env.sh\n\ +========================================\n\ +\n' "${CANN_VERSION}" > /etc/motd + +CMD ["bash"] \ No newline at end of file diff --git a/.github/scripts/build_image.sh b/.github/scripts/build_image.sh new file mode 100755 index 0000000000..25763c688e --- /dev/null +++ b/.github/scripts/build_image.sh @@ -0,0 +1,486 @@ +#!/bin/bash +# +# build_image.sh - 构建 PyTorch NPU Docker 镜像 +# +# 功能:按 CANN 版本构建镜像,镜像预装多 Python 版本,通过环境变量切换 +# +# 使用方式: +# ./build_image.sh --cann-version 9.0 +# ./build_image.sh --cann-version 9.0.0-beta.2 --push +# ./build_image.sh --list-versions # 查看支持的 CANN 版本 +# + +set -euo pipefail + +# ============================================================ +# CANN 版本映射表 +# 每个版本对应三个包的下载 URL +# ============================================================ + +declare -A CANN_VERSIONS=( + # 版本号 -> toolkit|a3_ops|nnal 的 URL + # 注意:OBS 上当前只有 9.0.0-beta.2 版本的包 + ["9.0"]="https://pytorch-package.obs.cn-north-4.myhuaweicloud.com/pta/cann-package/20260330/Ascend-cann-toolkit_9.0.0-beta.2_linux-aarch64.run|https://pytorch-package.obs.cn-north-4.myhuaweicloud.com/pta/cann-package/20260330/Ascend-cann-A3-ops_9.0.0-beta.2_linux-aarch64.run|https://pytorch-package.obs.cn-north-4.myhuaweicloud.com/pta/cann-package/20260330/Ascend-cann-nnal_9.0.0-beta.2_linux-aarch64.run" + + ["9.0.0-beta.2"]="https://pytorch-package.obs.cn-north-4.myhuaweicloud.com/pta/cann-package/20260330/Ascend-cann-toolkit_9.0.0-beta.2_linux-aarch64.run|https://pytorch-package.obs.cn-north-4.myhuaweicloud.com/pta/cann-package/20260330/Ascend-cann-A3-ops_9.0.0-beta.2_linux-aarch64.run|https://pytorch-package.obs.cn-north-4.myhuaweicloud.com/pta/cann-package/20260330/Ascend-cann-nnal_9.0.0-beta.2_linux-aarch64.run" + + ["8.0"]="https://pytorch-package.obs.cn-north-4.myhuaweicloud.com/pta/cann-package/20250101/Ascend-cann-toolkit_8.0.RC3_linux-aarch64.run|https://pytorch-package.obs.cn-north-4.myhuaweicloud.com/pta/cann-package/20250101/Ascend-cann-A3-ops_8.0.RC3_linux-aarch64.run|https://pytorch-package.obs.cn-north-4.myhuaweicloud.com/pta/cann-package/20250101/Ascend-cann-nnal_8.0.RC3_linux-aarch64.run" +) + +# Stable 版本标记(用于 latest 标签) +CANN_STABLE="9.0" + +# 预装的 Python 版本列表 +PYTHON_VERSIONS=("3.10" "3.11" "3.12" "3.13") + +# manylinux 对应的 Python 目录名映射 +declare -A PYTHON_DIR_MAP=( + ["3.10"]="cp310-cp310" + ["3.11"]="cp311-cp311" + ["3.12"]="cp312-cp312" + ["3.13"]="cp313-cp313" +) + +# ============================================================ +# 默认配置 +# ============================================================ + +DEFAULT_REGISTRY="quay.io" +DEFAULT_QUAY_ORG="kerer" +DEFAULT_IMAGE_NAME="pytorch" + +# 参数变量 +CANN_VERSION_INPUT="" +REGISTRY="" +QUAY_ORG="" +IMAGE_NAME="" +PUSH_IMAGE=false +FORCE_BUILD=false +VERBOSE=false +LIST_VERSIONS=false + +# ============================================================ +# 日志函数 +# ============================================================ + +log_info() { + echo "[INFO] $1" +} + +log_error() { + echo "[ERROR] $1" >&2 +} + +log_verbose() { + if [[ "$VERBOSE" == "true" ]]; then + echo "[VERBOSE] $1" + fi +} + +# ============================================================ +# 显示帮助信息 +# ============================================================ + +show_help() { + cat << EOF +用法: $0 [OPTIONS] + +构建支持不同 CANN 版本的 PyTorch NPU Docker 镜像。 + +镜像特性: + - 预装多个 Python 版本 (3.10/3.11/3.12/3.13) + - 通过环境变量切换 Python 版本 + - 按 CANN 版本构建镜像 + +CANN 参数: + --cann-version VERSION CANN 版本号(支持简化版或完整版) + 简化版: 9.0, 8.0 + 完整版: 9.0.0-beta.2 + --list-versions 显示支持的 CANN 版本列表 + +镜像参数: + --registry REGISTRY Docker registry 地址 (默认: quay.io) + --quay-org ORG Quay.io 组织名 (默认: kerer) + --image-name NAME 镜像名称 (默认: pytorch) + +构建选项: + --push 构建后推送镜像到 registry + --force 强制构建,即使镜像已存在 + --verbose 显示详细日志 + +Python 版本切换(运行时): + 镜像预装多个 Python 版本,使用时通过环境变量切换: + export PATH=/opt/python/cp311-cp311/bin:\$PATH # 使用 Python 3.11 + export PATH=/opt/python/cp312-cp312/bin:\$PATH # 使用 Python 3.12 + +示例: + $0 --cann-version 9.0 + $0 --cann-version 9.0.0-beta.2 --push + $0 --list-versions + +支持的 CANN 版本: +$(show_supported_versions) + +EOF +} + +show_supported_versions() { + echo "简化版本 完整版本" + echo "----------- ----------------" + for version in "${!CANN_VERSIONS[@]}"; do + if [[ ! "$version" =~ -beta ]] && [[ ! "$version" =~ -rc ]]; then + echo "$version (完整版见映射表)" + fi + done + echo "" + echo "完整版本示例:" + echo " 9.0.0-beta.2" + echo " 8.0.RC3" +} + +# ============================================================ +# 解析命令行参数 +# ============================================================ + +parse_args() { + while [[ $# -gt 0 ]]; do + case "$1" in + --cann-version) + CANN_VERSION_INPUT="$2" + shift 2 + ;; + --registry) + REGISTRY="$2" + shift 2 + ;; + --quay-org) + QUAY_ORG="$2" + shift 2 + ;; + --image-name) + IMAGE_NAME="$2" + shift 2 + ;; + --push) + PUSH_IMAGE=true + shift + ;; + --force) + FORCE_BUILD=true + shift + ;; + --verbose) + VERBOSE=true + shift + ;; + --list-versions) + LIST_VERSIONS=true + shift + ;; + -h|--help) + show_help + exit 0 + ;; + *) + log_error "未知参数: $1" + show_help + exit 1 + ;; + esac + done + + # 设置默认值 + REGISTRY="${REGISTRY:-$DEFAULT_REGISTRY}" + QUAY_ORG="${QUAY_ORG:-$DEFAULT_QUAY_ORG}" + IMAGE_NAME="${IMAGE_NAME:-$DEFAULT_IMAGE_NAME}" + + # 显示版本列表 + if [[ "$LIST_VERSIONS" == "true" ]]; then + echo "支持的 CANN 版本:" + echo "" + for version in "${!CANN_VERSIONS[@]}"; do + echo " - $version" + done + echo "" + echo "Stable 版本(用于 latest 标签): $CANN_STABLE" + exit 0 + fi + + # 验证参数 + if [[ -z "$CANN_VERSION_INPUT" ]]; then + log_error "必须指定 --cann-version 或使用 --list-versions" + show_help + exit 1 + fi +} + +# ============================================================ +# 解析 CANN 版本 +# ============================================================ + +parse_cann_version() { + local input="$CANN_VERSION_INPUT" + + log_verbose "解析 CANN 版本: $input" + + # 检查版本是否在映射表中 + if [[ ! -v CANN_VERSIONS[$input] ]]; then + log_error "不支持的 CANN 版本: $input" + log_info "支持的版本: ${!CANN_VERSIONS[*]}" + log_info "使用 --list-versions 查看完整列表" + exit 1 + fi + + # 解析 URL + local urls="${CANN_VERSIONS[$input]}" + CANN_TOOLKIT_URL=$(echo "$urls" | cut -d'|' -f1) + CANN_A3OPS_URL=$(echo "$urls" | cut -d'|' -f2) + CANN_NNAL_URL=$(echo "$urls" | cut -d'|' -f3) + + # 提取版本号(去掉 beta/rc 后缀) + CANN_VERSION_FULL="$input" + CANN_VERSION_MAJOR=$(echo "$input" | sed 's/-beta.*//' | sed 's/-rc.*//') + + # 判断是否为 stable 版本 + if [[ "$CANN_VERSION_MAJOR" == "$CANN_STABLE" ]]; then + IS_STABLE="true" + else + IS_STABLE="false" + fi + + log_verbose "Toolkit URL: $CANN_TOOLKIT_URL" + log_verbose "A3-ops URL: $CANN_A3OPS_URL" + log_verbose "NNAL URL: $CANN_NNAL_URL" + log_verbose "Full version: $CANN_VERSION_FULL" + log_verbose "Major version: $CANN_VERSION_MAJOR" + log_verbose "Is stable: $IS_STABLE" +} + +# ============================================================ +# 生成镜像标签 +# ============================================================ + +generate_tags() { + local timestamp=$(date +%Y%m%d) + local tags=() + + # 提取大版本号(去掉 patch 号,但保留 beta/rc) + # 例如:9.0.0-beta.2 → 9.0,9.0 → 9.0,8.0.RC3 → 8.0 + local cann_major + + # 如果版本号已经是简化格式(没有第二个点),则保持原样 + if [[ "$CANN_VERSION_FULL" =~ ^[0-9]+\.[0-9]+$ ]]; then + cann_major="$CANN_VERSION_FULL" + else + # 提取前两位数字(去掉 patch 号和 beta/rc 后缀) + cann_major=$(echo "$CANN_VERSION_FULL" | grep -oP '^[0-9]+\.[0-9]+') + fi + + # 1. 完整版本标签(带时间戳)- 用于追溯 + tags+=("cann${CANN_VERSION_FULL}-${timestamp}") + + # 2. 标准版本标签(无时间戳)- 用于日常使用 + # 如果输入已经是简化版本,则跳过完整版本标签,避免重复 + if [[ "$CANN_VERSION_FULL" != "$cann_major" ]]; then + tags+=("cann${CANN_VERSION_FULL}") + fi + + # 3. 大版本简化标签 - 用于快速识别 + tags+=("cann${cann_major}") + + # 4. latest 标签(仅 stable 版本) + if [[ "$IS_STABLE" == "true" ]]; then + tags+=("latest") + tags+=("cann-latest") + tags+=("cann${cann_major}-latest") + fi + + # 输出所有标签 + for tag in "${tags[@]}"; do + echo "$tag" + done +} + +# ============================================================ +# 构建镜像 +# ============================================================ + +build_image() { + log_info "==========================================" + log_info "构建镜像: CANN $CANN_VERSION_FULL" + log_info "==========================================" + + log_info "预装 Python 版本: ${PYTHON_VERSIONS[*]}" + + # 生成镜像标签 + local tags=$(generate_tags) + local tag_args="" + while IFS= read -r tag; do + tag_args+=" --tag ${REGISTRY}/${QUAY_ORG}/${IMAGE_NAME}:${tag}" + done <<< "$tags" + + log_info "镜像标签:" + while IFS= read -r tag; do + log_info " - ${REGISTRY}/${QUAY_ORG}/${IMAGE_NAME}:${tag}" + done <<< "$tags" + + # 检查镜像是否已存在(除非强制构建) + if [[ "$FORCE_BUILD" == "false" && "$PUSH_IMAGE" == "true" ]]; then + local first_tag=$(echo "$tags" | head -n1) + if docker pull "${REGISTRY}/${QUAY_ORG}/${IMAGE_NAME}:${first_tag}" &>/dev/null; then + log_info "镜像已存在,跳过构建: ${REGISTRY}/${QUAY_ORG}/${IMAGE_NAME}:${first_tag}" + return 0 + fi + fi + + # 确认需要构建,执行登录(如果需要推送) + # 如果环境变量 SKIP_DOCKER_LOGIN=true,则跳过(用于 CI,已通过 login-action 登录) + if [[ "$PUSH_IMAGE" == "true" ]]; then + if [[ "${SKIP_DOCKER_LOGIN:-false}" != "true" ]]; then + login_registry + else + log_verbose "跳过登录(SKIP_DOCKER_LOGIN=true)" + fi + fi + + # Dockerfile 路径 + # 使用 git 获取项目根目录(更可靠) + local script_dir="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + local project_root + + # 尝试使用 git 获取项目根目录 + if git rev-parse --show-toplevel &>/dev/null; then + project_root="$(git rev-parse --show-toplevel)" + else + # 如果不在 git 仓库中,从脚本目录向上推导 + project_root="$(cd "${script_dir}/.." && pwd)" + fi + + local dockerfile_dir="${project_root}/.github/docker" + local dockerfile="${dockerfile_dir}/pytorch-npu-builder.Dockerfile" + + log_verbose "Script dir: ${script_dir}" + log_verbose "Project root: ${project_root}" + log_verbose "Dockerfile dir: ${dockerfile_dir}" + + if [[ ! -f "$dockerfile" ]]; then + log_error "Dockerfile 不存在: $dockerfile" + exit 1 + fi + + log_verbose "Dockerfile: $dockerfile" + + # 构建参数(单行格式,避免换行符问题) + local build_args="--build-arg CANN_TOOLKIT_URL=${CANN_TOOLKIT_URL} --build-arg CANN_A3OPS_URL=${CANN_A3OPS_URL} --build-arg CANN_NNAL_URL=${CANN_NNAL_URL} --build-arg CANN_VERSION=${CANN_VERSION_FULL}" + + # 构建命令(单行格式) + local build_cmd="docker buildx build ${build_args} ${tag_args} --file ${dockerfile} --platform linux/arm64 ${dockerfile_dir}" + + if [[ "$PUSH_IMAGE" == "true" ]]; then + build_cmd+=" --push" + else + build_cmd+=" --load" + fi + + log_verbose "构建命令: $build_cmd" + + # 执行构建 + log_info "开始构建..." + if ! eval "$build_cmd"; then + log_error "构建失败" + return 1 + fi + + log_info "构建成功" + + # 输出构建信息 + echo "" + log_info "构建信息:" + log_info " CANN 版本: $CANN_VERSION_FULL" + log_info " CANN 大版本: $CANN_VERSION_MAJOR" + log_info " Stable: $IS_STABLE" + log_info " 预装 Python: ${PYTHON_VERSIONS[*]}" + log_info " 镜像地址:" + while IFS= read -r tag; do + log_info " ${REGISTRY}/${QUAY_ORG}/${IMAGE_NAME}:${tag}" + done <<< "$tags" + + echo "" + log_info "使用方法:" + log_info " docker run -it ${REGISTRY}/${QUAY_ORG}/${IMAGE_NAME}:cann${CANN_VERSION_MAJOR} bash" + log_info " # 切换 Python 版本:" + log_info " export PATH=/opt/python/cp311-cp311/bin:\$PATH # Python 3.11" + log_info " export PATH=/opt/python/cp312-cp312/bin:\$PATH # Python 3.12" + echo "" + + return 0 +} + +# ============================================================ +# 检查依赖 +# ============================================================ + +check_dependencies() { + log_verbose "检查依赖..." + + # 检查 docker + if ! command -v docker &>/dev/null; then + log_error "未安装 docker" + exit 1 + fi + + # 检查 docker buildx + if ! docker buildx version &>/dev/null; then + log_error "docker buildx 不可用" + exit 1 + fi + + log_verbose "依赖检查通过" +} + +# ============================================================ +# 登录 registry +# ============================================================ + +login_registry() { + if [[ "$PUSH_IMAGE" == "true" ]]; then + log_info "登录 Registry: $REGISTRY" + + case "$REGISTRY" in + quay.io) + if [[ -z "${QUAY_USERNAME:-}" || -z "${QUAY_PASSWORD:-}" ]]; then + log_error "需要设置环境变量 QUAY_USERNAME 和 QUAY_PASSWORD" + exit 1 + fi + docker login quay.io -u "$QUAY_USERNAME" --password-stdin <<< "$QUAY_PASSWORD" + ;; + ghcr.io) + if [[ -z "${GITHUB_TOKEN:-}" ]]; then + log_error "需要设置环境变量 GITHUB_TOKEN" + exit 1 + fi + echo "$GITHUB_TOKEN" | docker login ghcr.io -u "${GITHUB_ACTOR:-}" --password-stdin + ;; + *) + log_error "不支持的 registry: $REGISTRY" + exit 1 + ;; + esac + + log_info "登录成功" + fi +} + +# ============================================================ +# 主函数 +# ============================================================ + +main() { + parse_args "$@" + check_dependencies + parse_cann_version + build_image +} + +# 执行主函数 +main "$@" \ No newline at end of file diff --git a/.github/scripts/collect_all_cases.py b/.github/scripts/collect_all_cases.py new file mode 100644 index 0000000000..6773817dd9 --- /dev/null +++ b/.github/scripts/collect_all_cases.py @@ -0,0 +1,452 @@ +#!/usr/bin/env python3 +""" +Collect all test cases and split into shards. + +This script runs in prepare job (once) to: +1. Discover test files by type (distributed/regular) +2. Collect all test cases via pytest --collect-only +3. Split cases evenly into N shards +4. Output shard JSON files for each type +5. Save collection error logs for failed files + +Usage: + python collect_all_cases.py \ + --test-dir /path/to/pytorch/test \ + --case-paths-config /path/to/case_paths_ci.yml \ + --distributed-shards 2 \ + --regular-shards 5 \ + --output-dir /path/to/output \ + --error-log-dir /path/to/error_logs \ + --parallel 16 +""" + +import argparse +import json +import os +import subprocess +import sys +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path +from typing import Dict, List, Tuple + +# Import discover_test_files module +import discover_test_files + + +def _normalize_test_file_path(test_file: str) -> str: + """ + Remove 'test/' prefix from test file path if present. + + Args: + test_file: Test file path (e.g., "test/distributed/pipelining/test_backward.py") + + Returns: + Relative path without 'test/' prefix + """ + if test_file.startswith("test/"): + return test_file[5:] + return test_file + + +def get_test_file_parent_dir(test_file: str, test_dir: Path) -> Path: + """ + Get the parent directory of a test file. + + This directory should be added to PYTHONPATH to enable + imports of sibling modules (e.g., model_registry.py). + + Args: + test_file: Test file path (e.g., "test/distributed/pipelining/test_backward.py") + test_dir: Path to PyTorch test directory + + Returns: + Path to the test file's parent directory + """ + test_file_rel = _normalize_test_file_path(test_file) + test_file_path = Path(test_file_rel) + return test_dir / test_file_path.parent + + +def collect_cases_for_file(test_file: str, test_dir: Path) -> Tuple[str, str, List[str], bool, str]: + """ + Collect test cases from a single file. + + Adds test file's parent directory to PYTHONPATH to enable + imports of sibling modules (e.g., 'from model_registry import MLPModule'). + + Returns: + Tuple of (test_file, display_name, nodeids, success, error_message) + - test_file: Original test file path + - display_name: Short name for logging (remove test/ prefix and .py suffix) + - nodeids: List of collected test case nodeids + - success: True if collection succeeded without errors + - error_message: Error details if collection failed, empty string otherwise + """ + test_file_rel = _normalize_test_file_path(test_file) + + # Extract display name (remove .py suffix) + display_name = test_file_rel + if display_name.endswith(".py"): + display_name = display_name[:-3] + + # Get test file's parent directory for PYTHONPATH + test_file_dir = get_test_file_parent_dir(test_file, test_dir) + + # Build environment with test file directory in PYTHONPATH + env = os.environ.copy() + existing_pythonpath = env.get("PYTHONPATH", "") + env["PYTHONPATH"] = str(test_file_dir) + (":" + existing_pythonpath if existing_pythonpath else "") + + command = [ + sys.executable, + "-m", + "pytest", + "--collect-only", + "--quiet", + test_file_rel, + ] + + try: + result = subprocess.run( + command, + cwd=str(test_dir), + env=env, + capture_output=True, + text=True, + encoding="utf-8", + errors="replace", + timeout=120, + ) + + nodeids = [] + for line in result.stdout.splitlines(): + stripped = line.strip() + # pytest --collect-only -q outputs clean nodeids, one per line + # Filter rules: + # 1. Skip empty lines + # 2. Skip summary lines (contain "collected" or "selected") + # 3. Skip separator lines (start with "=") + # 4. Must contain ".py::" to ensure it's a Python test file nodeid + if not stripped: + continue + if "collected" in stripped or "selected" in stripped: + continue + if stripped.startswith("="): + continue + if ".py::" in stripped: + nodeids.append(stripped) + + # Check for collection errors based on pytest exit codes: + # 0: all passed (success) + # 2: pytest error (includes collection errors like ImportError) + # 3: all skipped (success) + # 4: command line error (error) + # 5: no tests collected (ERROR - test file should have cases) + # Key insight: if a test file is selected for execution, it should have cases. + # returncode 5 means 0 cases collected, which indicates a problem. + if result.returncode in (0, 3): + # Normal: passed or skipped + return (test_file, display_name, nodeids, True, "") + else: + # returncode 2, 4, 5: real collection error + # returncode 5 specifically means no tests collected - a problem for selected files + error_msg = result.stdout.strip() + if result.stderr.strip(): + error_msg += "\n--- stderr ---\n" + result.stderr.strip() + return (test_file, display_name, nodeids, False, error_msg) + + except subprocess.TimeoutExpired: + error_msg = f"TIMEOUT: Collection took >120s for {display_name}" + return (test_file, display_name, [], False, error_msg) + except Exception as e: + error_msg = f"ERROR: {e}" + return (test_file, display_name, [], False, error_msg) + + +def collect_all_cases( + test_files: List[str], + test_dir: Path, + error_log_dir: Path, + parallel: int = 16, +) -> List[Dict]: + """ + Collect all cases from all files. + + Args: + test_files: List of test file paths + test_dir: Path to PyTorch test directory + error_log_dir: Directory to save error logs for failed collections + parallel: Number of parallel workers + + Returns: + List of dicts with nodeid and file for each collected case + """ + all_cases = [] + failed_files = [] # Track files with collection errors for logging + + print(f"Collecting cases from {len(test_files)} files with {parallel} workers...") + print("=" * 60) + + # Create error log directory + error_log_dir.mkdir(parents=True, exist_ok=True) + + with ThreadPoolExecutor(max_workers=parallel) as executor: + futures = { + executor.submit(collect_cases_for_file, f, test_dir): f + for f in test_files + } + + completed = 0 + successful_count = 0 + failed_count = 0 + total_cases = 0 + + for future in as_completed(futures): + test_file, display_name, nodeids, success, error_msg = future.result() + completed += 1 + + if success: + successful_count += 1 + # Print concise log for successful files + print(f" {display_name}: {len(nodeids)} cases") + for nodeid in nodeids: + all_cases.append({ + "nodeid": nodeid, + "file": test_file, + }) + else: + failed_count += 1 + # Print concise log for failed files + print(f" [FAILED] {display_name}: {len(nodeids)} cases") + # Save error details to log file + failed_files.append({ + "file": display_name, + "error": error_msg, + "cases": len(nodeids), + "test_file": test_file, + }) + # Still add any cases that were collected despite errors + for nodeid in nodeids: + all_cases.append({ + "nodeid": nodeid, + "file": test_file, + }) + + # Update total cases count for progress display + total_cases += len(nodeids) + + # Print progress summary every 100 files + if completed % 100 == 0: + print(f" [Progress: {completed}/{len(test_files)} files, {successful_count} ok, {failed_count} failed, {total_cases} cases]") + + print("=" * 60) + + # Save error logs to files + if failed_files: + save_error_logs(failed_files, error_log_dir) + + # Final summary + print(f"Collection complete: {len(all_cases)} cases from {successful_count}/{len(test_files)} files") + if failed_count > 0: + print(f" WARNING: {failed_count} files had collection errors (logs saved to {error_log_dir})") + + return all_cases + + +def save_error_logs(failed_files: List[Dict], error_log_dir: Path) -> None: + """ + Save collection error logs to individual files and create a summary. + + Args: + failed_files: List of dicts with file, error, cases info + error_log_dir: Directory to save error logs + """ + print(f"Saving error logs for {len(failed_files)} failed files...") + + # Save individual error log files + for failed in failed_files: + # Create safe filename from display name (replace / with _) + safe_name = failed['file'].replace('/', '_') + log_file = error_log_dir / f"{safe_name}.log" + + # Write error log + with open(log_file, 'w', encoding='utf-8') as f: + f.write(f"File: {failed['file']}\n") + f.write(f"Cases collected: {failed['cases']}\n") + f.write(f"Test file path: {failed['test_file']}\n") + f.write("=" * 80 + "\n") + f.write("Collection Error:\n") + f.write("=" * 80 + "\n") + f.write(failed['error']) + f.write("\n") + + # Save summary JSON + summary_file = error_log_dir / "collection_errors_summary.json" + summary_data = { + "total_failed": len(failed_files), + "failed_files": [ + { + "file": f['file'], + "cases": f['cases'], + "test_file": f['test_file'], + "log_file": f"{f['file'].replace('/', '_')}.log", + } + for f in failed_files + ], + } + summary_file.write_text(json.dumps(summary_data, indent=2), encoding='utf-8') + + print(f" Error logs saved to {error_log_dir}") + print(f" Summary: {summary_file}") + + +def split_cases_into_shards(cases: List[Dict], num_shards: int) -> List[List[Dict]]: + """Split cases evenly into shards.""" + total = len(cases) + base_size = total // num_shards + remainder = total % num_shards + + shards = [] + start = 0 + for i in range(num_shards): + size = base_size + (1 if i < remainder else 0) + shards.append(cases[start:start + size]) + start += size + + return shards + + +def save_shards( + cases: List[Dict], + num_shards: int, + test_type: str, + output_dir: Path, +) -> Dict: + """Save shard JSONs and return summary.""" + shards = split_cases_into_shards(cases, num_shards) + + print(f"\nSaving {test_type} shards...") + for i, shard_cases in enumerate(shards, 1): + shard_file = output_dir / f"{test_type}_cases_shard_{i}.json" + shard_data = { + "shard": i, + "num_shards": num_shards, + "test_type": test_type, + "total_cases": len(shard_cases), + "cases": shard_cases, + } + shard_file.write_text(json.dumps(shard_data, indent=2), encoding="utf-8") + print(f" Shard {i}: {len(shard_cases)} cases -> {shard_file}") + + return { + "test_type": test_type, + "num_shards": num_shards, + "total_cases": len(cases), + "shard_sizes": [len(s) for s in shards], + } + + +def main(): + args = parse_args() + + test_dir = Path(args.test_dir).resolve() + output_dir = Path(args.output_dir).resolve() + output_dir.mkdir(parents=True, exist_ok=True) + + # Error log directory for failed collections + error_log_dir = Path(args.error_log_dir).resolve() if args.error_log_dir else output_dir / "collection_errors" + error_log_dir.mkdir(parents=True, exist_ok=True) + + summaries = [] + + # ======================================== + # Step 1: Collect distributed test cases + # ======================================== + print("=" * 80) + print("Collecting distributed test cases") + print("=" * 80) + + dist_files, dist_meta = discover_test_files.discover_test_files( + test_dir=test_dir, + test_type="distributed", + case_paths_config=args.case_paths_config, + ) + print(f"Found {len(dist_files)} distributed test files") + + dist_cases = collect_all_cases(dist_files, test_dir, error_log_dir / "distributed", args.parallel) + print(f"Total distributed cases: {len(dist_cases)}") + + dist_summary = save_shards(dist_cases, args.distributed_shards, "distributed", output_dir) + summaries.append(dist_summary) + + # ======================================== + # Step 2: Collect regular test cases + # ======================================== + print("\n" + "=" * 80) + print("Collecting regular test cases") + print("=" * 80) + + reg_files, reg_meta = discover_test_files.discover_test_files( + test_dir=test_dir, + test_type="regular", + case_paths_config=args.case_paths_config, + ) + print(f"Found {len(reg_files)} regular test files") + + reg_cases = collect_all_cases(reg_files, test_dir, error_log_dir / "regular", args.parallel) + print(f"Total regular cases: {len(reg_cases)}") + + reg_summary = save_shards(reg_cases, args.regular_shards, "regular", output_dir) + summaries.append(reg_summary) + + # ======================================== + # Step 3: Save overall summary + # ======================================== + # Calculate file counts (distributed + regular = total_files, no overlap) + dist_total = dist_meta.get("total_files", 0) + dist_selected = dist_meta.get("type_selected", 0) + reg_total = reg_meta.get("total_files", 0) + reg_selected = reg_meta.get("type_selected", 0) + # total_files is same for both (all test_*.py files), use one value + total_files = dist_total + + overall_summary = { + "distributed": { + "cases_summary": dist_summary, + "discovery_metadata": dist_meta, + }, + "regular": { + "cases_summary": reg_summary, + "discovery_metadata": reg_meta, + }, + "total_cases": len(dist_cases) + len(reg_cases), + "total_files_scanned": total_files, + "distributed_files": dist_selected, + "regular_files": reg_selected, + } + summary_file = output_dir / "cases_collection_summary.json" + summary_file.write_text(json.dumps(overall_summary, indent=2), encoding="utf-8") + print(f"\nOverall summary saved to {summary_file}") + + print("\n" + "=" * 80) + print("Collection Complete") + print("=" * 80) + print(f"Distributed: {len(dist_cases)} cases -> {args.distributed_shards} shards (serial execution)") + print(f"Regular: {len(reg_cases)} cases -> {args.regular_shards} shards (parallel execution)") + print(f"Total: {len(dist_cases) + len(reg_cases)} cases") + + +def parse_args(): + parser = argparse.ArgumentParser(description="Collect and shard test cases") + parser.add_argument("--test-dir", required=True, help="PyTorch test directory") + parser.add_argument("--case-paths-config", help="case_paths_ci.yml path") + parser.add_argument("--distributed-shards", type=int, default=2, help="Distributed test shards") + parser.add_argument("--regular-shards", type=int, default=5, help="Regular test shards") + parser.add_argument("--output-dir", required=True, help="Output directory for shard JSONs") + parser.add_argument("--error-log-dir", help="Output directory for collection error logs (default: output-dir/collection_errors)") + parser.add_argument("--parallel", type=int, default=16, help="Parallel collection workers") + return parser.parse_args() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/.github/scripts/discover_test_files.py b/.github/scripts/discover_test_files.py new file mode 100644 index 0000000000..0b6e6167f3 --- /dev/null +++ b/.github/scripts/discover_test_files.py @@ -0,0 +1,341 @@ +#!/usr/bin/env python3 +""" +Discover test files for PyTorch NPU testing. + +This script integrates 3 steps: + Step 1: Test file discovery (scan all test_*.py) + Step 2: Shard type filtering (distributed/regular) + Step 3: Whitelist/blacklist filtering (case_paths_ci.yml) + +Output: Sorted list of test file paths (with 'test/' prefix) + +Usage: + python discover_test_files.py \ + --test-dir /path/to/pytorch/test \ + --test-type distributed \ + --case-paths-config /path/to/case_paths_ci.yml \ + --output /path/to/output_file.txt + + # Or output to stdout: + python discover_test_files.py \ + --test-dir /path/to/pytorch/test \ + --test-type regular \ + --case-paths-config /path/to/case_paths_ci.yml +""" + +import argparse +import json +import sys +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +try: + import yaml # type: ignore +except ImportError: + yaml = None # type: ignore + + +# ============================================================================== +# Path Normalization Functions +# ============================================================================== + + +def normalize_path(value: str) -> str: + """Normalize path: convert backslashes, remove ./ prefix.""" + normalized = value.replace("\\", "/").strip() + while normalized.startswith("./"): + normalized = normalized[2:] + return normalized.strip("/") + + +def normalize_rule_path(rule: str) -> str: + """Normalize rule path: ensure it has 'test/' prefix.""" + normalized = normalize_path(rule) + if not normalized: + return "" + if normalized == "test" or normalized.startswith("test/"): + return normalized.rstrip("/") + return f"test/{normalized}".rstrip("/") + + +# ============================================================================== +# YAML Parsing Functions +# ============================================================================== + + +def parse_simple_yaml_lists(raw_text: str) -> Dict[str, List[str]]: + """Parse YAML file for whitelist/blacklist without yaml library.""" + parsed = {"whitelist": [], "blacklist": []} + current_key = None + + for raw_line in raw_text.splitlines(): + without_comment = raw_line.split("#", 1)[0].rstrip() + if not without_comment.strip(): + continue + + stripped = without_comment.lstrip() + if not raw_line.startswith((" ", "\t")) and stripped.endswith(":"): + key = stripped[:-1].strip() + current_key = key if key in parsed else None + continue + + if current_key and stripped.startswith("- "): + value = stripped[2:].strip().strip("\"'") + if value: + parsed[current_key].append(value) + + return parsed + + +def coerce_rule_list(value, key: str) -> List[str]: + """Validate and normalize rule list.""" + if value is None: + return [] + if not isinstance(value, list): + raise ValueError(f"Expected '{key}' to be a list, got {type(value).__name__}") + + normalized_values = [] + for item in value: + if not isinstance(item, str): + raise ValueError(f"Expected every '{key}' entry to be a string, got {type(item).__name__}") + normalized = normalize_rule_path(item) + if normalized: + normalized_values.append(normalized) + return normalized_values + + +def load_case_path_rules(config_file: Optional[str]) -> Tuple[str, List[str], List[str]]: + """Load whitelist/blacklist rules from case_paths_ci.yml.""" + if not config_file: + return "", [], [] + + config_path = Path(config_file).resolve() + if not config_path.exists(): + raise FileNotFoundError(f"case_paths_ci config not found: {config_path}") + + raw_text = config_path.read_text(encoding="utf-8") + + if yaml is not None: + payload = yaml.safe_load(raw_text) or {} + else: + payload = parse_simple_yaml_lists(raw_text) + + if not isinstance(payload, dict): + raise ValueError(f"Expected a YAML object in {config_path}, got {type(payload).__name__}") + + whitelist = coerce_rule_list(payload.get("whitelist"), "whitelist") + blacklist = coerce_rule_list(payload.get("blacklist"), "blacklist") + return str(config_path), whitelist, blacklist + + +# ============================================================================== +# Test File Discovery (Step 1) +# ============================================================================== + + +def discover_raw_test_files(test_dir: Path) -> List[str]: + """Scan all test_*.py files in test directory.""" + files = [] + for test_file in test_dir.rglob("test_*.py"): + rel_path = test_file.relative_to(test_dir).as_posix() + files.append(f"test/{rel_path}") + return sorted(files) + + +# ============================================================================== +# Type Filtering (Step 2) +# ============================================================================== + + +def filter_tests_by_type(test_files: List[str], test_type: str) -> Tuple[List[str], List[str]]: + """Filter test files by test type (distributed/regular).""" + if test_type == "distributed": + selected = [f for f in test_files if f.startswith("test/distributed/")] + excluded = [f for f in test_files if not f.startswith("test/distributed/")] + else: + selected = [f for f in test_files if not f.startswith("test/distributed/")] + excluded = [f for f in test_files if f.startswith("test/distributed/")] + return selected, excluded + + +# ============================================================================== +# Path Rules Filtering (Step 3) +# ============================================================================== + + +def path_matches_rule(test_path: str, rule: str) -> bool: + """Check if test path matches a rule (supports glob patterns).""" + import fnmatch + + normalized_path = normalize_path(test_path) + normalized_rule = normalize_rule_path(rule) + if not normalized_rule: + return False + + if any(char in normalized_rule for char in "*?[]"): + return fnmatch.fnmatch(normalized_path, normalized_rule) + + return normalized_path == normalized_rule or normalized_path.startswith(f"{normalized_rule}/") + + +def apply_case_path_rules( + test_files: List[str], whitelist: List[str], blacklist: List[str] +) -> Tuple[List[str], List[str]]: + """Apply whitelist and blacklist rules to filter test files.""" + # Apply whitelist (if empty, select all) + if whitelist: + selected = [path for path in test_files if any(path_matches_rule(path, rule) for rule in whitelist)] + else: + selected = list(test_files) + + # Apply blacklist + if blacklist: + selected = [path for path in selected if not any(path_matches_rule(path, rule) for rule in blacklist)] + + selected_set = set(selected) + excluded = [path for path in test_files if path not in selected_set] + return selected, excluded + + +# ============================================================================== +# Main Discovery Function +# ============================================================================== + + +def discover_test_files( + test_dir: Path, + test_type: str, + case_paths_config: Optional[str], +) -> Tuple[List[str], Dict]: + """ + Execute all 3 steps to discover test files. + + Returns: + Tuple of (selected_files, metadata_dict) + """ + # Step 1: Discover all test files + all_test_files = discover_raw_test_files(test_dir) + total_count = len(all_test_files) + + # Step 2: Filter by test type + type_selected, type_excluded = filter_tests_by_type(all_test_files, test_type) + + # Step 3: Apply whitelist/blacklist rules + config_path, whitelist, blacklist = load_case_path_rules(case_paths_config) + rules_selected, rules_excluded = apply_case_path_rules(type_selected, whitelist, blacklist) + + # Metadata for reporting + metadata = { + "test_dir": str(test_dir), + "test_type": test_type, + "total_files": total_count, + "type_selected": len(type_selected), + "type_excluded": len(type_excluded), + "whitelist_entries": len(whitelist), + "blacklist_entries": len(blacklist), + "rules_selected": len(rules_selected), + "rules_excluded": len(rules_excluded), + "case_paths_config": config_path, + } + + return rules_selected, metadata + + +# ============================================================================== +# CLI Interface +# ============================================================================== + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Discover test files for PyTorch NPU testing", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + parser.add_argument( + "--test-dir", + type=str, + required=True, + help="Path to the PyTorch test directory", + ) + parser.add_argument( + "--test-type", + type=str, + choices=["distributed", "regular"], + default="regular", + help="Test type: 'distributed' for distributed tests, 'regular' for other tests", + ) + parser.add_argument( + "--case-paths-config", + type=str, + help="Path to case_paths_ci.yml for file-level whitelist/blacklist control", + ) + parser.add_argument( + "--output", + type=str, + help="Output file path for test file list (default: stdout)", + ) + parser.add_argument( + "--metadata-output", + type=str, + help="Output file path for metadata JSON (optional)", + ) + parser.add_argument( + "--verbose", + "-v", + action="store_true", + help="Print verbose output including metadata", + ) + return parser.parse_args() + + +def main(): + args = parse_args() + + test_dir = Path(args.test_dir).resolve() + if not test_dir.is_dir(): + raise FileNotFoundError(f"Test directory not found: {test_dir}") + + # Execute discovery + selected_files, metadata = discover_test_files( + test_dir=test_dir, + test_type=args.test_type, + case_paths_config=args.case_paths_config, + ) + + # Output test file list + output_content = "\n".join(selected_files) + ("\n" if selected_files else "") + + if args.output: + output_path = Path(args.output).resolve() + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path.write_text(output_content, encoding="utf-8") + if args.verbose: + print(f"Written {len(selected_files)} test files to: {output_path}") + else: + sys.stdout.write(output_content) + + # Output metadata + if args.metadata_output: + metadata_path = Path(args.metadata_output).resolve() + metadata_path.parent.mkdir(parents=True, exist_ok=True) + metadata_path.write_text(json.dumps(metadata, indent=2), encoding="utf-8") + if args.verbose: + print(f"Written metadata to: {metadata_path}") + + # Verbose summary + if args.verbose: + print(f"\nDiscovery Summary:") + print(f" Test directory: {test_dir}") + print(f" Test type: {args.test_type}") + print(f" Total files scanned: {metadata['total_files']}") + print(f" After type filter: {metadata['type_selected']} selected, {metadata['type_excluded']} excluded") + if args.case_paths_config: + print(f" Whitelist entries: {metadata['whitelist_entries']}") + print(f" Blacklist entries: {metadata['blacklist_entries']}") + print(f" After rules filter: {metadata['rules_selected']} selected, {metadata['rules_excluded']} excluded") + print(f" Final selected files: {len(selected_files)}") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/.github/scripts/generate_npu_full_test_report.py b/.github/scripts/generate_npu_full_test_report.py new file mode 100644 index 0000000000..55822ec30a --- /dev/null +++ b/.github/scripts/generate_npu_full_test_report.py @@ -0,0 +1,674 @@ +#!/usr/bin/env python3 +""" +Generate a consolidated markdown/json report for the NPU full test workflow. +""" + +import argparse +import json +import re +from collections import Counter +from pathlib import Path +from typing import Dict, List, Tuple, Optional + +# Import aggregation function from parse_test_results.py +import parse_test_results + + +def parse_args(): + parser = argparse.ArgumentParser(description="Generate consolidated NPU full test report") + parser.add_argument("--reports-root", required=True, help="Root directory containing shard report files") + parser.add_argument("--output-markdown", required=True, help="Path to write markdown report") + parser.add_argument("--output-json", required=True, help="Path to write JSON report") + parser.add_argument("--pytorch-version", required=True, help="PyTorch version string") + parser.add_argument("--torch-npu-whl", required=True, help="torch_npu wheel URL") + parser.add_argument("--patch-count", default="N/A", help="Applied patch count") + parser.add_argument("--shard-matrix-json", required=True, help="JSON array of requested shard ids") + parser.add_argument("--docker-image", default="N/A", help="Docker image used for test execution") + parser.add_argument("--runner", default="N/A", help="Runner machine type") + parser.add_argument("--special-reports-root", help="Root directory containing special test report files") + parser.add_argument("--expected-special-tests-json", default="[]", help="JSON array of expected special test names") + parser.add_argument("--cases-summary", help="Path to cases_collection_summary.json for file discovery stats") + return parser.parse_args() + + +def load_json_file(path: Path) -> Dict: + """Load JSON file with error handling for malformed/truncated files.""" + try: + with open(path, "r", encoding="utf-8") as f: + return json.load(f) + except json.JSONDecodeError as e: + print(f"Warning: Invalid JSON in {path}: {e}") + # Read file content to diagnose truncation + try: + with open(path, "r", encoding="utf-8") as f: + content = f.read() + print(f" File size: {len(content)} bytes") + # Show context around error position + error_pos = e.pos if hasattr(e, 'pos') else 0 + start = max(0, error_pos - 100) + end = min(len(content), error_pos + 100) + print(f" Context around error (pos {error_pos}): ...{content[start:end]}...") + except Exception: + pass + return {} + except Exception as e: + print(f"Warning: Failed to load {path}: {e}") + return {} + + +def parse_requested_shards(raw: str) -> List[Tuple[str, int]]: + """ + Parse shard identifiers from JSON array. + + Supports formats: + - Integers: [1, 2, 3] -> [("regular", 1), ("regular", 2), ("regular", 3)] + - Type-prefixed: ["dist-1", "reg-2"] -> [("distributed", 1), ("regular", 2)] + + Returns list of (shard_type, shard_number) tuples. + """ + try: + value = json.loads(raw) + except json.JSONDecodeError: + return [] + + if not isinstance(value, list): + return [] + + result = [] + for item in value: + try: + if isinstance(item, str): + # Parse type-prefixed format: "dist-1", "reg-2" + if "-" in item: + type_prefix, num_str = item.split("-", 1) + if type_prefix == "dist": + shard_type = "distributed" + elif type_prefix == "reg": + shard_type = "regular" + else: + # Unknown prefix, skip + continue + shard_num = int(num_str) + result.append((shard_type, shard_num)) + else: + # String without prefix, try to parse as int + shard_num = int(item) + result.append(("regular", shard_num)) + elif isinstance(item, int): + # Plain integer, assume "regular" type + result.append(("regular", item)) + except (TypeError, ValueError): + continue + # Sort by type then number + return sorted(set(result), key=lambda x: (x[0], x[1])) + + +def parse_expected_special_tests(raw: str) -> List[str]: + try: + value = json.loads(raw) + except json.JSONDecodeError: + return [] + + if not isinstance(value, list): + return [] + + result = [] + for item in value: + if isinstance(item, str) and item: + result.append(item) + return sorted(set(result)) + + +def load_text_lines(path: Path) -> List[str]: + with open(path, "r", encoding="utf-8") as f: + return [line.strip() for line in f if line.strip()] + + +def get_int_value(payload: Dict, *keys: str) -> int: + for key in keys: + if key not in payload: + continue + try: + return int(payload.get(key, 0)) + except (TypeError, ValueError): + continue + return 0 + + +def discover_shard_files( + reports_root: Path, +) -> Tuple[ + Dict[Tuple[str, int], Path], # stats_files + Dict[Tuple[str, int], Path], # info_files + Dict[Tuple[str, int], Path], # cases_files +]: + """ + Discover all shard report files in the reports directory. + + Returns dicts keyed by (shard_type, shard_number) tuples. + + File name format: shard_{type}-{number}_{suffix} + Examples: + - shard_dist-1_stats.json + - shard_reg-1_info.json + - shard_dist-1_cases.json (case-level results) + """ + stats_files = {} + info_files = {} + cases_files = {} + + def parse_shard_filename(path: Path, suffix_pattern: str) -> Tuple[str, int]: + """ + Parse shard type and number from filename. + + Filename format: shard_{type}-{number}_{suffix} + e.g., shard_dist-1_stats.json -> ("distributed", 1) + """ + stem = path.stem # filename without extension + # Match pattern: shard_{type}-{number}_{suffix} + match = re.match(r"shard_(dist|reg)-(\d+)_" + suffix_pattern, stem) + if match: + type_prefix = match.group(1) + shard_num = int(match.group(2)) + if type_prefix == "dist": + return ("distributed", shard_num) + elif type_prefix == "reg": + return ("regular", shard_num) + return None + + for path in reports_root.rglob("shard_*_stats.json"): + key = parse_shard_filename(path, "stats") + if key: + stats_files[key] = path + + for path in reports_root.rglob("shard_*_info.json"): + key = parse_shard_filename(path, "info") + if key: + info_files[key] = path + + # Discover case-level results files + for path in reports_root.rglob("shard_*_cases.json"): + key = parse_shard_filename(path, "cases") + if key: + cases_files[key] = path + + return stats_files, info_files, cases_files + + +def build_file_to_shards_map(cases_shards_dir: Path) -> Dict[str, List[str]]: + """ + Build a mapping from test file path to shard IDs. + + Scans all shard JSON files in cases_shards_dir and extracts file->shard mapping. + + Args: + cases_shards_dir: Directory containing shard JSON files like + distributed_cases_shard_1.json, regular_cases_shard_2.json + + Returns: + Dict mapping file path (e.g., "test/test_ops.py") to list of shard IDs + (e.g., ["dist-1", "reg-2", "reg-3"]) + """ + file_to_shards = {} + + if not cases_shards_dir or not cases_shards_dir.exists(): + return file_to_shards + + # Pattern: {test_type}_cases_shard_{num}.json + for shard_file in cases_shards_dir.glob("*_cases_shard_*.json"): + try: + data = load_json_file(shard_file) + test_type = data.get("test_type", "regular") + shard_num = data.get("shard", 0) + + # Build shard ID: "dist-1" or "reg-2" + shard_prefix = "dist" if test_type == "distributed" else "reg" + shard_id = f"{shard_prefix}-{shard_num}" + + # Extract file paths from cases + cases = data.get("cases", []) + for case in cases: + file_path = case.get("file", "") + if file_path: + # Normalize file path (remove leading "test/" if present for consistency) + normalized_file = file_path + if normalized_file.startswith("test/"): + normalized_file = normalized_file[5:] + + if normalized_file not in file_to_shards: + file_to_shards[normalized_file] = [] + if shard_id not in file_to_shards[normalized_file]: + file_to_shards[normalized_file].append(shard_id) + except Exception as e: + print(f"Warning: Failed to parse shard file {shard_file}: {e}") + continue + + # Sort shard IDs for each file + for file_path in file_to_shards: + # Sort by type (dist first) then number + file_to_shards[file_path].sort(key=lambda x: (0 if x.startswith("dist") else 1, int(x.split("-")[1]))) + + return file_to_shards + + +def get_shard_status(stats: Dict, present: bool) -> str: + if not present: + return "MISSING" + if stats.get("timed_out"): + return "TIMEOUT" + if stats.get("incomplete"): + return "INCOMPLETE" + if stats.get("errors", 0) > 0: + return "ERROR" + if stats.get("failed", 0) > 0: + return "FAILED" + if stats.get("total", 0) == 0: + return "NO TESTS" + return "PASSED" + + +def get_overall_status(status_counts: Counter) -> str: + if status_counts["MISSING"] > 0: + return "FAILED" + if any(status_counts[key] > 0 for key in ("TIMEOUT", "INCOMPLETE", "ERROR", "FAILED")): + return "FAILED" + if status_counts["PASSED"] > 0: + return "PASSED" + return "NO TESTS" + + +def format_duration(seconds: float) -> str: + seconds = float(seconds) + hours = int(seconds // 3600) + minutes = int((seconds % 3600) // 60) + secs = seconds % 60 + if hours > 0: + return f"{hours}h {minutes}m {secs:.1f}s" + if minutes > 0: + return f"{minutes}m {secs:.1f}s" + return f"{secs:.1f}s" + + +def sanitize_markdown_cell(value: str) -> str: + return value.replace("|", "\\|").replace("\n", "
") + + +def render_table(headers: List[str], rows: List[List[str]]) -> List[str]: + lines = [ + "| " + " | ".join(headers) + " |", + "| " + " | ".join(["---"] * len(headers)) + " |", + ] + for row in rows: + lines.append("| " + " | ".join(row) + " |") + return lines + + +def discover_special_test_files(reports_root: Path | None) -> Dict[str, Path]: + if reports_root is None or not reports_root.exists(): + return {} + + special_files = {} + for path in reports_root.rglob("special_test_*.json"): + try: + payload = load_json_file(path) + except Exception: + continue + name = payload.get("name") + if isinstance(name, str) and name: + special_files[name] = path + return special_files + + +def main(): + args = parse_args() + reports_root = Path(args.reports_root) + output_markdown = Path(args.output_markdown) + output_json = Path(args.output_json) + requested_shards = parse_requested_shards(args.shard_matrix_json) + expected_special_tests = parse_expected_special_tests(args.expected_special_tests_json) + special_reports_root = Path(args.special_reports_root) if args.special_reports_root else None + + # Load cases collection summary for file discovery stats + cases_summary_data = None + file_discovery_stats = { + "total_files_scanned": 0, + "distributed_files": 0, + "regular_files": 0, + } + if args.cases_summary: + cases_summary_path = Path(args.cases_summary) + if cases_summary_path.exists(): + cases_summary_data = load_json_file(cases_summary_path) + # Extract file discovery stats (正交: total = distributed + regular) + if cases_summary_data: + file_discovery_stats["total_files_scanned"] = cases_summary_data.get("total_files_scanned", 0) + file_discovery_stats["distributed_files"] = cases_summary_data.get("distributed_files", 0) + file_discovery_stats["regular_files"] = cases_summary_data.get("regular_files", 0) + + stats_files, info_files, cases_files = discover_shard_files(reports_root) + special_test_files = discover_special_test_files(special_reports_root) + shard_ids = requested_shards or sorted(set(stats_files) | set(info_files) | set(cases_files)) + + # Build file to shards mapping from cases-shards directory + cases_shards_dir = Path(args.cases_summary).parent if args.cases_summary else None + file_to_shards_map = build_file_to_shards_map(cases_shards_dir) + + status_counts = Counter() + totals = { + "total": 0, + "passed": 0, + "failed": 0, + "errors": 0, + "skipped": 0, + "timeout": 0, + "duration": 0.0, + } + shard_rows = [] + selection_modes = set() + cases_results = {} # Store case-level results for each shard + + for shard_type, shard_num in shard_ids: + shard_key = (shard_type, shard_num) + stats_path = stats_files.get(shard_key) + info_path = info_files.get(shard_key) + cases_path = cases_files.get(shard_key) + stats = load_json_file(stats_path) if stats_path else {} + info = load_json_file(info_path) if info_path else {} + + # Load case-level results if available + cases_data = load_json_file(cases_path) if cases_path else {} + if cases_data: + cases_results[shard_key] = cases_data + # Override stats with case-level data + stats["total"] = cases_data.get("total_cases", 0) + stats["passed"] = cases_data.get("passed", 0) + stats["failed"] = cases_data.get("failed", 0) + stats["errors"] = cases_data.get("errors", 0) + stats["skipped"] = cases_data.get("skipped", 0) + stats["timeout"] = cases_data.get("timeout", 0) + stats["duration"] = cases_data.get("duration", 0.0) + # Update totals (正交累加: total = passed + failed + errors + skipped + timeout) + totals["total"] += cases_data.get("total_cases", 0) + totals["passed"] += cases_data.get("passed", 0) + totals["failed"] += cases_data.get("failed", 0) + totals["errors"] += cases_data.get("errors", 0) + totals["skipped"] += cases_data.get("skipped", 0) + totals["timeout"] += cases_data.get("timeout", 0) + totals["duration"] += cases_data.get("duration", 0.0) + + present = bool(stats_path or cases_path) + + if info.get("selection_mode"): + selection_modes.add(str(info.get("selection_mode"))) + + status = get_shard_status(stats, present) + status_counts[status] += 1 + + # Convert shard_type to display prefix ("distributed" -> "dist", "regular" -> "reg") + shard_prefix = "dist" if shard_type == "distributed" else "reg" + shard_rows.append( + { + "shard": f"{shard_prefix}-{shard_num}", # "dist-1" or "reg-1" + "shard_type": shard_type, + "shard_num": shard_num, + "status": status, + "total": int(stats.get("total", 0)), + "passed": int(stats.get("passed", 0)), + "failed": int(stats.get("failed", 0)), + "skipped": int(stats.get("skipped", 0)), + "errors": int(stats.get("errors", 0)), + "timeout": int(stats.get("timeout", 0)), + "duration": float(stats.get("duration", 0.0)), + } + ) + + overall_status = get_overall_status(status_counts) + whl_name = Path(args.torch_npu_whl).name + received_reports = len(stats_files) + expected_reports = len(shard_ids) + selection_mode_display = ", ".join(sorted(selection_modes)) if selection_modes else "-" + + # Show all shards in the detail table + sorted_shards = sorted(shard_rows, key=lambda row: (row["shard_type"], row["shard_num"])) + special_test_names = expected_special_tests or sorted(special_test_files) + special_test_rows = [] + special_status_counts = Counter() + + for test_name in special_test_names: + payload = load_json_file(special_test_files[test_name]) if test_name in special_test_files else {} + status = str(payload.get("status", "MISSING")) + special_status_counts[status] += 1 + special_test_rows.append( + { + "name": test_name, + "group": str(payload.get("group", "-")), + "status": status, + "duration": float(payload.get("duration", 0.0)), + "returncode": payload.get("returncode", "-"), + "note": str(payload.get("note", "") or "-"), + } + ) + + if any(row["status"] != "PASSED" for row in special_test_rows): + overall_status = "FAILED" + + include_special_tests = bool(special_test_names or special_test_rows) + + # Build Selection row content based on available data + if cases_summary_data: + # Use file discovery stats from cases_collection_summary.json + total_scanned = file_discovery_stats["total_files_scanned"] + dist_files = file_discovery_stats["distributed_files"] + reg_files = file_discovery_stats["regular_files"] + selection_content = ( + f"扫描发现 {total_scanned} 个测试文件 " + f"(distributed: {dist_files}, regular: {reg_files})" + ) + else: + # Fallback to original selection mode display + selection_content = selection_mode_display + + # Extract planned cases count from cases_collection_summary.json + planned_total_cases = 0 + planned_dist_cases = 0 + planned_reg_cases = 0 + if cases_summary_data: + planned_total_cases = cases_summary_data.get("total_cases", 0) + planned_dist_cases = cases_summary_data.get("distributed", {}).get("cases_summary", {}).get("total_cases", 0) + planned_reg_cases = cases_summary_data.get("regular", {}).get("cases_summary", {}).get("total_cases", 0) + + overview_rows = [ + ["Overall result", overall_status], + ["PyTorch", f"`v{args.pytorch_version}`"], + ["torch_npu", f"`{whl_name}`"], + ["Patches applied", str(args.patch_count)], + ["Docker image", f"`{args.docker_image}`"], + ["Runner", f"`{args.runner}`"], + ["Shards", f"{received_reports} / {expected_reports} reported"], + ["Selection", selection_content], + [ + "实际执行用例", + ( + f"{totals['total']} total; {totals['passed']} passed; {totals['failed']} failed; " + f"{totals['errors']} errors; {totals['skipped']} skipped; " + f"{totals['timeout']} timeout" + ), + ], + ] + # Add planned cases count row if available + if planned_total_cases > 0: + overview_rows.append([ + "规划用例总数", + f"{planned_total_cases} (distributed: {planned_dist_cases}, regular: {planned_reg_cases})", + ]) + overview_rows.append(["Duration", format_duration(totals["duration"])]) + if include_special_tests: + overview_rows.append(["Special tests expected", str(len(special_test_names))]) + + markdown_lines = [ + "# PyTorch NPU Full Test Summary", + "", + "## Overview", + ] + markdown_lines.extend( + render_table( + ["Item", "Value"], + overview_rows, + ) + ) + + # Add case-level statistics table if available + if cases_results: + markdown_lines.extend(["", "## 用例级执行统计"]) + markdown_lines.extend( + render_table( + ["Shard", "总用例", "通过", "失败", "错误", "跳过", "超时", "Duration"], + [ + [ + f"{row['shard']}", + str(row["total"]), + str(row["passed"]), + str(row["failed"]), + str(row["errors"]), + str(row.get("skipped", 0)), + str(row.get("timeout", 0)), + format_duration(row["duration"]), + ] + for row in sorted_shards + if (row["shard_type"], row["shard_num"]) in cases_results + ], + ) + ) + + # Add file-level statistics table + file_stats = parse_test_results.aggregate_all_cases_by_file(cases_results) + + if file_stats: + # Sort files by total cases descending + sorted_files = sorted( + file_stats.values(), + key=lambda x: (-x["total"], x["file"]) + ) + + markdown_lines.extend(["", "## 测试文件结果汇总"]) + + file_rows = [] + for fs in sorted_files: # Show all files + failed_total = fs["failed"] + fs["errors"] + fs["timeout"] + fail_rate = f"{(failed_total / fs['total'] * 100):.1f}%" if fs["total"] > 0 else "0%" + # Get shard info for this file + file_path = fs["file"] + # Normalize file path for lookup (remove leading "test/") + lookup_path = file_path + if lookup_path.startswith("test/"): + lookup_path = lookup_path[5:] + shards_for_file = file_to_shards_map.get(lookup_path, []) + shard_info = ", ".join(shards_for_file) if shards_for_file else "-" + file_rows.append([ + sanitize_markdown_cell(fs["file"]), + shard_info, + str(fs["total"]), + str(fs["passed"]), + str(fs["failed"]), + str(fs["errors"]), + str(fs["skipped"]), + str(fs["timeout"]), + fail_rate, + ]) + + markdown_lines.extend( + render_table( + ["测试文件", "分片", "总用例", "通过", "失败", "错误", "跳过", "超时", "失败率"], + file_rows, + ) + ) + + if include_special_tests: + markdown_lines.extend(["", "## Special Test Results"]) + markdown_lines.extend( + render_table( + ["Test", "Group", "Status", "Duration", "Return Code", "Note"], + [ + [ + row["name"], + row["group"], + row["status"], + format_duration(row["duration"]), + str(row["returncode"]), + sanitize_markdown_cell(row["note"]), + ] + for row in special_test_rows + ] or [["-", "-", "-", "0.0s", "-", "-"]], + ) + ) + + report_json = { + "overall_status": overall_status, + "requested_shards": shard_ids, + "reports_collected": received_reports, + "patch_count": args.patch_count, + "pytorch_version": args.pytorch_version, + "torch_npu_whl": whl_name, + "docker_image": args.docker_image, + "runner": args.runner, + "status_counts": dict(status_counts), + "totals": totals, + "file_discovery_stats": file_discovery_stats, + "planned_cases": { + "total": planned_total_cases, + "distributed": planned_dist_cases, + "regular": planned_reg_cases, + }, + "shards": shard_rows, + } + + # Add full cases summary if available + if cases_summary_data: + report_json["cases_collection_summary"] = cases_summary_data + + # Add case-level results if available + if cases_results: + report_json["cases_results"] = { + "shards": { + f"{shard_type}-{shard_num}": data + for (shard_type, shard_num), data in cases_results.items() + }, + } + + # Add file-level aggregation + file_stats = parse_test_results.aggregate_all_cases_by_file(cases_results) + # Add shard info to file stats + file_stats_with_shards = {} + for file_path, stats in file_stats.items(): + # Normalize file path for lookup + lookup_path = file_path + if lookup_path.startswith("test/"): + lookup_path = lookup_path[5:] + shards_for_file = file_to_shards_map.get(lookup_path, []) + stats["shards"] = shards_for_file + file_stats_with_shards[file_path] = stats + report_json["file_level_stats"] = dict(sorted( + file_stats_with_shards.items(), + key=lambda x: (-x[1]["total"], x[0]) + )) + + # Add list of files with failures + failed_files = parse_test_results.get_files_with_failures(file_stats) + report_json["files_with_failures"] = failed_files + + if include_special_tests: + report_json["special_tests"] = { + "expected": special_test_names, + "status_counts": dict(special_status_counts), + "results": special_test_rows, + } + + output_markdown.write_text("\n".join(markdown_lines) + "\n", encoding="utf-8") + output_json.write_text(json.dumps(report_json, indent=2, ensure_ascii=False) + "\n", encoding="utf-8") + + print(f"Generated markdown report: {output_markdown}") + print(f"Generated json report: {output_json}") + + +if __name__ == "__main__": + main() diff --git a/.github/scripts/parse_test_results.py b/.github/scripts/parse_test_results.py new file mode 100644 index 0000000000..46768b56d2 --- /dev/null +++ b/.github/scripts/parse_test_results.py @@ -0,0 +1,796 @@ +#!/usr/bin/env python3 +""" +Parse test results from JUnit XML files and pytest logs. + +This script provides utilities for: + - Parsing JUnit XML reports + - Aggregating test statistics + - Analyzing pytest log files + - Generating result reports (JSON, text) + +Usage as module: + from parse_test_results import ( + parse_junit_xml, + aggregate_junit_stats, + analyze_pytest_log, + finalize_stats, + save_stats_file, + save_info_file, + print_stats_summary, + ) + +Usage as CLI: + python parse_test_results.py \ + --report-dir test-reports \ + --shard 1 \ + --shard-type distributed \ + --output-dir parsed-results +""" + +import argparse +import json +import os +import re +import signal +import sys +import xml.etree.ElementTree as ET +from pathlib import Path +from typing import Dict, List, Optional, Tuple + + +# ============================================================================== +# JUnit XML Parsing +# ============================================================================== + + +def parse_junit_xml(xml_file: str) -> Dict: + """ + Parse a single JUnit XML file and extract test statistics. + + Args: + xml_file: Path to JUnit XML file + + Returns: + Dict with keys: total, passed, failed, skipped, errors, duration + """ + stats = { + "total": 0, + "passed": 0, + "failed": 0, + "skipped": 0, + "errors": 0, + "duration": 0.0, + } + + if not os.path.exists(xml_file): + return stats + + try: + tree = ET.parse(xml_file) + root = tree.getroot() + for testsuite in root.iter("testsuite"): + stats["total"] += int(testsuite.get("tests", 0)) + stats["failed"] += int(testsuite.get("failures", 0)) + stats["skipped"] += int(testsuite.get("skipped", 0)) + stats["errors"] += int(testsuite.get("errors", 0)) + stats["duration"] += float(testsuite.get("time", 0)) + stats["passed"] = stats["total"] - stats["failed"] - stats["skipped"] - stats["errors"] + except Exception as exc: + print(f"Warning: Failed to parse XML report {xml_file}: {exc}") + + return stats + + +def aggregate_junit_stats(report_roots: List[Path], pattern: str = "*.xml") -> Dict: + """ + Aggregate statistics from multiple JUnit XML files. + + Args: + report_roots: List of directories to search for XML files + pattern: Glob pattern for XML files (default: "*.xml") + + Returns: + Dict with aggregated stats: total, passed, failed, skipped, errors, duration + """ + totals = { + "total": 0, + "passed": 0, + "failed": 0, + "skipped": 0, + "errors": 0, + "duration": 0.0, + } + + seen_files = set() + for report_root in report_roots: + if not report_root.exists(): + continue + for xml_file in report_root.rglob(pattern): + try: + resolved = str(xml_file.resolve()) + except OSError: + resolved = str(xml_file) + if resolved in seen_files: + continue + seen_files.add(resolved) + + stats = parse_junit_xml(str(xml_file)) + for key in totals: + totals[key] += stats[key] + + totals["xml_files_count"] = len(seen_files) + return totals + + +def parse_shard_xml_files(report_dir: Path, shard: int, shard_type: str = "regular") -> Dict: + """ + Parse all JUnit XML files for a specific shard. + + Args: + report_dir: Directory containing test reports + shard: Shard number + shard_type: "distributed" or "regular" + + Returns: + Dict with aggregated stats for the shard + """ + prefix = get_shard_type_prefix(shard_type) + xml_pattern = f"shard_{prefix}-{shard}_pytest*.xml" + + xml_files = sorted(report_dir.glob(xml_pattern)) + if not xml_files: + return { + "total": 0, + "passed": 0, + "failed": 0, + "skipped": 0, + "errors": 0, + "duration": 0.0, + "junit_generated": False, + "junit_xml_files": 0, + } + + stats = aggregate_junit_stats([report_dir], xml_pattern) + stats["junit_generated"] = True + stats["junit_xml_files"] = len(xml_files) + return stats + + +# ============================================================================== +# Log Analysis +# ============================================================================== + + +def analyze_pytest_log(log_file: Path, returncode: int) -> Dict: + """ + Analyze pytest log file for failure patterns. + + Args: + log_file: Path to pytest log file + returncode: pytest process return code + + Returns: + Dict with: zero_item_test_files, startup_failures, import_failures, test_failures + """ + metrics = { + "zero_item_test_files": 0, + "startup_failures": 0, + "import_failures": 0, + "test_failures": 0, + } + + if not log_file.exists(): + return metrics + + try: + content = log_file.read_text(encoding="utf-8", errors="replace") + except OSError: + return metrics + + # Detect "no tests collected" scenarios + if returncode == 5 or "collected 0 items" in content or "no tests ran" in content: + metrics["zero_item_test_files"] = 1 + + # Count import errors + metrics["import_failures"] = len( + re.findall(r"^ImportError while importing test module", content, flags=re.MULTILINE) + ) + + # Count collection errors (excluding import errors) + collection_errors = len(re.findall(r"^ERROR collecting ", content, flags=re.MULTILINE)) + metrics["startup_failures"] = max(collection_errors - metrics["import_failures"], 0) + + return metrics + + +# ============================================================================== +# Stats Processing +# ============================================================================== + + +def create_empty_stats() -> Dict: + """Create empty statistics dictionary.""" + return { + "total": 0, + "passed": 0, + "failed": 0, + "skipped": 0, + "errors": 0, + "duration": 0.0, + "junit_generated": False, + "junit_xml_files": 0, + "zero_item_test_files": 0, + "startup_failures": 0, + "import_failures": 0, + "test_failures": 0, + } + + +def create_shard_info(shard: int, num_shards: int, timestamp: str) -> Dict: + """Create shard info dictionary template.""" + return { + "shard": shard, + "num_shards": num_shards, + "selection_mode": "pytest_direct", + "total_files": 0, + "selected_test_files": 0, + "shard_files": 0, + "path_filtered_out_files": 0, + "excluded_test_files": 0, + "disabled_count": 0, + "whitelist_entries": 0, + "blacklist_entries": 0, + "junit_generated": False, + "junit_xml_files": 0, + "zero_item_test_files": 0, + "startup_failures": 0, + "import_failures": 0, + "test_failures": 0, + "timestamp": timestamp, + } + + +def finalize_stats(base_stats: Dict, returncode: int, duration: float, error_message: str = "") -> Dict: + """ + Finalize statistics with returncode and duration. + + Args: + base_stats: Base statistics dict + returncode: Process return code + duration: Execution duration in seconds + error_message: Optional error message + + Returns: + Finalized stats dict + """ + stats = dict(base_stats) + stats["duration"] = max(float(stats.get("duration", 0.0)), duration) + + if returncode != 0: + stats["returncode"] = returncode + + # Handle signal crashes (negative returncode) + if returncode < 0: + signal_num = abs(returncode) + try: + signal_name = signal.Signals(signal_num).name + except ValueError: + signal_name = f"SIG{signal_num}" + stats["crashed"] = True + stats["crash_signal"] = signal_name + + # Mark incomplete if no tests + if stats.get("total", 0) == 0: + stats["errors"] = max(stats.get("errors", 0), 1) + stats["incomplete"] = True + + if error_message: + stats["error_message"] = error_message + else: + stats["returncode"] = 0 + + return stats + + +def get_shard_status(stats: Dict, has_xml: bool) -> str: + """ + Determine shard status from stats. + + Args: + stats: Statistics dict + has_xml: Whether XML files were generated + + Returns: + Status string: MISSING, CRASHED, TIMEOUT, ERROR, FAILED, NO_TESTS, PASSED + """ + if not has_xml: + return "MISSING" + + if stats.get("crashed"): + return "CRASHED" + + if stats.get("timed_out"): + return "TIMEOUT" + + if stats.get("incomplete"): + return "INCOMPLETE" + + if stats.get("errors", 0) > 0: + return "ERROR" + + if stats.get("failed", 0) > 0: + return "FAILED" + + if stats.get("total", 0) == 0: + return "NO_TESTS" + + return "PASSED" + + +# ============================================================================== +# Utility Functions +# ============================================================================== + + +def get_shard_type_prefix(shard_type: str) -> str: + """Convert shard type to short prefix for file naming.""" + return "dist" if shard_type == "distributed" else "reg" + + +def get_shard_log_file(report_dir: Path, shard: int, shard_type: str = "regular") -> Path: + """Get path for shard log file.""" + prefix = get_shard_type_prefix(shard_type) + return report_dir / f"test_shard_{prefix}-{shard}.log" + + +def load_disabled_testcases_count(json_file: str) -> int: + """Count entries in disabled_testcases.json.""" + if not json_file or not os.path.exists(json_file): + return 0 + + with open(json_file, encoding="utf-8") as f: + data = json.load(f) + + if isinstance(data, (dict, list)): + return len(data) + return 0 + + +# ============================================================================== +# File Save Functions +# ============================================================================== + + +def save_stats_file(report_dir: str, shard: int, stats: Dict, shard_type: str = "regular") -> str: + """Save statistics to JSON file.""" + os.makedirs(report_dir, exist_ok=True) + prefix = get_shard_type_prefix(shard_type) + stats_file = os.path.join(report_dir, f"shard_{prefix}-{shard}_stats.json") + with open(stats_file, "w", encoding="utf-8") as f: + json.dump(stats, f, indent=2) + return stats_file + + +def save_info_file(report_dir: str, shard: int, info: Dict, shard_type: str = "regular") -> str: + """Save info to JSON file.""" + os.makedirs(report_dir, exist_ok=True) + prefix = get_shard_type_prefix(shard_type) + info_file = os.path.join(report_dir, f"shard_{prefix}-{shard}_info.json") + with open(info_file, "w", encoding="utf-8") as f: + json.dump(info, f, indent=2) + return info_file + + +def save_test_plan_file(report_dir: str, shard: int, planned_tests: List[str], shard_type: str = "regular") -> str: + """Save planned test files list.""" + os.makedirs(report_dir, exist_ok=True) + prefix = get_shard_type_prefix(shard_type) + plan_file = os.path.join(report_dir, f"shard_{prefix}-{shard}_planned_test_files.txt") + with open(plan_file, "w", encoding="utf-8") as f: + for target in planned_tests: + f.write(f"{target}\n") + return plan_file + + +def save_excluded_test_files_file(report_dir: str, shard: int, excluded_files: List[str], shard_type: str = "regular") -> str: + """Save excluded test files list.""" + os.makedirs(report_dir, exist_ok=True) + prefix = get_shard_type_prefix(shard_type) + excluded_file = os.path.join(report_dir, f"shard_{prefix}-{shard}_excluded_test_files.txt") + with open(excluded_file, "w", encoding="utf-8") as f: + for target in excluded_files: + f.write(f"{target}\n") + return excluded_file + + +def save_missing_files_file(report_dir: str, shard: int, missing_files: List[str], shard_type: str = "regular") -> str: + """Save missing files list (crashed files without XML).""" + os.makedirs(report_dir, exist_ok=True) + prefix = get_shard_type_prefix(shard_type) + missing_file = os.path.join(report_dir, f"shard_{prefix}-{shard}_missing_files.txt") + with open(missing_file, "w", encoding="utf-8") as f: + for file_path in missing_files: + f.write(f"{file_path}\n") + return missing_file + + +def save_cases_file(report_dir: str, shard: int, cases_data: Dict, shard_type: str = "regular") -> str: + """Save case-level results to JSON file.""" + os.makedirs(report_dir, exist_ok=True) + prefix = get_shard_type_prefix(shard_type) + cases_file = os.path.join(report_dir, f"shard_{prefix}-{shard}_cases.json") + with open(cases_file, "w", encoding="utf-8") as f: + json.dump(cases_data, f, indent=2, ensure_ascii=False) + return cases_file + + +def load_cases_file(report_dir: Path, shard: int, shard_type: str = "regular") -> Dict: + """Load case-level results from JSON file.""" + prefix = get_shard_type_prefix(shard_type) + cases_file = report_dir / f"shard_{prefix}-{shard}_cases.json" + if not cases_file.exists(): + return {} + try: + with open(cases_file, encoding="utf-8") as f: + return json.load(f) + except Exception as e: + print(f"Warning: Failed to load cases file {cases_file}: {e}") + return {} + + +# ============================================================================== +# Case Aggregation by File +# ============================================================================== + + +def aggregate_cases_by_file(cases_list: List[Dict]) -> Dict[str, Dict]: + """ + Aggregate case results by test file. + + This function groups test cases by their source file and computes + statistics (passed, failed, errors, etc.) per file. It also collects + detailed failure information for reporting. + + Args: + cases_list: List of case result dicts with "nodeid", "file", "status" keys + + Returns: + Dict mapping test file path -> aggregated stats + Each entry contains: + - file: test file path + - total: total cases in file + - passed, failed, errors, crashed, timeout, skipped: counts + - failed_cases: list of failed/error/crashed/timeout cases with details + - duration: total execution time for file + """ + file_stats = {} + + for case in cases_list: + test_file = case.get("file", "unknown") + if not test_file: + # Try to extract file from nodeid + nodeid = case.get("nodeid", "") + if "::" in nodeid: + test_file = nodeid.split("::")[0] + else: + test_file = "unknown" + + status = case.get("status", "error") + duration = case.get("duration", 0.0) + + if test_file not in file_stats: + file_stats[test_file] = { + "file": test_file, + "total": 0, + "passed": 0, + "failed": 0, + "errors": 0, + "timeout": 0, + "skipped": 0, + "failed_cases": [], + "duration": 0.0, + } + + stats = file_stats[test_file] + stats["total"] += 1 + stats["duration"] += duration + + if status == "passed": + stats["passed"] += 1 + elif status == "failed": + stats["failed"] += 1 + stats["failed_cases"].append({ + "nodeid": case.get("nodeid"), + "status": "failed", + "message": case.get("message", ""), + "duration": duration, + }) + elif status == "error": + stats["errors"] += 1 + stats["failed_cases"].append({ + "nodeid": case.get("nodeid"), + "status": "error", + "message": case.get("message", ""), + "duration": duration, + }) + elif status == "timeout": + stats["timeout"] += 1 + stats["failed_cases"].append({ + "nodeid": case.get("nodeid"), + "status": "timeout", + "message": f"Timeout after {duration}s", + "duration": duration, + }) + elif status == "skipped": + stats["skipped"] += 1 + + return file_stats + + +def aggregate_all_cases_by_file(cases_results: Dict) -> Dict[str, Dict]: + """ + Aggregate all cases from multiple shards by test file. + + Args: + cases_results: Dict mapping shard_key -> cases_data (from shard_*_cases.json) + + Returns: + Dict mapping test file -> aggregated stats across all shards + """ + all_file_stats = {} + + for shard_key, cases_data in cases_results.items(): + shard_cases = cases_data.get("cases", []) + file_stats = aggregate_cases_by_file(shard_cases) + + for test_file, stats in file_stats.items(): + if test_file not in all_file_stats: + all_file_stats[test_file] = { + "file": test_file, + "total": 0, + "passed": 0, + "failed": 0, + "errors": 0, + "timeout": 0, + "skipped": 0, + "failed_cases": [], + "duration": 0.0, + } + + existing = all_file_stats[test_file] + existing["total"] += stats["total"] + existing["passed"] += stats["passed"] + existing["failed"] += stats["failed"] + existing["errors"] += stats["errors"] + existing["timeout"] += stats["timeout"] + existing["skipped"] += stats["skipped"] + existing["duration"] += stats["duration"] + existing["failed_cases"].extend(stats["failed_cases"]) + + # Sort failed_cases within each file + for test_file in all_file_stats: + all_file_stats[test_file]["failed_cases"].sort( + key=lambda x: x.get("nodeid", "") + ) + + return all_file_stats + + +def get_files_with_failures(file_stats: Dict[str, Dict]) -> List[Dict]: + """ + Get list of test files that have failures/errors/timeout. + + Args: + file_stats: Dict from aggregate_all_cases_by_file() + + Returns: + List of file stats dicts sorted by file name, only including files with failures + """ + failed_files = [] + for test_file, stats in file_stats.items(): + if stats["failed"] > 0 or stats["errors"] > 0 or stats["timeout"] > 0: + failed_files.append(stats) + + failed_files.sort(key=lambda x: x["file"]) + return failed_files + + +# ============================================================================== +# Summary Printing +# ============================================================================== + + +def print_stats_summary(shard: int, stats: Dict, shard_type: str = "regular") -> None: + """Print statistics summary to stdout.""" + prefix = get_shard_type_prefix(shard_type) + print(f"\n{'=' * 60}") + print(f"Test Results for Shard {prefix}-{shard}") + print(f"{'=' * 60}") + print(f"Total: {stats['total']}") + print(f"Passed: {stats['passed']}") + print(f"Failed: {stats['failed']}") + print(f"Skipped: {stats['skipped']}") + print(f"Errors: {stats['errors']}") + print(f"Duration: {stats['duration']:.2f}s") + if stats.get("missing_files_count"): + print(f"Missing files: {stats['missing_files_count']}") + if stats.get("crashed"): + print(f"Crash signal: {stats.get('crash_signal', 'unknown')}") + print(f"{'=' * 60}") + + +def create_result_summary(stats: Dict, shard: int, shard_type: str) -> str: + """Create a formatted result summary string.""" + prefix = get_shard_type_prefix(shard_type) + status = get_shard_status(stats, stats.get("junit_generated", False)) + + lines = [ + f"Shard {prefix}-{shard} Results:", + f" Status: {status}", + f" Total: {stats.get('total', 0)}", + f" Passed: {stats.get('passed', 0)}", + f" Failed: {stats.get('failed', 0)}", + f" Errors: {stats.get('errors', 0)}", + f" Duration: {stats.get('duration', 0.0):.2f}s", + ] + + if stats.get("missing_files_count"): + lines.append(f" Missing: {stats['missing_files_count']}") + + return "\n".join(lines) + + +# ============================================================================== +# High-Level Parsing Functions +# ============================================================================== + + +def parse_shard_results( + report_dir: Path, + shard: int, + shard_type: str, + returncode: int, + duration: float, + missing_files: List[str] = None, +) -> Tuple[Dict, Dict]: + """ + Parse all results for a shard and return (stats, log_metrics). + + This is the main entry point for result parsing. + + Args: + report_dir: Directory containing test reports + shard: Shard number + shard_type: "distributed" or "regular" + returncode: pytest process return code + duration: Execution duration + missing_files: List of files that crashed (no XML generated) + + Returns: + Tuple of (stats_dict, log_metrics_dict) + """ + missing_files = missing_files or [] + + # Parse JUnit XML files + stats = parse_shard_xml_files(report_dir, shard, shard_type) + + # Add per-file isolation metadata + stats["per_file_isolation"] = True + stats["missing_files_count"] = len(missing_files) + + # Analyze log file + log_file = get_shard_log_file(report_dir, shard, shard_type) + log_metrics = analyze_pytest_log(log_file, returncode) + + # Finalize stats + stats = finalize_stats(stats, returncode, duration) + + # Merge log metrics + log_metrics["test_failures"] = stats.get("failed", 0) + stats.get("errors", 0) + log_metrics["missing_files_count"] = len(missing_files) + stats.update(log_metrics) + + # Handle returncode=5 (no tests collected) as success + if returncode == 5 and stats.get("total", 0) == 0: + stats["returncode"] = 0 + + return stats, log_metrics + + +def generate_shard_reports( + report_dir: str, + shard: int, + shard_type: str, + stats: Dict, + info: Dict, + missing_files: List[str] = None, +) -> Dict[str, str]: + """ + Generate all report files for a shard. + + Args: + report_dir: Output directory + shard: Shard number + shard_type: "distributed" or "regular" + stats: Statistics dict + info: Info dict + missing_files: List of missing/crashed files + + Returns: + Dict mapping report type to file path + """ + report_files = {} + + # Save stats + report_files["stats"] = save_stats_file(report_dir, shard, stats, shard_type) + + # Save info + report_files["info"] = save_info_file(report_dir, shard, info, shard_type) + + # Save missing files if any + if missing_files: + report_files["missing"] = save_missing_files_file(report_dir, shard, missing_files, shard_type) + + return report_files + + +# ============================================================================== +# CLI Interface +# ============================================================================== + + +def parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser(description="Parse test results from JUnit XML files") + parser.add_argument("--report-dir", type=str, required=True, help="Directory containing test reports") + parser.add_argument("--shard", type=int, required=True, help="Shard number") + parser.add_argument( + "--shard-type", + type=str, + choices=["distributed", "regular"], + default="regular", + help="Shard type", + ) + parser.add_argument("--returncode", type=int, default=0, help="pytest return code") + parser.add_argument("--duration", type=float, default=0.0, help="Execution duration in seconds") + parser.add_argument("--output-stats", type=str, help="Output file for stats JSON") + parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output") + return parser.parse_args() + + +def main(): + """CLI entry point.""" + args = parse_args() + + report_dir = Path(args.report_dir).resolve() + if not report_dir.exists(): + print(f"Error: Report directory not found: {report_dir}") + sys.exit(1) + + # Parse results + stats, log_metrics = parse_shard_results( + report_dir=report_dir, + shard=args.shard, + shard_type=args.shard_type, + returncode=args.returncode, + duration=args.duration, + ) + + # Output + if args.output_stats: + output_path = Path(args.output_stats) + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path.write_text(json.dumps(stats, indent=2), encoding="utf-8") + print(f"Stats saved to: {output_path}") + + if args.verbose: + print(json.dumps(stats, indent=2)) + + print_stats_summary(args.shard, stats, args.shard_type) + + # Exit with appropriate code + sys.exit(stats.get("returncode", 0)) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/.github/scripts/run_npu_test_shard.py b/.github/scripts/run_npu_test_shard.py new file mode 100644 index 0000000000..d85bed23c7 --- /dev/null +++ b/.github/scripts/run_npu_test_shard.py @@ -0,0 +1,1294 @@ +#!/usr/bin/env python3 +""" +Run PyTorch NPU tests via per-case isolation pytest execution. + +This script executes pre-collected test cases or specified test files +with per-case subprocess isolation for crash safety. + +Execution modes: + - Pre-collected cases (--cases-json): Execute cases from JSON file + - Custom test files (--test-files): Execute specified test files + +Each case runs in its own pytest subprocess for isolation: + - NPU kernel crashes won't cascade to other cases + - Results recorded in cases.json file + +Test types: + - distributed: Serial execution (one case at a time) + - regular: Concurrent execution (multiple workers) + +Usage: + # Pre-collected cases mode (primary usage): + python run_npu_test_shard.py \ + --cases-json distributed_cases_shard_1.json \ + --test-dir /path/to/pytorch/test \ + --disabled-testcases /path/to/disabled_testcases.json \ + --report-dir test-reports \ + --timeout 1200 \ + --max-workers 64 \ + --verbose + + # Custom test files mode: + python run_npu_test_shard.py \ + --test-files test_meta.py,test_nn.py \ + --test-dir /path/to/pytorch/test \ + --disabled-testcases /path/to/disabled_testcases.json \ + --report-dir test-reports \ + --timeout 1200 \ + --max-workers 4 \ + --verbose + +Note: Shard discovery mode (--shard/--num-shards/--test-type) has been removed. + Use collect_all_cases.py for case discovery and sharding. +""" + +import argparse +import dataclasses +import importlib.util +import json +import os +import subprocess +import sys +import threading +import xml.etree.ElementTree as ET +from concurrent.futures import ThreadPoolExecutor, as_completed +from datetime import datetime +from pathlib import Path +from queue import Queue +from time import monotonic +from typing import Dict, List, Tuple + +import collect_all_cases + + +# ============================================================================== +# Import Result Parser Module +# ============================================================================== + + +def load_parse_test_results_module(script_dir: Path): + """Load parse_test_results module dynamically.""" + module_path = script_dir / "parse_test_results.py" + if not module_path.exists(): + raise FileNotFoundError(f"parse_test_results.py not found at {module_path}") + + spec = importlib.util.spec_from_file_location("parse_test_results", str(module_path)) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +# ============================================================================== +# Data Classes +# ============================================================================== + + +@dataclasses.dataclass +class CaseExecutionTask: + """Task for concurrent case execution.""" + case_idx: int + nodeid: str + test_file: str + file_idx: int + + +@dataclasses.dataclass +class ConcurrentExecutionConfig: + """Configuration for concurrent execution.""" + max_workers: int = 4 + per_case_timeout: int = 1200 + verbose: bool = False + + +# ============================================================================== +# Case Log Saving Functions +# ============================================================================== + + +def sanitize_nodeid_for_filename(nodeid: str) -> str: + """ + Convert nodeid to a safe filename. + + Replaces special characters with underscores and truncates if too long. + Invalid characters for NTFS/filesystems: " : < > | * ? \r \n + """ + # Replace special characters (including NTFS-invalid chars) + safe_name = nodeid.replace("::", "_").replace("/", "_").replace("\\", "_") + safe_name = safe_name.replace("(", "_").replace(")", "_").replace("[", "_").replace("]", "_") + # NTFS-invalid characters that GitHub Actions artifact upload rejects + safe_name = safe_name.replace("<", "_lt_").replace(">", "_gt_") + safe_name = safe_name.replace('"', "_quot_").replace("|", "_pipe_") + safe_name = safe_name.replace("*", "_star_").replace("?", "_q_") + safe_name = safe_name.replace(":", "_colon_") + safe_name = safe_name.replace(" ", "_") + safe_name = safe_name.replace(".", "_") + + # Remove leading underscores and collapse multiple underscores + while safe_name.startswith("_"): + safe_name = safe_name[1:] + while "__" in safe_name: + safe_name = safe_name.replace("__", "_") + + # Truncate if too long (max 200 chars) + if len(safe_name) > 200: + safe_name = safe_name[:200] + + return safe_name or "unknown_case" + + +def save_case_log( + report_dir: Path, + shard: int, + shard_type: str, + nodeid: str, + case_idx: int, + status: str, + stdout: str, + stderr: str, + duration: float, + returncode: int, + command: str, +) -> Path: + """ + Save complete execution log for all test cases. + + Creates a dedicated log file containing: + - Case metadata (nodeid, status, duration, returncode) + - Full stdout and stderr output + - Execution command + + Returns: + Path to the saved log file + """ + # Create cases log directory + cases_logs_dir = report_dir / "cases_logs" + cases_logs_dir.mkdir(parents=True, exist_ok=True) + + # Generate safe filename + safe_name = sanitize_nodeid_for_filename(nodeid) + prefix = "dist" if shard_type == "distributed" else "reg" + log_filename = f"{prefix}-{shard}_{case_idx}_{safe_name}.log" + log_path = cases_logs_dir / log_filename + + # Write log content + content_lines = [ + "=" * 80, + f"CASE LOG", + "=" * 80, + f"Shard: {prefix}-{shard}", + f"Case Index: {case_idx}", + f"Nodeid: {nodeid}", + f"Status: {status}", + f"Duration: {duration:.2f}s", + f"Return Code: {returncode}", + f"Command: {command}", + "=" * 80, + "", + "STDOUT:", + "-" * 80, + stdout or "(empty)", + "", + "STDERR:", + "-" * 80, + stderr or "(empty)", + "", + "=" * 80, + ] + + log_path.write_text("\n".join(content_lines), encoding="utf-8") + return log_path + + +class ConcurrentResultAggregator: + """Thread-safe result aggregator for concurrent execution.""" + + def __init__(self): + self._lock = threading.Lock() + self._cases_list: List[Dict] = [] + self._worst_returncode: int = 0 + self._passed_count: int = 0 + self._failed_count: int = 0 + self._error_count: int = 0 + self._skipped_count: int = 0 + self._timeout_count: int = 0 + self._total_cases: int = 0 + + def add_case_result(self, case_result: Dict) -> None: + """Thread-safe add case result.""" + with self._lock: + self._cases_list.append(case_result) + self._total_cases += 1 + + status = case_result.get("status", "error") + if status == "passed": + self._passed_count += 1 + elif status == "failed": + self._failed_count += 1 + elif status == "skipped": + self._skipped_count += 1 + elif status == "timeout": + self._timeout_count += 1 + else: + # error + self._error_count += 1 + + # Track worst returncode (ignore skipped) + rc = case_result.get("returncode", 1) + if rc != 0: + if self._worst_returncode == 0: + self._worst_returncode = rc + + def get_sorted_cases(self) -> List[Dict]: + """Get cases sorted by case_idx.""" + with self._lock: + return sorted(self._cases_list, key=lambda x: x.get("case_idx", 0)) + + def get_summary(self) -> Dict: + """Get execution summary.""" + with self._lock: + return { + "total_cases": self._total_cases, + "passed_count": self._passed_count, + "failed_count": self._failed_count, + "error_count": self._error_count, + "skipped_count": self._skipped_count, + "timeout_count": self._timeout_count, + "worst_returncode": self._worst_returncode, + } + + +class ProgressTracker: + """Thread-safe progress tracker with real-time output.""" + + def __init__(self, total_tasks: int): + self._total_tasks = total_tasks + self._completed_tasks = 0 + self._lock = threading.Lock() + self._start_time = monotonic() + + def mark_completed(self, nodeid: str, status: str, duration: float) -> None: + """Mark task completed and print progress.""" + with self._lock: + self._completed_tasks += 1 + elapsed = monotonic() - self._start_time + progress_pct = (self._completed_tasks / self._total_tasks) * 100 + + # Status indicator + status_icon = { + "passed": "[PASS]", + "failed": "[FAIL]", + "error": "[ERR]", + "timeout": "[TIME]", + "skipped": "[SKIP]", + }.get(status, "[?]") + + # Truncate nodeid for display + display_nodeid = nodeid[:60] + "..." if len(nodeid) > 60 else nodeid + + print(f"[{self._completed_tasks}/{self._total_tasks}] {progress_pct:.1f}% " + f"{status_icon} {display_nodeid} ({duration:.1f}s) " + f"[elapsed: {elapsed:.0f}s]", flush=True) + + +# ============================================================================== +# JUnit XML Parsing for Accurate Status Detection +# ============================================================================== + + +def parse_junit_xml_status(xml_file: Path) -> Dict: + """ + 解析 JUnit XML 报告,获取测试状态。 + + Args: + xml_file: JUnit XML 文件路径 + + Returns: + Dict: {"status": "passed" | "skipped" | "failed" | "error" | "no_xml", "message": str} + """ + if not xml_file.exists(): + return {"status": "no_xml", "message": "XML file not generated"} + + try: + tree = ET.parse(str(xml_file)) + root = tree.getroot() + + for testcase in root.iter("testcase"): + result = {"status": "passed", "message": ""} + + # Check + skipped_elem = testcase.find("skipped") + if skipped_elem is not None: + result["status"] = "skipped" + result["message"] = skipped_elem.get("message", "") + return result + + # Check + failure_elem = testcase.find("failure") + if failure_elem is not None: + result["status"] = "failed" + result["message"] = failure_elem.get("message", "") + return result + + # Check + error_elem = testcase.find("error") + if error_elem is not None: + result["status"] = "error" + result["message"] = error_elem.get("message", "") + return result + + # No failure/error/skipped = passed + return result + + return {"status": "error", "message": "No testcase in XML"} + + except Exception: + return {"status": "no_xml", "message": "XML parse failed"} + + +# ============================================================================== +# Utility Functions +# ============================================================================== + + +def strip_test_prefix_and_suffix(test_path: str) -> str: + """Remove 'test/' prefix and '.py' suffix from path.""" + path = test_path + if path.startswith("test/"): + path = path[5:] + if path.endswith(".py"): + path = path[:-3] + return path + + +def load_installed_torch_root() -> str: + """Get installed torch root directory.""" + try: + import torch + return str(Path(torch.__file__).resolve().parent.parent) + except Exception as exc: + print(f"Warning: Failed to import torch: {exc}") + return "" + + +# ============================================================================== +# Concurrent Case Execution +# ============================================================================== + + +def run_single_case_concurrent( + task: CaseExecutionTask, + test_dir: Path, + merged_env: Dict[str, str], + config: ConcurrentExecutionConfig, + result_aggregator: ConcurrentResultAggregator, + progress_tracker: ProgressTracker, + log_queue: Queue, + report_dir: Path, + shard: int, + shard_type: str, +) -> Dict: + """ + Execute a single test case in subprocess (for concurrent execution). + + This function runs in ThreadPoolExecutor threads. Each call spawns + an independent subprocess for the test case. Core dumps and crashes + in the subprocess do NOT affect the main Python process or other + concurrent tasks. + + CRITICAL: This function must catch ALL exceptions and return a result + dict. It should NEVER raise exceptions to ThreadPoolExecutor level. + + Args: + task: Case execution task with nodeid and metadata + test_dir: PyTorch test directory + merged_env: Environment variables + config: Execution configuration + result_aggregator: Thread-safe result collector + progress_tracker: Thread-safe progress tracker + log_queue: Queue for log messages + + Returns: + Dict with case result (never raises exception) + """ + start_time = monotonic() + original_nodeid = task.nodeid + case_nodeid = task.nodeid + + # Strip test/ prefix for pytest execution + if case_nodeid.startswith("test/"): + case_nodeid = case_nodeid[5:] + + # Generate XML file path with descriptive name + prefix = "dist" if shard_type == "distributed" else "reg" + safe_case_name = sanitize_nodeid_for_filename(original_nodeid) + xml_filename = f"{prefix}-{shard}_{task.case_idx}_{safe_case_name}.xml" + xml_file = report_dir / "junit_xmls" / xml_filename + + command = [ + sys.executable, + "-m", + "pytest", + "--color=no", + "-ra", + "--tb=short", + case_nodeid, + f"--junitxml={xml_file}", + "--junit-prefix=", + ] + + if config.per_case_timeout > 0: + command.append(f"--timeout={config.per_case_timeout}") + + if config.verbose: + command.append("-vv") + else: + command.append("-v") + + command_str = " ".join(command) + + # Build per-case environment with test file directory in PYTHONPATH + # This enables imports of sibling modules (e.g., 'from model_registry import MLPModule') + case_env = merged_env.copy() + test_file = task.test_file + if test_file.startswith("test/"): + test_file_rel = test_file[5:] + else: + test_file_rel = test_file + + test_file_path = Path(test_file_rel) + test_file_dir = test_dir / test_file_path.parent + + existing_pythonpath = case_env.get("PYTHONPATH", "") + case_env["PYTHONPATH"] = str(test_file_dir) + (":" + existing_pythonpath if existing_pythonpath else "") + + # Print start log to stdout (before execution) + # Truncate nodeid for display + display_nodeid = original_nodeid[:70] + "..." if len(original_nodeid) > 70 else original_nodeid + print(f"[{task.case_idx}] Starting: {display_nodeid}", flush=True) + + # Log start + log_queue.put({ + "type": "case_start", + "case_idx": task.case_idx, + "nodeid": original_nodeid, + "file": task.test_file, + "command": command_str, + }) + + # Execute subprocess - CRITICAL: catch ALL exceptions + try: + result = subprocess.run( + command, + cwd=str(test_dir), + env=case_env, # Use per-case environment with test file directory in PYTHONPATH + capture_output=True, + text=True, + encoding="utf-8", + errors="replace", + timeout=config.per_case_timeout + 30, # Extra buffer + ) + + duration = monotonic() - start_time + returncode = result.returncode + + # Parse JUnit XML for status + # - Has XML: use XML status + # - No XML: error + xml_result = parse_junit_xml_status(xml_file) + xml_status = xml_result.get("status") + + if xml_status == "no_xml": + # No XML → error + status = "error" + message = xml_result.get("message") + else: + # Has XML → use XML status + status = xml_status + message = xml_result.get("message", "") + + # Save logs for all cases + save_case_log( + report_dir=report_dir, + shard=shard, + shard_type=shard_type, + nodeid=original_nodeid, + case_idx=task.case_idx, + status=status, + stdout=result.stdout, + stderr=result.stderr, + duration=duration, + returncode=returncode, + command=command_str, + ) + + case_result = { + "nodeid": original_nodeid, + "status": status, + "duration": duration, + "returncode": returncode, + "message": message, + "command": command_str, + "file": task.test_file, + "case_idx": task.case_idx, + } + + except subprocess.TimeoutExpired: + # Timeout → no XML, status = timeout + duration = monotonic() - start_time + status = "timeout" + case_result = { + "nodeid": original_nodeid, + "status": status, + "duration": duration, + "returncode": -1, + "message": f"Timeout after {config.per_case_timeout}s", + "command": command_str, + "file": task.test_file, + "case_idx": task.case_idx, + } + + # Save log for timeout + save_case_log( + report_dir=report_dir, + shard=shard, + shard_type=shard_type, + nodeid=original_nodeid, + case_idx=task.case_idx, + status=status, + stdout="(process timed out, no output captured)", + stderr="(process timed out, no output captured)", + duration=duration, + returncode=-1, + command=command_str, + ) + + except Exception as e: + # Any other exception - return result, don't raise + duration = monotonic() - start_time + case_result = { + "nodeid": original_nodeid, + "status": "error", + "duration": duration, + "returncode": 1, + "message": f"Unexpected error: {str(e)[:200]}", + "command": command_str, + "file": task.test_file, + "case_idx": task.case_idx, + } + + # Save error case log + save_case_log( + report_dir=report_dir, + shard=shard, + shard_type=shard_type, + nodeid=original_nodeid, + case_idx=task.case_idx, + status="error", + stdout="(exception occurred before execution)", + stderr=str(e), + duration=duration, + returncode=1, + command=command_str, + ) + + # Log finish + log_queue.put({ + "type": "case_finish", + "case_idx": task.case_idx, + "nodeid": original_nodeid, + "status": case_result["status"], + "duration": case_result["duration"], + "message": case_result["message"][:200] if case_result["message"] else "", + }) + + # Update aggregator (thread-safe) + result_aggregator.add_case_result(case_result) + + # Update progress (thread-safe) + progress_tracker.mark_completed(original_nodeid, case_result["status"], duration) + + return case_result + + +def log_writer_thread(log_queue: Queue, log_file: Path, stop_event: threading.Event) -> None: + """ + Background thread for writing logs. + + Ensures thread-safe log file writes while concurrent tasks run. + """ + with log_file.open("w", encoding="utf-8") as log_handle: + while not stop_event.is_set() or not log_queue.empty(): + try: + log_entry = log_queue.get(timeout=0.5) + except: + continue + + if log_entry.get("type") == "header": + log_handle.write(log_entry.get("content", "")) + log_handle.flush() + elif log_entry.get("type") == "case_start": + log_handle.write(f"\n[{log_entry['case_idx']}] {log_entry['nodeid']}\n") + log_handle.write(f" File: {log_entry.get('file', '')}\n") + log_handle.write(f" Command: {log_entry.get('command', '')}\n") + log_handle.flush() + elif log_entry.get("type") == "case_finish": + status_str = log_entry.get("status", "") + duration_str = f"{log_entry.get('duration', 0):.2f}s" + log_handle.write(f" Status: {status_str}, Duration: {duration_str}\n") + if log_entry.get("message"): + log_handle.write(f" Message: {log_entry['message']}\n") + log_handle.flush() + elif log_entry.get("type") == "summary": + log_handle.write(log_entry.get("content", "")) + log_handle.flush() + + +def run_tests_with_tasks_concurrent( + tasks: List[CaseExecutionTask], + shard: int, + test_dir: Path, + report_dir: Path, + env_updates: Dict[str, str], + timeout: int, + verbose: bool, + shard_type: str, + max_workers: int, + result_module, + quick_test: int = None, +) -> Tuple[int, float, List[Dict]]: + """ + Execute pre-collected test cases with concurrent per-case isolation. + + This function takes CaseExecutionTask objects directly (pre-collected cases) + and executes them concurrently without the file-level case collection phase. + + Args: + tasks: List of CaseExecutionTask objects (pre-collected cases) + shard: Shard number + test_dir: PyTorch test directory + report_dir: Report output directory + env_updates: Environment variable updates + timeout: Per-case timeout in seconds + verbose: Verbose output + shard_type: "distributed" or "regular" + max_workers: Maximum concurrent subprocesses + result_module: parse_test_results module + quick_test: Maximum number of cases to execute (None = all cases) + + Returns: + Tuple of (worst_returncode, duration, cases_list_sorted) + """ + start = monotonic() + log_file = result_module.get_shard_log_file(report_dir, shard, shard_type) + + # Create junit_xmls directory for XML reports + junit_xml_dir = report_dir / "junit_xmls" + junit_xml_dir.mkdir(parents=True, exist_ok=True) + + merged_env = os.environ.copy() + merged_env.update(env_updates) + + config = ConcurrentExecutionConfig( + max_workers=max_workers, + per_case_timeout=timeout, + verbose=verbose, + ) + + # Thread-safe result aggregator + result_aggregator = ConcurrentResultAggregator() + + # Log queue and writer thread + log_queue = Queue() + stop_event = threading.Event() + log_thread = threading.Thread( + target=log_writer_thread, + args=(log_queue, log_file, stop_event), + daemon=True, + ) + + # Write log header + log_queue.put({ + "type": "header", + "content": ( + "=" * 80 + "\n" + f"Pre-collected cases concurrent execution ({shard_type} shard)\n" + "=" * 80 + "\n" + f"Total cases: {len(tasks)}\n" + f"Max concurrent workers: {max_workers}\n" + "Execution mode: concurrent subprocess, each case isolated\n" + "=" * 80 + "\n\n" + ), + }) + + log_thread.start() + + # Quick test: limit number of cases to execute + if quick_test and len(tasks) > quick_test: + tasks = tasks[:quick_test] + print(f"\nQuick test mode: executing only {quick_test} cases", flush=True) + + print(f"\n{'=' * 80}", flush=True) + print(f"Pre-collected cases: {len(tasks)} cases", flush=True) + print(f"Execution mode: {max_workers} workers concurrent, each case in subprocess", flush=True) + print(f"{'=' * 80}\n", flush=True) + + total_cases = len(tasks) + print(f"Phase 1: Executing {total_cases} pre-collected cases...", flush=True) + + # Phase 2: Concurrent execution via ThreadPoolExecutor + progress_tracker = ProgressTracker(total_cases) + + with ThreadPoolExecutor(max_workers=max_workers) as executor: + # Submit all tasks + future_to_task = { + executor.submit( + run_single_case_concurrent, + task, + test_dir, + merged_env, + config, + result_aggregator, + progress_tracker, + log_queue, + report_dir, + shard, + shard_type, + ): task + for task in tasks + } + + # Wait for completion (as_completed gives results as they finish) + for future in as_completed(future_to_task): + task = future_to_task[future] + try: + # Result already collected in aggregator + _ = future.result() + except Exception as e: + # Should never happen (run_single_case_concurrent catches all) + # But as safety, create error result + case_result = { + "nodeid": task.nodeid, + "status": "error", + "duration": 0.0, + "returncode": 1, + "message": f"Future error: {str(e)[:200]}", + "file": task.test_file, + "case_idx": task.case_idx, + } + result_aggregator.add_case_result(case_result) + progress_tracker.mark_completed(task.nodeid, "error", 0.0) + + # Stop log thread + elapsed = monotonic() - start + summary = result_aggregator.get_summary() + + log_queue.put({ + "type": "summary", + "content": ( + f"\n{'=' * 80}\n" + f"Summary: {summary['total_cases']} cases executed\n" + f" Passed: {summary['passed_count']}\n" + f" Failed: {summary['failed_count']}\n" + f" Errors: {summary['error_count']}\n" + f" Timeout: {summary['timeout_count']}\n" + f" Skipped: {summary['skipped_count']}\n" + f" Duration: {elapsed:.2f}s\n" + f" Concurrent workers: {max_workers}\n" + f"{'=' * 80}\n" + ), + }) + + stop_event.set() + log_thread.join(timeout=5) + + # Print final summary + print(f"\n{'=' * 80}", flush=True) + print(f"Summary: {summary['total_cases']} cases executed", flush=True) + print(f" Passed: {summary['passed_count']}", flush=True) + print(f" Failed: {summary['failed_count']}", flush=True) + print(f" Errors: {summary['error_count']}", flush=True) + print(f" Timeout: {summary['timeout_count']}", flush=True) + print(f" Skipped: {summary['skipped_count']}", flush=True) + print(f" Duration: {elapsed:.2f}s", flush=True) + print(f"{'=' * 80}", flush=True) + + return summary["worst_returncode"], elapsed, result_aggregator.get_sorted_cases() + + +def build_execution_env( + test_dir: Path, + script_dir: Path, + disabled_testcases_file: str, + shard: int, + shard_type: str, +) -> Dict[str, str]: + """Build environment variables for test execution.""" + repo_root = test_dir.parent + pythonpath_parts = [str(script_dir)] + + torch_path = load_installed_torch_root() + if torch_path: + pythonpath_parts.append(torch_path) + + pythonpath_parts.extend([str(repo_root), str(test_dir)]) + + existing_pythonpath = os.environ.get("PYTHONPATH", "") + if existing_pythonpath: + pythonpath_parts.append(existing_pythonpath) + + updates = { + "PYTHONPATH": os.pathsep.join(pythonpath_parts), + "PYTORCH_TEST_NPU": "1", + "TORCH_DEVICE_BACKEND_AUTOLOAD": "1", + "NO_TD": "1", + "PYTHONUNBUFFERED": "1", + # Note: Do NOT set CI=true here, as some test files have conditional + # test generation logic like: + # if not (IS_CI and torch.cuda.is_available()): + # globals().update(generate_tests(...)) + # Setting CI=true would prevent test case generation in those files. + } + + # Use PyTorch's built-in DISABLED_TESTS_FILE mechanism for skipping test cases + if disabled_testcases_file: + # The disabled_testcases.json format is similar to .pytorch-disabled-tests.json + # Set DISABLED_TESTS_FILE to use PyTorch's built-in skip mechanism + updates["DISABLED_TESTS_FILE"] = os.path.abspath(disabled_testcases_file) + + return updates + + +def clean_existing_junit_xml(report_dir: Path) -> None: + """Clean existing JUnit XML files.""" + if not report_dir.exists(): + return + for xml_file in report_dir.rglob("*.xml"): + xml_file.unlink(missing_ok=True) + + +def remove_existing_file(path: Path) -> None: + """Remove existing file.""" + path.unlink(missing_ok=True) + + +# ============================================================================== +# Test Files Input Parser +# ============================================================================== + + +def parse_test_files_input(test_files_str: str, test_dir: Path) -> List[str]: + """ + Parse comma-separated test file input and return standardized test file paths. + + Args: + test_files_str: Comma-separated test file paths (e.g., "test_meta.py,test_nn.py") + test_dir: Path to PyTorch test directory + + Returns: + List of standardized test file paths (e.g., ["test/test_meta.py", "test/test_nn.py"]) + + Raises: + FileNotFoundError: If any specified test file does not exist + """ + files = [f.strip() for f in test_files_str.split(",") if f.strip()] + result = [] + + for f in files: + # Normalize path format: ensure starts with "test/" + if not f.startswith("test/"): + f = "test/" + f + + # Remove leading "test/" prefix if it's duplicated + if f.startswith("test/test/"): + f = f[5:] + + # Verify file exists + full_path = test_dir.parent / f + if not full_path.exists(): + # Try with .py extension if not provided + if not f.endswith(".py"): + f_with_ext = f + ".py" + full_path_with_ext = test_dir.parent / f_with_ext + if full_path_with_ext.exists(): + f = f_with_ext + full_path = full_path_with_ext + else: + raise FileNotFoundError(f"Test file not found: {f} or {f_with_ext}") + else: + raise FileNotFoundError(f"Test file not found: {f}") + + result.append(f) + + return result + + +# ============================================================================== +# CLI +# ============================================================================== + + +def parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="Run PyTorch NPU tests via per-case isolation pytest execution" + ) + parser.add_argument("--test-files", type=str, help="Comma-separated test file paths to run directly (e.g., 'test_meta.py,test_nn.py')") + parser.add_argument("--cases-json", type=str, help="Path to pre-collected cases JSON file") + parser.add_argument("--test-dir", type=str, required=True, help="Path to PyTorch test directory") + parser.add_argument("--disabled-testcases", type=str, help="Path to disabled_testcases.json") + parser.add_argument("--report-dir", type=str, default="test-reports", help="Directory for reports") + parser.add_argument("--timeout", type=int, default=1200, help="Per-case timeout in seconds (default: 1200 = 20 minutes)") + parser.add_argument( + "--max-workers", + type=int, + default=4, + help="Maximum concurrent workers for regular tests (default: 4). Each worker runs one pytest subprocess.", + ) + parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output") + parser.add_argument("--quick-test", type=int, default=None, help="Quick test mode: execute only N cases for fast verification (default: None, run all cases)") + args = parser.parse_args() + + # Validate required arguments: must specify either --test-files or --cases-json + if not args.test_files and not args.cases_json: + parser.error("Either --test-files or --cases-json must be specified") + + # Validate max_workers + if args.max_workers < 1: + parser.error("--max-workers must be at least 1") + if args.max_workers > 128: + print(f"WARNING: --max-workers={args.max_workers} is very high, may cause resource contention") + + return args + + +def main(): + """Main entry point.""" + args = parse_args() + + # Resolve paths + test_dir = Path(args.test_dir).resolve() + if not test_dir.is_dir(): + raise FileNotFoundError(f"Test directory not found: {test_dir}") + + repo_root = test_dir.parent + script_dir = Path(__file__).resolve().parent + report_dir = Path(args.report_dir).resolve() + report_dir.mkdir(parents=True, exist_ok=True) + + # Load modules + result_module = load_parse_test_results_module(script_dir) + + timestamp = datetime.now().isoformat() + + # ========================================================================== + # Mode: Direct execution of specified test files + # ========================================================================== + if args.test_files: + print("=" * 80) + print("Custom Test Files Execution Mode") + print("=" * 80) + + # Parse test files input + planned_tests = parse_test_files_input(args.test_files, test_dir) + + # Use fixed shard number for custom mode + shard = 1 + num_shards = 1 + shard_type = "custom" + + print(f"Test files specified: {len(planned_tests)}") + print(f"Test directory: {test_dir}") + print(f"Execution mode: concurrent ({args.max_workers} workers, per-case subprocess isolation)") + if args.disabled_testcases: + disabled_count = result_module.load_disabled_testcases_count(args.disabled_testcases) + print(f"Disabled testcase entries: {disabled_count}") + print(f"\n{'=' * 80}\n") + + for index, target in enumerate(planned_tests, 1): + display_name = strip_test_prefix_and_suffix(target) + print(f" [{index:03d}] {display_name}") + + # Create info dict for custom mode + info = result_module.create_shard_info(shard, num_shards, timestamp) + info["selection_mode"] = "custom_files" + info["shard_type"] = shard_type + info["shard_files"] = len(planned_tests) + info["total_files"] = len(planned_tests) + info["selected_test_files"] = len(planned_tests) + if args.disabled_testcases: + info["disabled_count"] = result_module.load_disabled_testcases_count(args.disabled_testcases) + + # Save test plan + result_module.save_test_plan_file(str(report_dir), shard, planned_tests, shard_type) + + # Clean old files + clean_existing_junit_xml(report_dir) + remove_existing_file(result_module.get_shard_log_file(report_dir, shard, shard_type)) + + # Build execution env + env_updates = build_execution_env( + test_dir, script_dir, args.disabled_testcases, shard, shard_type + ) + + # Execute tests (custom mode uses concurrent execution by default) + cases_list = [] + if planned_tests: + # Phase 1: Collect all test cases using collect_all_cases module + print("\nPhase 1: Collecting test cases...") + error_log_dir = report_dir / "collection_errors" + collected_cases = collect_all_cases.collect_all_cases( + planned_tests, + test_dir, + error_log_dir, + parallel=16, + ) + + # Apply quick_test limit if specified + if args.quick_test and len(collected_cases) > args.quick_test: + collected_cases = collected_cases[:args.quick_test] + print(f" Quick test mode: using only {args.quick_test} cases") + + total_cases = len(collected_cases) + print(f"\nPhase 2: Executing {total_cases} cases with {args.max_workers} workers") + + # Build CaseExecutionTask list + tasks = [] + for i, case in enumerate(collected_cases, 1): + tasks.append(CaseExecutionTask( + case_idx=i, + nodeid=case["nodeid"], + test_file=case["file"], + file_idx=0, # Not needed for pre-collected cases + )) + + # Phase 2: Execute cases using run_tests_with_tasks_concurrent + returncode, duration, cases_list = run_tests_with_tasks_concurrent( + tasks, + shard, + test_dir, + report_dir, + env_updates, + args.timeout, + args.verbose, + shard_type, + args.max_workers, + result_module, + args.quick_test, + ) + info["per_case_isolation"] = True + info["concurrent_workers"] = args.max_workers + info["returncode"] = returncode + info["duration"] = duration + else: + returncode = 0 + duration = 0.0 + + # Build cases.json data + passed_count = sum(1 for c in cases_list if c["status"] == "passed") + failed_count = sum(1 for c in cases_list if c["status"] == "failed") + error_count = sum(1 for c in cases_list if c["status"] == "error") + timeout_count = sum(1 for c in cases_list if c["status"] == "timeout") + skipped_count = sum(1 for c in cases_list if c["status"] == "skipped") + + cases_data = { + "shard": shard, + "shard_type": shard_type, + "execution_mode": "concurrent", + "concurrent_workers": args.max_workers, + "total_cases": len(cases_list), + "passed": passed_count, + "failed": failed_count, + "errors": error_count, + "timeout": timeout_count, + "skipped": skipped_count, + "duration": duration, + "cases": cases_list, + } + + # Save cases.json + result_module.save_cases_file(str(report_dir), shard, cases_data, shard_type) + + # Save info and stats + result_module.save_info_file(str(report_dir), shard, info, shard_type) + + stats = { + "total": len(cases_list), + "passed": passed_count, + "failed": failed_count, + "skipped": skipped_count, + "errors": error_count, + "timeout": timeout_count, + "duration": duration, + "returncode": returncode, + "per_case_isolation": True, + } + + result_module.save_stats_file(str(report_dir), shard, stats, shard_type) + + # Print summary + result_module.print_stats_summary(shard, stats, shard_type) + + # Exit with 0 to allow step to succeed and report generation to proceed + # The actual test results are recorded in cases.json + sys.exit(0) + + # ========================================================================== + # Mode: Pre-collected cases JSON execution + # ========================================================================== + if args.cases_json: + print("=" * 80) + print("Pre-collected Cases Execution Mode") + print("=" * 80) + + cases_file = Path(args.cases_json).resolve() + if not cases_file.exists(): + raise FileNotFoundError(f"Cases JSON file not found: {cases_file}") + + cases_data = json.loads(cases_file.read_text(encoding="utf-8")) + + shard = cases_data["shard"] + num_shards = cases_data["num_shards"] + shard_type = cases_data.get("test_type", "regular") + planned_cases = cases_data["cases"] + total_cases = len(planned_cases) + + print(f"Cases JSON: {cases_file}") + print(f"Shard: {shard}/{num_shards}") + print(f"Test type: {shard_type}") + print(f"Total cases: {total_cases}") + print(f"Test directory: {test_dir}") + + # Execution mode based on test_type + if shard_type == "distributed": + print(f"Execution mode: SERIAL (per-case subprocess isolation)") + else: + print(f"Execution mode: CONCURRENT ({args.max_workers} workers, per-case subprocess isolation)") + + if args.disabled_testcases: + disabled_count = result_module.load_disabled_testcases_count(args.disabled_testcases) + print(f"Disabled testcase entries: {disabled_count}") + + print(f"\n{'=' * 80}\n") + + # Create info dict for cases-json mode + info = result_module.create_shard_info(shard, num_shards, timestamp) + info["selection_mode"] = "cases_json" + info["shard_type"] = shard_type + info["cases_json_file"] = str(cases_file) + info["total_cases"] = total_cases + info["per_case_isolation"] = True + if args.disabled_testcases: + info["disabled_count"] = result_module.load_disabled_testcases_count(args.disabled_testcases) + + # Clean old files + clean_existing_junit_xml(report_dir) + remove_existing_file(result_module.get_shard_log_file(report_dir, shard, shard_type)) + + # Build execution env + env_updates = build_execution_env( + test_dir, script_dir, args.disabled_testcases, shard, shard_type + ) + + # Convert cases to CaseExecutionTask format + tasks = [] + for i, case in enumerate(planned_cases, 1): + tasks.append(CaseExecutionTask( + case_idx=i, + nodeid=case["nodeid"], + test_file=case.get("file", ""), + file_idx=0, + )) + + # Execute tests based on shard_type + cases_list = [] + if tasks: + # Determine execution mode and worker count + if shard_type == "distributed": + # Distributed: serial execution (1 worker) + effective_workers = 1 + print(f"\nExecution mode: SERIAL (distributed tests require sequential execution)") + else: + # Regular: concurrent execution + effective_workers = args.max_workers + print(f"\nExecution mode: CONCURRENT ({effective_workers} workers)") + + # Execute tasks directly using the new function + returncode, duration, cases_list = run_tests_with_tasks_concurrent( + tasks, + shard, + test_dir, + report_dir, + env_updates, + args.timeout, + args.verbose, + shard_type, + effective_workers, + result_module, + args.quick_test, + ) + info["execution_mode"] = "serial" if effective_workers == 1 else "concurrent" + info["concurrent_workers"] = effective_workers + + info["returncode"] = returncode + info["duration"] = duration + else: + print("No cases to execute.") + returncode = 0 + duration = 0.0 + + # Build cases.json data + passed_count = sum(1 for c in cases_list if c["status"] == "passed") + failed_count = sum(1 for c in cases_list if c["status"] == "failed") + error_count = sum(1 for c in cases_list if c["status"] == "error") + timeout_count = sum(1 for c in cases_list if c["status"] == "timeout") + skipped_count = sum(1 for c in cases_list if c["status"] == "skipped") + + output_cases_data = { + "shard": shard, + "shard_type": shard_type, + "execution_mode": info.get("execution_mode", "unknown"), + "concurrent_workers": info.get("concurrent_workers", 1), + "total_cases": len(cases_list), + "passed": passed_count, + "failed": failed_count, + "errors": error_count, + "timeout": timeout_count, + "skipped": skipped_count, + "duration": duration, + "cases": cases_list, + } + + # Save cases.json + result_module.save_cases_file(str(report_dir), shard, output_cases_data, shard_type) + + # Save info and stats + result_module.save_info_file(str(report_dir), shard, info, shard_type) + + stats = { + "total": len(cases_list), + "passed": passed_count, + "failed": failed_count, + "skipped": skipped_count, + "errors": error_count, + "timeout": timeout_count, + "duration": duration, + "returncode": returncode, + "per_case_isolation": True, + } + + result_module.save_stats_file(str(report_dir), shard, stats, shard_type) + + # Print summary + result_module.print_stats_summary(shard, stats, shard_type) + + # Exit with 0 to allow step to succeed and report generation to proceed + # The actual test results are recorded in cases.json + sys.exit(0) + + # No valid mode specified (should not reach here due to argument validation) + print("ERROR: Either --test-files or --cases-json must be specified") + sys.exit(1) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/.github/workflows/_torch-npu-upstream-build.yml b/.github/workflows/_torch-npu-upstream-build.yml new file mode 100644 index 0000000000..b63ddc8f3a --- /dev/null +++ b/.github/workflows/_torch-npu-upstream-build.yml @@ -0,0 +1,416 @@ +name: Build PyTorch and torch_npu (with cache) + +on: + workflow_call: + inputs: + pytorch_ref: + required: true + type: string + description: 'PyTorch branch, tag, or commit SHA to build' + torch_npu_ref: + required: true + type: string + description: 'torch_npu branch, tag, or commit SHA to build' + python_version: + required: true + type: string + default: '3.11' + docker_image: + required: true + type: string + default: 'quay.io/kerer/pytorch:manylinux-cann9.0.0-beta.2-20260428' + description: 'Docker image URL to use for build' + outputs: + docker-image: + description: 'Full Docker image URL' + value: ${{ inputs.docker_image }} + torch-wheel: + description: 'PyTorch wheel artifact name' + value: 'torch-wheel-main' + torch-npu-wheel: + description: 'torch_npu wheel artifact name' + value: 'torch-npu-wheel-main' + pytorch-src: + description: 'PyTorch source and test code artifact name' + value: 'pytorch-src-main' + pytorch-version: + description: 'PyTorch version string' + value: ${{ jobs.build.outputs.pytorch-version }} + +env: + # 缓存版本号,当需要强制刷新缓存时修改此值 + CACHE_VERSION: 'v2' + # GitHub 代理 URL(用于加速 git clone,留空则不使用代理) + GH_PROXY_URL: 'https://gh-proxy.test.osinfra.cn' + # PyPI 缓存 URL(用于加速 pip 下载) + PYPI_CACHE_URL: 'http://cache-service.nginx-pypi-cache.svc.cluster.local/pypi/simple' + +jobs: + build: + runs-on: linux-aarch64-a3-16 + timeout-minutes: 240 + outputs: + pytorch-version: ${{ steps.get_version.outputs.pytorch_version }} + + container: + image: ${{ inputs.docker_image }} + options: --user root + + steps: + - name: Display Docker image + run: | + echo "Using Docker image: ${{ inputs.docker_image }}" + + - name: Setup CANN environment + run: | + source /usr/local/Ascend/cann/set_env.sh 2>/dev/null || true + source /usr/local/Ascend/nnal/atb/set_env.sh 2>/dev/null || true + + - name: Configure git proxy for faster clone + run: | + # 配置 git URL rewrite 来使用代理(加速 clone 和 submodules) + if [ -n "${{ env.GH_PROXY_URL }}" ]; then + git config --global url."${{ env.GH_PROXY_URL }}/https://github.com/".insteadOf "https://github.com/" + git config --global url."${{ env.GH_PROXY_URL }}/https://gitlab.com/".insteadOf "https://gitlab.com/" + echo "Git proxy configured:" + git config --global --list | grep url + else + echo "No proxy configured, using direct connection" + fi + + - name: Clone upstream PyTorch with submodules + id: clone_pytorch + run: | + # 使用代理加速 git clone(如果配置了 GH_PROXY_URL) + PYTORCH_REPO="https://github.com/pytorch/pytorch.git" + if [ -n "${{ env.GH_PROXY_URL }}" ]; then + PYTORCH_REPO="${{ env.GH_PROXY_URL }}/${PYTORCH_REPO}" + echo "Using proxy: ${PYTORCH_REPO}" + fi + + # 克隆指定 ref(branch, tag, 或 commit) + PYTORCH_REF="${{ inputs.pytorch_ref }}" + echo "Cloning PyTorch with ref: ${PYTORCH_REF}" + + # 先浅克隆,再 fetch 指定 ref,最后 checkout + git clone --depth=1 "${PYTORCH_REPO}" pytorch-src + cd pytorch-src + git fetch --depth=1 origin "${PYTORCH_REF}" + git checkout "${PYTORCH_REF}" + + # 初始化 submodules + git submodule update --init --recursive + + PYTORCH_SHA=$(git rev-parse HEAD) + PYTORCH_SHA_SHORT=$(git rev-parse --short HEAD) + echo "pytorch_sha=${PYTORCH_SHA}" >> $GITHUB_OUTPUT + echo "pytorch_sha_short=${PYTORCH_SHA_SHORT}" >> $GITHUB_OUTPUT + echo "Cloned PyTorch commit: ${PYTORCH_SHA}" + echo "Submodules downloaded:" + ls -la third_party/ | head -20 + + - name: Checkout torch_npu + uses: actions/checkout@v4 + with: + repository: Ascend/pytorch + ref: ${{ inputs.torch_npu_ref }} + path: torch_npu-src + submodules: recursive + + # ==================== pip 缓存配置 ==================== + # pip 缓存加速依赖下载,不影响构建结果 + # 缓存键基于 requirements-build.txt hash(依赖变化频率低) + - name: Get pip cache key + id: pip_key + run: | + REQUIREMENTS_HASH=$(cd pytorch-src && sha256sum requirements-build.txt | cut -d' ' -f1) + echo "cache_key=${{ env.CACHE_VERSION }}-pip-${{ inputs.python_version }}-${REQUIREMENTS_HASH}" >> $GITHUB_OUTPUT + + - name: Restore pip cache + uses: actions/cache/restore@v4 + with: + path: /root/.cache/pip + key: ${{ steps.pip_key.outputs.cache_key }} + restore-keys: | + ${{ env.CACHE_VERSION }}-pip-${{ inputs.python_version }}- + ${{ env.CACHE_VERSION }}-pip- + + - name: Setup pip cache directory + run: | + mkdir -p /root/.cache/pip + + - name: Configure pip index URL + run: | + # 配置 pip 使用 PyPI 缓存加速下载 + if [ -n "${{ env.PYPI_CACHE_URL }}" ]; then + pip${{ inputs.python_version }} config set global.index-url ${{ env.PYPI_CACHE_URL }} + pip${{ inputs.python_version }} config set global.trusted-host "cache-service.nginx-pypi-cache.svc.cluster.local" + echo "pip index-url configured: ${{ env.PYPI_CACHE_URL }}" + else + echo "No PyPI cache URL configured, using default" + fi + + - name: Upgrade pip and setuptools + run: | + export PIP_CACHE_DIR=/root/.cache/pip + # 先升级 pip 和 setuptools,避免旧版包兼容性问题 + pip${{ inputs.python_version }} install --upgrade pip setuptools wheel + + # ==================== ccache 缓存配置 ==================== + # ccache 是真正加速编译的关键(可节省 30-60 分钟) + # 注意:PyTorch 每次 clone 都是新 commit,所以缓存键不包含 PyTorch SHA + # 我们依赖 torch_npu SHA 和 requirements-build.txt hash 作为缓存键 + - name: Get ccache key + id: ccache_key + run: | + # ccache 缓存键:torch_npu SHA + requirements hash + # PyTorch SHA 每次都变化(--depth=1 clone 最新),所以不包含在缓存键中 + TORCH_NPU_SHA=$(cd torch_npu-src && git rev-parse HEAD) + REQUIREMENTS_HASH=$(cd pytorch-src && sha256sum requirements-build.txt | cut -d' ' -f1) + echo "cache_key=${{ env.CACHE_VERSION }}-ccache-${REQUIREMENTS_HASH}-${TORCH_NPU_SHA}" >> $GITHUB_OUTPUT + # partial_key 用于恢复同版本 requirements 的缓存(不同 torch_npu 版本) + echo "partial_key=${{ env.CACHE_VERSION }}-ccache-${REQUIREMENTS_HASH}-" >> $GITHUB_OUTPUT + # base_key 用于恢复同 CACHE_VERSION 的所有缓存 + echo "base_key=${{ env.CACHE_VERSION }}-ccache-" >> $GITHUB_OUTPUT + + - name: Restore ccache + uses: actions/cache/restore@v4 + with: + path: /root/.cache/ccache + key: ${{ steps.ccache_key.outputs.cache_key }} + restore-keys: | + ${{ steps.ccache_key.outputs.partial_key }} + ${{ steps.ccache_key.outputs.base_key }} + + - name: Setup ccache + run: | + # 安装 ccache(manylinux 镜像没有预装) + yum install -y ccache + + # 创建 ccache 配置目录(使用绝对路径) + CCACHE_DIR_PATH="/root/.cache/ccache" + mkdir -p "$CCACHE_DIR_PATH" + + # 直接写入配置文件(使用绝对路径) + cat > "$CCACHE_DIR_PATH/ccache.conf" << 'EOF' + max_size = 20G + cache_dir = /root/.cache/ccache + compression = true + compression_level = 6 + EOF + + # 使用符号链接方式让 ccache 模拟 gcc/g++ + mkdir -p /usr/local/bin + ln -sf /usr/bin/ccache /usr/local/bin/gcc + ln -sf /usr/bin/ccache /usr/local/bin/g++ + ln -sf /usr/bin/ccache /usr/local/bin/cc + ln -sf /usr/bin/ccache /usr/local/bin/c++ + + # 设置 PATH 优先使用符号链接 + echo "PATH=/usr/local/bin:$PATH" >> $GITHUB_ENV + + # 设置 CCACHE_DIR(使用绝对路径,不使用 ~) + echo "CCACHE_DIR=$CCACHE_DIR_PATH" >> $GITHUB_ENV + + # 设置编译器环境变量,确保 CMake/Ninja 使用 ccache + echo "CC=/usr/local/bin/gcc" >> $GITHUB_ENV + echo "CXX=/usr/local/bin/g++" >> $GITHUB_ENV + + echo "=== ccache Configuration ===" + CCACHE_DIR="$CCACHE_DIR_PATH" ccache --show-config + + echo "" + echo "=== Config File Contents ===" + cat "$CCACHE_DIR_PATH/ccache.conf" + + echo "" + echo "=== Cache Directory ===" + ls -la "$CCACHE_DIR_PATH/" + + echo "" + echo "=== Symbolic Links ===" + ls -la /usr/local/bin/gcc /usr/local/bin/g++ + + echo "" + echo "=== ccache Statistics (before build) ===" + CCACHE_DIR="$CCACHE_DIR_PATH" ccache --show-stats + + # ==================== 构建 PyTorch ==================== + - name: Build PyTorch wheel + run: | + source /usr/local/Ascend/cann/set_env.sh 2>/dev/null || true + + export PIP_CACHE_DIR=/root/.cache/pip + cd pytorch-src + + # 安装构建依赖(pip 缓存已恢复,加速下载) + pip${{ inputs.python_version }} install --upgrade pip setuptools wheel + pip${{ inputs.python_version }} install -r requirements-build.txt + + # 设置构建环境变量 + export MAX_JOBS=128 + export USE_CUDA=0 + export USE_CUDNN=0 + export USE_DISTRIBUTED=1 + export CMAKE_BUILD_TYPE=Release + export USE_OPENMP=1 + export USE_MKLDNN=0 + + # 确保使用 ccache(CMake 会检测 CC/CXX 环境变量) + export CC=/usr/local/bin/gcc + export CXX=/usr/local/bin/g++ + export CCACHE_DIR=/root/.cache/ccache + + # 清除 ccache 统计(开始新的构建) + ccache --zero-stats + + python${{ inputs.python_version }} setup.py build bdist_wheel + + echo "PyTorch wheel built:" + ls -la dist/ + + echo "" + echo "=== ccache Statistics (after PyTorch build) ===" + ccache --show-stats + + # ==================== 构建 torch_npu ==================== + - name: Install PyTorch wheel and build dependencies + run: | + source /usr/local/Ascend/cann/set_env.sh 2>/dev/null || true + source /usr/local/Ascend/nnal/atb/set_env.sh 2>/dev/null || true + + export PIP_CACHE_DIR=/root/.cache/pip + + echo "=== Installing built PyTorch wheel ===" + pip${{ inputs.python_version }} install pytorch-src/dist/*.whl + + echo "" + echo "=== Verifying PyTorch installation ===" + python${{ inputs.python_version }} -c "import torch; print(f'torch version: {torch.__version__}')" + + - name: Get PyTorch version + id: get_version + run: | + source /usr/local/Ascend/cann/set_env.sh 2>/dev/null || true + PYTORCH_VERSION=$(python${{ inputs.python_version }} -c "import torch; print(torch.__version__)") + echo "pytorch_version=${PYTORCH_VERSION}" >> $GITHUB_OUTPUT + echo "PyTorch version: ${PYTORCH_VERSION}" + + - name: Install torch_npu build dependencies + run: | + source /usr/local/Ascend/cann/set_env.sh 2>/dev/null || true + source /usr/local/Ascend/nnal/atb/set_env.sh 2>/dev/null || true + + export PIP_CACHE_DIR=/root/.cache/pip + pip${{ inputs.python_version }} install --upgrade pip setuptools wheel + pip${{ inputs.python_version }} install cmake ninja numpy packaging pyyaml requests six typing-extensions + + cd torch_npu-src + + # 显示 ccache 统计(依赖安装阶段) + echo "" + echo "=== ccache Statistics (before torch_npu build) ===" + CCACHE_DIR=/root/.cache/ccache ccache --show-stats + + - name: Build torch_npu wheel + run: | + source /usr/local/Ascend/cann/set_env.sh 2>/dev/null || true + source /usr/local/Ascend/nnal/atb/set_env.sh 2>/dev/null || true + + cd torch_npu-src + + export MAX_JOBS=128 + + # 确保使用 ccache + export CC=/usr/local/bin/gcc + export CXX=/usr/local/bin/g++ + export CCACHE_DIR=/root/.cache/ccache + + # 禁用 torchair 构建(上游 PyTorch main API 变化导致兼容性问题) + bash ci/build.sh --python=${{ inputs.python_version }} --disable_torchair + + echo "torch_npu wheel built:" + ls -la dist/ + + echo "" + echo "=== ccache Statistics (after torch_npu build) ===" + ccache --show-stats + + # ==================== 保存缓存 ==================== + - name: Save pip cache + if: always() + uses: actions/cache/save@v4 + with: + path: /root/.cache/pip + key: ${{ steps.pip_key.outputs.cache_key }} + + - name: Save ccache + if: always() + uses: actions/cache/save@v4 + with: + path: /root/.cache/ccache + key: ${{ steps.ccache_key.outputs.cache_key }} + + - name: Display cache save status + if: always() + run: | + echo "=== Cache Saved ===" + echo "pip cache key: ${{ steps.pip_key.outputs.cache_key }}" + PIP_CACHE_SIZE=$(du -sh /root/.cache/pip 2>/dev/null | cut -f1) + echo "pip cache size: ${PIP_CACHE_SIZE}" + echo "" + echo "ccache key: ${{ steps.ccache_key.outputs.cache_key }}" + CCACHE_SIZE=$(du -sh /root/.cache/ccache 2>/dev/null | cut -f1) + echo "ccache size: ${CCACHE_SIZE}" + + # ==================== 打包和上传 ==================== + - name: Package PyTorch source and build artifacts + run: | + # 打包整个 pytorch-src 目录(包含测试源码和编译产物) + # 排除不必要的文件以减小体积: + # - .git 目录(最占空间) + # - build/ 目录中的编译中间产物(CMakeFiles, .o 文件等) + # - dist/*.whl(已单独上传为 artifact) + + echo "=== PyTorch source directory size ===" + du -sh pytorch-src/ + + echo "" + echo "=== Build artifacts location ===" + ls -la pytorch-src/build/lib.*/torch/*.so 2>/dev/null | head -5 || echo "No .so files found in build/lib" + ls -la pytorch-src/torch/_C.so 2>/dev/null || echo "No _C.so in torch/" + + echo "" + echo "=== Creating archive (excluding large unnecessary files) ===" + tar -czf pytorch-src.tar.gz \ + --exclude='pytorch-src/.git' \ + --exclude='pytorch-src/build/CMakeFiles' \ + --exclude='pytorch-src/build/*.o' \ + --exclude='pytorch-src/build/**/*.o' \ + --exclude='pytorch-src/dist/*.whl' \ + pytorch-src + + echo "" + echo "=== Archive size ===" + ls -lh pytorch-src.tar.gz + + - name: Upload PyTorch wheel + uses: actions/upload-artifact@v4 + with: + name: torch-wheel-main + path: pytorch-src/dist/*.whl + retention-days: 7 + + - name: Upload torch_npu wheel + uses: actions/upload-artifact@v4 + with: + name: torch-npu-wheel-main + path: torch_npu-src/dist/*.whl + retention-days: 7 + + - name: Upload PyTorch source and build artifacts + uses: actions/upload-artifact@v4 + with: + name: pytorch-src-main + path: pytorch-src.tar.gz + retention-days: 7 \ No newline at end of file diff --git a/.github/workflows/_torch-npu-upstream-collect.yml b/.github/workflows/_torch-npu-upstream-collect.yml new file mode 100644 index 0000000000..bf57ab38c7 --- /dev/null +++ b/.github/workflows/_torch-npu-upstream-collect.yml @@ -0,0 +1,163 @@ +name: Torch NPU Upstream Collect + +on: + workflow_call: + inputs: + python_version: + required: true + type: string + description: Python version to use + torch_wheel_artifact: + required: true + type: string + description: Name of the torch wheel artifact from build + torch_npu_wheel_artifact: + required: true + type: string + description: Name of the torch_npu wheel artifact from build + pytorch_src_artifact: + required: true + type: string + description: Name of the pytorch source artifact from build + docker_image: + required: true + type: string + description: Docker image to use + distributed_shards: + required: false + type: string + default: '2' + description: Number of shards for distributed tests + regular_shards: + required: false + type: string + default: '5' + description: Number of shards for regular tests + outputs: + distributed_matrix: + description: Distributed shard matrix JSON + value: ${{ jobs.collect.outputs.distributed_matrix }} + regular_matrix: + description: Regular shard matrix JSON + value: ${{ jobs.collect.outputs.regular_matrix }} + distributed_shards: + description: Number of distributed shards + value: ${{ jobs.collect.outputs.distributed_shards }} + regular_shards: + description: Number of regular shards + value: ${{ jobs.collect.outputs.regular_shards }} + total_cases: + description: Total number of test cases + value: ${{ jobs.collect.outputs.total_cases }} + +jobs: + collect: + runs-on: linux-aarch64-a3-16 + timeout-minutes: 120 + container: + image: ${{ inputs.docker_image }} + options: --user root + outputs: + distributed_matrix: ${{ steps.collect_and_shard.outputs.distributed_matrix }} + regular_matrix: ${{ steps.collect_and_shard.outputs.regular_matrix }} + distributed_shards: ${{ steps.collect_and_shard.outputs.distributed_shards }} + regular_shards: ${{ steps.collect_and_shard.outputs.regular_shards }} + total_cases: ${{ steps.collect_and_shard.outputs.total_cases }} + + steps: + - name: Setup NPU test environment + uses: kerer-ai/pytorch/.github/actions/setup-npu-test-env@dev_master + with: + python_version: ${{ inputs.python_version }} + torch_wheel_artifact: ${{ inputs.torch_wheel_artifact }} + torch_npu_wheel_artifact: ${{ inputs.torch_npu_wheel_artifact }} + pytorch_src_artifact: ${{ inputs.pytorch_src_artifact }} + + - name: Collect all test cases and shard + id: collect_and_shard + run: | + source /usr/local/Ascend/cann/set_env.sh 2>/dev/null || true + source /usr/local/Ascend/nnal/atb/set_env.sh 2>/dev/null || true + + PYTHON=python${{ inputs.python_version }} + cd pytorch-src + + # 设置 BACKEND 环境变量,避免分布式测试收集阶段 KeyError + # 值设为 hccl(NPU 分布式后端),不是 gloo/nccl 时测试类不会被定义 + # 结果:pytest 收集到 0 个用例,不会报错 + export BACKEND="hccl" + + # Case-level sharding + DISTRIBUTED_SHARDS='${{ inputs.distributed_shards }}' + REGULAR_SHARDS='${{ inputs.regular_shards }}' + + echo "=== Collecting all test cases ===" + echo "Distributed shards: ${DISTRIBUTED_SHARDS}" + echo "Regular shards: ${REGULAR_SHARDS}" + + $PYTHON ../ascend_pytorch/.github/scripts/collect_all_cases.py \ + --test-dir test \ + --distributed-shards ${DISTRIBUTED_SHARDS} \ + --regular-shards ${REGULAR_SHARDS} \ + --output-dir cases_shards \ + --error-log-dir collection_errors \ + --parallel 16 \ + 2>&1 | tee /tmp/collect_cases.log + + # Verify output + echo "=== Generated shard files ===" + ls -la cases_shards/ + + echo "=== Collection summary ===" + cat cases_shards/cases_collection_summary.json + + # Extract total cases from summary + TOTAL_CASES=$(python3 -c "import json; d=json.load(open('cases_shards/cases_collection_summary.json')); print(d['total_cases'])") + + # Build shard matrices + DIST_SHARDS=$(seq 1 ${DISTRIBUTED_SHARDS} | tr '\n' ',' | sed 's/,$//') + REG_SHARDS=$(seq 1 ${REGULAR_SHARDS} | tr '\n' ',' | sed 's/,$//') + + echo "distributed_matrix=[${DIST_SHARDS}]" >> $GITHUB_OUTPUT + echo "distributed_shards=${DISTRIBUTED_SHARDS}" >> $GITHUB_OUTPUT + echo "regular_matrix=[${REG_SHARDS}]" >> $GITHUB_OUTPUT + echo "regular_shards=${REGULAR_SHARDS}" >> $GITHUB_OUTPUT + echo "total_cases=${TOTAL_CASES}" >> $GITHUB_OUTPUT + + echo "=== Shard configuration ===" + echo "Distributed tests: ${DISTRIBUTED_SHARDS} shards (case-level, serial execution, linux-aarch64-a3-16)" + echo "Regular tests: ${REGULAR_SHARDS} shards (case-level, 64 workers, linux-aarch64-a3-16)" + echo "Total cases: ${TOTAL_CASES}" + + # Package error logs if any + if [ -d "collection_errors" ] && [ "$(ls -A collection_errors 2>/dev/null)" ]; then + echo "=== Packaging collection error logs ===" + tar -czf collection_errors.tar.gz collection_errors/ + echo "Error logs packaged: collection_errors.tar.gz" + ls -la collection_errors.tar.gz + fi + + - name: Upload cases shard JSONs + uses: actions/upload-artifact@v4 + with: + name: cases-shards + path: pytorch-src/cases_shards/ + retention-days: 7 + + - name: Upload collection error logs + if: always() + uses: actions/upload-artifact@v4 + with: + name: collection-error-logs + path: pytorch-src/collection_errors.tar.gz + if-no-files-found: ignore + retention-days: 30 + + - name: Upload collect logs + if: always() + uses: actions/upload-artifact@v4 + with: + name: collect-cases-logs + path: /tmp/collect_cases.log + if-no-files-found: warn + retention-days: 30 \ No newline at end of file diff --git a/.github/workflows/_torch-npu-upstream-report.yml b/.github/workflows/_torch-npu-upstream-report.yml new file mode 100644 index 0000000000..d75248e5dc --- /dev/null +++ b/.github/workflows/_torch-npu-upstream-report.yml @@ -0,0 +1,111 @@ +name: Torch NPU Upstream Report + +on: + workflow_call: + inputs: + python_version: + required: true + type: string + description: Python version to use + pytorch_version: + required: true + type: string + description: PyTorch version string + torch_npu_wheel_name: + required: false + type: string + default: 'source-build.whl' + description: Name of the torch_npu wheel file + docker_image: + required: true + type: string + description: Docker image used for tests + distributed_matrix: + required: false + type: string + default: '[]' + description: Distributed shard matrix JSON + regular_matrix: + required: false + type: string + default: '[]' + description: Regular shard matrix JSON + +jobs: + generate_report: + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + repository: ${{ github.repository }} + ref: ${{ github.ref }} + fetch-depth: 1 + + - name: Setup Python ${{ inputs.python_version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ inputs.python_version }} + + - name: Download distributed shard reports + uses: actions/download-artifact@v4 + with: + pattern: test-reports-dist-* + path: all-test-reports + merge-multiple: true + + - name: Download regular shard reports + uses: actions/download-artifact@v4 + with: + pattern: test-reports-reg-* + path: all-test-reports + merge-multiple: true + + - name: Download custom test reports + uses: actions/download-artifact@v4 + with: + name: test-reports-custom + path: all-test-reports + merge-multiple: true + continue-on-error: true + + - name: Download cases collection summary + uses: actions/download-artifact@v4 + with: + name: cases-shards + path: cases-shards + continue-on-error: true + + - name: Generate consolidated summary + run: | + PYTHON=python + REPORT_MD=npu-full-test-summary.md + REPORT_JSON=npu-full-test-summary.json + + # Combine both shard matrices for reporting + DIST_MATRIX='${{ inputs.distributed_matrix }}' + REG_MATRIX='${{ inputs.regular_matrix }}' + COMBINED_MATRIX=$(python3 -c "import sys,json; dist=json.loads('${DIST_MATRIX}'); reg=json.loads('${REG_MATRIX}'); print(json.dumps(['dist-'+str(s) for s in dist]+['reg-'+str(s) for s in reg]))") + + $PYTHON .github/scripts/generate_npu_full_test_report.py \ + --reports-root all-test-reports \ + --output-markdown ${REPORT_MD} \ + --output-json ${REPORT_JSON} \ + --pytorch-version "${{ inputs.pytorch_version }}" \ + --torch-npu-whl "${{ inputs.torch_npu_wheel_name }}" \ + --shard-matrix-json "${COMBINED_MATRIX}" \ + --docker-image "${{ inputs.docker_image }}" \ + --runner "linux-aarch64-a3-16 (distributed, serial), linux-aarch64-a3-16 (regular, 64 workers), linux-aarch64-a3-8 (custom)" \ + --cases-summary cases-shards/cases_collection_summary.json + + cat ${REPORT_MD} >> $GITHUB_STEP_SUMMARY + + - name: Upload consolidated summary + if: always() + uses: actions/upload-artifact@v4 + with: + name: npu-full-test-summary + path: | + npu-full-test-summary.md + npu-full-test-summary.json + retention-days: 30 \ No newline at end of file diff --git a/.github/workflows/_torch-npu-upstream-test-custom.yml b/.github/workflows/_torch-npu-upstream-test-custom.yml new file mode 100644 index 0000000000..027eea0fb9 --- /dev/null +++ b/.github/workflows/_torch-npu-upstream-test-custom.yml @@ -0,0 +1,114 @@ +name: Torch NPU Upstream Test Custom + +on: + workflow_call: + inputs: + python_version: + required: true + type: string + description: Python version to use + torch_wheel_artifact: + required: true + type: string + description: Name of the torch wheel artifact from build + torch_npu_wheel_artifact: + required: true + type: string + description: Name of the torch_npu wheel artifact from build + pytorch_src_artifact: + required: true + type: string + description: Name of the pytorch source artifact from build + docker_image: + required: true + type: string + description: Docker image to use + test_files: + required: true + type: string + description: Test files to run (comma-separated) + +jobs: + run_tests: + name: test_custom + runs-on: linux-aarch64-a3-16 + timeout-minutes: 1200 + container: + image: ${{ inputs.docker_image }} + options: --user root + + steps: + - name: Setup NPU test environment + uses: kerer-ai/pytorch/.github/actions/setup-npu-test-env@dev_master + with: + python_version: ${{ inputs.python_version }} + torch_wheel_artifact: ${{ inputs.torch_wheel_artifact }} + torch_npu_wheel_artifact: ${{ inputs.torch_npu_wheel_artifact }} + pytorch_src_artifact: ${{ inputs.pytorch_src_artifact }} + + - name: Run custom test files + id: run_tests + run: | + source /usr/local/Ascend/cann/set_env.sh 2>/dev/null || true + source /usr/local/Ascend/nnal/atb/set_env.sh 2>/dev/null || true + + REPORT_DIR=test-reports + mkdir -p ${REPORT_DIR} + + # Custom test files: per-case isolation execution + python${{ inputs.python_version }} ascend_pytorch/.github/scripts/run_npu_test_shard.py \ + --test-files "${{ inputs.test_files }}" \ + --test-dir pytorch-src/test \ + --report-dir ${REPORT_DIR} \ + --timeout 600 \ + --verbose \ + 2>&1 | tee /tmp/test_custom.log + + TEST_STATUS=${PIPESTATUS[0]} + echo "status=${TEST_STATUS}" >> $GITHUB_OUTPUT + # Don't exit with test status - let step succeed to allow report generation + + - name: Package and upload test reports + if: always() + run: | + # Package junit XMLs into compressed archive + if [ -d "test-reports/junit_xmls" ]; then + echo "=== Compressing junit XMLs ===" + XML_COUNT=$(find test-reports/junit_xmls -type f -name "*.xml" | wc -l) + echo "Found ${XML_COUNT} XML files" + tar -czf test-reports/junit_xmls.tar.gz -C test-reports junit_xmls + rm -rf test-reports/junit_xmls + echo "JUnit XMLs compressed" + fi + + # Package failed cases logs into compressed archive + if [ -d "test-reports/failed_cases_logs" ]; then + echo "=== Compressing failed cases logs ===" + tar -czf test-reports/failed_cases_logs.tar.gz -C test-reports failed_cases_logs + rm -rf test-reports/failed_cases_logs + echo "Failed cases logs compressed" + fi + + - name: Upload test reports + if: always() + uses: actions/upload-artifact@v4 + with: + name: test-reports-custom + path: test-reports/ + retention-days: 30 + + - name: Compress and upload error logs + if: failure() + run: | + mkdir -p error-logs + cp /tmp/test_custom.log error-logs/ 2>/dev/null || true + tar -czf error-logs-custom.tar.gz error-logs/ + echo "Error logs compressed" + + - name: Upload error logs + if: failure() + uses: actions/upload-artifact@v4 + with: + name: error-logs-custom + path: error-logs-custom.tar.gz + retention-days: 30 \ No newline at end of file diff --git a/.github/workflows/_torch-npu-upstream-test-dist.yml b/.github/workflows/_torch-npu-upstream-test-dist.yml new file mode 100644 index 0000000000..77eb07d0ec --- /dev/null +++ b/.github/workflows/_torch-npu-upstream-test-dist.yml @@ -0,0 +1,132 @@ +name: Torch NPU Upstream Test Distributed + +on: + workflow_call: + inputs: + python_version: + required: true + type: string + description: Python version to use + torch_wheel_artifact: + required: true + type: string + description: Name of the torch wheel artifact from build + torch_npu_wheel_artifact: + required: true + type: string + description: Name of the torch_npu wheel artifact from build + pytorch_src_artifact: + required: true + type: string + description: Name of the pytorch source artifact from build + docker_image: + required: true + type: string + description: Docker image to use + distributed_matrix: + required: true + type: string + description: Distributed shard matrix JSON + distributed_shards: + required: true + type: string + description: Number of distributed shards + +jobs: + run_tests: + name: test_distributed (${{ matrix.shard }}/${{ inputs.distributed_shards }}) + runs-on: linux-aarch64-a3-16 + timeout-minutes: 1200 + container: + image: ${{ inputs.docker_image }} + options: --user root + strategy: + matrix: + shard: ${{ fromJson(inputs.distributed_matrix) }} + fail-fast: false + max-parallel: 2 + + steps: + - name: Setup NPU test environment + uses: kerer-ai/pytorch/.github/actions/setup-npu-test-env@dev_master + with: + python_version: ${{ inputs.python_version }} + torch_wheel_artifact: ${{ inputs.torch_wheel_artifact }} + torch_npu_wheel_artifact: ${{ inputs.torch_npu_wheel_artifact }} + pytorch_src_artifact: ${{ inputs.pytorch_src_artifact }} + + - name: Download cases shard JSONs + uses: actions/download-artifact@v4 + with: + name: cases-shards + path: cases-shards + + - name: Run distributed shard ${{ matrix.shard }}/${{ inputs.distributed_shards }} + id: run_test + run: | + source /usr/local/Ascend/cann/set_env.sh 2>/dev/null || true + source /usr/local/Ascend/nnal/atb/set_env.sh 2>/dev/null || true + + PYTHON=python${{ inputs.python_version }} + REPORT_DIR=test-reports + CASES_JSON="cases-shards/distributed_cases_shard_${{ matrix.shard }}.json" + + mkdir -p ${REPORT_DIR} + + # Get case count from JSON + TOTAL_CASES=$(python3 -c "import json; d=json.load(open('${CASES_JSON}')); print(d['total_cases'])") + + echo "=== Distributed Shard ${{ matrix.shard }} (Case-level) ===" + echo "Total cases: ${TOTAL_CASES}" + echo "Runner: linux-aarch64-a3-16 (16-card NPU)" + echo "Execution mode: SERIAL" + + # Distributed tests: pre-collected cases, serial execution + set +e + $PYTHON ascend_pytorch/.github/scripts/run_npu_test_shard.py \ + --cases-json "${CASES_JSON}" \ + --test-dir pytorch-src/test \ + --report-dir ${REPORT_DIR} \ + --quick-test 100 \ + --timeout 600 \ + --verbose \ + 2>&1 | tee /tmp/test_shard_dist_${{ matrix.shard }}.log + + TEST_STATUS=${PIPESTATUS[0]} + set -e + echo "status=${TEST_STATUS}" >> $GITHUB_OUTPUT + # Don't exit with test status - let step succeed to allow report generation + + - name: Package and upload test reports + if: always() + run: | + # Package junit XMLs into compressed archive + if [ -d "test-reports/junit_xmls" ]; then + echo "=== Compressing junit XMLs ===" + XML_COUNT=$(find test-reports/junit_xmls -type f -name "*.xml" | wc -l) + echo "Found ${XML_COUNT} XML files" + tar -czf test-reports/junit_xmls.tar.gz -C test-reports junit_xmls + rm -rf test-reports/junit_xmls + echo "JUnit XMLs compressed: $(ls -lh test-reports/junit_xmls.tar.gz)" + fi + + # Package failed cases logs into compressed archive + if [ -d "test-reports/failed_cases_logs" ]; then + echo "=== Compressing failed cases logs ===" + tar -czf test-reports/failed_cases_logs.tar.gz -C test-reports failed_cases_logs + rm -rf test-reports/failed_cases_logs + echo "Failed cases logs compressed: $(ls -lh test-reports/failed_cases_logs.tar.gz)" + fi + + # Package shard_cases.json + if [ -f "test-reports/shard_dist-${{ matrix.shard }}_cases.json" ]; then + echo "Cases JSON exists: $(ls -lh test-reports/shard_dist-${{ matrix.shard }}_cases.json)" + fi + + - name: Upload test reports + if: always() + uses: actions/upload-artifact@v4 + with: + name: test-reports-dist-${{ matrix.shard }} + path: test-reports/ + retention-days: 30 \ No newline at end of file diff --git a/.github/workflows/_torch-npu-upstream-test-regular.yml b/.github/workflows/_torch-npu-upstream-test-regular.yml new file mode 100644 index 0000000000..c9c61128c4 --- /dev/null +++ b/.github/workflows/_torch-npu-upstream-test-regular.yml @@ -0,0 +1,135 @@ +name: Torch NPU Upstream Test Regular + +on: + workflow_call: + inputs: + python_version: + required: true + type: string + description: Python version to use + torch_wheel_artifact: + required: true + type: string + description: Name of the torch wheel artifact from build + torch_npu_wheel_artifact: + required: true + type: string + description: Name of the torch_npu wheel artifact from build + pytorch_src_artifact: + required: true + type: string + description: Name of the pytorch source artifact from build + docker_image: + required: true + type: string + description: Docker image to use + regular_matrix: + required: true + type: string + description: Regular shard matrix JSON + regular_shards: + required: true + type: string + description: Number of regular shards + +jobs: + run_tests: + name: test_regular (${{ matrix.shard }}/${{ inputs.regular_shards }}) + runs-on: linux-aarch64-a3-16 + timeout-minutes: 1200 + container: + image: ${{ inputs.docker_image }} + options: --user root + strategy: + matrix: + shard: ${{ fromJson(inputs.regular_matrix) }} + fail-fast: false + max-parallel: 5 + + steps: + - name: Setup NPU test environment + uses: kerer-ai/pytorch/.github/actions/setup-npu-test-env@dev_master + with: + python_version: ${{ inputs.python_version }} + torch_wheel_artifact: ${{ inputs.torch_wheel_artifact }} + torch_npu_wheel_artifact: ${{ inputs.torch_npu_wheel_artifact }} + pytorch_src_artifact: ${{ inputs.pytorch_src_artifact }} + + - name: Download cases shard JSONs + uses: actions/download-artifact@v4 + with: + name: cases-shards + path: cases-shards + + - name: Run regular shard ${{ matrix.shard }}/${{ inputs.regular_shards }} + id: run_test + run: | + source /usr/local/Ascend/cann/set_env.sh 2>/dev/null || true + source /usr/local/Ascend/nnal/atb/set_env.sh 2>/dev/null || true + + PYTHON=python${{ inputs.python_version }} + REPORT_DIR=test-reports + CASES_JSON="cases-shards/regular_cases_shard_${{ matrix.shard }}.json" + + mkdir -p ${REPORT_DIR} + + # Get case count from JSON + TOTAL_CASES=$(python3 -c "import json; d=json.load(open('${CASES_JSON}')); print(d['total_cases'])") + + echo "=== Regular Shard ${{ matrix.shard }} (Case-level) ===" + echo "Total cases: ${TOTAL_CASES}" + echo "Runner: linux-aarch64-a3-16 (16-card NPU)" + echo "Execution mode: CONCURRENT (64 workers)" + + # Regular tests: pre-collected cases, 64 concurrent workers + set +e + $PYTHON ascend_pytorch/.github/scripts/run_npu_test_shard.py \ + --cases-json "${CASES_JSON}" \ + --test-dir pytorch-src/test \ + --report-dir ${REPORT_DIR} \ + --timeout 600 \ + --quick-test 100 \ + --max-workers 16 \ + --verbose \ + 2>&1 | tee /tmp/test_shard_reg_${{ matrix.shard }}.log + + TEST_STATUS=${PIPESTATUS[0]} + set -e + echo "status=${TEST_STATUS}" >> $GITHUB_OUTPUT + # Don't exit with test status - let step succeed to allow report generation + + - name: Package and upload test reports + if: always() + run: | + # Package junit XMLs into compressed archive + if [ -d "test-reports/junit_xmls" ]; then + echo "=== Compressing junit XMLs ===" + XML_COUNT=$(find test-reports/junit_xmls -type f -name "*.xml" | wc -l) + echo "Found ${XML_COUNT} XML files" + tar -czf test-reports/junit_xmls.tar.gz -C test-reports junit_xmls + rm -rf test-reports/junit_xmls + echo "JUnit XMLs compressed: $(ls -lh test-reports/junit_xmls.tar.gz)" + fi + + # Package failed cases logs into compressed archive + if [ -d "test-reports/failed_cases_logs" ]; then + echo "=== Compressing failed cases logs ===" + FAILED_LOGS_COUNT=$(find test-reports/failed_cases_logs -type f | wc -l) + echo "Found ${FAILED_LOGS_COUNT} failed case log files" + tar -czf test-reports/failed_cases_logs.tar.gz -C test-reports failed_cases_logs + rm -rf test-reports/failed_cases_logs + echo "Failed cases logs compressed: $(ls -lh test-reports/failed_cases_logs.tar.gz)" + fi + + # Package shard_cases.json + if [ -f "test-reports/shard_reg-${{ matrix.shard }}_cases.json" ]; then + echo "Cases JSON exists: $(ls -lh test-reports/shard_reg-${{ matrix.shard }}_cases.json)" + fi + + - name: Upload test reports + if: always() + uses: actions/upload-artifact@v4 + with: + name: test-reports-reg-${{ matrix.shard }} + path: test-reports/ + retention-days: 30 \ No newline at end of file diff --git a/.github/workflows/_torch-npu-upstream-test.yml b/.github/workflows/_torch-npu-upstream-test.yml new file mode 100644 index 0000000000..3b7ac687ce --- /dev/null +++ b/.github/workflows/_torch-npu-upstream-test.yml @@ -0,0 +1,144 @@ +name: Torch NPU Upstream Test + +on: + workflow_call: + inputs: + python_version: + required: true + type: string + description: Python version to use + pytorch_ref: + required: false + type: string + default: 'fccc94ae83f61fe26559abc999797297196bac29' + description: PyTorch branch, tag, or commit SHA to build + torch_npu_ref: + required: false + type: string + default: 'master' + description: torch_npu branch, tag, or commit SHA to build + docker_image: + required: false + type: string + default: 'quay.io/kerer/pytorch:manylinux-cann9.0.0-beta.2-20260428' + description: Docker image to use for all jobs + distributed_shards: + required: false + type: string + default: '2' + description: Number of shards for distributed tests + regular_shards: + required: false + type: string + default: '5' + description: Number of shards for regular tests + test_files: + required: false + type: string + default: '' + description: Test files to run directly (comma-separated) + +defaults: + run: + shell: bash + +jobs: + # ============================================================================ + # 1. Build PyTorch and torch_npu Wheels + # ============================================================================ + build: + uses: ./.github/workflows/_torch-npu-upstream-build.yml + with: + pytorch_ref: ${{ inputs.pytorch_ref }} + torch_npu_ref: ${{ inputs.torch_npu_ref }} + python_version: ${{ inputs.python_version }} + docker_image: ${{ inputs.docker_image }} + + # ============================================================================ + # 2. Collect Test Cases (only when test_files is empty) + # ============================================================================ + collect_cases: + needs: + - build + if: ${{ inputs.test_files == '' }} + uses: ./.github/workflows/_torch-npu-upstream-collect.yml + with: + python_version: ${{ inputs.python_version }} + torch_wheel_artifact: ${{ needs.build.outputs.torch-wheel }} + torch_npu_wheel_artifact: ${{ needs.build.outputs.torch-npu-wheel }} + pytorch_src_artifact: ${{ needs.build.outputs.pytorch-src }} + docker_image: ${{ inputs.docker_image }} + distributed_shards: ${{ inputs.distributed_shards }} + regular_shards: ${{ inputs.regular_shards }} + + # ============================================================================ + # 3. Run Distributed Tests (only when test_files is empty) + # ============================================================================ + test_distributed: + needs: + - build + - collect_cases + if: ${{ inputs.test_files == '' }} + uses: ./.github/workflows/_torch-npu-upstream-test-dist.yml + with: + python_version: ${{ inputs.python_version }} + torch_wheel_artifact: ${{ needs.build.outputs.torch-wheel }} + torch_npu_wheel_artifact: ${{ needs.build.outputs.torch-npu-wheel }} + pytorch_src_artifact: ${{ needs.build.outputs.pytorch-src }} + docker_image: ${{ inputs.docker_image }} + distributed_matrix: ${{ needs.collect_cases.outputs.distributed_matrix }} + distributed_shards: ${{ needs.collect_cases.outputs.distributed_shards }} + + # ============================================================================ + # 4. Run Regular Tests (only when test_files is empty) + # ============================================================================ + test_regular: + needs: + - build + - collect_cases + if: ${{ inputs.test_files == '' }} + uses: ./.github/workflows/_torch-npu-upstream-test-regular.yml + with: + python_version: ${{ inputs.python_version }} + torch_wheel_artifact: ${{ needs.build.outputs.torch-wheel }} + torch_npu_wheel_artifact: ${{ needs.build.outputs.torch-npu-wheel }} + pytorch_src_artifact: ${{ needs.build.outputs.pytorch-src }} + docker_image: ${{ inputs.docker_image }} + regular_matrix: ${{ needs.collect_cases.outputs.regular_matrix }} + regular_shards: ${{ needs.collect_cases.outputs.regular_shards }} + + # ============================================================================ + # 5. Run Custom Tests (only when test_files is provided) + # ============================================================================ + test_custom: + needs: + - build + if: ${{ inputs.test_files != '' }} + uses: ./.github/workflows/_torch-npu-upstream-test-custom.yml + with: + python_version: ${{ inputs.python_version }} + torch_wheel_artifact: ${{ needs.build.outputs.torch-wheel }} + torch_npu_wheel_artifact: ${{ needs.build.outputs.torch-npu-wheel }} + pytorch_src_artifact: ${{ needs.build.outputs.pytorch-src }} + docker_image: ${{ inputs.docker_image }} + test_files: ${{ inputs.test_files }} + + # ============================================================================ + # 6. Generate Test Report + # ============================================================================ + report: + needs: + - build + - collect_cases + - test_distributed + - test_regular + - test_custom + if: always() && needs.build.result == 'success' + uses: ./.github/workflows/_torch-npu-upstream-report.yml + with: + python_version: ${{ inputs.python_version }} + pytorch_version: ${{ needs.build.outputs.pytorch-version }} + torch_npu_wheel_name: ${{ needs.build.outputs.torch-npu-wheel }} + docker_image: ${{ inputs.docker_image }} + distributed_matrix: ${{ needs.collect_cases.outputs.distributed_matrix || '[]' }} + regular_matrix: ${{ needs.collect_cases.outputs.regular_matrix || '[]' }} \ No newline at end of file diff --git a/.github/workflows/build-docker-image.yml b/.github/workflows/build-docker-image.yml new file mode 100644 index 0000000000..7df92cd072 --- /dev/null +++ b/.github/workflows/build-docker-image.yml @@ -0,0 +1,169 @@ +name: Build Docker Image + +on: + push: + branches: [dev_master] + paths: + - '.github/docker/pytorch-npu-builder.Dockerfile' + - '.github/workflows/build-docker-image.yml' + - '.github/scripts/build_image.sh' + schedule: + - cron: '0 2 * * 0' # UTC 02:00, Beijing 10:00, every Sunday + workflow_dispatch: + inputs: + cann_version: + description: 'CANN version (e.g., 9.0, 9.0.0-beta.2, 8.0)' + required: true + default: '9.0' + type: string + push_image: + description: 'Push image to registry' + required: true + default: true + type: boolean + force_build: + description: 'Force rebuild even if image exists' + required: false + default: false + type: boolean + +env: + REGISTRY: quay.io + QUAY_ORG: kerer + IMAGE_NAME: pytorch + CANN_STABLE: '9.0' + +jobs: + build: + runs-on: ubuntu-22.04-arm + environment: QUAY_USERNAME + permissions: + contents: read + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Make script executable + run: chmod +x .github/scripts/build_image.sh + + - name: Determine build parameters + id: params + run: | + # 确定是否推送镜像 + # 规则:手动触发根据 inputs.push_image 决定 + # push/schedule 触发默认推送(除非是 PR) + if [[ "${{ github.event_name }}" == "workflow_dispatch" ]]; then + PUSH_IMAGE="${{ inputs.push_image }}" + elif [[ "${{ github.event_name }}" == "push" || "${{ github.event_name }}" == "schedule" ]]; then + PUSH_IMAGE="true" + else + PUSH_IMAGE="false" + fi + + echo "push_image=${PUSH_IMAGE}" >> $GITHUB_OUTPUT + + # 确定是否强制构建 + FORCE_BUILD="${{ inputs.force_build || 'false' }}" + echo "force_build=${FORCE_BUILD}" >> $GITHUB_OUTPUT + + # 确定 CANN 版本 + CANN_VERSION="${{ inputs.cann_version || env.CANN_STABLE }}" + echo "cann_version=${CANN_VERSION}" >> $GITHUB_OUTPUT + + echo "Build parameters:" + echo " Event: ${{ github.event_name }}" + echo " CANN version: ${CANN_VERSION}" + echo " Push image: ${PUSH_IMAGE}" + echo " Force build: ${FORCE_BUILD}" + + - name: Setup Docker Buildx + uses: docker/setup-buildx-action@v3 + with: + driver: docker-container + driver-opts: image=moby/buildkit:latest + + - name: Login to Quay.io + if: ${{ steps.params.outputs.push_image == 'true' }} + uses: docker/login-action@v3 + with: + registry: ${{ env.REGISTRY }} + username: ${{ secrets.QUAY_USERNAME }} + password: ${{ secrets.QUAY_PASSWORD }} + + - name: Build and push image + env: + QUAY_USERNAME: ${{ secrets.QUAY_USERNAME }} + QUAY_PASSWORD: ${{ secrets.QUAY_PASSWORD }} + SKIP_DOCKER_LOGIN: true # 已通过 login-action 登录,跳过脚本中的登录 + run: | + # 构建参数 + ARGS="" + if [[ "${{ steps.params.outputs.push_image }}" == "true" ]]; then + ARGS+=" --push" + fi + if [[ "${{ steps.params.outputs.force_build }}" == "true" ]]; then + ARGS+=" --force" + fi + + .github/scripts/build_image.sh \ + --cann-version "${{ steps.params.outputs.cann_version }}" \ + --registry "${{ env.REGISTRY }}" \ + --quay-org "${{ env.QUAY_ORG }}" \ + --image-name "${{ env.IMAGE_NAME }}" \ + ${ARGS} \ + --verbose + + - name: Summary + if: always() + run: | + echo "## Docker Image Build Summary" >> $GITHUB_STEP_SUMMARY + echo "" >> $GITHUB_STEP_SUMMARY + echo "### Build Details" >> $GITHUB_STEP_SUMMARY + echo "" >> $GITHUB_STEP_SUMMARY + echo "| Property | Value |" >> $GITHUB_STEP_SUMMARY + echo "|----------|-------|" >> $GITHUB_STEP_SUMMARY + echo "| **Trigger** | ${{ github.event_name }} |" >> $GITHUB_STEP_SUMMARY + echo "| **CANN Version** | ${{ steps.params.outputs.cann_version }} |" >> $GITHUB_STEP_SUMMARY + echo "| **Python Versions** | 3.10, 3.11, 3.12, 3.13 (pre-installed) |" >> $GITHUB_STEP_SUMMARY + echo "| **Registry** | ${{ env.REGISTRY }}/${{ env.QUAY_ORG }}/${{ env.IMAGE_NAME }} |" >> $GITHUB_STEP_SUMMARY + echo "| **Push Enabled** | ${{ steps.params.outputs.push_image }} |" >> $GITHUB_STEP_SUMMARY + echo "" >> $GITHUB_STEP_SUMMARY + + # 提取 CANN 大版本号(正确处理简化版本) + CANN_INPUT="${{ steps.params.outputs.cann_version }}" + + # 如果是简化版本(如 9.0),直接使用 + # 如果是完整版本(如 9.0.0-beta.2),提取前两位数字 + if [[ "$CANN_INPUT" =~ ^[0-9]+\.[0-9]+$ ]]; then + CANN_MAJOR="$CANN_INPUT" + else + CANN_MAJOR=$(echo "$CANN_INPUT" | grep -oP '^[0-9]+\.[0-9]+') + fi + + echo "### Image Tags" >> $GITHUB_STEP_SUMMARY + echo "\`\`\`" >> $GITHUB_STEP_SUMMARY + echo "quay.io/kerer/pytorch:cann${{ steps.params.outputs.cann_version }}" >> $GITHUB_STEP_SUMMARY + echo "quay.io/kerer/pytorch:cann${CANN_MAJOR}" >> $GITHUB_STEP_SUMMARY + if [[ "${CANN_MAJOR}" == "${{ env.CANN_STABLE }}" ]]; then + echo "quay.io/kerer/pytorch:latest" >> $GITHUB_STEP_SUMMARY + echo "quay.io/kerer/pytorch:cann-latest" >> $GITHUB_STEP_SUMMARY + fi + echo "\`\`\`" >> $GITHUB_STEP_SUMMARY + + echo "" >> $GITHUB_STEP_SUMMARY + echo "### Python Version Switch" >> $GITHUB_STEP_SUMMARY + echo "\`\`\`bash" >> $GITHUB_STEP_SUMMARY + echo "# Inside container:" >> $GITHUB_STEP_SUMMARY + echo "source /usr/local/bin/switch_python.sh 3.11" >> $GITHUB_STEP_SUMMARY + echo "source /usr/local/bin/switch_python.sh 3.12" >> $GITHUB_STEP_SUMMARY + echo "\`\`\`" >> $GITHUB_STEP_SUMMARY + + echo "" >> $GITHUB_STEP_SUMMARY + echo "### Build time: $(date -u +'%Y-%m-%d %H:%M:%S UTC')" >> $GITHUB_STEP_SUMMARY + + - name: Cleanup on failure + if: failure() + run: | + echo "::error::Build failed for CANN version ${{ steps.params.outputs.cann_version }}" + echo "Check the build logs for details" \ No newline at end of file diff --git a/.github/workflows/torch-npu-upstream-test-trigger.yml b/.github/workflows/torch-npu-upstream-test-trigger.yml new file mode 100644 index 0000000000..203864cb3e --- /dev/null +++ b/.github/workflows/torch-npu-upstream-test-trigger.yml @@ -0,0 +1,70 @@ +name: Torch NPU Upstream Main Test Trigger + +on: + push: + branches: + - main + - master + - 'release/**' + paths: + - '.github/workflows/torch-npu-upstream-test-trigger.yml' + - '.github/workflows/_torch-npu-upstream*.yml' + - '.github/scripts/*.py' + - 'test_upstream/**' + pull_request: + paths: + - '.github/workflows/torch-npu-upstream-test-trigger.yml' + - '.github/workflows/_torch-npu-upstream*.yml' + - '.github/scripts/*.py' + - 'test_upstream/**' + schedule: + - cron: '0 22 * * 1' # UTC 22:00 (Beijing time next day 06:00), every Monday + workflow_dispatch: + inputs: + python_version: + description: 'Python version (default 3.11)' + required: false + default: '3.11' + type: string + pytorch_ref: + description: 'PyTorch branch, tag, or commit SHA to build (default fccc94ae83f61fe26559abc999797297196bac29)' + required: false + default: 'fccc94ae83f61fe26559abc999797297196bac29' + type: string + torch_npu_ref: + description: 'torch_npu branch, tag, or commit SHA to build (default master)' + required: false + default: 'master' + type: string + docker_image: + description: 'Docker image to use for all jobs (default quay.io/kerer/pytorch:manylinux-cann9.0.0-beta.2-20260428)' + required: false + default: 'quay.io/kerer/pytorch:manylinux-cann9.0.0-beta.2-20260428' + type: string + distributed_shards: + description: 'Number of shards for distributed tests (default 3)' + required: false + default: '3' + type: string + regular_shards: + description: 'Number of shards for regular tests (default 5)' + required: false + default: '5' + type: string + test_files: + description: 'Test files to run directly (comma-separated, e.g., "test_meta.py,test_nn.py"). Skip shard assignment if set.' + required: false + default: '' + type: string + +jobs: + trigger_test: + uses: ./.github/workflows/_torch-npu-upstream-test.yml + with: + python_version: ${{ github.event.inputs.python_version || '3.11' }} + pytorch_ref: ${{ github.event.inputs.pytorch_ref || 'fccc94ae83f61fe26559abc999797297196bac29' }} + torch_npu_ref: ${{ github.event.inputs.torch_npu_ref || 'master' }} + docker_image: ${{ github.event.inputs.docker_image || 'quay.io/kerer/pytorch:manylinux-cann9.0.0-beta.2-20260428' }} + distributed_shards: ${{ github.event.inputs.distributed_shards || '3' }} + regular_shards: ${{ github.event.inputs.regular_shards || '5' }} + test_files: ${{ github.event.inputs.test_files || '' }} \ No newline at end of file diff --git a/test_upstream/case_paths_ci.yml b/test_upstream/case_paths_ci.yml new file mode 100644 index 0000000000..f7e50e31d6 --- /dev/null +++ b/test_upstream/case_paths_ci.yml @@ -0,0 +1,161 @@ +# Test file blacklist configuration for NPU CI +# +# This file defines test files/directories to exclude from collection. +# Rules support: +# - Exact path match: "distributed/test_c10d_nccl" +# - Directory prefix match: "dynamo/cpython/3_13" (matches all files in that directory) +# - Glob patterns: "distributed/*nccl*" +# +# Rule paths are relative to the test directory (e.g., "test/distributed/...") +# The "test/" prefix is automatically added if missing. + +blacklist: + # ============================================================================== + # Python version specific tests (Python 3.13 syntax, incompatible with 3.11) + # ============================================================================== + - dynamo/cpython/3_13 + + # ============================================================================== + # Platform-specific tests (CUDA/XPU/MPS - not supported on NPU) + # ============================================================================== + # CUDA specific tests + - test_cuda_multigpu + - test_cuda_nvml_based_avail + - test_cuda_primary_ctx + - test_cuda_sanitizer + - test_cuda_trace + - test_jiterator + - test_varlen_attention + + # CUDA inductor tests + - inductor/test_cuda_repro + - inductor/test_cudagraph_trees + - inductor/test_cudagraph_trees_expandable_segments + - inductor/test_gpu_select_algorithm + - inductor/test_triton_cpu_backend + - inductor/test_triton_heuristics + - inductor/test_pallas + - inductor/test_layout_optim + - inductor/test_op_dtype_prop + - inductor/test_autoheuristic + - inductor/test_b2b_gemm + - inductor/test_best_config + - inductor/test_coordinate_descent_tuner + + # JIT CUDA tests + - jit/test_cuda + + # Flash Attention tests (CUDA only) + - nn/attention/test_fa3 + - nn/attention/test_fa4 + + # XPU (Intel) tests + - test_xpu + - test_xpu_expandable_segments + - xpu/test_conv + - xpu/test_fusion + - xpu/test_gemm + + # MPS (Apple Metal) tests + - test_mps + + # ============================================================================== + # ONNX tests with heavy/optional dependencies + # ============================================================================== + # HuggingFace transformers tests + # Issue: Test requires transformers package (HuggingFace) with no graceful import handling + # Problem: Test file directly imports `import transformers` without try-import or skipif + # - transformers package: ~10 MB + dependencies (tokenizers, safetensors, huggingface-hub) + # - Test downloads model: hf-internal-testing/tiny-random-gptj (small, but needs network) + # - PyTorch CI excludes all ONNX tests by default (only runs with --onnx flag) + # Impact: ModuleNotFoundError: No module named 'transformers' + # Note: PyTorch run_test.py: options.exclude.extend(onnx_tests) in default behavior + - onnx/exporter/test_hf_models_e2e + + # ============================================================================== + # Tests requiring specific compiled libraries not available in NPU build + # ============================================================================== + - custom_operator/test_custom_ops # requires libcustom_ops.so + + # ============================================================================== + # Custom device extension tests (require separate compilation) + # ============================================================================== + # torch_openreg tests + # Issue: torch_openreg is a separate extension that must be compiled AFTER PyTorch is installed + # Problem: PyTorch run_test.py calls install_cpp_extensions() before running these tests + # - Requires CMake build linking against installed PyTorch + # - setup.py shows: "-DPYTORCH_INSTALL_DIR=" + get_pytorch_dir() + # - By default, PyTorch CI EXCLUDES test_openreg (only runs with --openreg flag) + # Impact: ModuleNotFoundError: No module named 'torch_openreg' + # See: run_test.py - options.exclude.append("test_openreg") in default behavior + - cpp_extensions/open_registration_extension/torch_openreg/tests/test_memory + + # ============================================================================== + # Architecture-specific tests (x86_64 only, no ARM64/aarch64 support) + # ============================================================================== + # TorchRec tests + # Issue: torchrec depends on fbgemm-gpu, which only has x86_64 wheels on PyPI + # Problem: fbgemm-gpu provides pre-built wheels for x86_64 only: + # - fbgemm_gpu-1.6.0-cp311-cp311-manylinux_2_28_x86_64.whl + # - No aarch64/ARM64 wheels available + # Impact: Cannot install fbgemm-gpu on ARM64 NPU runner (linux-aarch64-a3-16) + # Test uses graceful handling (NoTest fallback), but we blacklist for clarity + - dynamo/test_torchrec + + # ============================================================================== + # PyTorch upstream design issues (missing __init__.py or internal API changes) + # ============================================================================== + # Quantization experimental tests + # Issue: torch/ao/quantization/experimental/ directory exists but has NO __init__.py + # Python cannot recognize it as a package, causing ModuleNotFoundError on import + # See: https://github.com/pytorch/pytorch/tree/main/torch/ao/quantization/experimental + # Files exist: adaround_optimization.py, linear.py, observer.py, etc. + # Missing: __init__.py (required for package import) + # Impact: 6 test files fail with "No module named 'torch.ao.quantization.experimental'" + - quantization/core/experimental/test_adaround_eager + - quantization/core/experimental/test_fake_quantize + - quantization/core/experimental/test_linear + - quantization/core/experimental/test_nonuniform_observer + - quantization/core/experimental/test_quantized_tensor + - quantization/core/experimental/test_quantizer + + # numpy internal private module dependency + # Issue: Test imports private function from numpy's internal module structure + # Code: `from numpy.linalg.linalg import _multi_dot_matrix_chain_order` + # Problem: numpy 2.0+ restructured internal modules, `numpy.linalg.linalg` no longer exists + # - Private APIs (_ prefix) have no stability guarantee, can be moved/removed anytime + # - numpy 2.0 moved _multi_dot_matrix_chain_order to numpy.linalg._linalg or removed it + # - PyTorch test code uses private API instead of stable public API + # Impact: ModuleNotFoundError: No module named 'numpy.linalg.linalg' + - torch_np/numpy_tests/linalg/test_linalg + + # PyTorch test directory structure issue (test/ has no __init__.py) + # Issue: Test imports using package path `from test.jit.fixtures_srcs.generate_models import ...` + # Problem: PyTorch test/ directory is NOT a Python package (missing __init__.py files) + # - test/__init__.py does not exist + # - test/jit/__init__.py does not exist + # - Without __init__.py, Python cannot recognize `test.jit` as a valid package path + # - PyTorch CI works via special PYTHONPATH setup that our collection script doesn't replicate + # Impact: ModuleNotFoundError: No module named 'test.jit' + - jit/fixtures_srcs/test_upgrader_models_generation + + # ============================================================================== + # Tests with torch_npu compatibility issues + # ============================================================================== + # RPC tests fail due to torch_npu serialization.py 'object' type annotation + # (TorchScript incompatible) - these are handled by BACKEND=hccl skip + # Uncomment if needed: + # - distributed/rpc/test_tensorpipe_agent + # - distributed/rpc/test_faulty_agent + # - distributed/rpc/cuda/test_tensorpipe_agent + + # ============================================================================== + # Tests with environment variable requirements (handled by BACKEND=hccl) + # ============================================================================== + # These tests require BACKEND=gloo/nccl environment variable + # Setting BACKEND=hccl causes them to define no test classes + # - distributed/algorithms/quantization/test_quantization + # - distributed/test_distributed_spawn + +# Whitelist is empty - collect all tests except blacklist +whitelist: [] \ No newline at end of file