Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 13 additions & 38 deletions cub/cub/detail/segmented_params.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -99,48 +99,23 @@ struct supported_options
static constexpr ::cuda::std::size_t count = sizeof...(Options);
};

//! @brief Uniform discrete parameter — a single runtime value with a known set of supported options.
template <typename T, T... Options>
struct uniform_discrete_param
{
using value_type = T;
using supported_options_t = supported_options<T, Options...>;

T value;

_CCCL_HOST_DEVICE constexpr uniform_discrete_param(T v)
: value(v)
{}

uniform_discrete_param() = default;

template <typename SegmentIndexT>
_CCCL_HOST_DEVICE constexpr auto get_param([[maybe_unused]] SegmentIndexT segment_id) const
{
return value;
}
};

//! @brief Per-segment discrete parameter — per-segment values with a known set of supported options.
template <typename IteratorT, typename T, T... Options>
struct per_segment_discrete_param
//! @brief Static discrete parameter — a single compile-time value that is also its only supported option.
//!
//! Holds no runtime value, so it cannot be put into a state that disagrees with its supported option, and
//! @c dispatch_impl therefore always matches it. This is the safe representation for a compile-time-fixed discrete
//! parameter (e.g. a statically known top-k selection direction): modeling such a parameter with a runtime value
//! instead would risk that value silently disagreeing with the supported option (a no-op dispatch unless
//! @c CCCL_ENABLE_ASSERTIONS is set).
template <typename T, T Value>
struct static_discrete_param
{
using iterator_type = IteratorT;
using value_type = T;
using supported_options_t = supported_options<T, Options...>;

IteratorT iterator;

_CCCL_HOST_DEVICE constexpr per_segment_discrete_param(IteratorT iter)
: iterator(iter)
{}

per_segment_discrete_param() = default;
using supported_options_t = supported_options<T, Value>;

template <typename SegmentIndexT>
_CCCL_HOST_DEVICE constexpr auto get_param(SegmentIndexT segment_id) const
[[nodiscard]] _CCCL_HOST_DEVICE constexpr T get_param([[maybe_unused]] SegmentIndexT segment_id) const noexcept
{
return iterator[segment_id];
return Value;
}
};

Expand All @@ -164,7 +139,7 @@ dispatch_impl(T val, [[maybe_unused]] supported_options<T, Opts...> __supported_
return match_found;
}

