diff --git a/cub/cub/detail/segmented_params.cuh b/cub/cub/detail/segmented_params.cuh index fe5cc5c9162..543f55b5036 100644 --- a/cub/cub/detail/segmented_params.cuh +++ b/cub/cub/detail/segmented_params.cuh @@ -99,48 +99,23 @@ struct supported_options static constexpr ::cuda::std::size_t count = sizeof...(Options); }; -//! @brief Uniform discrete parameter — a single runtime value with a known set of supported options. -template -struct uniform_discrete_param -{ - using value_type = T; - using supported_options_t = supported_options; - - T value; - - _CCCL_HOST_DEVICE constexpr uniform_discrete_param(T v) - : value(v) - {} - - uniform_discrete_param() = default; - - template - _CCCL_HOST_DEVICE constexpr auto get_param([[maybe_unused]] SegmentIndexT segment_id) const - { - return value; - } -}; - -//! @brief Per-segment discrete parameter — per-segment values with a known set of supported options. -template -struct per_segment_discrete_param +//! @brief Static discrete parameter — a single compile-time value that is also its only supported option. +//! +//! Holds no runtime value, so it cannot be put into a state that disagrees with its supported option, and +//! @c dispatch_impl therefore always matches it. This is the safe representation for a compile-time-fixed discrete +//! parameter (e.g. a statically known top-k selection direction): modeling such a parameter with a runtime value +//! instead would risk that value silently disagreeing with the supported option (a no-op dispatch unless +//! @c CCCL_ENABLE_ASSERTIONS is set). +template +struct static_discrete_param { - using iterator_type = IteratorT; using value_type = T; - using supported_options_t = supported_options; - - IteratorT iterator; - - _CCCL_HOST_DEVICE constexpr per_segment_discrete_param(IteratorT iter) - : iterator(iter) - {} - - per_segment_discrete_param() = default; + using supported_options_t = supported_options; template - _CCCL_HOST_DEVICE constexpr auto get_param(SegmentIndexT segment_id) const + [[nodiscard]] _CCCL_HOST_DEVICE constexpr T get_param(SegmentIndexT) const noexcept { - return iterator[segment_id]; + return Value; } }; @@ -164,7 +139,7 @@ dispatch_impl(T val, [[maybe_unused]] supported_options __supported_ return match_found; } -//! @brief Dispatcher that resolves a per-segment discrete parameter to a compile-time constant +//! @brief Dispatcher that resolves a discrete parameter to a compile-time constant //! and invokes a functor with the matched option. //! //! @param[in] param Discrete parameter to resolve. diff --git a/cub/cub/device/dispatch/dispatch_batched_topk.cuh b/cub/cub/device/dispatch/dispatch_batched_topk.cuh index d0f2d4eed0a..56d12268dc9 100644 --- a/cub/cub/device/dispatch/dispatch_batched_topk.cuh +++ b/cub/cub/device/dispatch/dispatch_batched_topk.cuh @@ -36,6 +36,7 @@ #include #include #include +#include #include #include #include @@ -46,31 +47,35 @@ CUB_NAMESPACE_BEGIN namespace detail::batched_topk { // ----------------------------------------------------------------------------- -// Internal: wrap user-facing select direction into discrete param for dispatch +// Internal: wrap the compile-time select direction into a discrete param for dispatch // ----------------------------------------------------------------------------- -// Uniform (compile-time): __constant -> single-option uniform_discrete_param. +// The selection direction is compile-time only: callers pass `::cuda::__argument::__constant`, which maps to a +// value-less static_discrete_param. Because the direction is fixed at compile time and carries no runtime value, it +// can never disagree with its only supported option, so dispatch can never silently degrade to a no-op. template [[nodiscard]] _CCCL_HOST_DEVICE auto wrap_select_direction(::cuda::__argument::__constant) { - return params::uniform_discrete_param{Dir}; + return params::static_discrete_param{}; } -// Uniform: single enum value → uniform_discrete_param -[[nodiscard]] _CCCL_HOST_DEVICE inline auto wrap_select_direction(detail::topk::select dir) +// The selection direction is intentionally a compile-time constant: only `::cuda::__argument::__constant` is +// accepted (the overload above maps it to a value-less static_discrete_param). This catch-all documents that +// deliberate limitation and rejects anything else (e.g. a runtime `detail::topk::select` or a per-segment iterator of +// directions) with a clear diagnostic. It is an intent/documentation guard rather than a user-facing one: callers +// reach the algorithm through the min/max device entry points (DeviceBatchedTopK::{Max,Min}{Keys,Pairs}), which +// construct the matching `__constant` internally, so `dispatch` is only ever invoked with a direction we create. +template +[[nodiscard]] _CCCL_HOST_DEVICE auto wrap_select_direction(SelectDirectionT) { - return params::uniform_discrete_param{ - dir}; -} - -// Per-segment: iterator of enums → per_segment_discrete_param -_CCCL_TEMPLATE(typename IteratorT) -_CCCL_REQUIRES((!::cuda::std::is_same_v<::cuda::std::remove_cv_t, detail::topk::select>) ) -[[nodiscard]] _CCCL_HOST_DEVICE auto wrap_select_direction(IteratorT iter) -{ - return params:: - per_segment_discrete_param{ - iter}; + static_assert(::cuda::std::__always_false_v, + "DeviceBatchedTopK currently supports only compile-time selection directions: the min/max entry " + "points (DeviceBatchedTopK::{Max,Min}{Keys,Pairs}) dispatch with a " + "::cuda::__argument::__constant; runtime or per-segment directions are " + "intentionally not supported"); + // Unreachable (the static_assert above always fires); keeps the return type well-formed so the only diagnostic is + // the message above. + return params::static_discrete_param{}; } // ----------------------------------------------------------------------------- diff --git a/cub/test/catch2_test_device_segmented_topk_keys.cu b/cub/test/catch2_test_device_segmented_topk_keys.cu index 3d00c1119cc..3ef76bc2743 100644 --- a/cub/test/catch2_test_device_segmented_topk_keys.cu +++ b/cub/test/catch2_test_device_segmented_topk_keys.cu @@ -85,11 +85,16 @@ using key_types = >; // clang-format on +// Selection direction is a compile-time option; cover both as a static test axis. +using select_direction_list = + c2h::enum_type_list; + C2H_TEST("DeviceBatchedTopK::{Min,Max}Keys work with small fixed-size segments", "[keys][segmented][topk][device]", key_types, max_segment_size_list, - max_num_k_list) + max_num_k_list, + select_direction_list) { using segment_size_t = cuda::std::int64_t; using segment_index_t = cuda::std::int64_t; @@ -100,8 +105,8 @@ C2H_TEST("DeviceBatchedTopK::{Min,Max}Keys work with small fixed-size segments", constexpr segment_size_t static_max_segment_size = c2h::get<1, TestType>::value; constexpr segment_size_t static_max_k = c2h::get<2, TestType>::value; - // Test both directions (as runtime value) - const auto direction = GENERATE_COPY(cub::detail::topk::select::min, cub::detail::topk::select::max); + // Selection direction comes from the compile-time test axis. + constexpr auto direction = c2h::get<3, TestType>::value; // Generate segment size constexpr segment_size_t min_segment_size = 1; @@ -153,7 +158,7 @@ C2H_TEST("DeviceBatchedTopK::{Min,Max}Keys work with small fixed-size segments", d_keys_out, ::cuda::__argument::__immediate{segment_size, ::cuda::__argument::__bounds()}, ::cuda::__argument::__immediate{k, ::cuda::__argument::__bounds()}, - direction, + ::cuda::__argument::__constant{}, ::cuda::__argument::__immediate{num_segments}, ::cuda::__argument::__immediate{num_segments * segment_size}); // Prepare expected results @@ -170,7 +175,8 @@ C2H_TEST("DeviceBatchedTopK::{Min,Max}Keys work with small variable-size segment "[keys][segmented][topk][device]", key_types, max_segment_size_list, - max_num_k_list) + max_num_k_list, + select_direction_list) { using segment_size_t = cuda::std::int64_t; using segment_index_t = cuda::std::int64_t; @@ -181,8 +187,8 @@ C2H_TEST("DeviceBatchedTopK::{Min,Max}Keys work with small variable-size segment constexpr segment_size_t static_max_segment_size = c2h::get<1, TestType>::value; constexpr segment_size_t static_max_k = c2h::get<2, TestType>::value; - // Test both directions (as runtime value) - const auto direction = GENERATE_COPY(cub::detail::topk::select::min, cub::detail::topk::select::max); + // Selection direction comes from the compile-time test axis. + constexpr auto direction = c2h::get<3, TestType>::value; constexpr segment_size_t min_items = 1; constexpr segment_size_t max_items = 1'000'000; @@ -251,7 +257,7 @@ C2H_TEST("DeviceBatchedTopK::{Min,Max}Keys work with small variable-size segment ::cuda::__argument::__immediate_sequence{ segment_size_it, ::cuda::__argument::__bounds()}, ::cuda::__argument::__immediate{k, ::cuda::__argument::__bounds()}, - direction, + ::cuda::__argument::__constant{}, ::cuda::__argument::__immediate{num_segments}, ::cuda::__argument::__immediate{num_items}); @@ -289,7 +295,7 @@ C2H_TEST("DeviceBatchedTopK::MinKeys preserves -0.0f in output", "[keys][segment ::cuda::__argument::__immediate{ segment_size, ::cuda::__argument::__bounds()}, ::cuda::__argument::__immediate{k, ::cuda::__argument::__bounds()}, - cub::detail::topk::select::min, + ::cuda::__argument::__constant{}, ::cuda::__argument::__immediate{num_segments}, ::cuda::__argument::__immediate{num_segments * segment_size}); diff --git a/cub/test/catch2_test_device_segmented_topk_pairs.cu b/cub/test/catch2_test_device_segmented_topk_pairs.cu index cc34ceba3c6..b16a97b2472 100644 --- a/cub/test/catch2_test_device_segmented_topk_pairs.cu +++ b/cub/test/catch2_test_device_segmented_topk_pairs.cu @@ -79,6 +79,10 @@ using key_types = c2h::type_list; // Unsigned integer types used for the radix-pass boundary distribution test using uint_key_types = c2h::type_list; +// Selection direction is a compile-time option; cover both as a static test axis. +using select_direction_list = + c2h::enum_type_list; + // Consistency check: ensures values remain associated with their corresponding keys template bool verify_pairs_consistency(const c2h::device_vector& keys_in, @@ -148,7 +152,8 @@ C2H_TEST("DeviceBatchedTopK::{Min,Max}Pairs work with small fixed-size segments" "[pairs][segmented][topk][device]", key_types, max_segment_size_list, - max_num_k_list) + max_num_k_list, + select_direction_list) { using segment_size_t = cuda::std::int64_t; using segment_index_t = cuda::std::int64_t; @@ -160,8 +165,8 @@ C2H_TEST("DeviceBatchedTopK::{Min,Max}Pairs work with small fixed-size segments" constexpr segment_size_t static_max_segment_size = c2h::get<1, TestType>::value; constexpr segment_size_t static_max_k = c2h::get<2, TestType>::value; - // Test both directions (as runtime value) - const auto direction = GENERATE_COPY(cub::detail::topk::select::min, cub::detail::topk::select::max); + // Selection direction comes from the compile-time test axis. + constexpr auto direction = c2h::get<3, TestType>::value; // Generate segment size constexpr segment_size_t min_segment_size = 1; @@ -222,7 +227,7 @@ C2H_TEST("DeviceBatchedTopK::{Min,Max}Pairs work with small fixed-size segments" d_values_out, ::cuda::__argument::__immediate{segment_size, ::cuda::__argument::__bounds()}, ::cuda::__argument::__immediate{k, ::cuda::__argument::__bounds()}, - direction, + ::cuda::__argument::__constant{}, ::cuda::__argument::__immediate{num_segments}, ::cuda::__argument::__immediate{num_segments * segment_size}); @@ -250,7 +255,8 @@ C2H_TEST("DeviceBatchedTopK::{Min,Max}Pairs work with small variable-size segmen "[pairs][segmented][topk][device]", key_types, max_segment_size_list, - max_num_k_list) + max_num_k_list, + select_direction_list) { using segment_size_t = cuda::std::int64_t; using segment_index_t = cuda::std::int64_t; @@ -262,8 +268,8 @@ C2H_TEST("DeviceBatchedTopK::{Min,Max}Pairs work with small variable-size segmen constexpr segment_size_t static_max_segment_size = c2h::get<1, TestType>::value; constexpr segment_size_t static_max_k = c2h::get<2, TestType>::value; - // Test both directions (as runtime value) - const auto direction = GENERATE_COPY(cub::detail::topk::select::min, cub::detail::topk::select::max); + // Selection direction comes from the compile-time test axis. + constexpr auto direction = c2h::get<3, TestType>::value; constexpr segment_size_t min_items = 1; constexpr segment_size_t max_items = 1'000'000; @@ -343,7 +349,7 @@ C2H_TEST("DeviceBatchedTopK::{Min,Max}Pairs work with small variable-size segmen ::cuda::__argument::__immediate_sequence{ segment_size_it, ::cuda::__argument::__bounds()}, ::cuda::__argument::__immediate{k, ::cuda::__argument::__bounds()}, - direction, + ::cuda::__argument::__constant{}, ::cuda::__argument::__immediate{num_segments}, ::cuda::__argument::__immediate{num_items});