diff --git a/.ci/docker/requirements-ci.txt b/.ci/docker/requirements-ci.txt
index 8602d4d0fa..e0968f422c 100644
--- a/.ci/docker/requirements-ci.txt
+++ b/.ci/docker/requirements-ci.txt
@@ -1,7 +1,51 @@
-# Python dependencies required for unit tests
+# Python dependencies required for NPU tests
+# Based on upstream PyTorch .ci/docker/requirements-ci.txt
-mypy==1.9.0
-# Pin MyPy version because new errors are likely to appear with each release
-#Description: linter
-#Pinned versions: 1.9.0
-#test that import: test_typing.py, test_type_hints.py
+# pytest and plugins
+pytest==7.3.2
+pytest-xdist==3.3.1
+pytest-flakefinder==1.1.0
+pytest-rerunfailures>=10.3
+pytest-subtests==0.13.1
+pytest-timeout>=2.3.1
+xdoctest==1.3.0
+
+# test utilities
+hypothesis==6.56.4
+expecttest==0.3.0
+parameterized==0.8.1
+
+# numpy (version per Python version)
+numpy==1.26.2; python_version >= "3.11" and python_version < "3.14"
+
+# scientific packages
+scipy==1.14.1; python_version > "3.11" and python_version < "3.14"
+scikit-image==0.22.0
+pillow==12.1.1
+pywavelets==1.7.0; python_version >= "3.12"
+
+# core utilities
+networkx==2.8.8
+optree==0.13.0; python_version < "3.14"
+opt-einsum==3.3
+filelock==3.20.3
+sympy==1.13.3
+
+# build/serialization
+pyyaml==6.0.3
+packaging==24.0
+typing-extensions==4.12.2; python_version < "3.14"
+pyzstd
+setuptools>=70.1.0,<82
+zstandard
+
+# ONNX support
+onnx==1.20.0
+onnxscript==0.6.2
+protobuf==6.33.5
+
+# misc
+psutil
+jinja2==3.1.6
+tqdm>=4.66.0
+click
\ No newline at end of file
diff --git a/.github/actions/setup-npu-test-env/action.yml b/.github/actions/setup-npu-test-env/action.yml
new file mode 100644
index 0000000000..1c45976642
--- /dev/null
+++ b/.github/actions/setup-npu-test-env/action.yml
@@ -0,0 +1,146 @@
+name: 'Setup NPU Test Environment'
+description: 'Common environment setup for NPU upstream tests - checkout, cache, install PyTorch/torch_npu/triton-ascend, test dependencies'
+
+inputs:
+ python_version:
+ required: true
+ type: string
+ description: Python version to use
+ torch_wheel_artifact:
+ required: true
+ type: string
+ description: Name of the torch wheel artifact
+ torch_npu_wheel_artifact:
+ required: true
+ type: string
+ description: Name of the torch_npu wheel artifact
+ pytorch_src_artifact:
+ required: true
+ type: string
+ description: Name of the PyTorch source artifact
+
+env:
+ # PyPI 缓存 URL(用于加速 pip 下载)
+ PYPI_CACHE_URL: 'http://cache-service.nginx-pypi-cache.svc.cluster.local/pypi/simple'
+
+runs:
+ using: 'composite'
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@v4
+ with:
+ repository: ${{ github.repository }}
+ ref: ${{ github.ref }}
+ fetch-depth: 1
+ path: ascend_pytorch
+
+ - name: Setup cache directories
+ run: |
+ mkdir -p /github/home/.cache/pip
+ chmod -R 777 /github/home/.cache
+
+ - name: Cache pip
+ uses: actions/cache@v4
+ with:
+ path: /github/home/.cache/pip
+ key: pip-arm-collect-py${{ inputs.python_version }}
+ restore-keys: |
+ pip-arm-collect-py${{ inputs.python_version }}-
+ pip-arm-collect-
+
+ - name: Download built torch wheel
+ uses: actions/download-artifact@v4
+ with:
+ name: ${{ inputs.torch_wheel_artifact }}
+ path: torch-wheel-artifact
+
+ - name: Download built torch_npu wheel
+ uses: actions/download-artifact@v4
+ with:
+ name: ${{ inputs.torch_npu_wheel_artifact }}
+ path: torch-npu-wheel-artifact
+
+ - name: Download PyTorch source and test code
+ uses: actions/download-artifact@v4
+ with:
+ name: ${{ inputs.pytorch_src_artifact }}
+ path: pytorch-src-artifact
+
+ - name: Extract PyTorch source
+ run: |
+ tar -xzf pytorch-src-artifact/pytorch-src.tar.gz
+
+ - name: Install built PyTorch and torch_npu
+ run: |
+ source /usr/local/Ascend/cann/set_env.sh 2>/dev/null || true
+ source /usr/local/Ascend/nnal/atb/set_env.sh 2>/dev/null || true
+
+ PIP=pip${{ inputs.python_version }}
+ PYTHON=python${{ inputs.python_version }}
+ export PIP_CACHE_DIR=/github/home/.cache/pip
+
+ # Configure pip to use PyPI cache for faster downloads
+ if [ -n "${{ env.PYPI_CACHE_URL }}" ]; then
+ $PIP config set global.index-url ${{ env.PYPI_CACHE_URL }}
+ $PIP config set global.trusted-host "cache-service.nginx-pypi-cache.svc.cluster.local"
+ echo "pip index-url configured: ${{ env.PYPI_CACHE_URL }}"
+ fi
+
+ $PIP install --upgrade pip
+
+ # Install built torch wheel
+ TORCH_WHL=$(ls torch-wheel-artifact/*.whl | head -1)
+ $PIP install "${TORCH_WHL}"
+
+ # Install built torch_npu wheel
+ TORCH_NPU_WHL=$(ls torch-npu-wheel-artifact/*.whl | head -1)
+ $PIP install "${TORCH_NPU_WHL}"
+
+ echo "Installed PyTorch and torch_npu from built wheels"
+ echo "torch: ${TORCH_WHL}"
+ echo "torch_npu: ${TORCH_NPU_WHL}"
+
+ - name: Install test dependencies
+ run: |
+ PIP=pip${{ inputs.python_version }}
+ export PIP_CACHE_DIR=/github/home/.cache/pip
+ cd pytorch-src
+
+ # Core test dependencies
+ $PIP install pytest pytest-timeout pytest-xdist hypothesis zstandard pyyaml
+ $PIP install pytest-rerunfailures pytest-flakefinder
+ $PIP install 'pytest-subtests==0.13.1' 'xdoctest==1.1.0' 'pulp>=2.9'
+
+ # Optional dependencies for ONNX tests
+ # These are not in PyTorch requirements.txt but needed by specific tests
+ $PIP install onnxruntime onnxscript onnx-ir ml-dtypes || true
+
+ # torchvision for ONNX model tests (install without deps to bypass torch version check)
+ # PyPI torchvision requires exact torch version (torch==2.11.0), but we have dev build
+ # Use --no-deps to skip torch dependency, we already have our compiled torch installed
+ $PIP install numpy pillow || true
+ $PIP install torchvision --no-deps || true
+
+ # Other optional dependencies
+ $PIP install parameterized pandas || true
+ $PIP install opencv-python || true
+
+ # PyTorch requirements (if exists)
+ if [ -f requirements.txt ]; then
+ $PIP install -r requirements.txt || true
+ fi
+
+ - name: Verify NPU availability
+ run: |
+ source /usr/local/Ascend/cann/set_env.sh 2>/dev/null || true
+ source /usr/local/Ascend/nnal/atb/set_env.sh 2>/dev/null || true
+
+ PYTHON=python${{ inputs.python_version }}
+ $PYTHON -c "
+ import torch
+ print(f'torch: {torch.__version__}')
+ import torch_npu
+ print(f'torch_npu: {torch_npu.__version__}')
+ print(f'NPU available: {torch.npu.is_available()}')
+ print(f'NPU count: {torch.npu.device_count()}')
+ "
\ No newline at end of file
diff --git a/.github/docker/pytorch-npu-builder.Dockerfile b/.github/docker/pytorch-npu-builder.Dockerfile
new file mode 100644
index 0000000000..f8a443b402
--- /dev/null
+++ b/.github/docker/pytorch-npu-builder.Dockerfile
@@ -0,0 +1,152 @@
+# 基于 PyPA manylinux 2_28 aarch64 镜像 (与 PyTorch 主干一致)
+FROM quay.io/pypa/manylinux_2_28_aarch64
+
+ARG GCCTOOLSET_VERSION=13
+
+# CANN 包下载 URL(通过 build-arg 传入)
+ARG CANN_TOOLKIT_URL
+ARG CANN_A3OPS_URL
+ARG CANN_NNAL_URL
+ARG CANN_VERSION
+
+# Language variables
+ENV LC_ALL=en_US.UTF-8
+ENV LANG=en_US.UTF-8
+ENV LANGUAGE=en_US.UTF-8
+
+# 安装必要的 OS 包 (与 PyTorch 官方 Dockerfile 一致)
+RUN yum -y install epel-release && \
+ yum -y update && \
+ yum install -y \
+ autoconf \
+ automake \
+ bison \
+ bzip2 \
+ curl \
+ diffutils \
+ file \
+ git \
+ less \
+ libffi-devel \
+ libgomp \
+ make \
+ openssl-devel \
+ patch \
+ perl \
+ unzip \
+ util-linux \
+ wget \
+ which \
+ xz \
+ yasm \
+ zstd \
+ sudo \
+ gcc-toolset-${GCCTOOLSET_VERSION}-gcc \
+ gcc-toolset-${GCCTOOLSET_VERSION}-gcc-c++ \
+ gcc-toolset-${GCCTOOLSET_VERSION}-gcc-gfortran \
+ gcc-toolset-${GCCTOOLSET_VERSION}-gdb && \
+ yum install -y --enablerepo=powertools ninja-build && \
+ rm -rf /var/cache/yum
+
+# 确保使用正确的 devtoolset
+ENV PATH=/opt/rh/gcc-toolset-${GCCTOOLSET_VERSION}/root/usr/bin:$PATH
+ENV LD_LIBRARY_PATH=/opt/rh/gcc-toolset-${GCCTOOLSET_VERSION}/root/usr/lib64:/opt/rh/gcc-toolset-${GCCTOOLSET_VERSION}/root/usr/lib:$LD_LIBRARY_PATH
+
+# git 2.36+ 需要配置 safe.directory
+RUN git config --global --add safe.directory "*"
+
+# ============================================================
+# 预装所有 Python 版本(镜像支持多 Python 版本)
+# ============================================================
+# manylinux 镜像已包含 cp310-cp310, cp311-cp311, cp312-cp312, cp313-cp313
+# 默认使用 Python 3.11(可通过环境变量切换)
+
+ENV DEFAULT_PYTHON_VERSION=3.11
+ENV PATH=/opt/python/cp311-cp311/bin:$PATH
+
+# 创建 Python 版本切换脚本
+RUN printf '#!/bin/bash\n\
+# Python 版本切换辅助脚本\n\
+# 使用方法: source /usr/local/bin/switch_python.sh 3.11\n\
+\n\
+PYTHON_VERSION="${1:-3.11}"\n\
+\n\
+case "$PYTHON_VERSION" in\n\
+ 3.10) PYTHON_DIR="cp310-cp310" ;;\n\
+ 3.11) PYTHON_DIR="cp311-cp311" ;;\n\
+ 3.12) PYTHON_DIR="cp312-cp312" ;;\n\
+ 3.13) PYTHON_DIR="cp313-cp313" ;;\n\
+ *) echo "Unsupported Python version: $PYTHON_VERSION"; return 1 ;;\n\
+esac\n\
+\n\
+export PATH=/opt/python/$PYTHON_DIR/bin:$PATH\n\
+echo "Switched to Python $PYTHON_VERSION ($(python --version))"\n\
+' > /usr/local/bin/switch_python.sh && \
+ chmod +x /usr/local/bin/switch_python.sh
+
+# 为每个 Python 版本安装常用包
+RUN for py_dir in cp310-cp310 cp311-cp311 cp312-cp312 cp313-cp313; do \
+ /opt/python/$py_dir/bin/pip install --upgrade pip setuptools wheel; \
+ done
+
+# ============================================================
+# 安装 CANN(使用传入的 URL)
+# ============================================================
+
+WORKDIR /root
+
+RUN mkdir -p cann && cd cann && \
+ curl -O "${CANN_TOOLKIT_URL}" && \
+ curl -O "${CANN_A3OPS_URL}" && \
+ curl -O "${CANN_NNAL_URL}" && \
+ chmod +x Ascend-cann*.run && \
+ ./Ascend-cann-toolkit*.run --full --quiet --install-path=/usr/local/Ascend && \
+ ./Ascend-cann-A3*.run --install --quiet --install-path=/usr/local/Ascend && \
+ source /usr/local/Ascend/cann/set_env.sh && \
+ ./Ascend-cann-nnal*.run --install --quiet --install-path=/usr/local/Ascend && \
+ rm -rf cann
+
+# 设置环境变量
+ENV CANN_PATH=/usr/local/Ascend/cann
+ENV NNAL_PATH=/usr/local/Ascend/nnal
+ENV ASCEND_HOME=/usr/local/Ascend
+ENV CANN_VERSION=${CANN_VERSION}
+
+# 添加 CANN 环境初始化脚本
+RUN printf '#!/bin/bash\n\
+source /usr/local/Ascend/cann/set_env.sh 2>/dev/null || true\n\
+source /usr/local/Ascend/nnal/atb/set_env.sh 2>/dev/null || true\n\
+' > /etc/profile.d/cann_env.sh && \
+ chmod +x /etc/profile.d/cann_env.sh
+
+# ============================================================
+# 预安装 pytest 等测试依赖(为所有 Python 版本)
+# ============================================================
+
+RUN for py_dir in cp310-cp310 cp311-cp311 cp312-cp312 cp313-cp313; do \
+ /opt/python/$py_dir/bin/pip install pytest pytest-timeout pytest-xdist hypothesis pyyaml zstandard cmake ninja; \
+ done
+
+# ============================================================
+# 设置工作目录和默认命令
+# ============================================================
+
+WORKDIR /workspace
+
+# 创建 welcome 消息
+RUN printf '\n\
+========================================\n\
+PyTorch NPU Builder Image\n\
+========================================\n\
+CANN Version: %s\n\
+Python Versions: 3.10, 3.11, 3.12, 3.13 (default: 3.11)\n\
+\n\
+To switch Python version:\n\
+ source /usr/local/bin/switch_python.sh 3.12\n\
+\n\
+To setup CANN environment:\n\
+ source /etc/profile.d/cann_env.sh\n\
+========================================\n\
+\n' "${CANN_VERSION}" > /etc/motd
+
+CMD ["bash"]
\ No newline at end of file
diff --git a/.github/scripts/build_image.sh b/.github/scripts/build_image.sh
new file mode 100755
index 0000000000..25763c688e
--- /dev/null
+++ b/.github/scripts/build_image.sh
@@ -0,0 +1,486 @@
+#!/bin/bash
+#
+# build_image.sh - 构建 PyTorch NPU Docker 镜像
+#
+# 功能:按 CANN 版本构建镜像,镜像预装多 Python 版本,通过环境变量切换
+#
+# 使用方式:
+# ./build_image.sh --cann-version 9.0
+# ./build_image.sh --cann-version 9.0.0-beta.2 --push
+# ./build_image.sh --list-versions # 查看支持的 CANN 版本
+#
+
+set -euo pipefail
+
+# ============================================================
+# CANN 版本映射表
+# 每个版本对应三个包的下载 URL
+# ============================================================
+
+declare -A CANN_VERSIONS=(
+ # 版本号 -> toolkit|a3_ops|nnal 的 URL
+ # 注意:OBS 上当前只有 9.0.0-beta.2 版本的包
+ ["9.0"]="https://pytorch-package.obs.cn-north-4.myhuaweicloud.com/pta/cann-package/20260330/Ascend-cann-toolkit_9.0.0-beta.2_linux-aarch64.run|https://pytorch-package.obs.cn-north-4.myhuaweicloud.com/pta/cann-package/20260330/Ascend-cann-A3-ops_9.0.0-beta.2_linux-aarch64.run|https://pytorch-package.obs.cn-north-4.myhuaweicloud.com/pta/cann-package/20260330/Ascend-cann-nnal_9.0.0-beta.2_linux-aarch64.run"
+
+ ["9.0.0-beta.2"]="https://pytorch-package.obs.cn-north-4.myhuaweicloud.com/pta/cann-package/20260330/Ascend-cann-toolkit_9.0.0-beta.2_linux-aarch64.run|https://pytorch-package.obs.cn-north-4.myhuaweicloud.com/pta/cann-package/20260330/Ascend-cann-A3-ops_9.0.0-beta.2_linux-aarch64.run|https://pytorch-package.obs.cn-north-4.myhuaweicloud.com/pta/cann-package/20260330/Ascend-cann-nnal_9.0.0-beta.2_linux-aarch64.run"
+
+ ["8.0"]="https://pytorch-package.obs.cn-north-4.myhuaweicloud.com/pta/cann-package/20250101/Ascend-cann-toolkit_8.0.RC3_linux-aarch64.run|https://pytorch-package.obs.cn-north-4.myhuaweicloud.com/pta/cann-package/20250101/Ascend-cann-A3-ops_8.0.RC3_linux-aarch64.run|https://pytorch-package.obs.cn-north-4.myhuaweicloud.com/pta/cann-package/20250101/Ascend-cann-nnal_8.0.RC3_linux-aarch64.run"
+)
+
+# Stable 版本标记(用于 latest 标签)
+CANN_STABLE="9.0"
+
+# 预装的 Python 版本列表
+PYTHON_VERSIONS=("3.10" "3.11" "3.12" "3.13")
+
+# manylinux 对应的 Python 目录名映射
+declare -A PYTHON_DIR_MAP=(
+ ["3.10"]="cp310-cp310"
+ ["3.11"]="cp311-cp311"
+ ["3.12"]="cp312-cp312"
+ ["3.13"]="cp313-cp313"
+)
+
+# ============================================================
+# 默认配置
+# ============================================================
+
+DEFAULT_REGISTRY="quay.io"
+DEFAULT_QUAY_ORG="kerer"
+DEFAULT_IMAGE_NAME="pytorch"
+
+# 参数变量
+CANN_VERSION_INPUT=""
+REGISTRY=""
+QUAY_ORG=""
+IMAGE_NAME=""
+PUSH_IMAGE=false
+FORCE_BUILD=false
+VERBOSE=false
+LIST_VERSIONS=false
+
+# ============================================================
+# 日志函数
+# ============================================================
+
+log_info() {
+ echo "[INFO] $1"
+}
+
+log_error() {
+ echo "[ERROR] $1" >&2
+}
+
+log_verbose() {
+ if [[ "$VERBOSE" == "true" ]]; then
+ echo "[VERBOSE] $1"
+ fi
+}
+
+# ============================================================
+# 显示帮助信息
+# ============================================================
+
+show_help() {
+ cat << EOF
+用法: $0 [OPTIONS]
+
+构建支持不同 CANN 版本的 PyTorch NPU Docker 镜像。
+
+镜像特性:
+ - 预装多个 Python 版本 (3.10/3.11/3.12/3.13)
+ - 通过环境变量切换 Python 版本
+ - 按 CANN 版本构建镜像
+
+CANN 参数:
+ --cann-version VERSION CANN 版本号(支持简化版或完整版)
+ 简化版: 9.0, 8.0
+ 完整版: 9.0.0-beta.2
+ --list-versions 显示支持的 CANN 版本列表
+
+镜像参数:
+ --registry REGISTRY Docker registry 地址 (默认: quay.io)
+ --quay-org ORG Quay.io 组织名 (默认: kerer)
+ --image-name NAME 镜像名称 (默认: pytorch)
+
+构建选项:
+ --push 构建后推送镜像到 registry
+ --force 强制构建,即使镜像已存在
+ --verbose 显示详细日志
+
+Python 版本切换(运行时):
+ 镜像预装多个 Python 版本,使用时通过环境变量切换:
+ export PATH=/opt/python/cp311-cp311/bin:\$PATH # 使用 Python 3.11
+ export PATH=/opt/python/cp312-cp312/bin:\$PATH # 使用 Python 3.12
+
+示例:
+ $0 --cann-version 9.0
+ $0 --cann-version 9.0.0-beta.2 --push
+ $0 --list-versions
+
+支持的 CANN 版本:
+$(show_supported_versions)
+
+EOF
+}
+
+show_supported_versions() {
+ echo "简化版本 完整版本"
+ echo "----------- ----------------"
+ for version in "${!CANN_VERSIONS[@]}"; do
+ if [[ ! "$version" =~ -beta ]] && [[ ! "$version" =~ -rc ]]; then
+ echo "$version (完整版见映射表)"
+ fi
+ done
+ echo ""
+ echo "完整版本示例:"
+ echo " 9.0.0-beta.2"
+ echo " 8.0.RC3"
+}
+
+# ============================================================
+# 解析命令行参数
+# ============================================================
+
+parse_args() {
+ while [[ $# -gt 0 ]]; do
+ case "$1" in
+ --cann-version)
+ CANN_VERSION_INPUT="$2"
+ shift 2
+ ;;
+ --registry)
+ REGISTRY="$2"
+ shift 2
+ ;;
+ --quay-org)
+ QUAY_ORG="$2"
+ shift 2
+ ;;
+ --image-name)
+ IMAGE_NAME="$2"
+ shift 2
+ ;;
+ --push)
+ PUSH_IMAGE=true
+ shift
+ ;;
+ --force)
+ FORCE_BUILD=true
+ shift
+ ;;
+ --verbose)
+ VERBOSE=true
+ shift
+ ;;
+ --list-versions)
+ LIST_VERSIONS=true
+ shift
+ ;;
+ -h|--help)
+ show_help
+ exit 0
+ ;;
+ *)
+ log_error "未知参数: $1"
+ show_help
+ exit 1
+ ;;
+ esac
+ done
+
+ # 设置默认值
+ REGISTRY="${REGISTRY:-$DEFAULT_REGISTRY}"
+ QUAY_ORG="${QUAY_ORG:-$DEFAULT_QUAY_ORG}"
+ IMAGE_NAME="${IMAGE_NAME:-$DEFAULT_IMAGE_NAME}"
+
+ # 显示版本列表
+ if [[ "$LIST_VERSIONS" == "true" ]]; then
+ echo "支持的 CANN 版本:"
+ echo ""
+ for version in "${!CANN_VERSIONS[@]}"; do
+ echo " - $version"
+ done
+ echo ""
+ echo "Stable 版本(用于 latest 标签): $CANN_STABLE"
+ exit 0
+ fi
+
+ # 验证参数
+ if [[ -z "$CANN_VERSION_INPUT" ]]; then
+ log_error "必须指定 --cann-version 或使用 --list-versions"
+ show_help
+ exit 1
+ fi
+}
+
+# ============================================================
+# 解析 CANN 版本
+# ============================================================
+
+parse_cann_version() {
+ local input="$CANN_VERSION_INPUT"
+
+ log_verbose "解析 CANN 版本: $input"
+
+ # 检查版本是否在映射表中
+ if [[ ! -v CANN_VERSIONS[$input] ]]; then
+ log_error "不支持的 CANN 版本: $input"
+ log_info "支持的版本: ${!CANN_VERSIONS[*]}"
+ log_info "使用 --list-versions 查看完整列表"
+ exit 1
+ fi
+
+ # 解析 URL
+ local urls="${CANN_VERSIONS[$input]}"
+ CANN_TOOLKIT_URL=$(echo "$urls" | cut -d'|' -f1)
+ CANN_A3OPS_URL=$(echo "$urls" | cut -d'|' -f2)
+ CANN_NNAL_URL=$(echo "$urls" | cut -d'|' -f3)
+
+ # 提取版本号(去掉 beta/rc 后缀)
+ CANN_VERSION_FULL="$input"
+ CANN_VERSION_MAJOR=$(echo "$input" | sed 's/-beta.*//' | sed 's/-rc.*//')
+
+ # 判断是否为 stable 版本
+ if [[ "$CANN_VERSION_MAJOR" == "$CANN_STABLE" ]]; then
+ IS_STABLE="true"
+ else
+ IS_STABLE="false"
+ fi
+
+ log_verbose "Toolkit URL: $CANN_TOOLKIT_URL"
+ log_verbose "A3-ops URL: $CANN_A3OPS_URL"
+ log_verbose "NNAL URL: $CANN_NNAL_URL"
+ log_verbose "Full version: $CANN_VERSION_FULL"
+ log_verbose "Major version: $CANN_VERSION_MAJOR"
+ log_verbose "Is stable: $IS_STABLE"
+}
+
+# ============================================================
+# 生成镜像标签
+# ============================================================
+
+generate_tags() {
+ local timestamp=$(date +%Y%m%d)
+ local tags=()
+
+ # 提取大版本号(去掉 patch 号,但保留 beta/rc)
+ # 例如:9.0.0-beta.2 → 9.0,9.0 → 9.0,8.0.RC3 → 8.0
+ local cann_major
+
+ # 如果版本号已经是简化格式(没有第二个点),则保持原样
+ if [[ "$CANN_VERSION_FULL" =~ ^[0-9]+\.[0-9]+$ ]]; then
+ cann_major="$CANN_VERSION_FULL"
+ else
+ # 提取前两位数字(去掉 patch 号和 beta/rc 后缀)
+ cann_major=$(echo "$CANN_VERSION_FULL" | grep -oP '^[0-9]+\.[0-9]+')
+ fi
+
+ # 1. 完整版本标签(带时间戳)- 用于追溯
+ tags+=("cann${CANN_VERSION_FULL}-${timestamp}")
+
+ # 2. 标准版本标签(无时间戳)- 用于日常使用
+ # 如果输入已经是简化版本,则跳过完整版本标签,避免重复
+ if [[ "$CANN_VERSION_FULL" != "$cann_major" ]]; then
+ tags+=("cann${CANN_VERSION_FULL}")
+ fi
+
+ # 3. 大版本简化标签 - 用于快速识别
+ tags+=("cann${cann_major}")
+
+ # 4. latest 标签(仅 stable 版本)
+ if [[ "$IS_STABLE" == "true" ]]; then
+ tags+=("latest")
+ tags+=("cann-latest")
+ tags+=("cann${cann_major}-latest")
+ fi
+
+ # 输出所有标签
+ for tag in "${tags[@]}"; do
+ echo "$tag"
+ done
+}
+
+# ============================================================
+# 构建镜像
+# ============================================================
+
+build_image() {
+ log_info "=========================================="
+ log_info "构建镜像: CANN $CANN_VERSION_FULL"
+ log_info "=========================================="
+
+ log_info "预装 Python 版本: ${PYTHON_VERSIONS[*]}"
+
+ # 生成镜像标签
+ local tags=$(generate_tags)
+ local tag_args=""
+ while IFS= read -r tag; do
+ tag_args+=" --tag ${REGISTRY}/${QUAY_ORG}/${IMAGE_NAME}:${tag}"
+ done <<< "$tags"
+
+ log_info "镜像标签:"
+ while IFS= read -r tag; do
+ log_info " - ${REGISTRY}/${QUAY_ORG}/${IMAGE_NAME}:${tag}"
+ done <<< "$tags"
+
+ # 检查镜像是否已存在(除非强制构建)
+ if [[ "$FORCE_BUILD" == "false" && "$PUSH_IMAGE" == "true" ]]; then
+ local first_tag=$(echo "$tags" | head -n1)
+ if docker pull "${REGISTRY}/${QUAY_ORG}/${IMAGE_NAME}:${first_tag}" &>/dev/null; then
+ log_info "镜像已存在,跳过构建: ${REGISTRY}/${QUAY_ORG}/${IMAGE_NAME}:${first_tag}"
+ return 0
+ fi
+ fi
+
+ # 确认需要构建,执行登录(如果需要推送)
+ # 如果环境变量 SKIP_DOCKER_LOGIN=true,则跳过(用于 CI,已通过 login-action 登录)
+ if [[ "$PUSH_IMAGE" == "true" ]]; then
+ if [[ "${SKIP_DOCKER_LOGIN:-false}" != "true" ]]; then
+ login_registry
+ else
+ log_verbose "跳过登录(SKIP_DOCKER_LOGIN=true)"
+ fi
+ fi
+
+ # Dockerfile 路径
+ # 使用 git 获取项目根目录(更可靠)
+ local script_dir="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+ local project_root
+
+ # 尝试使用 git 获取项目根目录
+ if git rev-parse --show-toplevel &>/dev/null; then
+ project_root="$(git rev-parse --show-toplevel)"
+ else
+ # 如果不在 git 仓库中,从脚本目录向上推导
+ project_root="$(cd "${script_dir}/.." && pwd)"
+ fi
+
+ local dockerfile_dir="${project_root}/.github/docker"
+ local dockerfile="${dockerfile_dir}/pytorch-npu-builder.Dockerfile"
+
+ log_verbose "Script dir: ${script_dir}"
+ log_verbose "Project root: ${project_root}"
+ log_verbose "Dockerfile dir: ${dockerfile_dir}"
+
+ if [[ ! -f "$dockerfile" ]]; then
+ log_error "Dockerfile 不存在: $dockerfile"
+ exit 1
+ fi
+
+ log_verbose "Dockerfile: $dockerfile"
+
+ # 构建参数(单行格式,避免换行符问题)
+ local build_args="--build-arg CANN_TOOLKIT_URL=${CANN_TOOLKIT_URL} --build-arg CANN_A3OPS_URL=${CANN_A3OPS_URL} --build-arg CANN_NNAL_URL=${CANN_NNAL_URL} --build-arg CANN_VERSION=${CANN_VERSION_FULL}"
+
+ # 构建命令(单行格式)
+ local build_cmd="docker buildx build ${build_args} ${tag_args} --file ${dockerfile} --platform linux/arm64 ${dockerfile_dir}"
+
+ if [[ "$PUSH_IMAGE" == "true" ]]; then
+ build_cmd+=" --push"
+ else
+ build_cmd+=" --load"
+ fi
+
+ log_verbose "构建命令: $build_cmd"
+
+ # 执行构建
+ log_info "开始构建..."
+ if ! eval "$build_cmd"; then
+ log_error "构建失败"
+ return 1
+ fi
+
+ log_info "构建成功"
+
+ # 输出构建信息
+ echo ""
+ log_info "构建信息:"
+ log_info " CANN 版本: $CANN_VERSION_FULL"
+ log_info " CANN 大版本: $CANN_VERSION_MAJOR"
+ log_info " Stable: $IS_STABLE"
+ log_info " 预装 Python: ${PYTHON_VERSIONS[*]}"
+ log_info " 镜像地址:"
+ while IFS= read -r tag; do
+ log_info " ${REGISTRY}/${QUAY_ORG}/${IMAGE_NAME}:${tag}"
+ done <<< "$tags"
+
+ echo ""
+ log_info "使用方法:"
+ log_info " docker run -it ${REGISTRY}/${QUAY_ORG}/${IMAGE_NAME}:cann${CANN_VERSION_MAJOR} bash"
+ log_info " # 切换 Python 版本:"
+ log_info " export PATH=/opt/python/cp311-cp311/bin:\$PATH # Python 3.11"
+ log_info " export PATH=/opt/python/cp312-cp312/bin:\$PATH # Python 3.12"
+ echo ""
+
+ return 0
+}
+
+# ============================================================
+# 检查依赖
+# ============================================================
+
+check_dependencies() {
+ log_verbose "检查依赖..."
+
+ # 检查 docker
+ if ! command -v docker &>/dev/null; then
+ log_error "未安装 docker"
+ exit 1
+ fi
+
+ # 检查 docker buildx
+ if ! docker buildx version &>/dev/null; then
+ log_error "docker buildx 不可用"
+ exit 1
+ fi
+
+ log_verbose "依赖检查通过"
+}
+
+# ============================================================
+# 登录 registry
+# ============================================================
+
+login_registry() {
+ if [[ "$PUSH_IMAGE" == "true" ]]; then
+ log_info "登录 Registry: $REGISTRY"
+
+ case "$REGISTRY" in
+ quay.io)
+ if [[ -z "${QUAY_USERNAME:-}" || -z "${QUAY_PASSWORD:-}" ]]; then
+ log_error "需要设置环境变量 QUAY_USERNAME 和 QUAY_PASSWORD"
+ exit 1
+ fi
+ docker login quay.io -u "$QUAY_USERNAME" --password-stdin <<< "$QUAY_PASSWORD"
+ ;;
+ ghcr.io)
+ if [[ -z "${GITHUB_TOKEN:-}" ]]; then
+ log_error "需要设置环境变量 GITHUB_TOKEN"
+ exit 1
+ fi
+ echo "$GITHUB_TOKEN" | docker login ghcr.io -u "${GITHUB_ACTOR:-}" --password-stdin
+ ;;
+ *)
+ log_error "不支持的 registry: $REGISTRY"
+ exit 1
+ ;;
+ esac
+
+ log_info "登录成功"
+ fi
+}
+
+# ============================================================
+# 主函数
+# ============================================================
+
+main() {
+ parse_args "$@"
+ check_dependencies
+ parse_cann_version
+ build_image
+}
+
+# 执行主函数
+main "$@"
\ No newline at end of file
diff --git a/.github/scripts/collect_all_cases.py b/.github/scripts/collect_all_cases.py
new file mode 100644
index 0000000000..6773817dd9
--- /dev/null
+++ b/.github/scripts/collect_all_cases.py
@@ -0,0 +1,452 @@
+#!/usr/bin/env python3
+"""
+Collect all test cases and split into shards.
+
+This script runs in prepare job (once) to:
+1. Discover test files by type (distributed/regular)
+2. Collect all test cases via pytest --collect-only
+3. Split cases evenly into N shards
+4. Output shard JSON files for each type
+5. Save collection error logs for failed files
+
+Usage:
+ python collect_all_cases.py \
+ --test-dir /path/to/pytorch/test \
+ --case-paths-config /path/to/case_paths_ci.yml \
+ --distributed-shards 2 \
+ --regular-shards 5 \
+ --output-dir /path/to/output \
+ --error-log-dir /path/to/error_logs \
+ --parallel 16
+"""
+
+import argparse
+import json
+import os
+import subprocess
+import sys
+from concurrent.futures import ThreadPoolExecutor, as_completed
+from pathlib import Path
+from typing import Dict, List, Tuple
+
+# Import discover_test_files module
+import discover_test_files
+
+
+def _normalize_test_file_path(test_file: str) -> str:
+ """
+ Remove 'test/' prefix from test file path if present.
+
+ Args:
+ test_file: Test file path (e.g., "test/distributed/pipelining/test_backward.py")
+
+ Returns:
+ Relative path without 'test/' prefix
+ """
+ if test_file.startswith("test/"):
+ return test_file[5:]
+ return test_file
+
+
+def get_test_file_parent_dir(test_file: str, test_dir: Path) -> Path:
+ """
+ Get the parent directory of a test file.
+
+ This directory should be added to PYTHONPATH to enable
+ imports of sibling modules (e.g., model_registry.py).
+
+ Args:
+ test_file: Test file path (e.g., "test/distributed/pipelining/test_backward.py")
+ test_dir: Path to PyTorch test directory
+
+ Returns:
+ Path to the test file's parent directory
+ """
+ test_file_rel = _normalize_test_file_path(test_file)
+ test_file_path = Path(test_file_rel)
+ return test_dir / test_file_path.parent
+
+
+def collect_cases_for_file(test_file: str, test_dir: Path) -> Tuple[str, str, List[str], bool, str]:
+ """
+ Collect test cases from a single file.
+
+ Adds test file's parent directory to PYTHONPATH to enable
+ imports of sibling modules (e.g., 'from model_registry import MLPModule').
+
+ Returns:
+ Tuple of (test_file, display_name, nodeids, success, error_message)
+ - test_file: Original test file path
+ - display_name: Short name for logging (remove test/ prefix and .py suffix)
+ - nodeids: List of collected test case nodeids
+ - success: True if collection succeeded without errors
+ - error_message: Error details if collection failed, empty string otherwise
+ """
+ test_file_rel = _normalize_test_file_path(test_file)
+
+ # Extract display name (remove .py suffix)
+ display_name = test_file_rel
+ if display_name.endswith(".py"):
+ display_name = display_name[:-3]
+
+ # Get test file's parent directory for PYTHONPATH
+ test_file_dir = get_test_file_parent_dir(test_file, test_dir)
+
+ # Build environment with test file directory in PYTHONPATH
+ env = os.environ.copy()
+ existing_pythonpath = env.get("PYTHONPATH", "")
+ env["PYTHONPATH"] = str(test_file_dir) + (":" + existing_pythonpath if existing_pythonpath else "")
+
+ command = [
+ sys.executable,
+ "-m",
+ "pytest",
+ "--collect-only",
+ "--quiet",
+ test_file_rel,
+ ]
+
+ try:
+ result = subprocess.run(
+ command,
+ cwd=str(test_dir),
+ env=env,
+ capture_output=True,
+ text=True,
+ encoding="utf-8",
+ errors="replace",
+ timeout=120,
+ )
+
+ nodeids = []
+ for line in result.stdout.splitlines():
+ stripped = line.strip()
+ # pytest --collect-only -q outputs clean nodeids, one per line
+ # Filter rules:
+ # 1. Skip empty lines
+ # 2. Skip summary lines (contain "collected" or "selected")
+ # 3. Skip separator lines (start with "=")
+ # 4. Must contain ".py::" to ensure it's a Python test file nodeid
+ if not stripped:
+ continue
+ if "collected" in stripped or "selected" in stripped:
+ continue
+ if stripped.startswith("="):
+ continue
+ if ".py::" in stripped:
+ nodeids.append(stripped)
+
+ # Check for collection errors based on pytest exit codes:
+ # 0: all passed (success)
+ # 2: pytest error (includes collection errors like ImportError)
+ # 3: all skipped (success)
+ # 4: command line error (error)
+ # 5: no tests collected (ERROR - test file should have cases)
+ # Key insight: if a test file is selected for execution, it should have cases.
+ # returncode 5 means 0 cases collected, which indicates a problem.
+ if result.returncode in (0, 3):
+ # Normal: passed or skipped
+ return (test_file, display_name, nodeids, True, "")
+ else:
+ # returncode 2, 4, 5: real collection error
+ # returncode 5 specifically means no tests collected - a problem for selected files
+ error_msg = result.stdout.strip()
+ if result.stderr.strip():
+ error_msg += "\n--- stderr ---\n" + result.stderr.strip()
+ return (test_file, display_name, nodeids, False, error_msg)
+
+ except subprocess.TimeoutExpired:
+ error_msg = f"TIMEOUT: Collection took >120s for {display_name}"
+ return (test_file, display_name, [], False, error_msg)
+ except Exception as e:
+ error_msg = f"ERROR: {e}"
+ return (test_file, display_name, [], False, error_msg)
+
+
+def collect_all_cases(
+ test_files: List[str],
+ test_dir: Path,
+ error_log_dir: Path,
+ parallel: int = 16,
+) -> List[Dict]:
+ """
+ Collect all cases from all files.
+
+ Args:
+ test_files: List of test file paths
+ test_dir: Path to PyTorch test directory
+ error_log_dir: Directory to save error logs for failed collections
+ parallel: Number of parallel workers
+
+ Returns:
+ List of dicts with nodeid and file for each collected case
+ """
+ all_cases = []
+ failed_files = [] # Track files with collection errors for logging
+
+ print(f"Collecting cases from {len(test_files)} files with {parallel} workers...")
+ print("=" * 60)
+
+ # Create error log directory
+ error_log_dir.mkdir(parents=True, exist_ok=True)
+
+ with ThreadPoolExecutor(max_workers=parallel) as executor:
+ futures = {
+ executor.submit(collect_cases_for_file, f, test_dir): f
+ for f in test_files
+ }
+
+ completed = 0
+ successful_count = 0
+ failed_count = 0
+ total_cases = 0
+
+ for future in as_completed(futures):
+ test_file, display_name, nodeids, success, error_msg = future.result()
+ completed += 1
+
+ if success:
+ successful_count += 1
+ # Print concise log for successful files
+ print(f" {display_name}: {len(nodeids)} cases")
+ for nodeid in nodeids:
+ all_cases.append({
+ "nodeid": nodeid,
+ "file": test_file,
+ })
+ else:
+ failed_count += 1
+ # Print concise log for failed files
+ print(f" [FAILED] {display_name}: {len(nodeids)} cases")
+ # Save error details to log file
+ failed_files.append({
+ "file": display_name,
+ "error": error_msg,
+ "cases": len(nodeids),
+ "test_file": test_file,
+ })
+ # Still add any cases that were collected despite errors
+ for nodeid in nodeids:
+ all_cases.append({
+ "nodeid": nodeid,
+ "file": test_file,
+ })
+
+ # Update total cases count for progress display
+ total_cases += len(nodeids)
+
+ # Print progress summary every 100 files
+ if completed % 100 == 0:
+ print(f" [Progress: {completed}/{len(test_files)} files, {successful_count} ok, {failed_count} failed, {total_cases} cases]")
+
+ print("=" * 60)
+
+ # Save error logs to files
+ if failed_files:
+ save_error_logs(failed_files, error_log_dir)
+
+ # Final summary
+ print(f"Collection complete: {len(all_cases)} cases from {successful_count}/{len(test_files)} files")
+ if failed_count > 0:
+ print(f" WARNING: {failed_count} files had collection errors (logs saved to {error_log_dir})")
+
+ return all_cases
+
+
+def save_error_logs(failed_files: List[Dict], error_log_dir: Path) -> None:
+ """
+ Save collection error logs to individual files and create a summary.
+
+ Args:
+ failed_files: List of dicts with file, error, cases info
+ error_log_dir: Directory to save error logs
+ """
+ print(f"Saving error logs for {len(failed_files)} failed files...")
+
+ # Save individual error log files
+ for failed in failed_files:
+ # Create safe filename from display name (replace / with _)
+ safe_name = failed['file'].replace('/', '_')
+ log_file = error_log_dir / f"{safe_name}.log"
+
+ # Write error log
+ with open(log_file, 'w', encoding='utf-8') as f:
+ f.write(f"File: {failed['file']}\n")
+ f.write(f"Cases collected: {failed['cases']}\n")
+ f.write(f"Test file path: {failed['test_file']}\n")
+ f.write("=" * 80 + "\n")
+ f.write("Collection Error:\n")
+ f.write("=" * 80 + "\n")
+ f.write(failed['error'])
+ f.write("\n")
+
+ # Save summary JSON
+ summary_file = error_log_dir / "collection_errors_summary.json"
+ summary_data = {
+ "total_failed": len(failed_files),
+ "failed_files": [
+ {
+ "file": f['file'],
+ "cases": f['cases'],
+ "test_file": f['test_file'],
+ "log_file": f"{f['file'].replace('/', '_')}.log",
+ }
+ for f in failed_files
+ ],
+ }
+ summary_file.write_text(json.dumps(summary_data, indent=2), encoding='utf-8')
+
+ print(f" Error logs saved to {error_log_dir}")
+ print(f" Summary: {summary_file}")
+
+
+def split_cases_into_shards(cases: List[Dict], num_shards: int) -> List[List[Dict]]:
+ """Split cases evenly into shards."""
+ total = len(cases)
+ base_size = total // num_shards
+ remainder = total % num_shards
+
+ shards = []
+ start = 0
+ for i in range(num_shards):
+ size = base_size + (1 if i < remainder else 0)
+ shards.append(cases[start:start + size])
+ start += size
+
+ return shards
+
+
+def save_shards(
+ cases: List[Dict],
+ num_shards: int,
+ test_type: str,
+ output_dir: Path,
+) -> Dict:
+ """Save shard JSONs and return summary."""
+ shards = split_cases_into_shards(cases, num_shards)
+
+ print(f"\nSaving {test_type} shards...")
+ for i, shard_cases in enumerate(shards, 1):
+ shard_file = output_dir / f"{test_type}_cases_shard_{i}.json"
+ shard_data = {
+ "shard": i,
+ "num_shards": num_shards,
+ "test_type": test_type,
+ "total_cases": len(shard_cases),
+ "cases": shard_cases,
+ }
+ shard_file.write_text(json.dumps(shard_data, indent=2), encoding="utf-8")
+ print(f" Shard {i}: {len(shard_cases)} cases -> {shard_file}")
+
+ return {
+ "test_type": test_type,
+ "num_shards": num_shards,
+ "total_cases": len(cases),
+ "shard_sizes": [len(s) for s in shards],
+ }
+
+
+def main():
+ args = parse_args()
+
+ test_dir = Path(args.test_dir).resolve()
+ output_dir = Path(args.output_dir).resolve()
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ # Error log directory for failed collections
+ error_log_dir = Path(args.error_log_dir).resolve() if args.error_log_dir else output_dir / "collection_errors"
+ error_log_dir.mkdir(parents=True, exist_ok=True)
+
+ summaries = []
+
+ # ========================================
+ # Step 1: Collect distributed test cases
+ # ========================================
+ print("=" * 80)
+ print("Collecting distributed test cases")
+ print("=" * 80)
+
+ dist_files, dist_meta = discover_test_files.discover_test_files(
+ test_dir=test_dir,
+ test_type="distributed",
+ case_paths_config=args.case_paths_config,
+ )
+ print(f"Found {len(dist_files)} distributed test files")
+
+ dist_cases = collect_all_cases(dist_files, test_dir, error_log_dir / "distributed", args.parallel)
+ print(f"Total distributed cases: {len(dist_cases)}")
+
+ dist_summary = save_shards(dist_cases, args.distributed_shards, "distributed", output_dir)
+ summaries.append(dist_summary)
+
+ # ========================================
+ # Step 2: Collect regular test cases
+ # ========================================
+ print("\n" + "=" * 80)
+ print("Collecting regular test cases")
+ print("=" * 80)
+
+ reg_files, reg_meta = discover_test_files.discover_test_files(
+ test_dir=test_dir,
+ test_type="regular",
+ case_paths_config=args.case_paths_config,
+ )
+ print(f"Found {len(reg_files)} regular test files")
+
+ reg_cases = collect_all_cases(reg_files, test_dir, error_log_dir / "regular", args.parallel)
+ print(f"Total regular cases: {len(reg_cases)}")
+
+ reg_summary = save_shards(reg_cases, args.regular_shards, "regular", output_dir)
+ summaries.append(reg_summary)
+
+ # ========================================
+ # Step 3: Save overall summary
+ # ========================================
+ # Calculate file counts (distributed + regular = total_files, no overlap)
+ dist_total = dist_meta.get("total_files", 0)
+ dist_selected = dist_meta.get("type_selected", 0)
+ reg_total = reg_meta.get("total_files", 0)
+ reg_selected = reg_meta.get("type_selected", 0)
+ # total_files is same for both (all test_*.py files), use one value
+ total_files = dist_total
+
+ overall_summary = {
+ "distributed": {
+ "cases_summary": dist_summary,
+ "discovery_metadata": dist_meta,
+ },
+ "regular": {
+ "cases_summary": reg_summary,
+ "discovery_metadata": reg_meta,
+ },
+ "total_cases": len(dist_cases) + len(reg_cases),
+ "total_files_scanned": total_files,
+ "distributed_files": dist_selected,
+ "regular_files": reg_selected,
+ }
+ summary_file = output_dir / "cases_collection_summary.json"
+ summary_file.write_text(json.dumps(overall_summary, indent=2), encoding="utf-8")
+ print(f"\nOverall summary saved to {summary_file}")
+
+ print("\n" + "=" * 80)
+ print("Collection Complete")
+ print("=" * 80)
+ print(f"Distributed: {len(dist_cases)} cases -> {args.distributed_shards} shards (serial execution)")
+ print(f"Regular: {len(reg_cases)} cases -> {args.regular_shards} shards (parallel execution)")
+ print(f"Total: {len(dist_cases) + len(reg_cases)} cases")
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Collect and shard test cases")
+ parser.add_argument("--test-dir", required=True, help="PyTorch test directory")
+ parser.add_argument("--case-paths-config", help="case_paths_ci.yml path")
+ parser.add_argument("--distributed-shards", type=int, default=2, help="Distributed test shards")
+ parser.add_argument("--regular-shards", type=int, default=5, help="Regular test shards")
+ parser.add_argument("--output-dir", required=True, help="Output directory for shard JSONs")
+ parser.add_argument("--error-log-dir", help="Output directory for collection error logs (default: output-dir/collection_errors)")
+ parser.add_argument("--parallel", type=int, default=16, help="Parallel collection workers")
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/.github/scripts/discover_test_files.py b/.github/scripts/discover_test_files.py
new file mode 100644
index 0000000000..0b6e6167f3
--- /dev/null
+++ b/.github/scripts/discover_test_files.py
@@ -0,0 +1,341 @@
+#!/usr/bin/env python3
+"""
+Discover test files for PyTorch NPU testing.
+
+This script integrates 3 steps:
+ Step 1: Test file discovery (scan all test_*.py)
+ Step 2: Shard type filtering (distributed/regular)
+ Step 3: Whitelist/blacklist filtering (case_paths_ci.yml)
+
+Output: Sorted list of test file paths (with 'test/' prefix)
+
+Usage:
+ python discover_test_files.py \
+ --test-dir /path/to/pytorch/test \
+ --test-type distributed \
+ --case-paths-config /path/to/case_paths_ci.yml \
+ --output /path/to/output_file.txt
+
+ # Or output to stdout:
+ python discover_test_files.py \
+ --test-dir /path/to/pytorch/test \
+ --test-type regular \
+ --case-paths-config /path/to/case_paths_ci.yml
+"""
+
+import argparse
+import json
+import sys
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+try:
+ import yaml # type: ignore
+except ImportError:
+ yaml = None # type: ignore
+
+
+# ==============================================================================
+# Path Normalization Functions
+# ==============================================================================
+
+
+def normalize_path(value: str) -> str:
+ """Normalize path: convert backslashes, remove ./ prefix."""
+ normalized = value.replace("\\", "/").strip()
+ while normalized.startswith("./"):
+ normalized = normalized[2:]
+ return normalized.strip("/")
+
+
+def normalize_rule_path(rule: str) -> str:
+ """Normalize rule path: ensure it has 'test/' prefix."""
+ normalized = normalize_path(rule)
+ if not normalized:
+ return ""
+ if normalized == "test" or normalized.startswith("test/"):
+ return normalized.rstrip("/")
+ return f"test/{normalized}".rstrip("/")
+
+
+# ==============================================================================
+# YAML Parsing Functions
+# ==============================================================================
+
+
+def parse_simple_yaml_lists(raw_text: str) -> Dict[str, List[str]]:
+ """Parse YAML file for whitelist/blacklist without yaml library."""
+ parsed = {"whitelist": [], "blacklist": []}
+ current_key = None
+
+ for raw_line in raw_text.splitlines():
+ without_comment = raw_line.split("#", 1)[0].rstrip()
+ if not without_comment.strip():
+ continue
+
+ stripped = without_comment.lstrip()
+ if not raw_line.startswith((" ", "\t")) and stripped.endswith(":"):
+ key = stripped[:-1].strip()
+ current_key = key if key in parsed else None
+ continue
+
+ if current_key and stripped.startswith("- "):
+ value = stripped[2:].strip().strip("\"'")
+ if value:
+ parsed[current_key].append(value)
+
+ return parsed
+
+
+def coerce_rule_list(value, key: str) -> List[str]:
+ """Validate and normalize rule list."""
+ if value is None:
+ return []
+ if not isinstance(value, list):
+ raise ValueError(f"Expected '{key}' to be a list, got {type(value).__name__}")
+
+ normalized_values = []
+ for item in value:
+ if not isinstance(item, str):
+ raise ValueError(f"Expected every '{key}' entry to be a string, got {type(item).__name__}")
+ normalized = normalize_rule_path(item)
+ if normalized:
+ normalized_values.append(normalized)
+ return normalized_values
+
+
+def load_case_path_rules(config_file: Optional[str]) -> Tuple[str, List[str], List[str]]:
+ """Load whitelist/blacklist rules from case_paths_ci.yml."""
+ if not config_file:
+ return "", [], []
+
+ config_path = Path(config_file).resolve()
+ if not config_path.exists():
+ raise FileNotFoundError(f"case_paths_ci config not found: {config_path}")
+
+ raw_text = config_path.read_text(encoding="utf-8")
+
+ if yaml is not None:
+ payload = yaml.safe_load(raw_text) or {}
+ else:
+ payload = parse_simple_yaml_lists(raw_text)
+
+ if not isinstance(payload, dict):
+ raise ValueError(f"Expected a YAML object in {config_path}, got {type(payload).__name__}")
+
+ whitelist = coerce_rule_list(payload.get("whitelist"), "whitelist")
+ blacklist = coerce_rule_list(payload.get("blacklist"), "blacklist")
+ return str(config_path), whitelist, blacklist
+
+
+# ==============================================================================
+# Test File Discovery (Step 1)
+# ==============================================================================
+
+
+def discover_raw_test_files(test_dir: Path) -> List[str]:
+ """Scan all test_*.py files in test directory."""
+ files = []
+ for test_file in test_dir.rglob("test_*.py"):
+ rel_path = test_file.relative_to(test_dir).as_posix()
+ files.append(f"test/{rel_path}")
+ return sorted(files)
+
+
+# ==============================================================================
+# Type Filtering (Step 2)
+# ==============================================================================
+
+
+def filter_tests_by_type(test_files: List[str], test_type: str) -> Tuple[List[str], List[str]]:
+ """Filter test files by test type (distributed/regular)."""
+ if test_type == "distributed":
+ selected = [f for f in test_files if f.startswith("test/distributed/")]
+ excluded = [f for f in test_files if not f.startswith("test/distributed/")]
+ else:
+ selected = [f for f in test_files if not f.startswith("test/distributed/")]
+ excluded = [f for f in test_files if f.startswith("test/distributed/")]
+ return selected, excluded
+
+
+# ==============================================================================
+# Path Rules Filtering (Step 3)
+# ==============================================================================
+
+
+def path_matches_rule(test_path: str, rule: str) -> bool:
+ """Check if test path matches a rule (supports glob patterns)."""
+ import fnmatch
+
+ normalized_path = normalize_path(test_path)
+ normalized_rule = normalize_rule_path(rule)
+ if not normalized_rule:
+ return False
+
+ if any(char in normalized_rule for char in "*?[]"):
+ return fnmatch.fnmatch(normalized_path, normalized_rule)
+
+ return normalized_path == normalized_rule or normalized_path.startswith(f"{normalized_rule}/")
+
+
+def apply_case_path_rules(
+ test_files: List[str], whitelist: List[str], blacklist: List[str]
+) -> Tuple[List[str], List[str]]:
+ """Apply whitelist and blacklist rules to filter test files."""
+ # Apply whitelist (if empty, select all)
+ if whitelist:
+ selected = [path for path in test_files if any(path_matches_rule(path, rule) for rule in whitelist)]
+ else:
+ selected = list(test_files)
+
+ # Apply blacklist
+ if blacklist:
+ selected = [path for path in selected if not any(path_matches_rule(path, rule) for rule in blacklist)]
+
+ selected_set = set(selected)
+ excluded = [path for path in test_files if path not in selected_set]
+ return selected, excluded
+
+
+# ==============================================================================
+# Main Discovery Function
+# ==============================================================================
+
+
+def discover_test_files(
+ test_dir: Path,
+ test_type: str,
+ case_paths_config: Optional[str],
+) -> Tuple[List[str], Dict]:
+ """
+ Execute all 3 steps to discover test files.
+
+ Returns:
+ Tuple of (selected_files, metadata_dict)
+ """
+ # Step 1: Discover all test files
+ all_test_files = discover_raw_test_files(test_dir)
+ total_count = len(all_test_files)
+
+ # Step 2: Filter by test type
+ type_selected, type_excluded = filter_tests_by_type(all_test_files, test_type)
+
+ # Step 3: Apply whitelist/blacklist rules
+ config_path, whitelist, blacklist = load_case_path_rules(case_paths_config)
+ rules_selected, rules_excluded = apply_case_path_rules(type_selected, whitelist, blacklist)
+
+ # Metadata for reporting
+ metadata = {
+ "test_dir": str(test_dir),
+ "test_type": test_type,
+ "total_files": total_count,
+ "type_selected": len(type_selected),
+ "type_excluded": len(type_excluded),
+ "whitelist_entries": len(whitelist),
+ "blacklist_entries": len(blacklist),
+ "rules_selected": len(rules_selected),
+ "rules_excluded": len(rules_excluded),
+ "case_paths_config": config_path,
+ }
+
+ return rules_selected, metadata
+
+
+# ==============================================================================
+# CLI Interface
+# ==============================================================================
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(
+ description="Discover test files for PyTorch NPU testing",
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ epilog=__doc__,
+ )
+ parser.add_argument(
+ "--test-dir",
+ type=str,
+ required=True,
+ help="Path to the PyTorch test directory",
+ )
+ parser.add_argument(
+ "--test-type",
+ type=str,
+ choices=["distributed", "regular"],
+ default="regular",
+ help="Test type: 'distributed' for distributed tests, 'regular' for other tests",
+ )
+ parser.add_argument(
+ "--case-paths-config",
+ type=str,
+ help="Path to case_paths_ci.yml for file-level whitelist/blacklist control",
+ )
+ parser.add_argument(
+ "--output",
+ type=str,
+ help="Output file path for test file list (default: stdout)",
+ )
+ parser.add_argument(
+ "--metadata-output",
+ type=str,
+ help="Output file path for metadata JSON (optional)",
+ )
+ parser.add_argument(
+ "--verbose",
+ "-v",
+ action="store_true",
+ help="Print verbose output including metadata",
+ )
+ return parser.parse_args()
+
+
+def main():
+ args = parse_args()
+
+ test_dir = Path(args.test_dir).resolve()
+ if not test_dir.is_dir():
+ raise FileNotFoundError(f"Test directory not found: {test_dir}")
+
+ # Execute discovery
+ selected_files, metadata = discover_test_files(
+ test_dir=test_dir,
+ test_type=args.test_type,
+ case_paths_config=args.case_paths_config,
+ )
+
+ # Output test file list
+ output_content = "\n".join(selected_files) + ("\n" if selected_files else "")
+
+ if args.output:
+ output_path = Path(args.output).resolve()
+ output_path.parent.mkdir(parents=True, exist_ok=True)
+ output_path.write_text(output_content, encoding="utf-8")
+ if args.verbose:
+ print(f"Written {len(selected_files)} test files to: {output_path}")
+ else:
+ sys.stdout.write(output_content)
+
+ # Output metadata
+ if args.metadata_output:
+ metadata_path = Path(args.metadata_output).resolve()
+ metadata_path.parent.mkdir(parents=True, exist_ok=True)
+ metadata_path.write_text(json.dumps(metadata, indent=2), encoding="utf-8")
+ if args.verbose:
+ print(f"Written metadata to: {metadata_path}")
+
+ # Verbose summary
+ if args.verbose:
+ print(f"\nDiscovery Summary:")
+ print(f" Test directory: {test_dir}")
+ print(f" Test type: {args.test_type}")
+ print(f" Total files scanned: {metadata['total_files']}")
+ print(f" After type filter: {metadata['type_selected']} selected, {metadata['type_excluded']} excluded")
+ if args.case_paths_config:
+ print(f" Whitelist entries: {metadata['whitelist_entries']}")
+ print(f" Blacklist entries: {metadata['blacklist_entries']}")
+ print(f" After rules filter: {metadata['rules_selected']} selected, {metadata['rules_excluded']} excluded")
+ print(f" Final selected files: {len(selected_files)}")
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/.github/scripts/generate_npu_full_test_report.py b/.github/scripts/generate_npu_full_test_report.py
new file mode 100644
index 0000000000..55822ec30a
--- /dev/null
+++ b/.github/scripts/generate_npu_full_test_report.py
@@ -0,0 +1,674 @@
+#!/usr/bin/env python3
+"""
+Generate a consolidated markdown/json report for the NPU full test workflow.
+"""
+
+import argparse
+import json
+import re
+from collections import Counter
+from pathlib import Path
+from typing import Dict, List, Tuple, Optional
+
+# Import aggregation function from parse_test_results.py
+import parse_test_results
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Generate consolidated NPU full test report")
+ parser.add_argument("--reports-root", required=True, help="Root directory containing shard report files")
+ parser.add_argument("--output-markdown", required=True, help="Path to write markdown report")
+ parser.add_argument("--output-json", required=True, help="Path to write JSON report")
+ parser.add_argument("--pytorch-version", required=True, help="PyTorch version string")
+ parser.add_argument("--torch-npu-whl", required=True, help="torch_npu wheel URL")
+ parser.add_argument("--patch-count", default="N/A", help="Applied patch count")
+ parser.add_argument("--shard-matrix-json", required=True, help="JSON array of requested shard ids")
+ parser.add_argument("--docker-image", default="N/A", help="Docker image used for test execution")
+ parser.add_argument("--runner", default="N/A", help="Runner machine type")
+ parser.add_argument("--special-reports-root", help="Root directory containing special test report files")
+ parser.add_argument("--expected-special-tests-json", default="[]", help="JSON array of expected special test names")
+ parser.add_argument("--cases-summary", help="Path to cases_collection_summary.json for file discovery stats")
+ return parser.parse_args()
+
+
+def load_json_file(path: Path) -> Dict:
+ """Load JSON file with error handling for malformed/truncated files."""
+ try:
+ with open(path, "r", encoding="utf-8") as f:
+ return json.load(f)
+ except json.JSONDecodeError as e:
+ print(f"Warning: Invalid JSON in {path}: {e}")
+ # Read file content to diagnose truncation
+ try:
+ with open(path, "r", encoding="utf-8") as f:
+ content = f.read()
+ print(f" File size: {len(content)} bytes")
+ # Show context around error position
+ error_pos = e.pos if hasattr(e, 'pos') else 0
+ start = max(0, error_pos - 100)
+ end = min(len(content), error_pos + 100)
+ print(f" Context around error (pos {error_pos}): ...{content[start:end]}...")
+ except Exception:
+ pass
+ return {}
+ except Exception as e:
+ print(f"Warning: Failed to load {path}: {e}")
+ return {}
+
+
+def parse_requested_shards(raw: str) -> List[Tuple[str, int]]:
+ """
+ Parse shard identifiers from JSON array.
+
+ Supports formats:
+ - Integers: [1, 2, 3] -> [("regular", 1), ("regular", 2), ("regular", 3)]
+ - Type-prefixed: ["dist-1", "reg-2"] -> [("distributed", 1), ("regular", 2)]
+
+ Returns list of (shard_type, shard_number) tuples.
+ """
+ try:
+ value = json.loads(raw)
+ except json.JSONDecodeError:
+ return []
+
+ if not isinstance(value, list):
+ return []
+
+ result = []
+ for item in value:
+ try:
+ if isinstance(item, str):
+ # Parse type-prefixed format: "dist-1", "reg-2"
+ if "-" in item:
+ type_prefix, num_str = item.split("-", 1)
+ if type_prefix == "dist":
+ shard_type = "distributed"
+ elif type_prefix == "reg":
+ shard_type = "regular"
+ else:
+ # Unknown prefix, skip
+ continue
+ shard_num = int(num_str)
+ result.append((shard_type, shard_num))
+ else:
+ # String without prefix, try to parse as int
+ shard_num = int(item)
+ result.append(("regular", shard_num))
+ elif isinstance(item, int):
+ # Plain integer, assume "regular" type
+ result.append(("regular", item))
+ except (TypeError, ValueError):
+ continue
+ # Sort by type then number
+ return sorted(set(result), key=lambda x: (x[0], x[1]))
+
+
+def parse_expected_special_tests(raw: str) -> List[str]:
+ try:
+ value = json.loads(raw)
+ except json.JSONDecodeError:
+ return []
+
+ if not isinstance(value, list):
+ return []
+
+ result = []
+ for item in value:
+ if isinstance(item, str) and item:
+ result.append(item)
+ return sorted(set(result))
+
+
+def load_text_lines(path: Path) -> List[str]:
+ with open(path, "r", encoding="utf-8") as f:
+ return [line.strip() for line in f if line.strip()]
+
+
+def get_int_value(payload: Dict, *keys: str) -> int:
+ for key in keys:
+ if key not in payload:
+ continue
+ try:
+ return int(payload.get(key, 0))
+ except (TypeError, ValueError):
+ continue
+ return 0
+
+
+def discover_shard_files(
+ reports_root: Path,
+) -> Tuple[
+ Dict[Tuple[str, int], Path], # stats_files
+ Dict[Tuple[str, int], Path], # info_files
+ Dict[Tuple[str, int], Path], # cases_files
+]:
+ """
+ Discover all shard report files in the reports directory.
+
+ Returns dicts keyed by (shard_type, shard_number) tuples.
+
+ File name format: shard_{type}-{number}_{suffix}
+ Examples:
+ - shard_dist-1_stats.json
+ - shard_reg-1_info.json
+ - shard_dist-1_cases.json (case-level results)
+ """
+ stats_files = {}
+ info_files = {}
+ cases_files = {}
+
+ def parse_shard_filename(path: Path, suffix_pattern: str) -> Tuple[str, int]:
+ """
+ Parse shard type and number from filename.
+
+ Filename format: shard_{type}-{number}_{suffix}
+ e.g., shard_dist-1_stats.json -> ("distributed", 1)
+ """
+ stem = path.stem # filename without extension
+ # Match pattern: shard_{type}-{number}_{suffix}
+ match = re.match(r"shard_(dist|reg)-(\d+)_" + suffix_pattern, stem)
+ if match:
+ type_prefix = match.group(1)
+ shard_num = int(match.group(2))
+ if type_prefix == "dist":
+ return ("distributed", shard_num)
+ elif type_prefix == "reg":
+ return ("regular", shard_num)
+ return None
+
+ for path in reports_root.rglob("shard_*_stats.json"):
+ key = parse_shard_filename(path, "stats")
+ if key:
+ stats_files[key] = path
+
+ for path in reports_root.rglob("shard_*_info.json"):
+ key = parse_shard_filename(path, "info")
+ if key:
+ info_files[key] = path
+
+ # Discover case-level results files
+ for path in reports_root.rglob("shard_*_cases.json"):
+ key = parse_shard_filename(path, "cases")
+ if key:
+ cases_files[key] = path
+
+ return stats_files, info_files, cases_files
+
+
+def build_file_to_shards_map(cases_shards_dir: Path) -> Dict[str, List[str]]:
+ """
+ Build a mapping from test file path to shard IDs.
+
+ Scans all shard JSON files in cases_shards_dir and extracts file->shard mapping.
+
+ Args:
+ cases_shards_dir: Directory containing shard JSON files like
+ distributed_cases_shard_1.json, regular_cases_shard_2.json
+
+ Returns:
+ Dict mapping file path (e.g., "test/test_ops.py") to list of shard IDs
+ (e.g., ["dist-1", "reg-2", "reg-3"])
+ """
+ file_to_shards = {}
+
+ if not cases_shards_dir or not cases_shards_dir.exists():
+ return file_to_shards
+
+ # Pattern: {test_type}_cases_shard_{num}.json
+ for shard_file in cases_shards_dir.glob("*_cases_shard_*.json"):
+ try:
+ data = load_json_file(shard_file)
+ test_type = data.get("test_type", "regular")
+ shard_num = data.get("shard", 0)
+
+ # Build shard ID: "dist-1" or "reg-2"
+ shard_prefix = "dist" if test_type == "distributed" else "reg"
+ shard_id = f"{shard_prefix}-{shard_num}"
+
+ # Extract file paths from cases
+ cases = data.get("cases", [])
+ for case in cases:
+ file_path = case.get("file", "")
+ if file_path:
+ # Normalize file path (remove leading "test/" if present for consistency)
+ normalized_file = file_path
+ if normalized_file.startswith("test/"):
+ normalized_file = normalized_file[5:]
+
+ if normalized_file not in file_to_shards:
+ file_to_shards[normalized_file] = []
+ if shard_id not in file_to_shards[normalized_file]:
+ file_to_shards[normalized_file].append(shard_id)
+ except Exception as e:
+ print(f"Warning: Failed to parse shard file {shard_file}: {e}")
+ continue
+
+ # Sort shard IDs for each file
+ for file_path in file_to_shards:
+ # Sort by type (dist first) then number
+ file_to_shards[file_path].sort(key=lambda x: (0 if x.startswith("dist") else 1, int(x.split("-")[1])))
+
+ return file_to_shards
+
+
+def get_shard_status(stats: Dict, present: bool) -> str:
+ if not present:
+ return "MISSING"
+ if stats.get("timed_out"):
+ return "TIMEOUT"
+ if stats.get("incomplete"):
+ return "INCOMPLETE"
+ if stats.get("errors", 0) > 0:
+ return "ERROR"
+ if stats.get("failed", 0) > 0:
+ return "FAILED"
+ if stats.get("total", 0) == 0:
+ return "NO TESTS"
+ return "PASSED"
+
+
+def get_overall_status(status_counts: Counter) -> str:
+ if status_counts["MISSING"] > 0:
+ return "FAILED"
+ if any(status_counts[key] > 0 for key in ("TIMEOUT", "INCOMPLETE", "ERROR", "FAILED")):
+ return "FAILED"
+ if status_counts["PASSED"] > 0:
+ return "PASSED"
+ return "NO TESTS"
+
+
+def format_duration(seconds: float) -> str:
+ seconds = float(seconds)
+ hours = int(seconds // 3600)
+ minutes = int((seconds % 3600) // 60)
+ secs = seconds % 60
+ if hours > 0:
+ return f"{hours}h {minutes}m {secs:.1f}s"
+ if minutes > 0:
+ return f"{minutes}m {secs:.1f}s"
+ return f"{secs:.1f}s"
+
+
+def sanitize_markdown_cell(value: str) -> str:
+ return value.replace("|", "\\|").replace("\n", "
")
+
+
+def render_table(headers: List[str], rows: List[List[str]]) -> List[str]:
+ lines = [
+ "| " + " | ".join(headers) + " |",
+ "| " + " | ".join(["---"] * len(headers)) + " |",
+ ]
+ for row in rows:
+ lines.append("| " + " | ".join(row) + " |")
+ return lines
+
+
+def discover_special_test_files(reports_root: Path | None) -> Dict[str, Path]:
+ if reports_root is None or not reports_root.exists():
+ return {}
+
+ special_files = {}
+ for path in reports_root.rglob("special_test_*.json"):
+ try:
+ payload = load_json_file(path)
+ except Exception:
+ continue
+ name = payload.get("name")
+ if isinstance(name, str) and name:
+ special_files[name] = path
+ return special_files
+
+
+def main():
+ args = parse_args()
+ reports_root = Path(args.reports_root)
+ output_markdown = Path(args.output_markdown)
+ output_json = Path(args.output_json)
+ requested_shards = parse_requested_shards(args.shard_matrix_json)
+ expected_special_tests = parse_expected_special_tests(args.expected_special_tests_json)
+ special_reports_root = Path(args.special_reports_root) if args.special_reports_root else None
+
+ # Load cases collection summary for file discovery stats
+ cases_summary_data = None
+ file_discovery_stats = {
+ "total_files_scanned": 0,
+ "distributed_files": 0,
+ "regular_files": 0,
+ }
+ if args.cases_summary:
+ cases_summary_path = Path(args.cases_summary)
+ if cases_summary_path.exists():
+ cases_summary_data = load_json_file(cases_summary_path)
+ # Extract file discovery stats (正交: total = distributed + regular)
+ if cases_summary_data:
+ file_discovery_stats["total_files_scanned"] = cases_summary_data.get("total_files_scanned", 0)
+ file_discovery_stats["distributed_files"] = cases_summary_data.get("distributed_files", 0)
+ file_discovery_stats["regular_files"] = cases_summary_data.get("regular_files", 0)
+
+ stats_files, info_files, cases_files = discover_shard_files(reports_root)
+ special_test_files = discover_special_test_files(special_reports_root)
+ shard_ids = requested_shards or sorted(set(stats_files) | set(info_files) | set(cases_files))
+
+ # Build file to shards mapping from cases-shards directory
+ cases_shards_dir = Path(args.cases_summary).parent if args.cases_summary else None
+ file_to_shards_map = build_file_to_shards_map(cases_shards_dir)
+
+ status_counts = Counter()
+ totals = {
+ "total": 0,
+ "passed": 0,
+ "failed": 0,
+ "errors": 0,
+ "skipped": 0,
+ "timeout": 0,
+ "duration": 0.0,
+ }
+ shard_rows = []
+ selection_modes = set()
+ cases_results = {} # Store case-level results for each shard
+
+ for shard_type, shard_num in shard_ids:
+ shard_key = (shard_type, shard_num)
+ stats_path = stats_files.get(shard_key)
+ info_path = info_files.get(shard_key)
+ cases_path = cases_files.get(shard_key)
+ stats = load_json_file(stats_path) if stats_path else {}
+ info = load_json_file(info_path) if info_path else {}
+
+ # Load case-level results if available
+ cases_data = load_json_file(cases_path) if cases_path else {}
+ if cases_data:
+ cases_results[shard_key] = cases_data
+ # Override stats with case-level data
+ stats["total"] = cases_data.get("total_cases", 0)
+ stats["passed"] = cases_data.get("passed", 0)
+ stats["failed"] = cases_data.get("failed", 0)
+ stats["errors"] = cases_data.get("errors", 0)
+ stats["skipped"] = cases_data.get("skipped", 0)
+ stats["timeout"] = cases_data.get("timeout", 0)
+ stats["duration"] = cases_data.get("duration", 0.0)
+ # Update totals (正交累加: total = passed + failed + errors + skipped + timeout)
+ totals["total"] += cases_data.get("total_cases", 0)
+ totals["passed"] += cases_data.get("passed", 0)
+ totals["failed"] += cases_data.get("failed", 0)
+ totals["errors"] += cases_data.get("errors", 0)
+ totals["skipped"] += cases_data.get("skipped", 0)
+ totals["timeout"] += cases_data.get("timeout", 0)
+ totals["duration"] += cases_data.get("duration", 0.0)
+
+ present = bool(stats_path or cases_path)
+
+ if info.get("selection_mode"):
+ selection_modes.add(str(info.get("selection_mode")))
+
+ status = get_shard_status(stats, present)
+ status_counts[status] += 1
+
+ # Convert shard_type to display prefix ("distributed" -> "dist", "regular" -> "reg")
+ shard_prefix = "dist" if shard_type == "distributed" else "reg"
+ shard_rows.append(
+ {
+ "shard": f"{shard_prefix}-{shard_num}", # "dist-1" or "reg-1"
+ "shard_type": shard_type,
+ "shard_num": shard_num,
+ "status": status,
+ "total": int(stats.get("total", 0)),
+ "passed": int(stats.get("passed", 0)),
+ "failed": int(stats.get("failed", 0)),
+ "skipped": int(stats.get("skipped", 0)),
+ "errors": int(stats.get("errors", 0)),
+ "timeout": int(stats.get("timeout", 0)),
+ "duration": float(stats.get("duration", 0.0)),
+ }
+ )
+
+ overall_status = get_overall_status(status_counts)
+ whl_name = Path(args.torch_npu_whl).name
+ received_reports = len(stats_files)
+ expected_reports = len(shard_ids)
+ selection_mode_display = ", ".join(sorted(selection_modes)) if selection_modes else "-"
+
+ # Show all shards in the detail table
+ sorted_shards = sorted(shard_rows, key=lambda row: (row["shard_type"], row["shard_num"]))
+ special_test_names = expected_special_tests or sorted(special_test_files)
+ special_test_rows = []
+ special_status_counts = Counter()
+
+ for test_name in special_test_names:
+ payload = load_json_file(special_test_files[test_name]) if test_name in special_test_files else {}
+ status = str(payload.get("status", "MISSING"))
+ special_status_counts[status] += 1
+ special_test_rows.append(
+ {
+ "name": test_name,
+ "group": str(payload.get("group", "-")),
+ "status": status,
+ "duration": float(payload.get("duration", 0.0)),
+ "returncode": payload.get("returncode", "-"),
+ "note": str(payload.get("note", "") or "-"),
+ }
+ )
+
+ if any(row["status"] != "PASSED" for row in special_test_rows):
+ overall_status = "FAILED"
+
+ include_special_tests = bool(special_test_names or special_test_rows)
+
+ # Build Selection row content based on available data
+ if cases_summary_data:
+ # Use file discovery stats from cases_collection_summary.json
+ total_scanned = file_discovery_stats["total_files_scanned"]
+ dist_files = file_discovery_stats["distributed_files"]
+ reg_files = file_discovery_stats["regular_files"]
+ selection_content = (
+ f"扫描发现 {total_scanned} 个测试文件 "
+ f"(distributed: {dist_files}, regular: {reg_files})"
+ )
+ else:
+ # Fallback to original selection mode display
+ selection_content = selection_mode_display
+
+ # Extract planned cases count from cases_collection_summary.json
+ planned_total_cases = 0
+ planned_dist_cases = 0
+ planned_reg_cases = 0
+ if cases_summary_data:
+ planned_total_cases = cases_summary_data.get("total_cases", 0)
+ planned_dist_cases = cases_summary_data.get("distributed", {}).get("cases_summary", {}).get("total_cases", 0)
+ planned_reg_cases = cases_summary_data.get("regular", {}).get("cases_summary", {}).get("total_cases", 0)
+
+ overview_rows = [
+ ["Overall result", overall_status],
+ ["PyTorch", f"`v{args.pytorch_version}`"],
+ ["torch_npu", f"`{whl_name}`"],
+ ["Patches applied", str(args.patch_count)],
+ ["Docker image", f"`{args.docker_image}`"],
+ ["Runner", f"`{args.runner}`"],
+ ["Shards", f"{received_reports} / {expected_reports} reported"],
+ ["Selection", selection_content],
+ [
+ "实际执行用例",
+ (
+ f"{totals['total']} total; {totals['passed']} passed; {totals['failed']} failed; "
+ f"{totals['errors']} errors; {totals['skipped']} skipped; "
+ f"{totals['timeout']} timeout"
+ ),
+ ],
+ ]
+ # Add planned cases count row if available
+ if planned_total_cases > 0:
+ overview_rows.append([
+ "规划用例总数",
+ f"{planned_total_cases} (distributed: {planned_dist_cases}, regular: {planned_reg_cases})",
+ ])
+ overview_rows.append(["Duration", format_duration(totals["duration"])])
+ if include_special_tests:
+ overview_rows.append(["Special tests expected", str(len(special_test_names))])
+
+ markdown_lines = [
+ "# PyTorch NPU Full Test Summary",
+ "",
+ "## Overview",
+ ]
+ markdown_lines.extend(
+ render_table(
+ ["Item", "Value"],
+ overview_rows,
+ )
+ )
+
+ # Add case-level statistics table if available
+ if cases_results:
+ markdown_lines.extend(["", "## 用例级执行统计"])
+ markdown_lines.extend(
+ render_table(
+ ["Shard", "总用例", "通过", "失败", "错误", "跳过", "超时", "Duration"],
+ [
+ [
+ f"{row['shard']}",
+ str(row["total"]),
+ str(row["passed"]),
+ str(row["failed"]),
+ str(row["errors"]),
+ str(row.get("skipped", 0)),
+ str(row.get("timeout", 0)),
+ format_duration(row["duration"]),
+ ]
+ for row in sorted_shards
+ if (row["shard_type"], row["shard_num"]) in cases_results
+ ],
+ )
+ )
+
+ # Add file-level statistics table
+ file_stats = parse_test_results.aggregate_all_cases_by_file(cases_results)
+
+ if file_stats:
+ # Sort files by total cases descending
+ sorted_files = sorted(
+ file_stats.values(),
+ key=lambda x: (-x["total"], x["file"])
+ )
+
+ markdown_lines.extend(["", "## 测试文件结果汇总"])
+
+ file_rows = []
+ for fs in sorted_files: # Show all files
+ failed_total = fs["failed"] + fs["errors"] + fs["timeout"]
+ fail_rate = f"{(failed_total / fs['total'] * 100):.1f}%" if fs["total"] > 0 else "0%"
+ # Get shard info for this file
+ file_path = fs["file"]
+ # Normalize file path for lookup (remove leading "test/")
+ lookup_path = file_path
+ if lookup_path.startswith("test/"):
+ lookup_path = lookup_path[5:]
+ shards_for_file = file_to_shards_map.get(lookup_path, [])
+ shard_info = ", ".join(shards_for_file) if shards_for_file else "-"
+ file_rows.append([
+ sanitize_markdown_cell(fs["file"]),
+ shard_info,
+ str(fs["total"]),
+ str(fs["passed"]),
+ str(fs["failed"]),
+ str(fs["errors"]),
+ str(fs["skipped"]),
+ str(fs["timeout"]),
+ fail_rate,
+ ])
+
+ markdown_lines.extend(
+ render_table(
+ ["测试文件", "分片", "总用例", "通过", "失败", "错误", "跳过", "超时", "失败率"],
+ file_rows,
+ )
+ )
+
+ if include_special_tests:
+ markdown_lines.extend(["", "## Special Test Results"])
+ markdown_lines.extend(
+ render_table(
+ ["Test", "Group", "Status", "Duration", "Return Code", "Note"],
+ [
+ [
+ row["name"],
+ row["group"],
+ row["status"],
+ format_duration(row["duration"]),
+ str(row["returncode"]),
+ sanitize_markdown_cell(row["note"]),
+ ]
+ for row in special_test_rows
+ ] or [["-", "-", "-", "0.0s", "-", "-"]],
+ )
+ )
+
+ report_json = {
+ "overall_status": overall_status,
+ "requested_shards": shard_ids,
+ "reports_collected": received_reports,
+ "patch_count": args.patch_count,
+ "pytorch_version": args.pytorch_version,
+ "torch_npu_whl": whl_name,
+ "docker_image": args.docker_image,
+ "runner": args.runner,
+ "status_counts": dict(status_counts),
+ "totals": totals,
+ "file_discovery_stats": file_discovery_stats,
+ "planned_cases": {
+ "total": planned_total_cases,
+ "distributed": planned_dist_cases,
+ "regular": planned_reg_cases,
+ },
+ "shards": shard_rows,
+ }
+
+ # Add full cases summary if available
+ if cases_summary_data:
+ report_json["cases_collection_summary"] = cases_summary_data
+
+ # Add case-level results if available
+ if cases_results:
+ report_json["cases_results"] = {
+ "shards": {
+ f"{shard_type}-{shard_num}": data
+ for (shard_type, shard_num), data in cases_results.items()
+ },
+ }
+
+ # Add file-level aggregation
+ file_stats = parse_test_results.aggregate_all_cases_by_file(cases_results)
+ # Add shard info to file stats
+ file_stats_with_shards = {}
+ for file_path, stats in file_stats.items():
+ # Normalize file path for lookup
+ lookup_path = file_path
+ if lookup_path.startswith("test/"):
+ lookup_path = lookup_path[5:]
+ shards_for_file = file_to_shards_map.get(lookup_path, [])
+ stats["shards"] = shards_for_file
+ file_stats_with_shards[file_path] = stats
+ report_json["file_level_stats"] = dict(sorted(
+ file_stats_with_shards.items(),
+ key=lambda x: (-x[1]["total"], x[0])
+ ))
+
+ # Add list of files with failures
+ failed_files = parse_test_results.get_files_with_failures(file_stats)
+ report_json["files_with_failures"] = failed_files
+
+ if include_special_tests:
+ report_json["special_tests"] = {
+ "expected": special_test_names,
+ "status_counts": dict(special_status_counts),
+ "results": special_test_rows,
+ }
+
+ output_markdown.write_text("\n".join(markdown_lines) + "\n", encoding="utf-8")
+ output_json.write_text(json.dumps(report_json, indent=2, ensure_ascii=False) + "\n", encoding="utf-8")
+
+ print(f"Generated markdown report: {output_markdown}")
+ print(f"Generated json report: {output_json}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/.github/scripts/parse_test_results.py b/.github/scripts/parse_test_results.py
new file mode 100644
index 0000000000..46768b56d2
--- /dev/null
+++ b/.github/scripts/parse_test_results.py
@@ -0,0 +1,796 @@
+#!/usr/bin/env python3
+"""
+Parse test results from JUnit XML files and pytest logs.
+
+This script provides utilities for:
+ - Parsing JUnit XML reports
+ - Aggregating test statistics
+ - Analyzing pytest log files
+ - Generating result reports (JSON, text)
+
+Usage as module:
+ from parse_test_results import (
+ parse_junit_xml,
+ aggregate_junit_stats,
+ analyze_pytest_log,
+ finalize_stats,
+ save_stats_file,
+ save_info_file,
+ print_stats_summary,
+ )
+
+Usage as CLI:
+ python parse_test_results.py \
+ --report-dir test-reports \
+ --shard 1 \
+ --shard-type distributed \
+ --output-dir parsed-results
+"""
+
+import argparse
+import json
+import os
+import re
+import signal
+import sys
+import xml.etree.ElementTree as ET
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+
+# ==============================================================================
+# JUnit XML Parsing
+# ==============================================================================
+
+
+def parse_junit_xml(xml_file: str) -> Dict:
+ """
+ Parse a single JUnit XML file and extract test statistics.
+
+ Args:
+ xml_file: Path to JUnit XML file
+
+ Returns:
+ Dict with keys: total, passed, failed, skipped, errors, duration
+ """
+ stats = {
+ "total": 0,
+ "passed": 0,
+ "failed": 0,
+ "skipped": 0,
+ "errors": 0,
+ "duration": 0.0,
+ }
+
+ if not os.path.exists(xml_file):
+ return stats
+
+ try:
+ tree = ET.parse(xml_file)
+ root = tree.getroot()
+ for testsuite in root.iter("testsuite"):
+ stats["total"] += int(testsuite.get("tests", 0))
+ stats["failed"] += int(testsuite.get("failures", 0))
+ stats["skipped"] += int(testsuite.get("skipped", 0))
+ stats["errors"] += int(testsuite.get("errors", 0))
+ stats["duration"] += float(testsuite.get("time", 0))
+ stats["passed"] = stats["total"] - stats["failed"] - stats["skipped"] - stats["errors"]
+ except Exception as exc:
+ print(f"Warning: Failed to parse XML report {xml_file}: {exc}")
+
+ return stats
+
+
+def aggregate_junit_stats(report_roots: List[Path], pattern: str = "*.xml") -> Dict:
+ """
+ Aggregate statistics from multiple JUnit XML files.
+
+ Args:
+ report_roots: List of directories to search for XML files
+ pattern: Glob pattern for XML files (default: "*.xml")
+
+ Returns:
+ Dict with aggregated stats: total, passed, failed, skipped, errors, duration
+ """
+ totals = {
+ "total": 0,
+ "passed": 0,
+ "failed": 0,
+ "skipped": 0,
+ "errors": 0,
+ "duration": 0.0,
+ }
+
+ seen_files = set()
+ for report_root in report_roots:
+ if not report_root.exists():
+ continue
+ for xml_file in report_root.rglob(pattern):
+ try:
+ resolved = str(xml_file.resolve())
+ except OSError:
+ resolved = str(xml_file)
+ if resolved in seen_files:
+ continue
+ seen_files.add(resolved)
+
+ stats = parse_junit_xml(str(xml_file))
+ for key in totals:
+ totals[key] += stats[key]
+
+ totals["xml_files_count"] = len(seen_files)
+ return totals
+
+
+def parse_shard_xml_files(report_dir: Path, shard: int, shard_type: str = "regular") -> Dict:
+ """
+ Parse all JUnit XML files for a specific shard.
+
+ Args:
+ report_dir: Directory containing test reports
+ shard: Shard number
+ shard_type: "distributed" or "regular"
+
+ Returns:
+ Dict with aggregated stats for the shard
+ """
+ prefix = get_shard_type_prefix(shard_type)
+ xml_pattern = f"shard_{prefix}-{shard}_pytest*.xml"
+
+ xml_files = sorted(report_dir.glob(xml_pattern))
+ if not xml_files:
+ return {
+ "total": 0,
+ "passed": 0,
+ "failed": 0,
+ "skipped": 0,
+ "errors": 0,
+ "duration": 0.0,
+ "junit_generated": False,
+ "junit_xml_files": 0,
+ }
+
+ stats = aggregate_junit_stats([report_dir], xml_pattern)
+ stats["junit_generated"] = True
+ stats["junit_xml_files"] = len(xml_files)
+ return stats
+
+
+# ==============================================================================
+# Log Analysis
+# ==============================================================================
+
+
+def analyze_pytest_log(log_file: Path, returncode: int) -> Dict:
+ """
+ Analyze pytest log file for failure patterns.
+
+ Args:
+ log_file: Path to pytest log file
+ returncode: pytest process return code
+
+ Returns:
+ Dict with: zero_item_test_files, startup_failures, import_failures, test_failures
+ """
+ metrics = {
+ "zero_item_test_files": 0,
+ "startup_failures": 0,
+ "import_failures": 0,
+ "test_failures": 0,
+ }
+
+ if not log_file.exists():
+ return metrics
+
+ try:
+ content = log_file.read_text(encoding="utf-8", errors="replace")
+ except OSError:
+ return metrics
+
+ # Detect "no tests collected" scenarios
+ if returncode == 5 or "collected 0 items" in content or "no tests ran" in content:
+ metrics["zero_item_test_files"] = 1
+
+ # Count import errors
+ metrics["import_failures"] = len(
+ re.findall(r"^ImportError while importing test module", content, flags=re.MULTILINE)
+ )
+
+ # Count collection errors (excluding import errors)
+ collection_errors = len(re.findall(r"^ERROR collecting ", content, flags=re.MULTILINE))
+ metrics["startup_failures"] = max(collection_errors - metrics["import_failures"], 0)
+
+ return metrics
+
+
+# ==============================================================================
+# Stats Processing
+# ==============================================================================
+
+
+def create_empty_stats() -> Dict:
+ """Create empty statistics dictionary."""
+ return {
+ "total": 0,
+ "passed": 0,
+ "failed": 0,
+ "skipped": 0,
+ "errors": 0,
+ "duration": 0.0,
+ "junit_generated": False,
+ "junit_xml_files": 0,
+ "zero_item_test_files": 0,
+ "startup_failures": 0,
+ "import_failures": 0,
+ "test_failures": 0,
+ }
+
+
+def create_shard_info(shard: int, num_shards: int, timestamp: str) -> Dict:
+ """Create shard info dictionary template."""
+ return {
+ "shard": shard,
+ "num_shards": num_shards,
+ "selection_mode": "pytest_direct",
+ "total_files": 0,
+ "selected_test_files": 0,
+ "shard_files": 0,
+ "path_filtered_out_files": 0,
+ "excluded_test_files": 0,
+ "disabled_count": 0,
+ "whitelist_entries": 0,
+ "blacklist_entries": 0,
+ "junit_generated": False,
+ "junit_xml_files": 0,
+ "zero_item_test_files": 0,
+ "startup_failures": 0,
+ "import_failures": 0,
+ "test_failures": 0,
+ "timestamp": timestamp,
+ }
+
+
+def finalize_stats(base_stats: Dict, returncode: int, duration: float, error_message: str = "") -> Dict:
+ """
+ Finalize statistics with returncode and duration.
+
+ Args:
+ base_stats: Base statistics dict
+ returncode: Process return code
+ duration: Execution duration in seconds
+ error_message: Optional error message
+
+ Returns:
+ Finalized stats dict
+ """
+ stats = dict(base_stats)
+ stats["duration"] = max(float(stats.get("duration", 0.0)), duration)
+
+ if returncode != 0:
+ stats["returncode"] = returncode
+
+ # Handle signal crashes (negative returncode)
+ if returncode < 0:
+ signal_num = abs(returncode)
+ try:
+ signal_name = signal.Signals(signal_num).name
+ except ValueError:
+ signal_name = f"SIG{signal_num}"
+ stats["crashed"] = True
+ stats["crash_signal"] = signal_name
+
+ # Mark incomplete if no tests
+ if stats.get("total", 0) == 0:
+ stats["errors"] = max(stats.get("errors", 0), 1)
+ stats["incomplete"] = True
+
+ if error_message:
+ stats["error_message"] = error_message
+ else:
+ stats["returncode"] = 0
+
+ return stats
+
+
+def get_shard_status(stats: Dict, has_xml: bool) -> str:
+ """
+ Determine shard status from stats.
+
+ Args:
+ stats: Statistics dict
+ has_xml: Whether XML files were generated
+
+ Returns:
+ Status string: MISSING, CRASHED, TIMEOUT, ERROR, FAILED, NO_TESTS, PASSED
+ """
+ if not has_xml:
+ return "MISSING"
+
+ if stats.get("crashed"):
+ return "CRASHED"
+
+ if stats.get("timed_out"):
+ return "TIMEOUT"
+
+ if stats.get("incomplete"):
+ return "INCOMPLETE"
+
+ if stats.get("errors", 0) > 0:
+ return "ERROR"
+
+ if stats.get("failed", 0) > 0:
+ return "FAILED"
+
+ if stats.get("total", 0) == 0:
+ return "NO_TESTS"
+
+ return "PASSED"
+
+
+# ==============================================================================
+# Utility Functions
+# ==============================================================================
+
+
+def get_shard_type_prefix(shard_type: str) -> str:
+ """Convert shard type to short prefix for file naming."""
+ return "dist" if shard_type == "distributed" else "reg"
+
+
+def get_shard_log_file(report_dir: Path, shard: int, shard_type: str = "regular") -> Path:
+ """Get path for shard log file."""
+ prefix = get_shard_type_prefix(shard_type)
+ return report_dir / f"test_shard_{prefix}-{shard}.log"
+
+
+def load_disabled_testcases_count(json_file: str) -> int:
+ """Count entries in disabled_testcases.json."""
+ if not json_file or not os.path.exists(json_file):
+ return 0
+
+ with open(json_file, encoding="utf-8") as f:
+ data = json.load(f)
+
+ if isinstance(data, (dict, list)):
+ return len(data)
+ return 0
+
+
+# ==============================================================================
+# File Save Functions
+# ==============================================================================
+
+
+def save_stats_file(report_dir: str, shard: int, stats: Dict, shard_type: str = "regular") -> str:
+ """Save statistics to JSON file."""
+ os.makedirs(report_dir, exist_ok=True)
+ prefix = get_shard_type_prefix(shard_type)
+ stats_file = os.path.join(report_dir, f"shard_{prefix}-{shard}_stats.json")
+ with open(stats_file, "w", encoding="utf-8") as f:
+ json.dump(stats, f, indent=2)
+ return stats_file
+
+
+def save_info_file(report_dir: str, shard: int, info: Dict, shard_type: str = "regular") -> str:
+ """Save info to JSON file."""
+ os.makedirs(report_dir, exist_ok=True)
+ prefix = get_shard_type_prefix(shard_type)
+ info_file = os.path.join(report_dir, f"shard_{prefix}-{shard}_info.json")
+ with open(info_file, "w", encoding="utf-8") as f:
+ json.dump(info, f, indent=2)
+ return info_file
+
+
+def save_test_plan_file(report_dir: str, shard: int, planned_tests: List[str], shard_type: str = "regular") -> str:
+ """Save planned test files list."""
+ os.makedirs(report_dir, exist_ok=True)
+ prefix = get_shard_type_prefix(shard_type)
+ plan_file = os.path.join(report_dir, f"shard_{prefix}-{shard}_planned_test_files.txt")
+ with open(plan_file, "w", encoding="utf-8") as f:
+ for target in planned_tests:
+ f.write(f"{target}\n")
+ return plan_file
+
+
+def save_excluded_test_files_file(report_dir: str, shard: int, excluded_files: List[str], shard_type: str = "regular") -> str:
+ """Save excluded test files list."""
+ os.makedirs(report_dir, exist_ok=True)
+ prefix = get_shard_type_prefix(shard_type)
+ excluded_file = os.path.join(report_dir, f"shard_{prefix}-{shard}_excluded_test_files.txt")
+ with open(excluded_file, "w", encoding="utf-8") as f:
+ for target in excluded_files:
+ f.write(f"{target}\n")
+ return excluded_file
+
+
+def save_missing_files_file(report_dir: str, shard: int, missing_files: List[str], shard_type: str = "regular") -> str:
+ """Save missing files list (crashed files without XML)."""
+ os.makedirs(report_dir, exist_ok=True)
+ prefix = get_shard_type_prefix(shard_type)
+ missing_file = os.path.join(report_dir, f"shard_{prefix}-{shard}_missing_files.txt")
+ with open(missing_file, "w", encoding="utf-8") as f:
+ for file_path in missing_files:
+ f.write(f"{file_path}\n")
+ return missing_file
+
+
+def save_cases_file(report_dir: str, shard: int, cases_data: Dict, shard_type: str = "regular") -> str:
+ """Save case-level results to JSON file."""
+ os.makedirs(report_dir, exist_ok=True)
+ prefix = get_shard_type_prefix(shard_type)
+ cases_file = os.path.join(report_dir, f"shard_{prefix}-{shard}_cases.json")
+ with open(cases_file, "w", encoding="utf-8") as f:
+ json.dump(cases_data, f, indent=2, ensure_ascii=False)
+ return cases_file
+
+
+def load_cases_file(report_dir: Path, shard: int, shard_type: str = "regular") -> Dict:
+ """Load case-level results from JSON file."""
+ prefix = get_shard_type_prefix(shard_type)
+ cases_file = report_dir / f"shard_{prefix}-{shard}_cases.json"
+ if not cases_file.exists():
+ return {}
+ try:
+ with open(cases_file, encoding="utf-8") as f:
+ return json.load(f)
+ except Exception as e:
+ print(f"Warning: Failed to load cases file {cases_file}: {e}")
+ return {}
+
+
+# ==============================================================================
+# Case Aggregation by File
+# ==============================================================================
+
+
+def aggregate_cases_by_file(cases_list: List[Dict]) -> Dict[str, Dict]:
+ """
+ Aggregate case results by test file.
+
+ This function groups test cases by their source file and computes
+ statistics (passed, failed, errors, etc.) per file. It also collects
+ detailed failure information for reporting.
+
+ Args:
+ cases_list: List of case result dicts with "nodeid", "file", "status" keys
+
+ Returns:
+ Dict mapping test file path -> aggregated stats
+ Each entry contains:
+ - file: test file path
+ - total: total cases in file
+ - passed, failed, errors, crashed, timeout, skipped: counts
+ - failed_cases: list of failed/error/crashed/timeout cases with details
+ - duration: total execution time for file
+ """
+ file_stats = {}
+
+ for case in cases_list:
+ test_file = case.get("file", "unknown")
+ if not test_file:
+ # Try to extract file from nodeid
+ nodeid = case.get("nodeid", "")
+ if "::" in nodeid:
+ test_file = nodeid.split("::")[0]
+ else:
+ test_file = "unknown"
+
+ status = case.get("status", "error")
+ duration = case.get("duration", 0.0)
+
+ if test_file not in file_stats:
+ file_stats[test_file] = {
+ "file": test_file,
+ "total": 0,
+ "passed": 0,
+ "failed": 0,
+ "errors": 0,
+ "timeout": 0,
+ "skipped": 0,
+ "failed_cases": [],
+ "duration": 0.0,
+ }
+
+ stats = file_stats[test_file]
+ stats["total"] += 1
+ stats["duration"] += duration
+
+ if status == "passed":
+ stats["passed"] += 1
+ elif status == "failed":
+ stats["failed"] += 1
+ stats["failed_cases"].append({
+ "nodeid": case.get("nodeid"),
+ "status": "failed",
+ "message": case.get("message", ""),
+ "duration": duration,
+ })
+ elif status == "error":
+ stats["errors"] += 1
+ stats["failed_cases"].append({
+ "nodeid": case.get("nodeid"),
+ "status": "error",
+ "message": case.get("message", ""),
+ "duration": duration,
+ })
+ elif status == "timeout":
+ stats["timeout"] += 1
+ stats["failed_cases"].append({
+ "nodeid": case.get("nodeid"),
+ "status": "timeout",
+ "message": f"Timeout after {duration}s",
+ "duration": duration,
+ })
+ elif status == "skipped":
+ stats["skipped"] += 1
+
+ return file_stats
+
+
+def aggregate_all_cases_by_file(cases_results: Dict) -> Dict[str, Dict]:
+ """
+ Aggregate all cases from multiple shards by test file.
+
+ Args:
+ cases_results: Dict mapping shard_key -> cases_data (from shard_*_cases.json)
+
+ Returns:
+ Dict mapping test file -> aggregated stats across all shards
+ """
+ all_file_stats = {}
+
+ for shard_key, cases_data in cases_results.items():
+ shard_cases = cases_data.get("cases", [])
+ file_stats = aggregate_cases_by_file(shard_cases)
+
+ for test_file, stats in file_stats.items():
+ if test_file not in all_file_stats:
+ all_file_stats[test_file] = {
+ "file": test_file,
+ "total": 0,
+ "passed": 0,
+ "failed": 0,
+ "errors": 0,
+ "timeout": 0,
+ "skipped": 0,
+ "failed_cases": [],
+ "duration": 0.0,
+ }
+
+ existing = all_file_stats[test_file]
+ existing["total"] += stats["total"]
+ existing["passed"] += stats["passed"]
+ existing["failed"] += stats["failed"]
+ existing["errors"] += stats["errors"]
+ existing["timeout"] += stats["timeout"]
+ existing["skipped"] += stats["skipped"]
+ existing["duration"] += stats["duration"]
+ existing["failed_cases"].extend(stats["failed_cases"])
+
+ # Sort failed_cases within each file
+ for test_file in all_file_stats:
+ all_file_stats[test_file]["failed_cases"].sort(
+ key=lambda x: x.get("nodeid", "")
+ )
+
+ return all_file_stats
+
+
+def get_files_with_failures(file_stats: Dict[str, Dict]) -> List[Dict]:
+ """
+ Get list of test files that have failures/errors/timeout.
+
+ Args:
+ file_stats: Dict from aggregate_all_cases_by_file()
+
+ Returns:
+ List of file stats dicts sorted by file name, only including files with failures
+ """
+ failed_files = []
+ for test_file, stats in file_stats.items():
+ if stats["failed"] > 0 or stats["errors"] > 0 or stats["timeout"] > 0:
+ failed_files.append(stats)
+
+ failed_files.sort(key=lambda x: x["file"])
+ return failed_files
+
+
+# ==============================================================================
+# Summary Printing
+# ==============================================================================
+
+
+def print_stats_summary(shard: int, stats: Dict, shard_type: str = "regular") -> None:
+ """Print statistics summary to stdout."""
+ prefix = get_shard_type_prefix(shard_type)
+ print(f"\n{'=' * 60}")
+ print(f"Test Results for Shard {prefix}-{shard}")
+ print(f"{'=' * 60}")
+ print(f"Total: {stats['total']}")
+ print(f"Passed: {stats['passed']}")
+ print(f"Failed: {stats['failed']}")
+ print(f"Skipped: {stats['skipped']}")
+ print(f"Errors: {stats['errors']}")
+ print(f"Duration: {stats['duration']:.2f}s")
+ if stats.get("missing_files_count"):
+ print(f"Missing files: {stats['missing_files_count']}")
+ if stats.get("crashed"):
+ print(f"Crash signal: {stats.get('crash_signal', 'unknown')}")
+ print(f"{'=' * 60}")
+
+
+def create_result_summary(stats: Dict, shard: int, shard_type: str) -> str:
+ """Create a formatted result summary string."""
+ prefix = get_shard_type_prefix(shard_type)
+ status = get_shard_status(stats, stats.get("junit_generated", False))
+
+ lines = [
+ f"Shard {prefix}-{shard} Results:",
+ f" Status: {status}",
+ f" Total: {stats.get('total', 0)}",
+ f" Passed: {stats.get('passed', 0)}",
+ f" Failed: {stats.get('failed', 0)}",
+ f" Errors: {stats.get('errors', 0)}",
+ f" Duration: {stats.get('duration', 0.0):.2f}s",
+ ]
+
+ if stats.get("missing_files_count"):
+ lines.append(f" Missing: {stats['missing_files_count']}")
+
+ return "\n".join(lines)
+
+
+# ==============================================================================
+# High-Level Parsing Functions
+# ==============================================================================
+
+
+def parse_shard_results(
+ report_dir: Path,
+ shard: int,
+ shard_type: str,
+ returncode: int,
+ duration: float,
+ missing_files: List[str] = None,
+) -> Tuple[Dict, Dict]:
+ """
+ Parse all results for a shard and return (stats, log_metrics).
+
+ This is the main entry point for result parsing.
+
+ Args:
+ report_dir: Directory containing test reports
+ shard: Shard number
+ shard_type: "distributed" or "regular"
+ returncode: pytest process return code
+ duration: Execution duration
+ missing_files: List of files that crashed (no XML generated)
+
+ Returns:
+ Tuple of (stats_dict, log_metrics_dict)
+ """
+ missing_files = missing_files or []
+
+ # Parse JUnit XML files
+ stats = parse_shard_xml_files(report_dir, shard, shard_type)
+
+ # Add per-file isolation metadata
+ stats["per_file_isolation"] = True
+ stats["missing_files_count"] = len(missing_files)
+
+ # Analyze log file
+ log_file = get_shard_log_file(report_dir, shard, shard_type)
+ log_metrics = analyze_pytest_log(log_file, returncode)
+
+ # Finalize stats
+ stats = finalize_stats(stats, returncode, duration)
+
+ # Merge log metrics
+ log_metrics["test_failures"] = stats.get("failed", 0) + stats.get("errors", 0)
+ log_metrics["missing_files_count"] = len(missing_files)
+ stats.update(log_metrics)
+
+ # Handle returncode=5 (no tests collected) as success
+ if returncode == 5 and stats.get("total", 0) == 0:
+ stats["returncode"] = 0
+
+ return stats, log_metrics
+
+
+def generate_shard_reports(
+ report_dir: str,
+ shard: int,
+ shard_type: str,
+ stats: Dict,
+ info: Dict,
+ missing_files: List[str] = None,
+) -> Dict[str, str]:
+ """
+ Generate all report files for a shard.
+
+ Args:
+ report_dir: Output directory
+ shard: Shard number
+ shard_type: "distributed" or "regular"
+ stats: Statistics dict
+ info: Info dict
+ missing_files: List of missing/crashed files
+
+ Returns:
+ Dict mapping report type to file path
+ """
+ report_files = {}
+
+ # Save stats
+ report_files["stats"] = save_stats_file(report_dir, shard, stats, shard_type)
+
+ # Save info
+ report_files["info"] = save_info_file(report_dir, shard, info, shard_type)
+
+ # Save missing files if any
+ if missing_files:
+ report_files["missing"] = save_missing_files_file(report_dir, shard, missing_files, shard_type)
+
+ return report_files
+
+
+# ==============================================================================
+# CLI Interface
+# ==============================================================================
+
+
+def parse_args():
+ """Parse command line arguments."""
+ parser = argparse.ArgumentParser(description="Parse test results from JUnit XML files")
+ parser.add_argument("--report-dir", type=str, required=True, help="Directory containing test reports")
+ parser.add_argument("--shard", type=int, required=True, help="Shard number")
+ parser.add_argument(
+ "--shard-type",
+ type=str,
+ choices=["distributed", "regular"],
+ default="regular",
+ help="Shard type",
+ )
+ parser.add_argument("--returncode", type=int, default=0, help="pytest return code")
+ parser.add_argument("--duration", type=float, default=0.0, help="Execution duration in seconds")
+ parser.add_argument("--output-stats", type=str, help="Output file for stats JSON")
+ parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output")
+ return parser.parse_args()
+
+
+def main():
+ """CLI entry point."""
+ args = parse_args()
+
+ report_dir = Path(args.report_dir).resolve()
+ if not report_dir.exists():
+ print(f"Error: Report directory not found: {report_dir}")
+ sys.exit(1)
+
+ # Parse results
+ stats, log_metrics = parse_shard_results(
+ report_dir=report_dir,
+ shard=args.shard,
+ shard_type=args.shard_type,
+ returncode=args.returncode,
+ duration=args.duration,
+ )
+
+ # Output
+ if args.output_stats:
+ output_path = Path(args.output_stats)
+ output_path.parent.mkdir(parents=True, exist_ok=True)
+ output_path.write_text(json.dumps(stats, indent=2), encoding="utf-8")
+ print(f"Stats saved to: {output_path}")
+
+ if args.verbose:
+ print(json.dumps(stats, indent=2))
+
+ print_stats_summary(args.shard, stats, args.shard_type)
+
+ # Exit with appropriate code
+ sys.exit(stats.get("returncode", 0))
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/.github/scripts/run_npu_test_shard.py b/.github/scripts/run_npu_test_shard.py
new file mode 100644
index 0000000000..d85bed23c7
--- /dev/null
+++ b/.github/scripts/run_npu_test_shard.py
@@ -0,0 +1,1294 @@
+#!/usr/bin/env python3
+"""
+Run PyTorch NPU tests via per-case isolation pytest execution.
+
+This script executes pre-collected test cases or specified test files
+with per-case subprocess isolation for crash safety.
+
+Execution modes:
+ - Pre-collected cases (--cases-json): Execute cases from JSON file
+ - Custom test files (--test-files): Execute specified test files
+
+Each case runs in its own pytest subprocess for isolation:
+ - NPU kernel crashes won't cascade to other cases
+ - Results recorded in cases.json file
+
+Test types:
+ - distributed: Serial execution (one case at a time)
+ - regular: Concurrent execution (multiple workers)
+
+Usage:
+ # Pre-collected cases mode (primary usage):
+ python run_npu_test_shard.py \
+ --cases-json distributed_cases_shard_1.json \
+ --test-dir /path/to/pytorch/test \
+ --disabled-testcases /path/to/disabled_testcases.json \
+ --report-dir test-reports \
+ --timeout 1200 \
+ --max-workers 64 \
+ --verbose
+
+ # Custom test files mode:
+ python run_npu_test_shard.py \
+ --test-files test_meta.py,test_nn.py \
+ --test-dir /path/to/pytorch/test \
+ --disabled-testcases /path/to/disabled_testcases.json \
+ --report-dir test-reports \
+ --timeout 1200 \
+ --max-workers 4 \
+ --verbose
+
+Note: Shard discovery mode (--shard/--num-shards/--test-type) has been removed.
+ Use collect_all_cases.py for case discovery and sharding.
+"""
+
+import argparse
+import dataclasses
+import importlib.util
+import json
+import os
+import subprocess
+import sys
+import threading
+import xml.etree.ElementTree as ET
+from concurrent.futures import ThreadPoolExecutor, as_completed
+from datetime import datetime
+from pathlib import Path
+from queue import Queue
+from time import monotonic
+from typing import Dict, List, Tuple
+
+import collect_all_cases
+
+
+# ==============================================================================
+# Import Result Parser Module
+# ==============================================================================
+
+
+def load_parse_test_results_module(script_dir: Path):
+ """Load parse_test_results module dynamically."""
+ module_path = script_dir / "parse_test_results.py"
+ if not module_path.exists():
+ raise FileNotFoundError(f"parse_test_results.py not found at {module_path}")
+
+ spec = importlib.util.spec_from_file_location("parse_test_results", str(module_path))
+ module = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(module)
+ return module
+
+
+# ==============================================================================
+# Data Classes
+# ==============================================================================
+
+
+@dataclasses.dataclass
+class CaseExecutionTask:
+ """Task for concurrent case execution."""
+ case_idx: int
+ nodeid: str
+ test_file: str
+ file_idx: int
+
+
+@dataclasses.dataclass
+class ConcurrentExecutionConfig:
+ """Configuration for concurrent execution."""
+ max_workers: int = 4
+ per_case_timeout: int = 1200
+ verbose: bool = False
+
+
+# ==============================================================================
+# Case Log Saving Functions
+# ==============================================================================
+
+
+def sanitize_nodeid_for_filename(nodeid: str) -> str:
+ """
+ Convert nodeid to a safe filename.
+
+ Replaces special characters with underscores and truncates if too long.
+ Invalid characters for NTFS/filesystems: " : < > | * ? \r \n
+ """
+ # Replace special characters (including NTFS-invalid chars)
+ safe_name = nodeid.replace("::", "_").replace("/", "_").replace("\\", "_")
+ safe_name = safe_name.replace("(", "_").replace(")", "_").replace("[", "_").replace("]", "_")
+ # NTFS-invalid characters that GitHub Actions artifact upload rejects
+ safe_name = safe_name.replace("<", "_lt_").replace(">", "_gt_")
+ safe_name = safe_name.replace('"', "_quot_").replace("|", "_pipe_")
+ safe_name = safe_name.replace("*", "_star_").replace("?", "_q_")
+ safe_name = safe_name.replace(":", "_colon_")
+ safe_name = safe_name.replace(" ", "_")
+ safe_name = safe_name.replace(".", "_")
+
+ # Remove leading underscores and collapse multiple underscores
+ while safe_name.startswith("_"):
+ safe_name = safe_name[1:]
+ while "__" in safe_name:
+ safe_name = safe_name.replace("__", "_")
+
+ # Truncate if too long (max 200 chars)
+ if len(safe_name) > 200:
+ safe_name = safe_name[:200]
+
+ return safe_name or "unknown_case"
+
+
+def save_case_log(
+ report_dir: Path,
+ shard: int,
+ shard_type: str,
+ nodeid: str,
+ case_idx: int,
+ status: str,
+ stdout: str,
+ stderr: str,
+ duration: float,
+ returncode: int,
+ command: str,
+) -> Path:
+ """
+ Save complete execution log for all test cases.
+
+ Creates a dedicated log file containing:
+ - Case metadata (nodeid, status, duration, returncode)
+ - Full stdout and stderr output
+ - Execution command
+
+ Returns:
+ Path to the saved log file
+ """
+ # Create cases log directory
+ cases_logs_dir = report_dir / "cases_logs"
+ cases_logs_dir.mkdir(parents=True, exist_ok=True)
+
+ # Generate safe filename
+ safe_name = sanitize_nodeid_for_filename(nodeid)
+ prefix = "dist" if shard_type == "distributed" else "reg"
+ log_filename = f"{prefix}-{shard}_{case_idx}_{safe_name}.log"
+ log_path = cases_logs_dir / log_filename
+
+ # Write log content
+ content_lines = [
+ "=" * 80,
+ f"CASE LOG",
+ "=" * 80,
+ f"Shard: {prefix}-{shard}",
+ f"Case Index: {case_idx}",
+ f"Nodeid: {nodeid}",
+ f"Status: {status}",
+ f"Duration: {duration:.2f}s",
+ f"Return Code: {returncode}",
+ f"Command: {command}",
+ "=" * 80,
+ "",
+ "STDOUT:",
+ "-" * 80,
+ stdout or "(empty)",
+ "",
+ "STDERR:",
+ "-" * 80,
+ stderr or "(empty)",
+ "",
+ "=" * 80,
+ ]
+
+ log_path.write_text("\n".join(content_lines), encoding="utf-8")
+ return log_path
+
+
+class ConcurrentResultAggregator:
+ """Thread-safe result aggregator for concurrent execution."""
+
+ def __init__(self):
+ self._lock = threading.Lock()
+ self._cases_list: List[Dict] = []
+ self._worst_returncode: int = 0
+ self._passed_count: int = 0
+ self._failed_count: int = 0
+ self._error_count: int = 0
+ self._skipped_count: int = 0
+ self._timeout_count: int = 0
+ self._total_cases: int = 0
+
+ def add_case_result(self, case_result: Dict) -> None:
+ """Thread-safe add case result."""
+ with self._lock:
+ self._cases_list.append(case_result)
+ self._total_cases += 1
+
+ status = case_result.get("status", "error")
+ if status == "passed":
+ self._passed_count += 1
+ elif status == "failed":
+ self._failed_count += 1
+ elif status == "skipped":
+ self._skipped_count += 1
+ elif status == "timeout":
+ self._timeout_count += 1
+ else:
+ # error
+ self._error_count += 1
+
+ # Track worst returncode (ignore skipped)
+ rc = case_result.get("returncode", 1)
+ if rc != 0:
+ if self._worst_returncode == 0:
+ self._worst_returncode = rc
+
+ def get_sorted_cases(self) -> List[Dict]:
+ """Get cases sorted by case_idx."""
+ with self._lock:
+ return sorted(self._cases_list, key=lambda x: x.get("case_idx", 0))
+
+ def get_summary(self) -> Dict:
+ """Get execution summary."""
+ with self._lock:
+ return {
+ "total_cases": self._total_cases,
+ "passed_count": self._passed_count,
+ "failed_count": self._failed_count,
+ "error_count": self._error_count,
+ "skipped_count": self._skipped_count,
+ "timeout_count": self._timeout_count,
+ "worst_returncode": self._worst_returncode,
+ }
+
+
+class ProgressTracker:
+ """Thread-safe progress tracker with real-time output."""
+
+ def __init__(self, total_tasks: int):
+ self._total_tasks = total_tasks
+ self._completed_tasks = 0
+ self._lock = threading.Lock()
+ self._start_time = monotonic()
+
+ def mark_completed(self, nodeid: str, status: str, duration: float) -> None:
+ """Mark task completed and print progress."""
+ with self._lock:
+ self._completed_tasks += 1
+ elapsed = monotonic() - self._start_time
+ progress_pct = (self._completed_tasks / self._total_tasks) * 100
+
+ # Status indicator
+ status_icon = {
+ "passed": "[PASS]",
+ "failed": "[FAIL]",
+ "error": "[ERR]",
+ "timeout": "[TIME]",
+ "skipped": "[SKIP]",
+ }.get(status, "[?]")
+
+ # Truncate nodeid for display
+ display_nodeid = nodeid[:60] + "..." if len(nodeid) > 60 else nodeid
+
+ print(f"[{self._completed_tasks}/{self._total_tasks}] {progress_pct:.1f}% "
+ f"{status_icon} {display_nodeid} ({duration:.1f}s) "
+ f"[elapsed: {elapsed:.0f}s]", flush=True)
+
+
+# ==============================================================================
+# JUnit XML Parsing for Accurate Status Detection
+# ==============================================================================
+
+
+def parse_junit_xml_status(xml_file: Path) -> Dict:
+ """
+ 解析 JUnit XML 报告,获取测试状态。
+
+ Args:
+ xml_file: JUnit XML 文件路径
+
+ Returns:
+ Dict: {"status": "passed" | "skipped" | "failed" | "error" | "no_xml", "message": str}
+ """
+ if not xml_file.exists():
+ return {"status": "no_xml", "message": "XML file not generated"}
+
+ try:
+ tree = ET.parse(str(xml_file))
+ root = tree.getroot()
+
+ for testcase in root.iter("testcase"):
+ result = {"status": "passed", "message": ""}
+
+ # Check
+ skipped_elem = testcase.find("skipped")
+ if skipped_elem is not None:
+ result["status"] = "skipped"
+ result["message"] = skipped_elem.get("message", "")
+ return result
+
+ # Check
+ failure_elem = testcase.find("failure")
+ if failure_elem is not None:
+ result["status"] = "failed"
+ result["message"] = failure_elem.get("message", "")
+ return result
+
+ # Check
+ error_elem = testcase.find("error")
+ if error_elem is not None:
+ result["status"] = "error"
+ result["message"] = error_elem.get("message", "")
+ return result
+
+ # No failure/error/skipped = passed
+ return result
+
+ return {"status": "error", "message": "No testcase in XML"}
+
+ except Exception:
+ return {"status": "no_xml", "message": "XML parse failed"}
+
+
+# ==============================================================================
+# Utility Functions
+# ==============================================================================
+
+
+def strip_test_prefix_and_suffix(test_path: str) -> str:
+ """Remove 'test/' prefix and '.py' suffix from path."""
+ path = test_path
+ if path.startswith("test/"):
+ path = path[5:]
+ if path.endswith(".py"):
+ path = path[:-3]
+ return path
+
+
+def load_installed_torch_root() -> str:
+ """Get installed torch root directory."""
+ try:
+ import torch
+ return str(Path(torch.__file__).resolve().parent.parent)
+ except Exception as exc:
+ print(f"Warning: Failed to import torch: {exc}")
+ return ""
+
+
+# ==============================================================================
+# Concurrent Case Execution
+# ==============================================================================
+
+
+def run_single_case_concurrent(
+ task: CaseExecutionTask,
+ test_dir: Path,
+ merged_env: Dict[str, str],
+ config: ConcurrentExecutionConfig,
+ result_aggregator: ConcurrentResultAggregator,
+ progress_tracker: ProgressTracker,
+ log_queue: Queue,
+ report_dir: Path,
+ shard: int,
+ shard_type: str,
+) -> Dict:
+ """
+ Execute a single test case in subprocess (for concurrent execution).
+
+ This function runs in ThreadPoolExecutor threads. Each call spawns
+ an independent subprocess for the test case. Core dumps and crashes
+ in the subprocess do NOT affect the main Python process or other
+ concurrent tasks.
+
+ CRITICAL: This function must catch ALL exceptions and return a result
+ dict. It should NEVER raise exceptions to ThreadPoolExecutor level.
+
+ Args:
+ task: Case execution task with nodeid and metadata
+ test_dir: PyTorch test directory
+ merged_env: Environment variables
+ config: Execution configuration
+ result_aggregator: Thread-safe result collector
+ progress_tracker: Thread-safe progress tracker
+ log_queue: Queue for log messages
+
+ Returns:
+ Dict with case result (never raises exception)
+ """
+ start_time = monotonic()
+ original_nodeid = task.nodeid
+ case_nodeid = task.nodeid
+
+ # Strip test/ prefix for pytest execution
+ if case_nodeid.startswith("test/"):
+ case_nodeid = case_nodeid[5:]
+
+ # Generate XML file path with descriptive name
+ prefix = "dist" if shard_type == "distributed" else "reg"
+ safe_case_name = sanitize_nodeid_for_filename(original_nodeid)
+ xml_filename = f"{prefix}-{shard}_{task.case_idx}_{safe_case_name}.xml"
+ xml_file = report_dir / "junit_xmls" / xml_filename
+
+ command = [
+ sys.executable,
+ "-m",
+ "pytest",
+ "--color=no",
+ "-ra",
+ "--tb=short",
+ case_nodeid,
+ f"--junitxml={xml_file}",
+ "--junit-prefix=",
+ ]
+
+ if config.per_case_timeout > 0:
+ command.append(f"--timeout={config.per_case_timeout}")
+
+ if config.verbose:
+ command.append("-vv")
+ else:
+ command.append("-v")
+
+ command_str = " ".join(command)
+
+ # Build per-case environment with test file directory in PYTHONPATH
+ # This enables imports of sibling modules (e.g., 'from model_registry import MLPModule')
+ case_env = merged_env.copy()
+ test_file = task.test_file
+ if test_file.startswith("test/"):
+ test_file_rel = test_file[5:]
+ else:
+ test_file_rel = test_file
+
+ test_file_path = Path(test_file_rel)
+ test_file_dir = test_dir / test_file_path.parent
+
+ existing_pythonpath = case_env.get("PYTHONPATH", "")
+ case_env["PYTHONPATH"] = str(test_file_dir) + (":" + existing_pythonpath if existing_pythonpath else "")
+
+ # Print start log to stdout (before execution)
+ # Truncate nodeid for display
+ display_nodeid = original_nodeid[:70] + "..." if len(original_nodeid) > 70 else original_nodeid
+ print(f"[{task.case_idx}] Starting: {display_nodeid}", flush=True)
+
+ # Log start
+ log_queue.put({
+ "type": "case_start",
+ "case_idx": task.case_idx,
+ "nodeid": original_nodeid,
+ "file": task.test_file,
+ "command": command_str,
+ })
+
+ # Execute subprocess - CRITICAL: catch ALL exceptions
+ try:
+ result = subprocess.run(
+ command,
+ cwd=str(test_dir),
+ env=case_env, # Use per-case environment with test file directory in PYTHONPATH
+ capture_output=True,
+ text=True,
+ encoding="utf-8",
+ errors="replace",
+ timeout=config.per_case_timeout + 30, # Extra buffer
+ )
+
+ duration = monotonic() - start_time
+ returncode = result.returncode
+
+ # Parse JUnit XML for status
+ # - Has XML: use XML status
+ # - No XML: error
+ xml_result = parse_junit_xml_status(xml_file)
+ xml_status = xml_result.get("status")
+
+ if xml_status == "no_xml":
+ # No XML → error
+ status = "error"
+ message = xml_result.get("message")
+ else:
+ # Has XML → use XML status
+ status = xml_status
+ message = xml_result.get("message", "")
+
+ # Save logs for all cases
+ save_case_log(
+ report_dir=report_dir,
+ shard=shard,
+ shard_type=shard_type,
+ nodeid=original_nodeid,
+ case_idx=task.case_idx,
+ status=status,
+ stdout=result.stdout,
+ stderr=result.stderr,
+ duration=duration,
+ returncode=returncode,
+ command=command_str,
+ )
+
+ case_result = {
+ "nodeid": original_nodeid,
+ "status": status,
+ "duration": duration,
+ "returncode": returncode,
+ "message": message,
+ "command": command_str,
+ "file": task.test_file,
+ "case_idx": task.case_idx,
+ }
+
+ except subprocess.TimeoutExpired:
+ # Timeout → no XML, status = timeout
+ duration = monotonic() - start_time
+ status = "timeout"
+ case_result = {
+ "nodeid": original_nodeid,
+ "status": status,
+ "duration": duration,
+ "returncode": -1,
+ "message": f"Timeout after {config.per_case_timeout}s",
+ "command": command_str,
+ "file": task.test_file,
+ "case_idx": task.case_idx,
+ }
+
+ # Save log for timeout
+ save_case_log(
+ report_dir=report_dir,
+ shard=shard,
+ shard_type=shard_type,
+ nodeid=original_nodeid,
+ case_idx=task.case_idx,
+ status=status,
+ stdout="(process timed out, no output captured)",
+ stderr="(process timed out, no output captured)",
+ duration=duration,
+ returncode=-1,
+ command=command_str,
+ )
+
+ except Exception as e:
+ # Any other exception - return result, don't raise
+ duration = monotonic() - start_time
+ case_result = {
+ "nodeid": original_nodeid,
+ "status": "error",
+ "duration": duration,
+ "returncode": 1,
+ "message": f"Unexpected error: {str(e)[:200]}",
+ "command": command_str,
+ "file": task.test_file,
+ "case_idx": task.case_idx,
+ }
+
+ # Save error case log
+ save_case_log(
+ report_dir=report_dir,
+ shard=shard,
+ shard_type=shard_type,
+ nodeid=original_nodeid,
+ case_idx=task.case_idx,
+ status="error",
+ stdout="(exception occurred before execution)",
+ stderr=str(e),
+ duration=duration,
+ returncode=1,
+ command=command_str,
+ )
+
+ # Log finish
+ log_queue.put({
+ "type": "case_finish",
+ "case_idx": task.case_idx,
+ "nodeid": original_nodeid,
+ "status": case_result["status"],
+ "duration": case_result["duration"],
+ "message": case_result["message"][:200] if case_result["message"] else "",
+ })
+
+ # Update aggregator (thread-safe)
+ result_aggregator.add_case_result(case_result)
+
+ # Update progress (thread-safe)
+ progress_tracker.mark_completed(original_nodeid, case_result["status"], duration)
+
+ return case_result
+
+
+def log_writer_thread(log_queue: Queue, log_file: Path, stop_event: threading.Event) -> None:
+ """
+ Background thread for writing logs.
+
+ Ensures thread-safe log file writes while concurrent tasks run.
+ """
+ with log_file.open("w", encoding="utf-8") as log_handle:
+ while not stop_event.is_set() or not log_queue.empty():
+ try:
+ log_entry = log_queue.get(timeout=0.5)
+ except:
+ continue
+
+ if log_entry.get("type") == "header":
+ log_handle.write(log_entry.get("content", ""))
+ log_handle.flush()
+ elif log_entry.get("type") == "case_start":
+ log_handle.write(f"\n[{log_entry['case_idx']}] {log_entry['nodeid']}\n")
+ log_handle.write(f" File: {log_entry.get('file', '')}\n")
+ log_handle.write(f" Command: {log_entry.get('command', '')}\n")
+ log_handle.flush()
+ elif log_entry.get("type") == "case_finish":
+ status_str = log_entry.get("status", "")
+ duration_str = f"{log_entry.get('duration', 0):.2f}s"
+ log_handle.write(f" Status: {status_str}, Duration: {duration_str}\n")
+ if log_entry.get("message"):
+ log_handle.write(f" Message: {log_entry['message']}\n")
+ log_handle.flush()
+ elif log_entry.get("type") == "summary":
+ log_handle.write(log_entry.get("content", ""))
+ log_handle.flush()
+
+
+def run_tests_with_tasks_concurrent(
+ tasks: List[CaseExecutionTask],
+ shard: int,
+ test_dir: Path,
+ report_dir: Path,
+ env_updates: Dict[str, str],
+ timeout: int,
+ verbose: bool,
+ shard_type: str,
+ max_workers: int,
+ result_module,
+ quick_test: int = None,
+) -> Tuple[int, float, List[Dict]]:
+ """
+ Execute pre-collected test cases with concurrent per-case isolation.
+
+ This function takes CaseExecutionTask objects directly (pre-collected cases)
+ and executes them concurrently without the file-level case collection phase.
+
+ Args:
+ tasks: List of CaseExecutionTask objects (pre-collected cases)
+ shard: Shard number
+ test_dir: PyTorch test directory
+ report_dir: Report output directory
+ env_updates: Environment variable updates
+ timeout: Per-case timeout in seconds
+ verbose: Verbose output
+ shard_type: "distributed" or "regular"
+ max_workers: Maximum concurrent subprocesses
+ result_module: parse_test_results module
+ quick_test: Maximum number of cases to execute (None = all cases)
+
+ Returns:
+ Tuple of (worst_returncode, duration, cases_list_sorted)
+ """
+ start = monotonic()
+ log_file = result_module.get_shard_log_file(report_dir, shard, shard_type)
+
+ # Create junit_xmls directory for XML reports
+ junit_xml_dir = report_dir / "junit_xmls"
+ junit_xml_dir.mkdir(parents=True, exist_ok=True)
+
+ merged_env = os.environ.copy()
+ merged_env.update(env_updates)
+
+ config = ConcurrentExecutionConfig(
+ max_workers=max_workers,
+ per_case_timeout=timeout,
+ verbose=verbose,
+ )
+
+ # Thread-safe result aggregator
+ result_aggregator = ConcurrentResultAggregator()
+
+ # Log queue and writer thread
+ log_queue = Queue()
+ stop_event = threading.Event()
+ log_thread = threading.Thread(
+ target=log_writer_thread,
+ args=(log_queue, log_file, stop_event),
+ daemon=True,
+ )
+
+ # Write log header
+ log_queue.put({
+ "type": "header",
+ "content": (
+ "=" * 80 + "\n"
+ f"Pre-collected cases concurrent execution ({shard_type} shard)\n"
+ "=" * 80 + "\n"
+ f"Total cases: {len(tasks)}\n"
+ f"Max concurrent workers: {max_workers}\n"
+ "Execution mode: concurrent subprocess, each case isolated\n"
+ "=" * 80 + "\n\n"
+ ),
+ })
+
+ log_thread.start()
+
+ # Quick test: limit number of cases to execute
+ if quick_test and len(tasks) > quick_test:
+ tasks = tasks[:quick_test]
+ print(f"\nQuick test mode: executing only {quick_test} cases", flush=True)
+
+ print(f"\n{'=' * 80}", flush=True)
+ print(f"Pre-collected cases: {len(tasks)} cases", flush=True)
+ print(f"Execution mode: {max_workers} workers concurrent, each case in subprocess", flush=True)
+ print(f"{'=' * 80}\n", flush=True)
+
+ total_cases = len(tasks)
+ print(f"Phase 1: Executing {total_cases} pre-collected cases...", flush=True)
+
+ # Phase 2: Concurrent execution via ThreadPoolExecutor
+ progress_tracker = ProgressTracker(total_cases)
+
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
+ # Submit all tasks
+ future_to_task = {
+ executor.submit(
+ run_single_case_concurrent,
+ task,
+ test_dir,
+ merged_env,
+ config,
+ result_aggregator,
+ progress_tracker,
+ log_queue,
+ report_dir,
+ shard,
+ shard_type,
+ ): task
+ for task in tasks
+ }
+
+ # Wait for completion (as_completed gives results as they finish)
+ for future in as_completed(future_to_task):
+ task = future_to_task[future]
+ try:
+ # Result already collected in aggregator
+ _ = future.result()
+ except Exception as e:
+ # Should never happen (run_single_case_concurrent catches all)
+ # But as safety, create error result
+ case_result = {
+ "nodeid": task.nodeid,
+ "status": "error",
+ "duration": 0.0,
+ "returncode": 1,
+ "message": f"Future error: {str(e)[:200]}",
+ "file": task.test_file,
+ "case_idx": task.case_idx,
+ }
+ result_aggregator.add_case_result(case_result)
+ progress_tracker.mark_completed(task.nodeid, "error", 0.0)
+
+ # Stop log thread
+ elapsed = monotonic() - start
+ summary = result_aggregator.get_summary()
+
+ log_queue.put({
+ "type": "summary",
+ "content": (
+ f"\n{'=' * 80}\n"
+ f"Summary: {summary['total_cases']} cases executed\n"
+ f" Passed: {summary['passed_count']}\n"
+ f" Failed: {summary['failed_count']}\n"
+ f" Errors: {summary['error_count']}\n"
+ f" Timeout: {summary['timeout_count']}\n"
+ f" Skipped: {summary['skipped_count']}\n"
+ f" Duration: {elapsed:.2f}s\n"
+ f" Concurrent workers: {max_workers}\n"
+ f"{'=' * 80}\n"
+ ),
+ })
+
+ stop_event.set()
+ log_thread.join(timeout=5)
+
+ # Print final summary
+ print(f"\n{'=' * 80}", flush=True)
+ print(f"Summary: {summary['total_cases']} cases executed", flush=True)
+ print(f" Passed: {summary['passed_count']}", flush=True)
+ print(f" Failed: {summary['failed_count']}", flush=True)
+ print(f" Errors: {summary['error_count']}", flush=True)
+ print(f" Timeout: {summary['timeout_count']}", flush=True)
+ print(f" Skipped: {summary['skipped_count']}", flush=True)
+ print(f" Duration: {elapsed:.2f}s", flush=True)
+ print(f"{'=' * 80}", flush=True)
+
+ return summary["worst_returncode"], elapsed, result_aggregator.get_sorted_cases()
+
+
+def build_execution_env(
+ test_dir: Path,
+ script_dir: Path,
+ disabled_testcases_file: str,
+ shard: int,
+ shard_type: str,
+) -> Dict[str, str]:
+ """Build environment variables for test execution."""
+ repo_root = test_dir.parent
+ pythonpath_parts = [str(script_dir)]
+
+ torch_path = load_installed_torch_root()
+ if torch_path:
+ pythonpath_parts.append(torch_path)
+
+ pythonpath_parts.extend([str(repo_root), str(test_dir)])
+
+ existing_pythonpath = os.environ.get("PYTHONPATH", "")
+ if existing_pythonpath:
+ pythonpath_parts.append(existing_pythonpath)
+
+ updates = {
+ "PYTHONPATH": os.pathsep.join(pythonpath_parts),
+ "PYTORCH_TEST_NPU": "1",
+ "TORCH_DEVICE_BACKEND_AUTOLOAD": "1",
+ "NO_TD": "1",
+ "PYTHONUNBUFFERED": "1",
+ # Note: Do NOT set CI=true here, as some test files have conditional
+ # test generation logic like:
+ # if not (IS_CI and torch.cuda.is_available()):
+ # globals().update(generate_tests(...))
+ # Setting CI=true would prevent test case generation in those files.
+ }
+
+ # Use PyTorch's built-in DISABLED_TESTS_FILE mechanism for skipping test cases
+ if disabled_testcases_file:
+ # The disabled_testcases.json format is similar to .pytorch-disabled-tests.json
+ # Set DISABLED_TESTS_FILE to use PyTorch's built-in skip mechanism
+ updates["DISABLED_TESTS_FILE"] = os.path.abspath(disabled_testcases_file)
+
+ return updates
+
+
+def clean_existing_junit_xml(report_dir: Path) -> None:
+ """Clean existing JUnit XML files."""
+ if not report_dir.exists():
+ return
+ for xml_file in report_dir.rglob("*.xml"):
+ xml_file.unlink(missing_ok=True)
+
+
+def remove_existing_file(path: Path) -> None:
+ """Remove existing file."""
+ path.unlink(missing_ok=True)
+
+
+# ==============================================================================
+# Test Files Input Parser
+# ==============================================================================
+
+
+def parse_test_files_input(test_files_str: str, test_dir: Path) -> List[str]:
+ """
+ Parse comma-separated test file input and return standardized test file paths.
+
+ Args:
+ test_files_str: Comma-separated test file paths (e.g., "test_meta.py,test_nn.py")
+ test_dir: Path to PyTorch test directory
+
+ Returns:
+ List of standardized test file paths (e.g., ["test/test_meta.py", "test/test_nn.py"])
+
+ Raises:
+ FileNotFoundError: If any specified test file does not exist
+ """
+ files = [f.strip() for f in test_files_str.split(",") if f.strip()]
+ result = []
+
+ for f in files:
+ # Normalize path format: ensure starts with "test/"
+ if not f.startswith("test/"):
+ f = "test/" + f
+
+ # Remove leading "test/" prefix if it's duplicated
+ if f.startswith("test/test/"):
+ f = f[5:]
+
+ # Verify file exists
+ full_path = test_dir.parent / f
+ if not full_path.exists():
+ # Try with .py extension if not provided
+ if not f.endswith(".py"):
+ f_with_ext = f + ".py"
+ full_path_with_ext = test_dir.parent / f_with_ext
+ if full_path_with_ext.exists():
+ f = f_with_ext
+ full_path = full_path_with_ext
+ else:
+ raise FileNotFoundError(f"Test file not found: {f} or {f_with_ext}")
+ else:
+ raise FileNotFoundError(f"Test file not found: {f}")
+
+ result.append(f)
+
+ return result
+
+
+# ==============================================================================
+# CLI
+# ==============================================================================
+
+
+def parse_args():
+ """Parse command line arguments."""
+ parser = argparse.ArgumentParser(
+ description="Run PyTorch NPU tests via per-case isolation pytest execution"
+ )
+ parser.add_argument("--test-files", type=str, help="Comma-separated test file paths to run directly (e.g., 'test_meta.py,test_nn.py')")
+ parser.add_argument("--cases-json", type=str, help="Path to pre-collected cases JSON file")
+ parser.add_argument("--test-dir", type=str, required=True, help="Path to PyTorch test directory")
+ parser.add_argument("--disabled-testcases", type=str, help="Path to disabled_testcases.json")
+ parser.add_argument("--report-dir", type=str, default="test-reports", help="Directory for reports")
+ parser.add_argument("--timeout", type=int, default=1200, help="Per-case timeout in seconds (default: 1200 = 20 minutes)")
+ parser.add_argument(
+ "--max-workers",
+ type=int,
+ default=4,
+ help="Maximum concurrent workers for regular tests (default: 4). Each worker runs one pytest subprocess.",
+ )
+ parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output")
+ parser.add_argument("--quick-test", type=int, default=None, help="Quick test mode: execute only N cases for fast verification (default: None, run all cases)")
+ args = parser.parse_args()
+
+ # Validate required arguments: must specify either --test-files or --cases-json
+ if not args.test_files and not args.cases_json:
+ parser.error("Either --test-files or --cases-json must be specified")
+
+ # Validate max_workers
+ if args.max_workers < 1:
+ parser.error("--max-workers must be at least 1")
+ if args.max_workers > 128:
+ print(f"WARNING: --max-workers={args.max_workers} is very high, may cause resource contention")
+
+ return args
+
+
+def main():
+ """Main entry point."""
+ args = parse_args()
+
+ # Resolve paths
+ test_dir = Path(args.test_dir).resolve()
+ if not test_dir.is_dir():
+ raise FileNotFoundError(f"Test directory not found: {test_dir}")
+
+ repo_root = test_dir.parent
+ script_dir = Path(__file__).resolve().parent
+ report_dir = Path(args.report_dir).resolve()
+ report_dir.mkdir(parents=True, exist_ok=True)
+
+ # Load modules
+ result_module = load_parse_test_results_module(script_dir)
+
+ timestamp = datetime.now().isoformat()
+
+ # ==========================================================================
+ # Mode: Direct execution of specified test files
+ # ==========================================================================
+ if args.test_files:
+ print("=" * 80)
+ print("Custom Test Files Execution Mode")
+ print("=" * 80)
+
+ # Parse test files input
+ planned_tests = parse_test_files_input(args.test_files, test_dir)
+
+ # Use fixed shard number for custom mode
+ shard = 1
+ num_shards = 1
+ shard_type = "custom"
+
+ print(f"Test files specified: {len(planned_tests)}")
+ print(f"Test directory: {test_dir}")
+ print(f"Execution mode: concurrent ({args.max_workers} workers, per-case subprocess isolation)")
+ if args.disabled_testcases:
+ disabled_count = result_module.load_disabled_testcases_count(args.disabled_testcases)
+ print(f"Disabled testcase entries: {disabled_count}")
+ print(f"\n{'=' * 80}\n")
+
+ for index, target in enumerate(planned_tests, 1):
+ display_name = strip_test_prefix_and_suffix(target)
+ print(f" [{index:03d}] {display_name}")
+
+ # Create info dict for custom mode
+ info = result_module.create_shard_info(shard, num_shards, timestamp)
+ info["selection_mode"] = "custom_files"
+ info["shard_type"] = shard_type
+ info["shard_files"] = len(planned_tests)
+ info["total_files"] = len(planned_tests)
+ info["selected_test_files"] = len(planned_tests)
+ if args.disabled_testcases:
+ info["disabled_count"] = result_module.load_disabled_testcases_count(args.disabled_testcases)
+
+ # Save test plan
+ result_module.save_test_plan_file(str(report_dir), shard, planned_tests, shard_type)
+
+ # Clean old files
+ clean_existing_junit_xml(report_dir)
+ remove_existing_file(result_module.get_shard_log_file(report_dir, shard, shard_type))
+
+ # Build execution env
+ env_updates = build_execution_env(
+ test_dir, script_dir, args.disabled_testcases, shard, shard_type
+ )
+
+ # Execute tests (custom mode uses concurrent execution by default)
+ cases_list = []
+ if planned_tests:
+ # Phase 1: Collect all test cases using collect_all_cases module
+ print("\nPhase 1: Collecting test cases...")
+ error_log_dir = report_dir / "collection_errors"
+ collected_cases = collect_all_cases.collect_all_cases(
+ planned_tests,
+ test_dir,
+ error_log_dir,
+ parallel=16,
+ )
+
+ # Apply quick_test limit if specified
+ if args.quick_test and len(collected_cases) > args.quick_test:
+ collected_cases = collected_cases[:args.quick_test]
+ print(f" Quick test mode: using only {args.quick_test} cases")
+
+ total_cases = len(collected_cases)
+ print(f"\nPhase 2: Executing {total_cases} cases with {args.max_workers} workers")
+
+ # Build CaseExecutionTask list
+ tasks = []
+ for i, case in enumerate(collected_cases, 1):
+ tasks.append(CaseExecutionTask(
+ case_idx=i,
+ nodeid=case["nodeid"],
+ test_file=case["file"],
+ file_idx=0, # Not needed for pre-collected cases
+ ))
+
+ # Phase 2: Execute cases using run_tests_with_tasks_concurrent
+ returncode, duration, cases_list = run_tests_with_tasks_concurrent(
+ tasks,
+ shard,
+ test_dir,
+ report_dir,
+ env_updates,
+ args.timeout,
+ args.verbose,
+ shard_type,
+ args.max_workers,
+ result_module,
+ args.quick_test,
+ )
+ info["per_case_isolation"] = True
+ info["concurrent_workers"] = args.max_workers
+ info["returncode"] = returncode
+ info["duration"] = duration
+ else:
+ returncode = 0
+ duration = 0.0
+
+ # Build cases.json data
+ passed_count = sum(1 for c in cases_list if c["status"] == "passed")
+ failed_count = sum(1 for c in cases_list if c["status"] == "failed")
+ error_count = sum(1 for c in cases_list if c["status"] == "error")
+ timeout_count = sum(1 for c in cases_list if c["status"] == "timeout")
+ skipped_count = sum(1 for c in cases_list if c["status"] == "skipped")
+
+ cases_data = {
+ "shard": shard,
+ "shard_type": shard_type,
+ "execution_mode": "concurrent",
+ "concurrent_workers": args.max_workers,
+ "total_cases": len(cases_list),
+ "passed": passed_count,
+ "failed": failed_count,
+ "errors": error_count,
+ "timeout": timeout_count,
+ "skipped": skipped_count,
+ "duration": duration,
+ "cases": cases_list,
+ }
+
+ # Save cases.json
+ result_module.save_cases_file(str(report_dir), shard, cases_data, shard_type)
+
+ # Save info and stats
+ result_module.save_info_file(str(report_dir), shard, info, shard_type)
+
+ stats = {
+ "total": len(cases_list),
+ "passed": passed_count,
+ "failed": failed_count,
+ "skipped": skipped_count,
+ "errors": error_count,
+ "timeout": timeout_count,
+ "duration": duration,
+ "returncode": returncode,
+ "per_case_isolation": True,
+ }
+
+ result_module.save_stats_file(str(report_dir), shard, stats, shard_type)
+
+ # Print summary
+ result_module.print_stats_summary(shard, stats, shard_type)
+
+ # Exit with 0 to allow step to succeed and report generation to proceed
+ # The actual test results are recorded in cases.json
+ sys.exit(0)
+
+ # ==========================================================================
+ # Mode: Pre-collected cases JSON execution
+ # ==========================================================================
+ if args.cases_json:
+ print("=" * 80)
+ print("Pre-collected Cases Execution Mode")
+ print("=" * 80)
+
+ cases_file = Path(args.cases_json).resolve()
+ if not cases_file.exists():
+ raise FileNotFoundError(f"Cases JSON file not found: {cases_file}")
+
+ cases_data = json.loads(cases_file.read_text(encoding="utf-8"))
+
+ shard = cases_data["shard"]
+ num_shards = cases_data["num_shards"]
+ shard_type = cases_data.get("test_type", "regular")
+ planned_cases = cases_data["cases"]
+ total_cases = len(planned_cases)
+
+ print(f"Cases JSON: {cases_file}")
+ print(f"Shard: {shard}/{num_shards}")
+ print(f"Test type: {shard_type}")
+ print(f"Total cases: {total_cases}")
+ print(f"Test directory: {test_dir}")
+
+ # Execution mode based on test_type
+ if shard_type == "distributed":
+ print(f"Execution mode: SERIAL (per-case subprocess isolation)")
+ else:
+ print(f"Execution mode: CONCURRENT ({args.max_workers} workers, per-case subprocess isolation)")
+
+ if args.disabled_testcases:
+ disabled_count = result_module.load_disabled_testcases_count(args.disabled_testcases)
+ print(f"Disabled testcase entries: {disabled_count}")
+
+ print(f"\n{'=' * 80}\n")
+
+ # Create info dict for cases-json mode
+ info = result_module.create_shard_info(shard, num_shards, timestamp)
+ info["selection_mode"] = "cases_json"
+ info["shard_type"] = shard_type
+ info["cases_json_file"] = str(cases_file)
+ info["total_cases"] = total_cases
+ info["per_case_isolation"] = True
+ if args.disabled_testcases:
+ info["disabled_count"] = result_module.load_disabled_testcases_count(args.disabled_testcases)
+
+ # Clean old files
+ clean_existing_junit_xml(report_dir)
+ remove_existing_file(result_module.get_shard_log_file(report_dir, shard, shard_type))
+
+ # Build execution env
+ env_updates = build_execution_env(
+ test_dir, script_dir, args.disabled_testcases, shard, shard_type
+ )
+
+ # Convert cases to CaseExecutionTask format
+ tasks = []
+ for i, case in enumerate(planned_cases, 1):
+ tasks.append(CaseExecutionTask(
+ case_idx=i,
+ nodeid=case["nodeid"],
+ test_file=case.get("file", ""),
+ file_idx=0,
+ ))
+
+ # Execute tests based on shard_type
+ cases_list = []
+ if tasks:
+ # Determine execution mode and worker count
+ if shard_type == "distributed":
+ # Distributed: serial execution (1 worker)
+ effective_workers = 1
+ print(f"\nExecution mode: SERIAL (distributed tests require sequential execution)")
+ else:
+ # Regular: concurrent execution
+ effective_workers = args.max_workers
+ print(f"\nExecution mode: CONCURRENT ({effective_workers} workers)")
+
+ # Execute tasks directly using the new function
+ returncode, duration, cases_list = run_tests_with_tasks_concurrent(
+ tasks,
+ shard,
+ test_dir,
+ report_dir,
+ env_updates,
+ args.timeout,
+ args.verbose,
+ shard_type,
+ effective_workers,
+ result_module,
+ args.quick_test,
+ )
+ info["execution_mode"] = "serial" if effective_workers == 1 else "concurrent"
+ info["concurrent_workers"] = effective_workers
+
+ info["returncode"] = returncode
+ info["duration"] = duration
+ else:
+ print("No cases to execute.")
+ returncode = 0
+ duration = 0.0
+
+ # Build cases.json data
+ passed_count = sum(1 for c in cases_list if c["status"] == "passed")
+ failed_count = sum(1 for c in cases_list if c["status"] == "failed")
+ error_count = sum(1 for c in cases_list if c["status"] == "error")
+ timeout_count = sum(1 for c in cases_list if c["status"] == "timeout")
+ skipped_count = sum(1 for c in cases_list if c["status"] == "skipped")
+
+ output_cases_data = {
+ "shard": shard,
+ "shard_type": shard_type,
+ "execution_mode": info.get("execution_mode", "unknown"),
+ "concurrent_workers": info.get("concurrent_workers", 1),
+ "total_cases": len(cases_list),
+ "passed": passed_count,
+ "failed": failed_count,
+ "errors": error_count,
+ "timeout": timeout_count,
+ "skipped": skipped_count,
+ "duration": duration,
+ "cases": cases_list,
+ }
+
+ # Save cases.json
+ result_module.save_cases_file(str(report_dir), shard, output_cases_data, shard_type)
+
+ # Save info and stats
+ result_module.save_info_file(str(report_dir), shard, info, shard_type)
+
+ stats = {
+ "total": len(cases_list),
+ "passed": passed_count,
+ "failed": failed_count,
+ "skipped": skipped_count,
+ "errors": error_count,
+ "timeout": timeout_count,
+ "duration": duration,
+ "returncode": returncode,
+ "per_case_isolation": True,
+ }
+
+ result_module.save_stats_file(str(report_dir), shard, stats, shard_type)
+
+ # Print summary
+ result_module.print_stats_summary(shard, stats, shard_type)
+
+ # Exit with 0 to allow step to succeed and report generation to proceed
+ # The actual test results are recorded in cases.json
+ sys.exit(0)
+
+ # No valid mode specified (should not reach here due to argument validation)
+ print("ERROR: Either --test-files or --cases-json must be specified")
+ sys.exit(1)
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/.github/workflows/_torch-npu-upstream-build.yml b/.github/workflows/_torch-npu-upstream-build.yml
new file mode 100644
index 0000000000..b63ddc8f3a
--- /dev/null
+++ b/.github/workflows/_torch-npu-upstream-build.yml
@@ -0,0 +1,416 @@
+name: Build PyTorch and torch_npu (with cache)
+
+on:
+ workflow_call:
+ inputs:
+ pytorch_ref:
+ required: true
+ type: string
+ description: 'PyTorch branch, tag, or commit SHA to build'
+ torch_npu_ref:
+ required: true
+ type: string
+ description: 'torch_npu branch, tag, or commit SHA to build'
+ python_version:
+ required: true
+ type: string
+ default: '3.11'
+ docker_image:
+ required: true
+ type: string
+ default: 'quay.io/kerer/pytorch:manylinux-cann9.0.0-beta.2-20260428'
+ description: 'Docker image URL to use for build'
+ outputs:
+ docker-image:
+ description: 'Full Docker image URL'
+ value: ${{ inputs.docker_image }}
+ torch-wheel:
+ description: 'PyTorch wheel artifact name'
+ value: 'torch-wheel-main'
+ torch-npu-wheel:
+ description: 'torch_npu wheel artifact name'
+ value: 'torch-npu-wheel-main'
+ pytorch-src:
+ description: 'PyTorch source and test code artifact name'
+ value: 'pytorch-src-main'
+ pytorch-version:
+ description: 'PyTorch version string'
+ value: ${{ jobs.build.outputs.pytorch-version }}
+
+env:
+ # 缓存版本号,当需要强制刷新缓存时修改此值
+ CACHE_VERSION: 'v2'
+ # GitHub 代理 URL(用于加速 git clone,留空则不使用代理)
+ GH_PROXY_URL: 'https://gh-proxy.test.osinfra.cn'
+ # PyPI 缓存 URL(用于加速 pip 下载)
+ PYPI_CACHE_URL: 'http://cache-service.nginx-pypi-cache.svc.cluster.local/pypi/simple'
+
+jobs:
+ build:
+ runs-on: linux-aarch64-a3-16
+ timeout-minutes: 240
+ outputs:
+ pytorch-version: ${{ steps.get_version.outputs.pytorch_version }}
+
+ container:
+ image: ${{ inputs.docker_image }}
+ options: --user root
+
+ steps:
+ - name: Display Docker image
+ run: |
+ echo "Using Docker image: ${{ inputs.docker_image }}"
+
+ - name: Setup CANN environment
+ run: |
+ source /usr/local/Ascend/cann/set_env.sh 2>/dev/null || true
+ source /usr/local/Ascend/nnal/atb/set_env.sh 2>/dev/null || true
+
+ - name: Configure git proxy for faster clone
+ run: |
+ # 配置 git URL rewrite 来使用代理(加速 clone 和 submodules)
+ if [ -n "${{ env.GH_PROXY_URL }}" ]; then
+ git config --global url."${{ env.GH_PROXY_URL }}/https://github.com/".insteadOf "https://github.com/"
+ git config --global url."${{ env.GH_PROXY_URL }}/https://gitlab.com/".insteadOf "https://gitlab.com/"
+ echo "Git proxy configured:"
+ git config --global --list | grep url
+ else
+ echo "No proxy configured, using direct connection"
+ fi
+
+ - name: Clone upstream PyTorch with submodules
+ id: clone_pytorch
+ run: |
+ # 使用代理加速 git clone(如果配置了 GH_PROXY_URL)
+ PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
+ if [ -n "${{ env.GH_PROXY_URL }}" ]; then
+ PYTORCH_REPO="${{ env.GH_PROXY_URL }}/${PYTORCH_REPO}"
+ echo "Using proxy: ${PYTORCH_REPO}"
+ fi
+
+ # 克隆指定 ref(branch, tag, 或 commit)
+ PYTORCH_REF="${{ inputs.pytorch_ref }}"
+ echo "Cloning PyTorch with ref: ${PYTORCH_REF}"
+
+ # 先浅克隆,再 fetch 指定 ref,最后 checkout
+ git clone --depth=1 "${PYTORCH_REPO}" pytorch-src
+ cd pytorch-src
+ git fetch --depth=1 origin "${PYTORCH_REF}"
+ git checkout "${PYTORCH_REF}"
+
+ # 初始化 submodules
+ git submodule update --init --recursive
+
+ PYTORCH_SHA=$(git rev-parse HEAD)
+ PYTORCH_SHA_SHORT=$(git rev-parse --short HEAD)
+ echo "pytorch_sha=${PYTORCH_SHA}" >> $GITHUB_OUTPUT
+ echo "pytorch_sha_short=${PYTORCH_SHA_SHORT}" >> $GITHUB_OUTPUT
+ echo "Cloned PyTorch commit: ${PYTORCH_SHA}"
+ echo "Submodules downloaded:"
+ ls -la third_party/ | head -20
+
+ - name: Checkout torch_npu
+ uses: actions/checkout@v4
+ with:
+ repository: Ascend/pytorch
+ ref: ${{ inputs.torch_npu_ref }}
+ path: torch_npu-src
+ submodules: recursive
+
+ # ==================== pip 缓存配置 ====================
+ # pip 缓存加速依赖下载,不影响构建结果
+ # 缓存键基于 requirements-build.txt hash(依赖变化频率低)
+ - name: Get pip cache key
+ id: pip_key
+ run: |
+ REQUIREMENTS_HASH=$(cd pytorch-src && sha256sum requirements-build.txt | cut -d' ' -f1)
+ echo "cache_key=${{ env.CACHE_VERSION }}-pip-${{ inputs.python_version }}-${REQUIREMENTS_HASH}" >> $GITHUB_OUTPUT
+
+ - name: Restore pip cache
+ uses: actions/cache/restore@v4
+ with:
+ path: /root/.cache/pip
+ key: ${{ steps.pip_key.outputs.cache_key }}
+ restore-keys: |
+ ${{ env.CACHE_VERSION }}-pip-${{ inputs.python_version }}-
+ ${{ env.CACHE_VERSION }}-pip-
+
+ - name: Setup pip cache directory
+ run: |
+ mkdir -p /root/.cache/pip
+
+ - name: Configure pip index URL
+ run: |
+ # 配置 pip 使用 PyPI 缓存加速下载
+ if [ -n "${{ env.PYPI_CACHE_URL }}" ]; then
+ pip${{ inputs.python_version }} config set global.index-url ${{ env.PYPI_CACHE_URL }}
+ pip${{ inputs.python_version }} config set global.trusted-host "cache-service.nginx-pypi-cache.svc.cluster.local"
+ echo "pip index-url configured: ${{ env.PYPI_CACHE_URL }}"
+ else
+ echo "No PyPI cache URL configured, using default"
+ fi
+
+ - name: Upgrade pip and setuptools
+ run: |
+ export PIP_CACHE_DIR=/root/.cache/pip
+ # 先升级 pip 和 setuptools,避免旧版包兼容性问题
+ pip${{ inputs.python_version }} install --upgrade pip setuptools wheel
+
+ # ==================== ccache 缓存配置 ====================
+ # ccache 是真正加速编译的关键(可节省 30-60 分钟)
+ # 注意:PyTorch 每次 clone 都是新 commit,所以缓存键不包含 PyTorch SHA
+ # 我们依赖 torch_npu SHA 和 requirements-build.txt hash 作为缓存键
+ - name: Get ccache key
+ id: ccache_key
+ run: |
+ # ccache 缓存键:torch_npu SHA + requirements hash
+ # PyTorch SHA 每次都变化(--depth=1 clone 最新),所以不包含在缓存键中
+ TORCH_NPU_SHA=$(cd torch_npu-src && git rev-parse HEAD)
+ REQUIREMENTS_HASH=$(cd pytorch-src && sha256sum requirements-build.txt | cut -d' ' -f1)
+ echo "cache_key=${{ env.CACHE_VERSION }}-ccache-${REQUIREMENTS_HASH}-${TORCH_NPU_SHA}" >> $GITHUB_OUTPUT
+ # partial_key 用于恢复同版本 requirements 的缓存(不同 torch_npu 版本)
+ echo "partial_key=${{ env.CACHE_VERSION }}-ccache-${REQUIREMENTS_HASH}-" >> $GITHUB_OUTPUT
+ # base_key 用于恢复同 CACHE_VERSION 的所有缓存
+ echo "base_key=${{ env.CACHE_VERSION }}-ccache-" >> $GITHUB_OUTPUT
+
+ - name: Restore ccache
+ uses: actions/cache/restore@v4
+ with:
+ path: /root/.cache/ccache
+ key: ${{ steps.ccache_key.outputs.cache_key }}
+ restore-keys: |
+ ${{ steps.ccache_key.outputs.partial_key }}
+ ${{ steps.ccache_key.outputs.base_key }}
+
+ - name: Setup ccache
+ run: |
+ # 安装 ccache(manylinux 镜像没有预装)
+ yum install -y ccache
+
+ # 创建 ccache 配置目录(使用绝对路径)
+ CCACHE_DIR_PATH="/root/.cache/ccache"
+ mkdir -p "$CCACHE_DIR_PATH"
+
+ # 直接写入配置文件(使用绝对路径)
+ cat > "$CCACHE_DIR_PATH/ccache.conf" << 'EOF'
+ max_size = 20G
+ cache_dir = /root/.cache/ccache
+ compression = true
+ compression_level = 6
+ EOF
+
+ # 使用符号链接方式让 ccache 模拟 gcc/g++
+ mkdir -p /usr/local/bin
+ ln -sf /usr/bin/ccache /usr/local/bin/gcc
+ ln -sf /usr/bin/ccache /usr/local/bin/g++
+ ln -sf /usr/bin/ccache /usr/local/bin/cc
+ ln -sf /usr/bin/ccache /usr/local/bin/c++
+
+ # 设置 PATH 优先使用符号链接
+ echo "PATH=/usr/local/bin:$PATH" >> $GITHUB_ENV
+
+ # 设置 CCACHE_DIR(使用绝对路径,不使用 ~)
+ echo "CCACHE_DIR=$CCACHE_DIR_PATH" >> $GITHUB_ENV
+
+ # 设置编译器环境变量,确保 CMake/Ninja 使用 ccache
+ echo "CC=/usr/local/bin/gcc" >> $GITHUB_ENV
+ echo "CXX=/usr/local/bin/g++" >> $GITHUB_ENV
+
+ echo "=== ccache Configuration ==="
+ CCACHE_DIR="$CCACHE_DIR_PATH" ccache --show-config
+
+ echo ""
+ echo "=== Config File Contents ==="
+ cat "$CCACHE_DIR_PATH/ccache.conf"
+
+ echo ""
+ echo "=== Cache Directory ==="
+ ls -la "$CCACHE_DIR_PATH/"
+
+ echo ""
+ echo "=== Symbolic Links ==="
+ ls -la /usr/local/bin/gcc /usr/local/bin/g++
+
+ echo ""
+ echo "=== ccache Statistics (before build) ==="
+ CCACHE_DIR="$CCACHE_DIR_PATH" ccache --show-stats
+
+ # ==================== 构建 PyTorch ====================
+ - name: Build PyTorch wheel
+ run: |
+ source /usr/local/Ascend/cann/set_env.sh 2>/dev/null || true
+
+ export PIP_CACHE_DIR=/root/.cache/pip
+ cd pytorch-src
+
+ # 安装构建依赖(pip 缓存已恢复,加速下载)
+ pip${{ inputs.python_version }} install --upgrade pip setuptools wheel
+ pip${{ inputs.python_version }} install -r requirements-build.txt
+
+ # 设置构建环境变量
+ export MAX_JOBS=128
+ export USE_CUDA=0
+ export USE_CUDNN=0
+ export USE_DISTRIBUTED=1
+ export CMAKE_BUILD_TYPE=Release
+ export USE_OPENMP=1
+ export USE_MKLDNN=0
+
+ # 确保使用 ccache(CMake 会检测 CC/CXX 环境变量)
+ export CC=/usr/local/bin/gcc
+ export CXX=/usr/local/bin/g++
+ export CCACHE_DIR=/root/.cache/ccache
+
+ # 清除 ccache 统计(开始新的构建)
+ ccache --zero-stats
+
+ python${{ inputs.python_version }} setup.py build bdist_wheel
+
+ echo "PyTorch wheel built:"
+ ls -la dist/
+
+ echo ""
+ echo "=== ccache Statistics (after PyTorch build) ==="
+ ccache --show-stats
+
+ # ==================== 构建 torch_npu ====================
+ - name: Install PyTorch wheel and build dependencies
+ run: |
+ source /usr/local/Ascend/cann/set_env.sh 2>/dev/null || true
+ source /usr/local/Ascend/nnal/atb/set_env.sh 2>/dev/null || true
+
+ export PIP_CACHE_DIR=/root/.cache/pip
+
+ echo "=== Installing built PyTorch wheel ==="
+ pip${{ inputs.python_version }} install pytorch-src/dist/*.whl
+
+ echo ""
+ echo "=== Verifying PyTorch installation ==="
+ python${{ inputs.python_version }} -c "import torch; print(f'torch version: {torch.__version__}')"
+
+ - name: Get PyTorch version
+ id: get_version
+ run: |
+ source /usr/local/Ascend/cann/set_env.sh 2>/dev/null || true
+ PYTORCH_VERSION=$(python${{ inputs.python_version }} -c "import torch; print(torch.__version__)")
+ echo "pytorch_version=${PYTORCH_VERSION}" >> $GITHUB_OUTPUT
+ echo "PyTorch version: ${PYTORCH_VERSION}"
+
+ - name: Install torch_npu build dependencies
+ run: |
+ source /usr/local/Ascend/cann/set_env.sh 2>/dev/null || true
+ source /usr/local/Ascend/nnal/atb/set_env.sh 2>/dev/null || true
+
+ export PIP_CACHE_DIR=/root/.cache/pip
+ pip${{ inputs.python_version }} install --upgrade pip setuptools wheel
+ pip${{ inputs.python_version }} install cmake ninja numpy packaging pyyaml requests six typing-extensions
+
+ cd torch_npu-src
+
+ # 显示 ccache 统计(依赖安装阶段)
+ echo ""
+ echo "=== ccache Statistics (before torch_npu build) ==="
+ CCACHE_DIR=/root/.cache/ccache ccache --show-stats
+
+ - name: Build torch_npu wheel
+ run: |
+ source /usr/local/Ascend/cann/set_env.sh 2>/dev/null || true
+ source /usr/local/Ascend/nnal/atb/set_env.sh 2>/dev/null || true
+
+ cd torch_npu-src
+
+ export MAX_JOBS=128
+
+ # 确保使用 ccache
+ export CC=/usr/local/bin/gcc
+ export CXX=/usr/local/bin/g++
+ export CCACHE_DIR=/root/.cache/ccache
+
+ # 禁用 torchair 构建(上游 PyTorch main API 变化导致兼容性问题)
+ bash ci/build.sh --python=${{ inputs.python_version }} --disable_torchair
+
+ echo "torch_npu wheel built:"
+ ls -la dist/
+
+ echo ""
+ echo "=== ccache Statistics (after torch_npu build) ==="
+ ccache --show-stats
+
+ # ==================== 保存缓存 ====================
+ - name: Save pip cache
+ if: always()
+ uses: actions/cache/save@v4
+ with:
+ path: /root/.cache/pip
+ key: ${{ steps.pip_key.outputs.cache_key }}
+
+ - name: Save ccache
+ if: always()
+ uses: actions/cache/save@v4
+ with:
+ path: /root/.cache/ccache
+ key: ${{ steps.ccache_key.outputs.cache_key }}
+
+ - name: Display cache save status
+ if: always()
+ run: |
+ echo "=== Cache Saved ==="
+ echo "pip cache key: ${{ steps.pip_key.outputs.cache_key }}"
+ PIP_CACHE_SIZE=$(du -sh /root/.cache/pip 2>/dev/null | cut -f1)
+ echo "pip cache size: ${PIP_CACHE_SIZE}"
+ echo ""
+ echo "ccache key: ${{ steps.ccache_key.outputs.cache_key }}"
+ CCACHE_SIZE=$(du -sh /root/.cache/ccache 2>/dev/null | cut -f1)
+ echo "ccache size: ${CCACHE_SIZE}"
+
+ # ==================== 打包和上传 ====================
+ - name: Package PyTorch source and build artifacts
+ run: |
+ # 打包整个 pytorch-src 目录(包含测试源码和编译产物)
+ # 排除不必要的文件以减小体积:
+ # - .git 目录(最占空间)
+ # - build/ 目录中的编译中间产物(CMakeFiles, .o 文件等)
+ # - dist/*.whl(已单独上传为 artifact)
+
+ echo "=== PyTorch source directory size ==="
+ du -sh pytorch-src/
+
+ echo ""
+ echo "=== Build artifacts location ==="
+ ls -la pytorch-src/build/lib.*/torch/*.so 2>/dev/null | head -5 || echo "No .so files found in build/lib"
+ ls -la pytorch-src/torch/_C.so 2>/dev/null || echo "No _C.so in torch/"
+
+ echo ""
+ echo "=== Creating archive (excluding large unnecessary files) ==="
+ tar -czf pytorch-src.tar.gz \
+ --exclude='pytorch-src/.git' \
+ --exclude='pytorch-src/build/CMakeFiles' \
+ --exclude='pytorch-src/build/*.o' \
+ --exclude='pytorch-src/build/**/*.o' \
+ --exclude='pytorch-src/dist/*.whl' \
+ pytorch-src
+
+ echo ""
+ echo "=== Archive size ==="
+ ls -lh pytorch-src.tar.gz
+
+ - name: Upload PyTorch wheel
+ uses: actions/upload-artifact@v4
+ with:
+ name: torch-wheel-main
+ path: pytorch-src/dist/*.whl
+ retention-days: 7
+
+ - name: Upload torch_npu wheel
+ uses: actions/upload-artifact@v4
+ with:
+ name: torch-npu-wheel-main
+ path: torch_npu-src/dist/*.whl
+ retention-days: 7
+
+ - name: Upload PyTorch source and build artifacts
+ uses: actions/upload-artifact@v4
+ with:
+ name: pytorch-src-main
+ path: pytorch-src.tar.gz
+ retention-days: 7
\ No newline at end of file
diff --git a/.github/workflows/_torch-npu-upstream-collect.yml b/.github/workflows/_torch-npu-upstream-collect.yml
new file mode 100644
index 0000000000..bf57ab38c7
--- /dev/null
+++ b/.github/workflows/_torch-npu-upstream-collect.yml
@@ -0,0 +1,163 @@
+name: Torch NPU Upstream Collect
+
+on:
+ workflow_call:
+ inputs:
+ python_version:
+ required: true
+ type: string
+ description: Python version to use
+ torch_wheel_artifact:
+ required: true
+ type: string
+ description: Name of the torch wheel artifact from build
+ torch_npu_wheel_artifact:
+ required: true
+ type: string
+ description: Name of the torch_npu wheel artifact from build
+ pytorch_src_artifact:
+ required: true
+ type: string
+ description: Name of the pytorch source artifact from build
+ docker_image:
+ required: true
+ type: string
+ description: Docker image to use
+ distributed_shards:
+ required: false
+ type: string
+ default: '2'
+ description: Number of shards for distributed tests
+ regular_shards:
+ required: false
+ type: string
+ default: '5'
+ description: Number of shards for regular tests
+ outputs:
+ distributed_matrix:
+ description: Distributed shard matrix JSON
+ value: ${{ jobs.collect.outputs.distributed_matrix }}
+ regular_matrix:
+ description: Regular shard matrix JSON
+ value: ${{ jobs.collect.outputs.regular_matrix }}
+ distributed_shards:
+ description: Number of distributed shards
+ value: ${{ jobs.collect.outputs.distributed_shards }}
+ regular_shards:
+ description: Number of regular shards
+ value: ${{ jobs.collect.outputs.regular_shards }}
+ total_cases:
+ description: Total number of test cases
+ value: ${{ jobs.collect.outputs.total_cases }}
+
+jobs:
+ collect:
+ runs-on: linux-aarch64-a3-16
+ timeout-minutes: 120
+ container:
+ image: ${{ inputs.docker_image }}
+ options: --user root
+ outputs:
+ distributed_matrix: ${{ steps.collect_and_shard.outputs.distributed_matrix }}
+ regular_matrix: ${{ steps.collect_and_shard.outputs.regular_matrix }}
+ distributed_shards: ${{ steps.collect_and_shard.outputs.distributed_shards }}
+ regular_shards: ${{ steps.collect_and_shard.outputs.regular_shards }}
+ total_cases: ${{ steps.collect_and_shard.outputs.total_cases }}
+
+ steps:
+ - name: Setup NPU test environment
+ uses: kerer-ai/pytorch/.github/actions/setup-npu-test-env@dev_master
+ with:
+ python_version: ${{ inputs.python_version }}
+ torch_wheel_artifact: ${{ inputs.torch_wheel_artifact }}
+ torch_npu_wheel_artifact: ${{ inputs.torch_npu_wheel_artifact }}
+ pytorch_src_artifact: ${{ inputs.pytorch_src_artifact }}
+
+ - name: Collect all test cases and shard
+ id: collect_and_shard
+ run: |
+ source /usr/local/Ascend/cann/set_env.sh 2>/dev/null || true
+ source /usr/local/Ascend/nnal/atb/set_env.sh 2>/dev/null || true
+
+ PYTHON=python${{ inputs.python_version }}
+ cd pytorch-src
+
+ # 设置 BACKEND 环境变量,避免分布式测试收集阶段 KeyError
+ # 值设为 hccl(NPU 分布式后端),不是 gloo/nccl 时测试类不会被定义
+ # 结果:pytest 收集到 0 个用例,不会报错
+ export BACKEND="hccl"
+
+ # Case-level sharding
+ DISTRIBUTED_SHARDS='${{ inputs.distributed_shards }}'
+ REGULAR_SHARDS='${{ inputs.regular_shards }}'
+
+ echo "=== Collecting all test cases ==="
+ echo "Distributed shards: ${DISTRIBUTED_SHARDS}"
+ echo "Regular shards: ${REGULAR_SHARDS}"
+
+ $PYTHON ../ascend_pytorch/.github/scripts/collect_all_cases.py \
+ --test-dir test \
+ --distributed-shards ${DISTRIBUTED_SHARDS} \
+ --regular-shards ${REGULAR_SHARDS} \
+ --output-dir cases_shards \
+ --error-log-dir collection_errors \
+ --parallel 16 \
+ 2>&1 | tee /tmp/collect_cases.log
+
+ # Verify output
+ echo "=== Generated shard files ==="
+ ls -la cases_shards/
+
+ echo "=== Collection summary ==="
+ cat cases_shards/cases_collection_summary.json
+
+ # Extract total cases from summary
+ TOTAL_CASES=$(python3 -c "import json; d=json.load(open('cases_shards/cases_collection_summary.json')); print(d['total_cases'])")
+
+ # Build shard matrices
+ DIST_SHARDS=$(seq 1 ${DISTRIBUTED_SHARDS} | tr '\n' ',' | sed 's/,$//')
+ REG_SHARDS=$(seq 1 ${REGULAR_SHARDS} | tr '\n' ',' | sed 's/,$//')
+
+ echo "distributed_matrix=[${DIST_SHARDS}]" >> $GITHUB_OUTPUT
+ echo "distributed_shards=${DISTRIBUTED_SHARDS}" >> $GITHUB_OUTPUT
+ echo "regular_matrix=[${REG_SHARDS}]" >> $GITHUB_OUTPUT
+ echo "regular_shards=${REGULAR_SHARDS}" >> $GITHUB_OUTPUT
+ echo "total_cases=${TOTAL_CASES}" >> $GITHUB_OUTPUT
+
+ echo "=== Shard configuration ==="
+ echo "Distributed tests: ${DISTRIBUTED_SHARDS} shards (case-level, serial execution, linux-aarch64-a3-16)"
+ echo "Regular tests: ${REGULAR_SHARDS} shards (case-level, 64 workers, linux-aarch64-a3-16)"
+ echo "Total cases: ${TOTAL_CASES}"
+
+ # Package error logs if any
+ if [ -d "collection_errors" ] && [ "$(ls -A collection_errors 2>/dev/null)" ]; then
+ echo "=== Packaging collection error logs ==="
+ tar -czf collection_errors.tar.gz collection_errors/
+ echo "Error logs packaged: collection_errors.tar.gz"
+ ls -la collection_errors.tar.gz
+ fi
+
+ - name: Upload cases shard JSONs
+ uses: actions/upload-artifact@v4
+ with:
+ name: cases-shards
+ path: pytorch-src/cases_shards/
+ retention-days: 7
+
+ - name: Upload collection error logs
+ if: always()
+ uses: actions/upload-artifact@v4
+ with:
+ name: collection-error-logs
+ path: pytorch-src/collection_errors.tar.gz
+ if-no-files-found: ignore
+ retention-days: 30
+
+ - name: Upload collect logs
+ if: always()
+ uses: actions/upload-artifact@v4
+ with:
+ name: collect-cases-logs
+ path: /tmp/collect_cases.log
+ if-no-files-found: warn
+ retention-days: 30
\ No newline at end of file
diff --git a/.github/workflows/_torch-npu-upstream-report.yml b/.github/workflows/_torch-npu-upstream-report.yml
new file mode 100644
index 0000000000..d75248e5dc
--- /dev/null
+++ b/.github/workflows/_torch-npu-upstream-report.yml
@@ -0,0 +1,111 @@
+name: Torch NPU Upstream Report
+
+on:
+ workflow_call:
+ inputs:
+ python_version:
+ required: true
+ type: string
+ description: Python version to use
+ pytorch_version:
+ required: true
+ type: string
+ description: PyTorch version string
+ torch_npu_wheel_name:
+ required: false
+ type: string
+ default: 'source-build.whl'
+ description: Name of the torch_npu wheel file
+ docker_image:
+ required: true
+ type: string
+ description: Docker image used for tests
+ distributed_matrix:
+ required: false
+ type: string
+ default: '[]'
+ description: Distributed shard matrix JSON
+ regular_matrix:
+ required: false
+ type: string
+ default: '[]'
+ description: Regular shard matrix JSON
+
+jobs:
+ generate_report:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@v4
+ with:
+ repository: ${{ github.repository }}
+ ref: ${{ github.ref }}
+ fetch-depth: 1
+
+ - name: Setup Python ${{ inputs.python_version }}
+ uses: actions/setup-python@v5
+ with:
+ python-version: ${{ inputs.python_version }}
+
+ - name: Download distributed shard reports
+ uses: actions/download-artifact@v4
+ with:
+ pattern: test-reports-dist-*
+ path: all-test-reports
+ merge-multiple: true
+
+ - name: Download regular shard reports
+ uses: actions/download-artifact@v4
+ with:
+ pattern: test-reports-reg-*
+ path: all-test-reports
+ merge-multiple: true
+
+ - name: Download custom test reports
+ uses: actions/download-artifact@v4
+ with:
+ name: test-reports-custom
+ path: all-test-reports
+ merge-multiple: true
+ continue-on-error: true
+
+ - name: Download cases collection summary
+ uses: actions/download-artifact@v4
+ with:
+ name: cases-shards
+ path: cases-shards
+ continue-on-error: true
+
+ - name: Generate consolidated summary
+ run: |
+ PYTHON=python
+ REPORT_MD=npu-full-test-summary.md
+ REPORT_JSON=npu-full-test-summary.json
+
+ # Combine both shard matrices for reporting
+ DIST_MATRIX='${{ inputs.distributed_matrix }}'
+ REG_MATRIX='${{ inputs.regular_matrix }}'
+ COMBINED_MATRIX=$(python3 -c "import sys,json; dist=json.loads('${DIST_MATRIX}'); reg=json.loads('${REG_MATRIX}'); print(json.dumps(['dist-'+str(s) for s in dist]+['reg-'+str(s) for s in reg]))")
+
+ $PYTHON .github/scripts/generate_npu_full_test_report.py \
+ --reports-root all-test-reports \
+ --output-markdown ${REPORT_MD} \
+ --output-json ${REPORT_JSON} \
+ --pytorch-version "${{ inputs.pytorch_version }}" \
+ --torch-npu-whl "${{ inputs.torch_npu_wheel_name }}" \
+ --shard-matrix-json "${COMBINED_MATRIX}" \
+ --docker-image "${{ inputs.docker_image }}" \
+ --runner "linux-aarch64-a3-16 (distributed, serial), linux-aarch64-a3-16 (regular, 64 workers), linux-aarch64-a3-8 (custom)" \
+ --cases-summary cases-shards/cases_collection_summary.json
+
+ cat ${REPORT_MD} >> $GITHUB_STEP_SUMMARY
+
+ - name: Upload consolidated summary
+ if: always()
+ uses: actions/upload-artifact@v4
+ with:
+ name: npu-full-test-summary
+ path: |
+ npu-full-test-summary.md
+ npu-full-test-summary.json
+ retention-days: 30
\ No newline at end of file
diff --git a/.github/workflows/_torch-npu-upstream-test-custom.yml b/.github/workflows/_torch-npu-upstream-test-custom.yml
new file mode 100644
index 0000000000..027eea0fb9
--- /dev/null
+++ b/.github/workflows/_torch-npu-upstream-test-custom.yml
@@ -0,0 +1,114 @@
+name: Torch NPU Upstream Test Custom
+
+on:
+ workflow_call:
+ inputs:
+ python_version:
+ required: true
+ type: string
+ description: Python version to use
+ torch_wheel_artifact:
+ required: true
+ type: string
+ description: Name of the torch wheel artifact from build
+ torch_npu_wheel_artifact:
+ required: true
+ type: string
+ description: Name of the torch_npu wheel artifact from build
+ pytorch_src_artifact:
+ required: true
+ type: string
+ description: Name of the pytorch source artifact from build
+ docker_image:
+ required: true
+ type: string
+ description: Docker image to use
+ test_files:
+ required: true
+ type: string
+ description: Test files to run (comma-separated)
+
+jobs:
+ run_tests:
+ name: test_custom
+ runs-on: linux-aarch64-a3-16
+ timeout-minutes: 1200
+ container:
+ image: ${{ inputs.docker_image }}
+ options: --user root
+
+ steps:
+ - name: Setup NPU test environment
+ uses: kerer-ai/pytorch/.github/actions/setup-npu-test-env@dev_master
+ with:
+ python_version: ${{ inputs.python_version }}
+ torch_wheel_artifact: ${{ inputs.torch_wheel_artifact }}
+ torch_npu_wheel_artifact: ${{ inputs.torch_npu_wheel_artifact }}
+ pytorch_src_artifact: ${{ inputs.pytorch_src_artifact }}
+
+ - name: Run custom test files
+ id: run_tests
+ run: |
+ source /usr/local/Ascend/cann/set_env.sh 2>/dev/null || true
+ source /usr/local/Ascend/nnal/atb/set_env.sh 2>/dev/null || true
+
+ REPORT_DIR=test-reports
+ mkdir -p ${REPORT_DIR}
+
+ # Custom test files: per-case isolation execution
+ python${{ inputs.python_version }} ascend_pytorch/.github/scripts/run_npu_test_shard.py \
+ --test-files "${{ inputs.test_files }}" \
+ --test-dir pytorch-src/test \
+ --report-dir ${REPORT_DIR} \
+ --timeout 600 \
+ --verbose \
+ 2>&1 | tee /tmp/test_custom.log
+
+ TEST_STATUS=${PIPESTATUS[0]}
+ echo "status=${TEST_STATUS}" >> $GITHUB_OUTPUT
+ # Don't exit with test status - let step succeed to allow report generation
+
+ - name: Package and upload test reports
+ if: always()
+ run: |
+ # Package junit XMLs into compressed archive
+ if [ -d "test-reports/junit_xmls" ]; then
+ echo "=== Compressing junit XMLs ==="
+ XML_COUNT=$(find test-reports/junit_xmls -type f -name "*.xml" | wc -l)
+ echo "Found ${XML_COUNT} XML files"
+ tar -czf test-reports/junit_xmls.tar.gz -C test-reports junit_xmls
+ rm -rf test-reports/junit_xmls
+ echo "JUnit XMLs compressed"
+ fi
+
+ # Package failed cases logs into compressed archive
+ if [ -d "test-reports/failed_cases_logs" ]; then
+ echo "=== Compressing failed cases logs ==="
+ tar -czf test-reports/failed_cases_logs.tar.gz -C test-reports failed_cases_logs
+ rm -rf test-reports/failed_cases_logs
+ echo "Failed cases logs compressed"
+ fi
+
+ - name: Upload test reports
+ if: always()
+ uses: actions/upload-artifact@v4
+ with:
+ name: test-reports-custom
+ path: test-reports/
+ retention-days: 30
+
+ - name: Compress and upload error logs
+ if: failure()
+ run: |
+ mkdir -p error-logs
+ cp /tmp/test_custom.log error-logs/ 2>/dev/null || true
+ tar -czf error-logs-custom.tar.gz error-logs/
+ echo "Error logs compressed"
+
+ - name: Upload error logs
+ if: failure()
+ uses: actions/upload-artifact@v4
+ with:
+ name: error-logs-custom
+ path: error-logs-custom.tar.gz
+ retention-days: 30
\ No newline at end of file
diff --git a/.github/workflows/_torch-npu-upstream-test-dist.yml b/.github/workflows/_torch-npu-upstream-test-dist.yml
new file mode 100644
index 0000000000..77eb07d0ec
--- /dev/null
+++ b/.github/workflows/_torch-npu-upstream-test-dist.yml
@@ -0,0 +1,132 @@
+name: Torch NPU Upstream Test Distributed
+
+on:
+ workflow_call:
+ inputs:
+ python_version:
+ required: true
+ type: string
+ description: Python version to use
+ torch_wheel_artifact:
+ required: true
+ type: string
+ description: Name of the torch wheel artifact from build
+ torch_npu_wheel_artifact:
+ required: true
+ type: string
+ description: Name of the torch_npu wheel artifact from build
+ pytorch_src_artifact:
+ required: true
+ type: string
+ description: Name of the pytorch source artifact from build
+ docker_image:
+ required: true
+ type: string
+ description: Docker image to use
+ distributed_matrix:
+ required: true
+ type: string
+ description: Distributed shard matrix JSON
+ distributed_shards:
+ required: true
+ type: string
+ description: Number of distributed shards
+
+jobs:
+ run_tests:
+ name: test_distributed (${{ matrix.shard }}/${{ inputs.distributed_shards }})
+ runs-on: linux-aarch64-a3-16
+ timeout-minutes: 1200
+ container:
+ image: ${{ inputs.docker_image }}
+ options: --user root
+ strategy:
+ matrix:
+ shard: ${{ fromJson(inputs.distributed_matrix) }}
+ fail-fast: false
+ max-parallel: 2
+
+ steps:
+ - name: Setup NPU test environment
+ uses: kerer-ai/pytorch/.github/actions/setup-npu-test-env@dev_master
+ with:
+ python_version: ${{ inputs.python_version }}
+ torch_wheel_artifact: ${{ inputs.torch_wheel_artifact }}
+ torch_npu_wheel_artifact: ${{ inputs.torch_npu_wheel_artifact }}
+ pytorch_src_artifact: ${{ inputs.pytorch_src_artifact }}
+
+ - name: Download cases shard JSONs
+ uses: actions/download-artifact@v4
+ with:
+ name: cases-shards
+ path: cases-shards
+
+ - name: Run distributed shard ${{ matrix.shard }}/${{ inputs.distributed_shards }}
+ id: run_test
+ run: |
+ source /usr/local/Ascend/cann/set_env.sh 2>/dev/null || true
+ source /usr/local/Ascend/nnal/atb/set_env.sh 2>/dev/null || true
+
+ PYTHON=python${{ inputs.python_version }}
+ REPORT_DIR=test-reports
+ CASES_JSON="cases-shards/distributed_cases_shard_${{ matrix.shard }}.json"
+
+ mkdir -p ${REPORT_DIR}
+
+ # Get case count from JSON
+ TOTAL_CASES=$(python3 -c "import json; d=json.load(open('${CASES_JSON}')); print(d['total_cases'])")
+
+ echo "=== Distributed Shard ${{ matrix.shard }} (Case-level) ==="
+ echo "Total cases: ${TOTAL_CASES}"
+ echo "Runner: linux-aarch64-a3-16 (16-card NPU)"
+ echo "Execution mode: SERIAL"
+
+ # Distributed tests: pre-collected cases, serial execution
+ set +e
+ $PYTHON ascend_pytorch/.github/scripts/run_npu_test_shard.py \
+ --cases-json "${CASES_JSON}" \
+ --test-dir pytorch-src/test \
+ --report-dir ${REPORT_DIR} \
+ --quick-test 100 \
+ --timeout 600 \
+ --verbose \
+ 2>&1 | tee /tmp/test_shard_dist_${{ matrix.shard }}.log
+
+ TEST_STATUS=${PIPESTATUS[0]}
+ set -e
+ echo "status=${TEST_STATUS}" >> $GITHUB_OUTPUT
+ # Don't exit with test status - let step succeed to allow report generation
+
+ - name: Package and upload test reports
+ if: always()
+ run: |
+ # Package junit XMLs into compressed archive
+ if [ -d "test-reports/junit_xmls" ]; then
+ echo "=== Compressing junit XMLs ==="
+ XML_COUNT=$(find test-reports/junit_xmls -type f -name "*.xml" | wc -l)
+ echo "Found ${XML_COUNT} XML files"
+ tar -czf test-reports/junit_xmls.tar.gz -C test-reports junit_xmls
+ rm -rf test-reports/junit_xmls
+ echo "JUnit XMLs compressed: $(ls -lh test-reports/junit_xmls.tar.gz)"
+ fi
+
+ # Package failed cases logs into compressed archive
+ if [ -d "test-reports/failed_cases_logs" ]; then
+ echo "=== Compressing failed cases logs ==="
+ tar -czf test-reports/failed_cases_logs.tar.gz -C test-reports failed_cases_logs
+ rm -rf test-reports/failed_cases_logs
+ echo "Failed cases logs compressed: $(ls -lh test-reports/failed_cases_logs.tar.gz)"
+ fi
+
+ # Package shard_cases.json
+ if [ -f "test-reports/shard_dist-${{ matrix.shard }}_cases.json" ]; then
+ echo "Cases JSON exists: $(ls -lh test-reports/shard_dist-${{ matrix.shard }}_cases.json)"
+ fi
+
+ - name: Upload test reports
+ if: always()
+ uses: actions/upload-artifact@v4
+ with:
+ name: test-reports-dist-${{ matrix.shard }}
+ path: test-reports/
+ retention-days: 30
\ No newline at end of file
diff --git a/.github/workflows/_torch-npu-upstream-test-regular.yml b/.github/workflows/_torch-npu-upstream-test-regular.yml
new file mode 100644
index 0000000000..c9c61128c4
--- /dev/null
+++ b/.github/workflows/_torch-npu-upstream-test-regular.yml
@@ -0,0 +1,135 @@
+name: Torch NPU Upstream Test Regular
+
+on:
+ workflow_call:
+ inputs:
+ python_version:
+ required: true
+ type: string
+ description: Python version to use
+ torch_wheel_artifact:
+ required: true
+ type: string
+ description: Name of the torch wheel artifact from build
+ torch_npu_wheel_artifact:
+ required: true
+ type: string
+ description: Name of the torch_npu wheel artifact from build
+ pytorch_src_artifact:
+ required: true
+ type: string
+ description: Name of the pytorch source artifact from build
+ docker_image:
+ required: true
+ type: string
+ description: Docker image to use
+ regular_matrix:
+ required: true
+ type: string
+ description: Regular shard matrix JSON
+ regular_shards:
+ required: true
+ type: string
+ description: Number of regular shards
+
+jobs:
+ run_tests:
+ name: test_regular (${{ matrix.shard }}/${{ inputs.regular_shards }})
+ runs-on: linux-aarch64-a3-16
+ timeout-minutes: 1200
+ container:
+ image: ${{ inputs.docker_image }}
+ options: --user root
+ strategy:
+ matrix:
+ shard: ${{ fromJson(inputs.regular_matrix) }}
+ fail-fast: false
+ max-parallel: 5
+
+ steps:
+ - name: Setup NPU test environment
+ uses: kerer-ai/pytorch/.github/actions/setup-npu-test-env@dev_master
+ with:
+ python_version: ${{ inputs.python_version }}
+ torch_wheel_artifact: ${{ inputs.torch_wheel_artifact }}
+ torch_npu_wheel_artifact: ${{ inputs.torch_npu_wheel_artifact }}
+ pytorch_src_artifact: ${{ inputs.pytorch_src_artifact }}
+
+ - name: Download cases shard JSONs
+ uses: actions/download-artifact@v4
+ with:
+ name: cases-shards
+ path: cases-shards
+
+ - name: Run regular shard ${{ matrix.shard }}/${{ inputs.regular_shards }}
+ id: run_test
+ run: |
+ source /usr/local/Ascend/cann/set_env.sh 2>/dev/null || true
+ source /usr/local/Ascend/nnal/atb/set_env.sh 2>/dev/null || true
+
+ PYTHON=python${{ inputs.python_version }}
+ REPORT_DIR=test-reports
+ CASES_JSON="cases-shards/regular_cases_shard_${{ matrix.shard }}.json"
+
+ mkdir -p ${REPORT_DIR}
+
+ # Get case count from JSON
+ TOTAL_CASES=$(python3 -c "import json; d=json.load(open('${CASES_JSON}')); print(d['total_cases'])")
+
+ echo "=== Regular Shard ${{ matrix.shard }} (Case-level) ==="
+ echo "Total cases: ${TOTAL_CASES}"
+ echo "Runner: linux-aarch64-a3-16 (16-card NPU)"
+ echo "Execution mode: CONCURRENT (64 workers)"
+
+ # Regular tests: pre-collected cases, 64 concurrent workers
+ set +e
+ $PYTHON ascend_pytorch/.github/scripts/run_npu_test_shard.py \
+ --cases-json "${CASES_JSON}" \
+ --test-dir pytorch-src/test \
+ --report-dir ${REPORT_DIR} \
+ --timeout 600 \
+ --quick-test 100 \
+ --max-workers 16 \
+ --verbose \
+ 2>&1 | tee /tmp/test_shard_reg_${{ matrix.shard }}.log
+
+ TEST_STATUS=${PIPESTATUS[0]}
+ set -e
+ echo "status=${TEST_STATUS}" >> $GITHUB_OUTPUT
+ # Don't exit with test status - let step succeed to allow report generation
+
+ - name: Package and upload test reports
+ if: always()
+ run: |
+ # Package junit XMLs into compressed archive
+ if [ -d "test-reports/junit_xmls" ]; then
+ echo "=== Compressing junit XMLs ==="
+ XML_COUNT=$(find test-reports/junit_xmls -type f -name "*.xml" | wc -l)
+ echo "Found ${XML_COUNT} XML files"
+ tar -czf test-reports/junit_xmls.tar.gz -C test-reports junit_xmls
+ rm -rf test-reports/junit_xmls
+ echo "JUnit XMLs compressed: $(ls -lh test-reports/junit_xmls.tar.gz)"
+ fi
+
+ # Package failed cases logs into compressed archive
+ if [ -d "test-reports/failed_cases_logs" ]; then
+ echo "=== Compressing failed cases logs ==="
+ FAILED_LOGS_COUNT=$(find test-reports/failed_cases_logs -type f | wc -l)
+ echo "Found ${FAILED_LOGS_COUNT} failed case log files"
+ tar -czf test-reports/failed_cases_logs.tar.gz -C test-reports failed_cases_logs
+ rm -rf test-reports/failed_cases_logs
+ echo "Failed cases logs compressed: $(ls -lh test-reports/failed_cases_logs.tar.gz)"
+ fi
+
+ # Package shard_cases.json
+ if [ -f "test-reports/shard_reg-${{ matrix.shard }}_cases.json" ]; then
+ echo "Cases JSON exists: $(ls -lh test-reports/shard_reg-${{ matrix.shard }}_cases.json)"
+ fi
+
+ - name: Upload test reports
+ if: always()
+ uses: actions/upload-artifact@v4
+ with:
+ name: test-reports-reg-${{ matrix.shard }}
+ path: test-reports/
+ retention-days: 30
\ No newline at end of file
diff --git a/.github/workflows/_torch-npu-upstream-test.yml b/.github/workflows/_torch-npu-upstream-test.yml
new file mode 100644
index 0000000000..3b7ac687ce
--- /dev/null
+++ b/.github/workflows/_torch-npu-upstream-test.yml
@@ -0,0 +1,144 @@
+name: Torch NPU Upstream Test
+
+on:
+ workflow_call:
+ inputs:
+ python_version:
+ required: true
+ type: string
+ description: Python version to use
+ pytorch_ref:
+ required: false
+ type: string
+ default: 'fccc94ae83f61fe26559abc999797297196bac29'
+ description: PyTorch branch, tag, or commit SHA to build
+ torch_npu_ref:
+ required: false
+ type: string
+ default: 'master'
+ description: torch_npu branch, tag, or commit SHA to build
+ docker_image:
+ required: false
+ type: string
+ default: 'quay.io/kerer/pytorch:manylinux-cann9.0.0-beta.2-20260428'
+ description: Docker image to use for all jobs
+ distributed_shards:
+ required: false
+ type: string
+ default: '2'
+ description: Number of shards for distributed tests
+ regular_shards:
+ required: false
+ type: string
+ default: '5'
+ description: Number of shards for regular tests
+ test_files:
+ required: false
+ type: string
+ default: ''
+ description: Test files to run directly (comma-separated)
+
+defaults:
+ run:
+ shell: bash
+
+jobs:
+ # ============================================================================
+ # 1. Build PyTorch and torch_npu Wheels
+ # ============================================================================
+ build:
+ uses: ./.github/workflows/_torch-npu-upstream-build.yml
+ with:
+ pytorch_ref: ${{ inputs.pytorch_ref }}
+ torch_npu_ref: ${{ inputs.torch_npu_ref }}
+ python_version: ${{ inputs.python_version }}
+ docker_image: ${{ inputs.docker_image }}
+
+ # ============================================================================
+ # 2. Collect Test Cases (only when test_files is empty)
+ # ============================================================================
+ collect_cases:
+ needs:
+ - build
+ if: ${{ inputs.test_files == '' }}
+ uses: ./.github/workflows/_torch-npu-upstream-collect.yml
+ with:
+ python_version: ${{ inputs.python_version }}
+ torch_wheel_artifact: ${{ needs.build.outputs.torch-wheel }}
+ torch_npu_wheel_artifact: ${{ needs.build.outputs.torch-npu-wheel }}
+ pytorch_src_artifact: ${{ needs.build.outputs.pytorch-src }}
+ docker_image: ${{ inputs.docker_image }}
+ distributed_shards: ${{ inputs.distributed_shards }}
+ regular_shards: ${{ inputs.regular_shards }}
+
+ # ============================================================================
+ # 3. Run Distributed Tests (only when test_files is empty)
+ # ============================================================================
+ test_distributed:
+ needs:
+ - build
+ - collect_cases
+ if: ${{ inputs.test_files == '' }}
+ uses: ./.github/workflows/_torch-npu-upstream-test-dist.yml
+ with:
+ python_version: ${{ inputs.python_version }}
+ torch_wheel_artifact: ${{ needs.build.outputs.torch-wheel }}
+ torch_npu_wheel_artifact: ${{ needs.build.outputs.torch-npu-wheel }}
+ pytorch_src_artifact: ${{ needs.build.outputs.pytorch-src }}
+ docker_image: ${{ inputs.docker_image }}
+ distributed_matrix: ${{ needs.collect_cases.outputs.distributed_matrix }}
+ distributed_shards: ${{ needs.collect_cases.outputs.distributed_shards }}
+
+ # ============================================================================
+ # 4. Run Regular Tests (only when test_files is empty)
+ # ============================================================================
+ test_regular:
+ needs:
+ - build
+ - collect_cases
+ if: ${{ inputs.test_files == '' }}
+ uses: ./.github/workflows/_torch-npu-upstream-test-regular.yml
+ with:
+ python_version: ${{ inputs.python_version }}
+ torch_wheel_artifact: ${{ needs.build.outputs.torch-wheel }}
+ torch_npu_wheel_artifact: ${{ needs.build.outputs.torch-npu-wheel }}
+ pytorch_src_artifact: ${{ needs.build.outputs.pytorch-src }}
+ docker_image: ${{ inputs.docker_image }}
+ regular_matrix: ${{ needs.collect_cases.outputs.regular_matrix }}
+ regular_shards: ${{ needs.collect_cases.outputs.regular_shards }}
+
+ # ============================================================================
+ # 5. Run Custom Tests (only when test_files is provided)
+ # ============================================================================
+ test_custom:
+ needs:
+ - build
+ if: ${{ inputs.test_files != '' }}
+ uses: ./.github/workflows/_torch-npu-upstream-test-custom.yml
+ with:
+ python_version: ${{ inputs.python_version }}
+ torch_wheel_artifact: ${{ needs.build.outputs.torch-wheel }}
+ torch_npu_wheel_artifact: ${{ needs.build.outputs.torch-npu-wheel }}
+ pytorch_src_artifact: ${{ needs.build.outputs.pytorch-src }}
+ docker_image: ${{ inputs.docker_image }}
+ test_files: ${{ inputs.test_files }}
+
+ # ============================================================================
+ # 6. Generate Test Report
+ # ============================================================================
+ report:
+ needs:
+ - build
+ - collect_cases
+ - test_distributed
+ - test_regular
+ - test_custom
+ if: always() && needs.build.result == 'success'
+ uses: ./.github/workflows/_torch-npu-upstream-report.yml
+ with:
+ python_version: ${{ inputs.python_version }}
+ pytorch_version: ${{ needs.build.outputs.pytorch-version }}
+ torch_npu_wheel_name: ${{ needs.build.outputs.torch-npu-wheel }}
+ docker_image: ${{ inputs.docker_image }}
+ distributed_matrix: ${{ needs.collect_cases.outputs.distributed_matrix || '[]' }}
+ regular_matrix: ${{ needs.collect_cases.outputs.regular_matrix || '[]' }}
\ No newline at end of file
diff --git a/.github/workflows/build-docker-image.yml b/.github/workflows/build-docker-image.yml
new file mode 100644
index 0000000000..7df92cd072
--- /dev/null
+++ b/.github/workflows/build-docker-image.yml
@@ -0,0 +1,169 @@
+name: Build Docker Image
+
+on:
+ push:
+ branches: [dev_master]
+ paths:
+ - '.github/docker/pytorch-npu-builder.Dockerfile'
+ - '.github/workflows/build-docker-image.yml'
+ - '.github/scripts/build_image.sh'
+ schedule:
+ - cron: '0 2 * * 0' # UTC 02:00, Beijing 10:00, every Sunday
+ workflow_dispatch:
+ inputs:
+ cann_version:
+ description: 'CANN version (e.g., 9.0, 9.0.0-beta.2, 8.0)'
+ required: true
+ default: '9.0'
+ type: string
+ push_image:
+ description: 'Push image to registry'
+ required: true
+ default: true
+ type: boolean
+ force_build:
+ description: 'Force rebuild even if image exists'
+ required: false
+ default: false
+ type: boolean
+
+env:
+ REGISTRY: quay.io
+ QUAY_ORG: kerer
+ IMAGE_NAME: pytorch
+ CANN_STABLE: '9.0'
+
+jobs:
+ build:
+ runs-on: ubuntu-22.04-arm
+ environment: QUAY_USERNAME
+ permissions:
+ contents: read
+
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@v4
+
+ - name: Make script executable
+ run: chmod +x .github/scripts/build_image.sh
+
+ - name: Determine build parameters
+ id: params
+ run: |
+ # 确定是否推送镜像
+ # 规则:手动触发根据 inputs.push_image 决定
+ # push/schedule 触发默认推送(除非是 PR)
+ if [[ "${{ github.event_name }}" == "workflow_dispatch" ]]; then
+ PUSH_IMAGE="${{ inputs.push_image }}"
+ elif [[ "${{ github.event_name }}" == "push" || "${{ github.event_name }}" == "schedule" ]]; then
+ PUSH_IMAGE="true"
+ else
+ PUSH_IMAGE="false"
+ fi
+
+ echo "push_image=${PUSH_IMAGE}" >> $GITHUB_OUTPUT
+
+ # 确定是否强制构建
+ FORCE_BUILD="${{ inputs.force_build || 'false' }}"
+ echo "force_build=${FORCE_BUILD}" >> $GITHUB_OUTPUT
+
+ # 确定 CANN 版本
+ CANN_VERSION="${{ inputs.cann_version || env.CANN_STABLE }}"
+ echo "cann_version=${CANN_VERSION}" >> $GITHUB_OUTPUT
+
+ echo "Build parameters:"
+ echo " Event: ${{ github.event_name }}"
+ echo " CANN version: ${CANN_VERSION}"
+ echo " Push image: ${PUSH_IMAGE}"
+ echo " Force build: ${FORCE_BUILD}"
+
+ - name: Setup Docker Buildx
+ uses: docker/setup-buildx-action@v3
+ with:
+ driver: docker-container
+ driver-opts: image=moby/buildkit:latest
+
+ - name: Login to Quay.io
+ if: ${{ steps.params.outputs.push_image == 'true' }}
+ uses: docker/login-action@v3
+ with:
+ registry: ${{ env.REGISTRY }}
+ username: ${{ secrets.QUAY_USERNAME }}
+ password: ${{ secrets.QUAY_PASSWORD }}
+
+ - name: Build and push image
+ env:
+ QUAY_USERNAME: ${{ secrets.QUAY_USERNAME }}
+ QUAY_PASSWORD: ${{ secrets.QUAY_PASSWORD }}
+ SKIP_DOCKER_LOGIN: true # 已通过 login-action 登录,跳过脚本中的登录
+ run: |
+ # 构建参数
+ ARGS=""
+ if [[ "${{ steps.params.outputs.push_image }}" == "true" ]]; then
+ ARGS+=" --push"
+ fi
+ if [[ "${{ steps.params.outputs.force_build }}" == "true" ]]; then
+ ARGS+=" --force"
+ fi
+
+ .github/scripts/build_image.sh \
+ --cann-version "${{ steps.params.outputs.cann_version }}" \
+ --registry "${{ env.REGISTRY }}" \
+ --quay-org "${{ env.QUAY_ORG }}" \
+ --image-name "${{ env.IMAGE_NAME }}" \
+ ${ARGS} \
+ --verbose
+
+ - name: Summary
+ if: always()
+ run: |
+ echo "## Docker Image Build Summary" >> $GITHUB_STEP_SUMMARY
+ echo "" >> $GITHUB_STEP_SUMMARY
+ echo "### Build Details" >> $GITHUB_STEP_SUMMARY
+ echo "" >> $GITHUB_STEP_SUMMARY
+ echo "| Property | Value |" >> $GITHUB_STEP_SUMMARY
+ echo "|----------|-------|" >> $GITHUB_STEP_SUMMARY
+ echo "| **Trigger** | ${{ github.event_name }} |" >> $GITHUB_STEP_SUMMARY
+ echo "| **CANN Version** | ${{ steps.params.outputs.cann_version }} |" >> $GITHUB_STEP_SUMMARY
+ echo "| **Python Versions** | 3.10, 3.11, 3.12, 3.13 (pre-installed) |" >> $GITHUB_STEP_SUMMARY
+ echo "| **Registry** | ${{ env.REGISTRY }}/${{ env.QUAY_ORG }}/${{ env.IMAGE_NAME }} |" >> $GITHUB_STEP_SUMMARY
+ echo "| **Push Enabled** | ${{ steps.params.outputs.push_image }} |" >> $GITHUB_STEP_SUMMARY
+ echo "" >> $GITHUB_STEP_SUMMARY
+
+ # 提取 CANN 大版本号(正确处理简化版本)
+ CANN_INPUT="${{ steps.params.outputs.cann_version }}"
+
+ # 如果是简化版本(如 9.0),直接使用
+ # 如果是完整版本(如 9.0.0-beta.2),提取前两位数字
+ if [[ "$CANN_INPUT" =~ ^[0-9]+\.[0-9]+$ ]]; then
+ CANN_MAJOR="$CANN_INPUT"
+ else
+ CANN_MAJOR=$(echo "$CANN_INPUT" | grep -oP '^[0-9]+\.[0-9]+')
+ fi
+
+ echo "### Image Tags" >> $GITHUB_STEP_SUMMARY
+ echo "\`\`\`" >> $GITHUB_STEP_SUMMARY
+ echo "quay.io/kerer/pytorch:cann${{ steps.params.outputs.cann_version }}" >> $GITHUB_STEP_SUMMARY
+ echo "quay.io/kerer/pytorch:cann${CANN_MAJOR}" >> $GITHUB_STEP_SUMMARY
+ if [[ "${CANN_MAJOR}" == "${{ env.CANN_STABLE }}" ]]; then
+ echo "quay.io/kerer/pytorch:latest" >> $GITHUB_STEP_SUMMARY
+ echo "quay.io/kerer/pytorch:cann-latest" >> $GITHUB_STEP_SUMMARY
+ fi
+ echo "\`\`\`" >> $GITHUB_STEP_SUMMARY
+
+ echo "" >> $GITHUB_STEP_SUMMARY
+ echo "### Python Version Switch" >> $GITHUB_STEP_SUMMARY
+ echo "\`\`\`bash" >> $GITHUB_STEP_SUMMARY
+ echo "# Inside container:" >> $GITHUB_STEP_SUMMARY
+ echo "source /usr/local/bin/switch_python.sh 3.11" >> $GITHUB_STEP_SUMMARY
+ echo "source /usr/local/bin/switch_python.sh 3.12" >> $GITHUB_STEP_SUMMARY
+ echo "\`\`\`" >> $GITHUB_STEP_SUMMARY
+
+ echo "" >> $GITHUB_STEP_SUMMARY
+ echo "### Build time: $(date -u +'%Y-%m-%d %H:%M:%S UTC')" >> $GITHUB_STEP_SUMMARY
+
+ - name: Cleanup on failure
+ if: failure()
+ run: |
+ echo "::error::Build failed for CANN version ${{ steps.params.outputs.cann_version }}"
+ echo "Check the build logs for details"
\ No newline at end of file
diff --git a/.github/workflows/torch-npu-upstream-test-trigger.yml b/.github/workflows/torch-npu-upstream-test-trigger.yml
new file mode 100644
index 0000000000..203864cb3e
--- /dev/null
+++ b/.github/workflows/torch-npu-upstream-test-trigger.yml
@@ -0,0 +1,70 @@
+name: Torch NPU Upstream Main Test Trigger
+
+on:
+ push:
+ branches:
+ - main
+ - master
+ - 'release/**'
+ paths:
+ - '.github/workflows/torch-npu-upstream-test-trigger.yml'
+ - '.github/workflows/_torch-npu-upstream*.yml'
+ - '.github/scripts/*.py'
+ - 'test_upstream/**'
+ pull_request:
+ paths:
+ - '.github/workflows/torch-npu-upstream-test-trigger.yml'
+ - '.github/workflows/_torch-npu-upstream*.yml'
+ - '.github/scripts/*.py'
+ - 'test_upstream/**'
+ schedule:
+ - cron: '0 22 * * 1' # UTC 22:00 (Beijing time next day 06:00), every Monday
+ workflow_dispatch:
+ inputs:
+ python_version:
+ description: 'Python version (default 3.11)'
+ required: false
+ default: '3.11'
+ type: string
+ pytorch_ref:
+ description: 'PyTorch branch, tag, or commit SHA to build (default fccc94ae83f61fe26559abc999797297196bac29)'
+ required: false
+ default: 'fccc94ae83f61fe26559abc999797297196bac29'
+ type: string
+ torch_npu_ref:
+ description: 'torch_npu branch, tag, or commit SHA to build (default master)'
+ required: false
+ default: 'master'
+ type: string
+ docker_image:
+ description: 'Docker image to use for all jobs (default quay.io/kerer/pytorch:manylinux-cann9.0.0-beta.2-20260428)'
+ required: false
+ default: 'quay.io/kerer/pytorch:manylinux-cann9.0.0-beta.2-20260428'
+ type: string
+ distributed_shards:
+ description: 'Number of shards for distributed tests (default 3)'
+ required: false
+ default: '3'
+ type: string
+ regular_shards:
+ description: 'Number of shards for regular tests (default 5)'
+ required: false
+ default: '5'
+ type: string
+ test_files:
+ description: 'Test files to run directly (comma-separated, e.g., "test_meta.py,test_nn.py"). Skip shard assignment if set.'
+ required: false
+ default: ''
+ type: string
+
+jobs:
+ trigger_test:
+ uses: ./.github/workflows/_torch-npu-upstream-test.yml
+ with:
+ python_version: ${{ github.event.inputs.python_version || '3.11' }}
+ pytorch_ref: ${{ github.event.inputs.pytorch_ref || 'fccc94ae83f61fe26559abc999797297196bac29' }}
+ torch_npu_ref: ${{ github.event.inputs.torch_npu_ref || 'master' }}
+ docker_image: ${{ github.event.inputs.docker_image || 'quay.io/kerer/pytorch:manylinux-cann9.0.0-beta.2-20260428' }}
+ distributed_shards: ${{ github.event.inputs.distributed_shards || '3' }}
+ regular_shards: ${{ github.event.inputs.regular_shards || '5' }}
+ test_files: ${{ github.event.inputs.test_files || '' }}
\ No newline at end of file
diff --git a/test_upstream/case_paths_ci.yml b/test_upstream/case_paths_ci.yml
new file mode 100644
index 0000000000..f7e50e31d6
--- /dev/null
+++ b/test_upstream/case_paths_ci.yml
@@ -0,0 +1,161 @@
+# Test file blacklist configuration for NPU CI
+#
+# This file defines test files/directories to exclude from collection.
+# Rules support:
+# - Exact path match: "distributed/test_c10d_nccl"
+# - Directory prefix match: "dynamo/cpython/3_13" (matches all files in that directory)
+# - Glob patterns: "distributed/*nccl*"
+#
+# Rule paths are relative to the test directory (e.g., "test/distributed/...")
+# The "test/" prefix is automatically added if missing.
+
+blacklist:
+ # ==============================================================================
+ # Python version specific tests (Python 3.13 syntax, incompatible with 3.11)
+ # ==============================================================================
+ - dynamo/cpython/3_13
+
+ # ==============================================================================
+ # Platform-specific tests (CUDA/XPU/MPS - not supported on NPU)
+ # ==============================================================================
+ # CUDA specific tests
+ - test_cuda_multigpu
+ - test_cuda_nvml_based_avail
+ - test_cuda_primary_ctx
+ - test_cuda_sanitizer
+ - test_cuda_trace
+ - test_jiterator
+ - test_varlen_attention
+
+ # CUDA inductor tests
+ - inductor/test_cuda_repro
+ - inductor/test_cudagraph_trees
+ - inductor/test_cudagraph_trees_expandable_segments
+ - inductor/test_gpu_select_algorithm
+ - inductor/test_triton_cpu_backend
+ - inductor/test_triton_heuristics
+ - inductor/test_pallas
+ - inductor/test_layout_optim
+ - inductor/test_op_dtype_prop
+ - inductor/test_autoheuristic
+ - inductor/test_b2b_gemm
+ - inductor/test_best_config
+ - inductor/test_coordinate_descent_tuner
+
+ # JIT CUDA tests
+ - jit/test_cuda
+
+ # Flash Attention tests (CUDA only)
+ - nn/attention/test_fa3
+ - nn/attention/test_fa4
+
+ # XPU (Intel) tests
+ - test_xpu
+ - test_xpu_expandable_segments
+ - xpu/test_conv
+ - xpu/test_fusion
+ - xpu/test_gemm
+
+ # MPS (Apple Metal) tests
+ - test_mps
+
+ # ==============================================================================
+ # ONNX tests with heavy/optional dependencies
+ # ==============================================================================
+ # HuggingFace transformers tests
+ # Issue: Test requires transformers package (HuggingFace) with no graceful import handling
+ # Problem: Test file directly imports `import transformers` without try-import or skipif
+ # - transformers package: ~10 MB + dependencies (tokenizers, safetensors, huggingface-hub)
+ # - Test downloads model: hf-internal-testing/tiny-random-gptj (small, but needs network)
+ # - PyTorch CI excludes all ONNX tests by default (only runs with --onnx flag)
+ # Impact: ModuleNotFoundError: No module named 'transformers'
+ # Note: PyTorch run_test.py: options.exclude.extend(onnx_tests) in default behavior
+ - onnx/exporter/test_hf_models_e2e
+
+ # ==============================================================================
+ # Tests requiring specific compiled libraries not available in NPU build
+ # ==============================================================================
+ - custom_operator/test_custom_ops # requires libcustom_ops.so
+
+ # ==============================================================================
+ # Custom device extension tests (require separate compilation)
+ # ==============================================================================
+ # torch_openreg tests
+ # Issue: torch_openreg is a separate extension that must be compiled AFTER PyTorch is installed
+ # Problem: PyTorch run_test.py calls install_cpp_extensions() before running these tests
+ # - Requires CMake build linking against installed PyTorch
+ # - setup.py shows: "-DPYTORCH_INSTALL_DIR=" + get_pytorch_dir()
+ # - By default, PyTorch CI EXCLUDES test_openreg (only runs with --openreg flag)
+ # Impact: ModuleNotFoundError: No module named 'torch_openreg'
+ # See: run_test.py - options.exclude.append("test_openreg") in default behavior
+ - cpp_extensions/open_registration_extension/torch_openreg/tests/test_memory
+
+ # ==============================================================================
+ # Architecture-specific tests (x86_64 only, no ARM64/aarch64 support)
+ # ==============================================================================
+ # TorchRec tests
+ # Issue: torchrec depends on fbgemm-gpu, which only has x86_64 wheels on PyPI
+ # Problem: fbgemm-gpu provides pre-built wheels for x86_64 only:
+ # - fbgemm_gpu-1.6.0-cp311-cp311-manylinux_2_28_x86_64.whl
+ # - No aarch64/ARM64 wheels available
+ # Impact: Cannot install fbgemm-gpu on ARM64 NPU runner (linux-aarch64-a3-16)
+ # Test uses graceful handling (NoTest fallback), but we blacklist for clarity
+ - dynamo/test_torchrec
+
+ # ==============================================================================
+ # PyTorch upstream design issues (missing __init__.py or internal API changes)
+ # ==============================================================================
+ # Quantization experimental tests
+ # Issue: torch/ao/quantization/experimental/ directory exists but has NO __init__.py
+ # Python cannot recognize it as a package, causing ModuleNotFoundError on import
+ # See: https://github.com/pytorch/pytorch/tree/main/torch/ao/quantization/experimental
+ # Files exist: adaround_optimization.py, linear.py, observer.py, etc.
+ # Missing: __init__.py (required for package import)
+ # Impact: 6 test files fail with "No module named 'torch.ao.quantization.experimental'"
+ - quantization/core/experimental/test_adaround_eager
+ - quantization/core/experimental/test_fake_quantize
+ - quantization/core/experimental/test_linear
+ - quantization/core/experimental/test_nonuniform_observer
+ - quantization/core/experimental/test_quantized_tensor
+ - quantization/core/experimental/test_quantizer
+
+ # numpy internal private module dependency
+ # Issue: Test imports private function from numpy's internal module structure
+ # Code: `from numpy.linalg.linalg import _multi_dot_matrix_chain_order`
+ # Problem: numpy 2.0+ restructured internal modules, `numpy.linalg.linalg` no longer exists
+ # - Private APIs (_ prefix) have no stability guarantee, can be moved/removed anytime
+ # - numpy 2.0 moved _multi_dot_matrix_chain_order to numpy.linalg._linalg or removed it
+ # - PyTorch test code uses private API instead of stable public API
+ # Impact: ModuleNotFoundError: No module named 'numpy.linalg.linalg'
+ - torch_np/numpy_tests/linalg/test_linalg
+
+ # PyTorch test directory structure issue (test/ has no __init__.py)
+ # Issue: Test imports using package path `from test.jit.fixtures_srcs.generate_models import ...`
+ # Problem: PyTorch test/ directory is NOT a Python package (missing __init__.py files)
+ # - test/__init__.py does not exist
+ # - test/jit/__init__.py does not exist
+ # - Without __init__.py, Python cannot recognize `test.jit` as a valid package path
+ # - PyTorch CI works via special PYTHONPATH setup that our collection script doesn't replicate
+ # Impact: ModuleNotFoundError: No module named 'test.jit'
+ - jit/fixtures_srcs/test_upgrader_models_generation
+
+ # ==============================================================================
+ # Tests with torch_npu compatibility issues
+ # ==============================================================================
+ # RPC tests fail due to torch_npu serialization.py 'object' type annotation
+ # (TorchScript incompatible) - these are handled by BACKEND=hccl skip
+ # Uncomment if needed:
+ # - distributed/rpc/test_tensorpipe_agent
+ # - distributed/rpc/test_faulty_agent
+ # - distributed/rpc/cuda/test_tensorpipe_agent
+
+ # ==============================================================================
+ # Tests with environment variable requirements (handled by BACKEND=hccl)
+ # ==============================================================================
+ # These tests require BACKEND=gloo/nccl environment variable
+ # Setting BACKEND=hccl causes them to define no test classes
+ # - distributed/algorithms/quantization/test_quantization
+ # - distributed/test_distributed_spawn
+
+# Whitelist is empty - collect all tests except blacklist
+whitelist: []
\ No newline at end of file