//! @brief Dispatcher that resolves a per-segment discrete parameter to a compile-time constant
//! @brief Dispatcher that resolves a discrete parameter to a compile-time constant
//! and invokes a functor with the matched option.
//!
//! @param[in] param Discrete parameter to resolve.
Expand Down
26 changes: 6 additions & 20 deletions cub/cub/device/dispatch/dispatch_batched_topk.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -46,31 +46,17 @@ CUB_NAMESPACE_BEGIN
namespace detail::batched_topk
{
// -----------------------------------------------------------------------------
// Internal: wrap user-facing select direction into discrete param for dispatch
// Internal: wrap the compile-time select direction into a discrete param for dispatch
// -----------------------------------------------------------------------------

// Uniform (compile-time): __constant<Dir> -> single-option uniform_discrete_param.
// The selection direction is compile-time only: callers pass `::cuda::__argument::__constant<Dir>`, which maps to a
// value-less static_discrete_param. Because the direction is fixed at compile time and carries no runtime value, it
// can never disagree with its only supported option, so dispatch can never silently degrade to a no-op. Anything other
// than a `__constant<Dir>` is rejected at compile time (no matching overload).
template <detail::topk::select Dir>
[[nodiscard]] _CCCL_HOST_DEVICE auto wrap_select_direction(::cuda::__argument::__constant<Dir>)
{
return params::uniform_discrete_param<detail::topk::select, Dir>{Dir};
}

// Uniform: single enum value → uniform_discrete_param
[[nodiscard]] _CCCL_HOST_DEVICE inline auto wrap_select_direction(detail::topk::select dir)
{
return params::uniform_discrete_param<detail::topk::select, detail::topk::select::max, detail::topk::select::min>{
dir};
}

// Per-segment: iterator of enums → per_segment_discrete_param
_CCCL_TEMPLATE(typename IteratorT)
_CCCL_REQUIRES((!::cuda::std::is_same_v<::cuda::std::remove_cv_t<IteratorT>, detail::topk::select>) )
[[nodiscard]] _CCCL_HOST_DEVICE auto wrap_select_direction(IteratorT iter)
{
return params::
per_segment_discrete_param<IteratorT, detail::topk::select, detail::topk::select::max, detail::topk::select::min>{
iter};
return params::static_discrete_param<detail::topk::select, Dir>{};
}

// -----------------------------------------------------------------------------
Expand Down
24 changes: 15 additions & 9 deletions cub/test/catch2_test_device_segmented_topk_keys.cu
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,16 @@ using key_types =
>;
// clang-format on

// Selection direction is a compile-time option; cover both as a static test axis.
using select_direction_list =
c2h::enum_type_list<cub::detail::topk::select, cub::detail::topk::select::min, cub::detail::topk::select::max>;

C2H_TEST("DeviceBatchedTopK::{Min,Max}Keys work with small fixed-size segments",
"[keys][segmented][topk][device]",
key_types,
max_segment_size_list,
max_num_k_list)
max_num_k_list,
select_direction_list)
{
using segment_size_t = cuda::std::int64_t;
using segment_index_t = cuda::std::int64_t;
Expand All @@ -100,8 +105,8 @@ C2H_TEST("DeviceBatchedTopK::{Min,Max}Keys work with small fixed-size segments",
constexpr segment_size_t static_max_segment_size = c2h::get<1, TestType>::value;
constexpr segment_size_t static_max_k = c2h::get<2, TestType>::value;

// Test both directions (as runtime value)
const auto direction = GENERATE_COPY(cub::detail::topk::select::min, cub::detail::topk::select::max);
// Selection direction comes from the compile-time test axis.
constexpr auto direction = c2h::get<3, TestType>::value;

// Generate segment size
constexpr segment_size_t min_segment_size = 1;
Expand Down Expand Up @@ -153,7 +158,7 @@ C2H_TEST("DeviceBatchedTopK::{Min,Max}Keys work with small fixed-size segments",
d_keys_out,
::cuda::__argument::__immediate{segment_size, ::cuda::__argument::__bounds<segment_size_t{1}, max_segment_size>()},
::cuda::__argument::__immediate{k, ::cuda::__argument::__bounds<segment_size_t{1}, static_max_k>()},
direction,
::cuda::__argument::__constant<direction>{},
::cuda::__argument::__immediate{num_segments},
::cuda::__argument::__immediate{num_segments * segment_size});
// Prepare expected results
Expand All @@ -170,7 +175,8 @@ C2H_TEST("DeviceBatchedTopK::{Min,Max}Keys work with small variable-size segment
"[keys][segmented][topk][device]",
key_types,
max_segment_size_list,
max_num_k_list)
max_num_k_list,
select_direction_list)
{
using segment_size_t = cuda::std::int64_t;
using segment_index_t = cuda::std::int64_t;
Expand All @@ -181,8 +187,8 @@ C2H_TEST("DeviceBatchedTopK::{Min,Max}Keys work with small variable-size segment
constexpr segment_size_t static_max_segment_size = c2h::get<1, TestType>::value;
constexpr segment_size_t static_max_k = c2h::get<2, TestType>::value;

// Test both directions (as runtime value)
const auto direction = GENERATE_COPY(cub::detail::topk::select::min, cub::detail::topk::select::max);
// Selection direction comes from the compile-time test axis.
constexpr auto direction = c2h::get<3, TestType>::value;

constexpr segment_size_t min_items = 1;
constexpr segment_size_t max_items = 1'000'000;
Expand Down Expand Up @@ -251,7 +257,7 @@ C2H_TEST("DeviceBatchedTopK::{Min,Max}Keys work with small variable-size segment
::cuda::__argument::__immediate_sequence{
segment_size_it, ::cuda::__argument::__bounds<segment_size_t{1}, static_max_segment_size>()},
::cuda::__argument::__immediate{k, ::cuda::__argument::__bounds<segment_size_t{1}, static_max_k>()},
direction,
::cuda::__argument::__constant<direction>{},
::cuda::__argument::__immediate{num_segments},
::cuda::__argument::__immediate{num_items});

Expand Down Expand Up @@ -289,7 +295,7 @@ C2H_TEST("DeviceBatchedTopK::MinKeys preserves -0.0f in output", "[keys][segment
::cuda::__argument::__immediate{
segment_size, ::cuda::__argument::__bounds<cuda::std::int64_t{1}, max_segment_size>()},
::cuda::__argument::__immediate{k, ::cuda::__argument::__bounds<cuda::std::int64_t{1}, k>()},
cub::detail::topk::select::min,
::cuda::__argument::__constant<cub::detail::topk::select::min>{},
::cuda::__argument::__immediate{num_segments},
::cuda::__argument::__immediate{num_segments * segment_size});

Expand Down
22 changes: 14 additions & 8 deletions cub/test/catch2_test_device_segmented_topk_pairs.cu
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ using key_types = c2h::type_list<cuda::std::uint64_t>;
// Unsigned integer types used for the radix-pass boundary distribution test
using uint_key_types = c2h::type_list<cuda::std::uint8_t, cuda::std::uint16_t, cuda::std::uint64_t>;

// Selection direction is a compile-time option; cover both as a static test axis.
using select_direction_list =
c2h::enum_type_list<cub::detail::topk::select, cub::detail::topk::select::min, cub::detail::topk::select::max>;

// Consistency check: ensures values remain associated with their corresponding keys
template <typename KeyT, typename ValueT>
bool verify_pairs_consistency(const c2h::device_vector<KeyT>& keys_in,
Expand Down Expand Up @@ -148,7 +152,8 @@ C2H_TEST("DeviceBatchedTopK::{Min,Max}Pairs work with small fixed-size segments"
"[pairs][segmented][topk][device]",
key_types,
max_segment_size_list,
max_num_k_list)
max_num_k_list,
select_direction_list)
{
using segment_size_t = cuda::std::int64_t;
using segment_index_t = cuda::std::int64_t;
Expand All @@ -160,8 +165,8 @@ C2H_TEST("DeviceBatchedTopK::{Min,Max}Pairs work with small fixed-size segments"
constexpr segment_size_t static_max_segment_size = c2h::get<1, TestType>::value;
constexpr segment_size_t static_max_k = c2h::get<2, TestType>::value;

// Test both directions (as runtime value)
const auto direction = GENERATE_COPY(cub::detail::topk::select::min, cub::detail::topk::select::max);
// Selection direction comes from the compile-time test axis.
constexpr auto direction = c2h::get<3, TestType>::value;

// Generate segment size
constexpr segment_size_t min_segment_size = 1;
Expand Down Expand Up @@ -222,7 +227,7 @@ C2H_TEST("DeviceBatchedTopK::{Min,Max}Pairs work with small fixed-size segments"
d_values_out,
::cuda::__argument::__immediate{segment_size, ::cuda::__argument::__bounds<segment_size_t{1}, max_segment_size>()},
::cuda::__argument::__immediate{k, ::cuda::__argument::__bounds<segment_size_t{1}, static_max_k>()},
direction,
::cuda::__argument::__constant<direction>{},
::cuda::__argument::__immediate{num_segments},
::cuda::__argument::__immediate{num_segments * segment_size});

Expand Down Expand Up @@ -250,7 +255,8 @@ C2H_TEST("DeviceBatchedTopK::{Min,Max}Pairs work with small variable-size segmen
"[pairs][segmented][topk][device]",
key_types,
max_segment_size_list,
max_num_k_list)
max_num_k_list,
select_direction_list)
{
using segment_size_t = cuda::std::int64_t;
using segment_index_t = cuda::std::int64_t;
Expand All @@ -262,8 +268,8 @@ C2H_TEST("DeviceBatchedTopK::{Min,Max}Pairs work with small variable-size segmen
constexpr segment_size_t static_max_segment_size = c2h::get<1, TestType>::value;
constexpr segment_size_t static_max_k = c2h::get<2, TestType>::value;

// Test both directions (as runtime value)
const auto direction = GENERATE_COPY(cub::detail::topk::select::min, cub::detail::topk::select::max);
// Selection direction comes from the compile-time test axis.
constexpr auto direction = c2h::get<3, TestType>::value;

constexpr segment_size_t min_items = 1;
constexpr segment_size_t max_items = 1'000'000;
Expand Down Expand Up @@ -343,7 +349,7 @@ C2H_TEST("DeviceBatchedTopK::{Min,Max}Pairs work with small variable-size segmen
::cuda::__argument::__immediate_sequence{
segment_size_it, ::cuda::__argument::__bounds<segment_size_t{1}, static_max_segment_size>()},
::cuda::__argument::__immediate{k, ::cuda::__argument::__bounds<segment_size_t{1}, static_max_k>()},
direction,
::cuda::__argument::__constant<direction>{},
::cuda::__argument::__immediate{num_segments},
::cuda::__argument::__immediate{num_items});

Expand Down
Loading