diff --git a/include/xsf/bessel.h b/include/xsf/bessel.h index 566664cca6..5bc9954913 100644 --- a/include/xsf/bessel.h +++ b/include/xsf/bessel.h @@ -1164,6 +1164,37 @@ inline std::complex cyl_hankel_1(float v, std::complex z) { return static_cast>(cyl_hankel_1(static_cast(v), static_cast>(z))); } +template +inline void cyl_hankel_1_all(double v, std::complex z, OutputVec cy) { + int kode = 1; + int m = 1; + int nz, ierr; + int sign = 1; + + int n = cy.extent(0); + + if (std::isnan(v) || std::isnan(z.real()) || std::isnan(z.imag())) { + for (int i = 0; i < n; ++i) { + cy(i).real(NAN); + cy(i).imag(NAN); + } + return; + } + + if (v < 0) { + v = -v; + sign = -1; + } + + nz = amos::besh(z, v, kode, m, n, cy.data_handle(), &ierr); + set_error_and_nan("hankel1_all:", ierr_to_sferr(nz, ierr), cy); + if (sign == -1) { + for (int i = 0; i < n; ++i) { + cy(i) = detail::rotate(cy(i), v + i); + } + } +} + inline std::complex cyl_hankel_2(double v, std::complex z) { int n = 1; int kode = 1; diff --git a/include/xsf/error.h b/include/xsf/error.h index 7221b5e6c4..d2763c5000 100644 --- a/include/xsf/error.h +++ b/include/xsf/error.h @@ -52,6 +52,20 @@ XSF_HOST_DEVICE void set_error_and_nan(const char *name, sf_error_t code, std::c } } +template +XSF_HOST_DEVICE void set_error_and_nan(const char *name, sf_error_t code, OutputVec &cy) { + if (code != SF_ERROR_OK) { + set_error(name, code, nullptr); + + if (code == SF_ERROR_DOMAIN || code == SF_ERROR_OVERFLOW || code == SF_ERROR_NO_RESULT) { + for (int i = 0; i < cy.extent(0); ++i) { + cy(i).real(std::numeric_limits::quiet_NaN()); + cy(i).imag(std::numeric_limits::quiet_NaN()); + } + } + } +} + } // namespace xsf #endif diff --git a/include/xsf/numpy.h b/include/xsf/numpy.h index 465d4feee3..d38ee8fdf3 100644 --- a/include/xsf/numpy.h +++ b/include/xsf/numpy.h @@ -204,6 +204,8 @@ namespace numpy { using lD_D = cdouble (*)(long int, cdouble); using Dd_D = cdouble (*)(cdouble, double); using Ff_F = cfloat (*)(cfloat, float); + using dD_D1 = void (*)(double, cdouble, cdouble_1d); + using fF_F1 = void (*)(float, cfloat, cfloat_1d); // autodiff, 2 inputs, 1 output using autodiff0_if_f = autodiff0_float (*)(int, autodiff0_float); diff --git a/tests/testing_utils.h b/tests/testing_utils.h index 7d62f37e29..a121a9dded 100644 --- a/tests/testing_utils.h +++ b/tests/testing_utils.h @@ -16,6 +16,7 @@ #include #include #include +#include #include diff --git a/tests/xsf_tests/test_cyl_bessel_all.cpp b/tests/xsf_tests/test_cyl_bessel_all.cpp new file mode 100644 index 0000000000..8f818d5e53 --- /dev/null +++ b/tests/xsf_tests/test_cyl_bessel_all.cpp @@ -0,0 +1,110 @@ +#include "../testing_utils.h" +#include +#include + +#include +#include +#include + +// parameter lists +namespace bessel_test_params { + +// real and imaginary parts for z +const std::vector Z_PARTS = {-100.0, -10.0, -1.0, -0.1, -1e-6, 0.0, 1e-6, 0.1, 1.0, 10.0, 100.0}; + +const std::vector NU_VALUES = {-0.5, -0.25, 0.0, 0.25, 0.5}; + +const std::vector N_VALUES = {10, 100}; + +} // namespace bessel_test_params + +static bool is_nan(std::complex const &z) { return std::isnan(z.real()) || std::isnan(z.imag()); } + +// match exact (z, nu, n, i) combinations +using skip_entry_t = std::tuple, double, int, int>; + +static bool should_skip(std::complex z, double nu, int n, int i, std::vector const &skip_list) { + for (auto const &[skip_z, skip_nu, skip_n, skip_i] : skip_list) { + if (z == skip_z && nu == skip_nu && n == skip_n && i == skip_i) { + return true; + } + } + return false; +} + +// helper: compare any vectorized Bessel "_all" function against scalar calls +template +static void compare_vectorized_with_scalar( + std::complex z, double nu, int n, double rtol, VecFunc &&vec_func, ScalarFunc &&scalar_func, + std::vector const &skip_list = {} +) { + // compute all scalar references + std::vector> refs(n); + bool any_nan = false; + for (int i = 0; i < n; ++i) { + refs[i] = scalar_func(z, nu + std::copysign(i, nu)); + if (is_nan(refs[i])) { + any_nan = true; + } + } + + // call the vectorized routine + std::vector> cy_vec(n); + std::mdspan cy_span(cy_vec.data(), cy_vec.size()); + vec_func(z, nu, cy_span); + + CAPTURE(z, nu, n, rtol); + + // if any scalar ref is NaN the "_all" routine NaN'd the whole array + // + // The underlying AMOS routines might set ierr = 1, 2, 4, 5. + // For scalar wrappers (e.g. cyl_hankel_1) only that single element is NaN'd. + // For "_all" wrappers the entire output array is NaN'd when ierr = 1, 2, 4, or 5. + // Therefore, if any of the scalar reference values is NaN, expect the whole + // vectorized result to be NaN. + if (any_nan) { + for (int i = 0; i < n; ++i) { + CAPTURE(i, cy_vec[i]); + REQUIRE(is_nan(cy_vec[i])); + } + return; + } + + // compare element-wise + for (int i = 0; i < n; ++i) { + if (should_skip(z, nu, n, i, skip_list)) { + continue; + } + const auto rel_error = xsf::extended_relative_error(cy_vec[i], refs[i]); + CAPTURE(i, cy_vec[i], refs[i], rel_error); + REQUIRE(rel_error <= rtol); + } +} + +namespace { +// known mismatches between scalar and vectorized cyl_hankel_1 +// {z, nu, n, i} where i is the element index to skip +const std::vector HANKEL1_SKIP_LIST = { + {std::complex{-1e-6, -1.0}, 0.5, 10, 1}, {std::complex{-1e-6, -1.0}, 0.5, 100, 1}, + {std::complex{-1e-6, -1.0}, -0.5, 10, 1}, {std::complex{-1e-6, -1.0}, -0.5, 100, 1}, + {std::complex{1e-6, -1.0}, 0.5, 10, 1}, {std::complex{1e-6, -1.0}, 0.5, 100, 1}, + {std::complex{1e-6, -1.0}, -0.5, 10, 1}, {std::complex{1e-6, -1.0}, -0.5, 100, 1}, + {std::complex{0.0, -1.0}, 0.5, 10, 1}, {std::complex{0.0, -1.0}, 0.5, 100, 1}, + {std::complex{0.0, -1.0}, -0.5, 10, 1}, {std::complex{0.0, -1.0}, -0.5, 100, 1}, +}; +} // namespace + +TEST_CASE("cyl_hankel_1_all vectorized vs scalar", "[cyl_hankel_1_all][xsf_tests]") { + const double zr = GENERATE(from_range(bessel_test_params::Z_PARTS)); + const double zi = GENERATE(from_range(bessel_test_params::Z_PARTS)); + const double nu = GENERATE(from_range(bessel_test_params::NU_VALUES)); + const int n = GENERATE(from_range(bessel_test_params::N_VALUES)); + + std::complex z(zr, zi); + double rtol = 1e-12; + + compare_vectorized_with_scalar( + z, nu, n, rtol, [](std::complex z, double nu, auto cy) { xsf::cyl_hankel_1_all(nu, z, cy); }, + [](std::complex z, double nu) { return xsf::cyl_hankel_1(nu, z); }, HANKEL1_SKIP_LIST + ); +}