Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions include/xsf/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
#include <cuda/std/cstddef>
#include <cuda/std/cstdint>
#include <cuda/std/limits>
#include <cuda/std/numeric>
#include <cuda/std/tuple>
#include <cuda/std/type_traits>
#include <cuda/std/utility>
Expand Down Expand Up @@ -158,6 +159,9 @@ XSF_HOST_DEVICE constexpr T clamp(T &v, T &lo, T &hi) {
template <typename T>
using numeric_limits = cuda::std::numeric_limits<T>;

using cuda::std::gcd;
using cuda::std::lcm;

Comment on lines +162 to +164

@steppi steppi Apr 9, 2026

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't worry about the CUDA side. I still have to figure out cupy/cupy#9839 before we can even try these in CuPy. I'm pretty sure just adding using cuda::std::gcd won't work in all cases, and we actually need wrappers for stdlib functions like the other ones in this file. I recall I had suggested using using like this when Irwin first set this up, but there was a reason he had done things the way he did.

// Must use thrust for complex types in order to support CuPy
template <typename T>
using complex = thrust::complex<T>;
Expand Down Expand Up @@ -251,6 +255,7 @@ using cuda::std::uint64_t;
#include <iterator>
#include <limits>
#include <math.h>
#include <numeric>
#include <tuple>
#include <type_traits>
#include <utility>
Expand Down
147 changes: 147 additions & 0 deletions include/xsf/stats.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include "xsf/bessel.h"
#include "xsf/binom.h"
#include "xsf/cephes/bdtr.h"
#include "xsf/cephes/chdtr.h"
#include "xsf/cephes/const.h"
Expand Down Expand Up @@ -486,6 +487,152 @@ inline double pdtrc(double k, double m) { return cephes::pdtrc(k, m); }

inline double pdtri(int k, double y) { return cephes::pdtri(k, y); }

template <typename InputMat>
XSF_HOST_DEVICE inline typename InputMat::value_type take_from_discrete_sf(InputMat pmf, long long int k) {
// Return the inclusive survival mass sum_{i >= k} pmf(i).
using T = typename InputMat::value_type;
auto size = pmf.extent(0);
if (k >= static_cast<long long int>(size)) {
return T(0.0);
}
auto k0 = k < 0 ? decltype(size)(0) : static_cast<decltype(size)>(k);
T total = T(0.0);
for (auto i = k0; i < size; ++i) {
total += pmf(i);
}
return total;
}

namespace detail {

template <typename FreqTable2D>
struct cvm_freq_table_row {
using value_type = typename FreqTable2D::value_type;

FreqTable2D freq_table;
int64_t row;

// Expose the selected frequency-table row as a 1D distribution.
XSF_HOST_DEVICE inline auto extent(int) const { return freq_table.extent(1); }

// Read the k-th entry from the selected frequency-table row.
XSF_HOST_DEVICE inline value_type operator()(int64_t k) const { return freq_table(row, k); }
};

template <typename FreqTable2D>
XSF_HOST_DEVICE inline void
cvm_freq_table_all(int64_t m, int64_t n, int64_t a, int64_t b, FreqTable2D gs, FreqTable2D next_gs) {
// Fill the exact Cramér-von Mises two-sample frequency table.
using T = typename FreqTable2D::value_type;
int64_t K = static_cast<int64_t>(gs.extent(1));

// initialize gs to 0
for (int64_t v = 0; v < m + 1; ++v)
for (int64_t k = 0; k < K; ++k)
gs(v, k) = T(0);
// base case: gs(0, 0) = 1
gs(0, 0) = T(1);

for (int64_t u = 0; u < n + 1; ++u) {
// v = 0: no next_gs(v-1, ...) term
{
int64_t d = -b * u;
int64_t d2 = d * d;
int64_t kstart = (d2 < K) ? d2 : K;
// next_gs(0, k) = gs(0, k - d2) for k >= d2, else 0
for (int64_t k = 0; k < kstart; ++k) {
next_gs(0, k) = T(0);
}
for (int64_t k = kstart; k < K; ++k) {
next_gs(0, k) = gs(0, k - d2);
}
}
// v > 0: both terms contribute
for (int64_t v = 1; v < m + 1; ++v) {
int64_t d = a * v - b * u;
int64_t d2 = d * d; // d^2 = (a*v - b*u)^2
int64_t kstart = (d2 < K) ? d2 : K;
for (int64_t k = 0; k < kstart; ++k) {
next_gs(v, k) = T(0);
}
for (int64_t k = kstart; k < K; ++k) {
next_gs(v, k) = next_gs(v - 1, k - d2) + gs(v, k - d2);
}
}
FreqTable2D tmp = gs;
gs = next_gs;
next_gs = tmp;
}
// We swap `gs` and `next_gs` at each u-step, so buffer parity depends on n.
// If n is even, the final table ends up in the original `next_gs` buffer;
// copy it back so the caller can always read results from the original `gs`.
if (n % 2 == 0) {
for (int64_t v = 0; v < m + 1; ++v) {
for (int64_t k = 0; k < K; ++k) {
next_gs(v, k) = gs(v, k);
}
}
}
}

} // namespace detail

template <typename FreqTable2D>
XSF_HOST_DEVICE inline void cvm_2samp_freq_table(int64_t m, int64_t n, FreqTable2D freq_table, FreqTable2D workspace) {
// Prepare constants and generate the frequency table used by the exact p-value.
/*
* Generate the exact Cramér-von Mises two-sample frequency table for
* sample sizes m and n. The table is independent of the scalar statistic.
*/
if (m <= 0 || n <= 0) {
set_error("cvm_2samp_freq_table", SF_ERROR_DOMAIN, "m and n must be positive");
return;
}
// [1, p. 3]
int64_t lcm = std::lcm(m, n);
// [1, p. 4], below eq. 3
int64_t a = lcm / m;
int64_t b = lcm / n;

detail::cvm_freq_table_all(m, n, a, b, freq_table, workspace);
}

template <typename FreqTable2D>
XSF_HOST_DEVICE inline double pval_cvm_2samp_exact(double s, int64_t m, int64_t n, FreqTable2D freq_table) {
// Compute the exact p-value from a precomputed Cramér-von Mises frequency table.
/*
* Compute the exact p-value of the Cramér-von Mises two-sample test
* for a given value s of the test statistic and where m and n are the sizes
* of the samples.
*
* [1] Y. Xiao, A. Gordon, and A. Yakovlev, "A C++ Program for
* the Cramér-Von Mises Two-Sample Test", J. Stat. Soft.,
* vol. 17, no. 8, pp. 1-15, Dec. 2006.
* [2] T. W. Anderson "On the Distribution of the Two-Sample Cramér-von Mises
* Criterion," The Annals of Mathematical Statistics, Ann. Math. Statist.
* 33(3), 1148-1159, (September, 1962)
*/
if (m <= 0 || n <= 0) {
set_error("pval_cvm_2samp_exact", SF_ERROR_DOMAIN, "m and n must be positive");
return std::numeric_limits<double>::quiet_NaN();
}
// [1, p. 3]
int64_t lcm = std::lcm(m, n);
// Combine Eq. 9 in [2] with Eq. 2 in [1] and solve for $\zeta$
// Hint: `s` is $U$ in [2], and $T_2$ in [1] is $T$ in [2]
int64_t mn = m * n;

// Uses double floor division since s is double
int64_t zeta =
static_cast<int64_t>(std::floor((lcm * lcm * (m + n) * (6.0 * s - mn * (4.0 * mn - 1))) / (6.0 * mn * mn)));

detail::cvm_freq_table_row<FreqTable2D> freq_table_row{freq_table, m};
auto sum_freq = take_from_discrete_sf(freq_table_row, zeta);

double combinations = xsf::binom(static_cast<double>(m + n), static_cast<double>(m));
return sum_freq / combinations;
}

inline double smirnov(int n, double x) { return cephes::smirnov(n, x); }

inline double smirnovc(int n, double x) { return cephes::smirnovc(n, x); }
Expand Down
83 changes: 83 additions & 0 deletions tests/xsf_tests/test_pval_cvm_2samp_exact.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
#include "../testing_utils.h"
#define MDSPAN_USE_PAREN_OPERATOR 1
#include <xsf/stats.h>
#include <xsf/third_party/kokkos/mdspan.hpp>

/*
// Reference values computed with scipy.stats._hypotests._pval_cvm_2samp_exact

import numpy as np
from scipy import stats

rng = np.random.default_rng(seed=42)

list_m = rng.integers(3, 30, size=5)
list_n = rng.integers(3, 30, size=5)
rtol = 1e-10

for m, n in zip(list_m, list_n):
x = rng.standard_normal(m)
y = rng.standard_normal(n)
res = stats.cramervonmises_2samp(x, y, method="exact")
T = res.statistic
# Convert normalized statistic T to the unnormalized U
U = m * n * (m + n) * T + m * n * (4 * m * n - 1) / 6
p_value = stats._hypotests._pval_cvm_2samp_exact(U, m, n)
assert np.isclose(res.pvalue, p_value, rtol=rtol), "The p-values do not match!"
print(f"U={U}, m={m}, n={n}, p-value={p_value}")
*/
TEST_CASE("take_from_discrete_sf test", "[take_from_discrete_sf][xsf_tests]") {
std::vector<double> pmf = {0.125, 0.375, 0.375, 0.125};
std::mdspan pmf_span(pmf.data(), pmf.size());

REQUIRE(xsf::take_from_discrete_sf(pmf_span, -1) == 1.0);
REQUIRE(xsf::take_from_discrete_sf(pmf_span, 0) == 1.0);
REQUIRE(xsf::take_from_discrete_sf(pmf_span, 1) == 0.875);
REQUIRE(xsf::take_from_discrete_sf(pmf_span, 3) == 0.125);
REQUIRE(xsf::take_from_discrete_sf(pmf_span, 4) == 0.0);
}

TEST_CASE("pval_cvm_2samp_exact test", "[pval_cvm_2samp_exact][xsf_tests]") {
using test_case = std::tuple<double, int, int, double, double>;
auto [s, m, n, pval_expected, rtol] = GENERATE(
test_case{12559.0, 5, 26, 0.11812654860485784, 1e-10}, test_case{8901.0, 23, 5, 0.9907610907610908, 1e-10},
test_case{119376.0, 20, 21, 0.5716351061359124, 1e-10}, test_case{8862.0, 14, 8, 0.2679738562091503, 1e-10},
test_case{3491.0000000000005, 14, 5, 0.34657722738218094, 1e-10}
);

const int64_t lcm = std::lcm(m, n);
const int64_t K = (m + n) * lcm * lcm + 1;

std::vector<int64_t> buf1((m + 1) * K, 0);
std::vector<int64_t> buf2((m + 1) * K, 0);

using mdspan_2d = std::mdspan<int64_t, std::dextents<size_t, 2>>;
mdspan_2d gs(buf1.data(), static_cast<size_t>(m + 1), static_cast<size_t>(K));
mdspan_2d next_gs(buf2.data(), static_cast<size_t>(m + 1), static_cast<size_t>(K));

xsf::cvm_2samp_freq_table(m, n, gs, next_gs);
auto pval = xsf::pval_cvm_2samp_exact(s, m, n, gs);
const double rel_error = xsf::extended_relative_error(pval, pval_expected);
CAPTURE(s, m, n, K, pval, pval_expected, rel_error);
REQUIRE(rel_error <= rtol);
}

TEST_CASE("pval_cvm_2samp_exact edge cases", "[pval_cvm_2samp_exact][xsf_tests]") {
using test_case = std::tuple<double, int, int, double>;
auto [s, m, n, pval_expected] = GENERATE(test_case{0.0, 3, 3, 1.0}, test_case{1e6, 3, 3, 0.0});

const int64_t lcm = std::lcm(m, n);
const int64_t K = (m + n) * lcm * lcm + 1;

std::vector<int64_t> buf1((m + 1) * K, 0);
std::vector<int64_t> buf2((m + 1) * K, 0);

using mdspan_2d = std::mdspan<int64_t, std::dextents<size_t, 2>>;
mdspan_2d gs(buf1.data(), static_cast<size_t>(m + 1), static_cast<size_t>(K));
mdspan_2d next_gs(buf2.data(), static_cast<size_t>(m + 1), static_cast<size_t>(K));

xsf::cvm_2samp_freq_table(m, n, gs, next_gs);
auto pval = xsf::pval_cvm_2samp_exact(s, m, n, gs);
CAPTURE(s, m, n, K, pval, pval_expected);
REQUIRE(pval == pval_expected);
}
Loading