Skip to content

[cub] Replace cub parameter framework with cuda::argument#9074

Merged
pciolkosz merged 4 commits into
NVIDIA:mainfrom
pciolkosz:replace_cub_parameter_framework
Jun 4, 2026
Merged

[cub] Replace cub parameter framework with cuda::argument#9074
pciolkosz merged 4 commits into
NVIDIA:mainfrom
pciolkosz:replace_cub_parameter_framework

Conversation

@pciolkosz
Copy link
Copy Markdown
Contributor

This PR replaces most of the functionality in segmented_params.cuh with cuda::argument wrappers from #8875. This PR contains the other one, since it's not merged yet.

There are two things that were left from the original implementation, the static dispatch over bounded set of values and get_param that either gets item from a sequence at a given index or returns a uniform value depending on the argument. Both of those things were more fitting for a cub-specific functionality, but its not set in stone

@pciolkosz pciolkosz requested review from a team as code owners May 20, 2026 04:42
@pciolkosz pciolkosz requested a review from wmaxey May 20, 2026 04:42
@github-project-automation github-project-automation Bot moved this to Todo in CCCL May 20, 2026
@pciolkosz pciolkosz requested a review from pauleonix May 20, 2026 04:42
@cccl-authenticator-app cccl-authenticator-app Bot moved this from Todo to In Review in CCCL May 20, 2026
@pciolkosz pciolkosz force-pushed the replace_cub_parameter_framework branch from 845daaf to 5dd3c87 Compare May 20, 2026 06:00
@pciolkosz pciolkosz requested a review from a team as a code owner May 20, 2026 06:00
@pciolkosz pciolkosz requested a review from shwina May 20, 2026 06:00
@pciolkosz pciolkosz force-pushed the replace_cub_parameter_framework branch from 5dd3c87 to 8a3b299 Compare May 20, 2026 06:01
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 20, 2026

Review Change Stack

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: c9bc04b4-902e-4fdc-bcb7-daa089aaf7b4

📥 Commits

Reviewing files that changed from the base of the PR and between e72e879 and 75f4500.

📒 Files selected for processing (1)
  • cub/cub/detail/segmented_params.cuh
🚧 Files skipped from review as they are similar to previous changes (1)
  • cub/cub/detail/segmented_params.cuh

Note: CodeRabbit is enabled on this repository as a convenience for maintainers
and contributors. Use your best judgment when considering its review comments and
suggestions — a suggested change may be inadequate, unnecessary, or safe to ignore.
Contributors are not expected to address every comment. Human reviews are what
ultimately matter for merging.

Overview

This PR replaces the CUB parameter framework in cub/cub/detail/segmented_params.cuh with CUDA ::cuda::__argument wrappers (from PR #8875, included here). The refactor keeps two CUB-specific features for now: (1) static dispatch over a bounded set of values, and (2) params::get_param which returns either a per-segment indexed value or a uniform value depending on the argument.

Key Changes

Core Framework Refactoring

cub/cub/detail/segmented_params.cuh

  • Removed parameter-type mixins and helpers: static_bounds_mixin, static_constant_param, uniform_param, per_segment_param and associated type-trait helpers (is_static_param_v, is_uniform_param_v, is_per_segment_param_v, static_max_value_v, static_min_value_v, has_single_static_value_v).
  • Added unified detail::params::get_param API with overloads for ::cuda::__argument wrapper types (constant, constant_sequence, immediate, immediate_sequence, deferred, deferred_sequence) to support both single-value and indexed-per-segment reads.
  • Added discrete-parameter support via detail::params::supported_options<T, Options...> and compile-time dispatch that matches runtime discrete values against allowed options and invokes a functor with the matched option as an integral_constant.
  • Doxygen brief for per-segment discrete parameter added.

Dispatch and Kernel Updates

cub/cub/device/dispatch/dispatch_batched_topk.cuh

  • Removed many segmented-topk parameter type aliases (select_direction_, segment_size_, k_, num_segments_) and total_num_items_guarantee.
  • Added wrap_select_direction helpers to convert raw detail::topk::select enum or iterators into discrete internal dispatch parameters.
  • dispatch(...) signature changed: now accepts a raw SelectDirectionT select_direction (wrapped internally where needed) rather than a SelectDirectionParameterT wrapper.
  • Uses ::cuda::__argument::__traits for bounds/element-type queries and params::get_param for runtime reads.
  • Tightened constraint: num_segments must be a single-value parameter (no per-segment segments).
  • Default env-dispatch K bound now uses ::cuda::__argument::__traits< KParameterT >::max.

cub/cub/device/dispatch/kernels/kernel_batched_topk.cuh

  • Replaced segmented_params-based bounds/value queries with ::cuda::__argument::__traits usage (e.g., max).
  • Kernel template pointer element types (d_counters, d_large_segments_ids) now derive element_type via ::cuda::__argument::__traits::element_type.
  • Static_asserts and policy selection use __traits where static bounds were previously queried.

Agent and Worker Updates

cub/cub/agent/agent_batched_topk.cuh

  • Uses ::cuda::__argument::__traits to derive element_type and compile-time properties (max, lowest, is_constant).
  • Refactored compile-time checks (only_small_segments, is_full_tile) to use traits-based methods.
  • Process() uses params::get_param for runtime accesses (num_segments, segment_sizes, k) and computes per-segment k via cuda::std::min(params::get_param(k_param, segment_id), segment_size).

Benchmarks and Tests

Benchmarks (cub/benchmarks/bench/segmented_topk/fixed/keys.cu and variable/keys.cu)

  • Switched from CUB wrapper helpers to CUDA ::cuda::__argument wrapper types (::__immediate, ::__immediate_sequence, ::__bounds, ::__constant).
  • Removed construction/usage of total_num_items_guarantee; total counts are represented via immediate arguments.

Tests (cub/test/device_segmented_topk.cu)

  • Updated test dispatches to pass select direction directly and to use ::cuda::__argument forms for segment sizes and k (::__immediate / ::__bounds / ::__immediate_sequence).
  • Updated dispatch_batched_topk_keys signature/template to accept SelectDirectionT/select_direction instead of a SelectDirectionParamT wrapper.
  • Adjusted a constexpr type in one test (max_segment_size -> cuda::std::int64_t) to match traits/argument types.

Public API / Notable Signature Changes

  • Added: detail::params::get_param overloads for ::cuda::__argument wrapper types and detail::params::supported_options.
  • Added: cub::detail::batched_topk::wrap_select_direction overloads (for constant, enum, and iterator forms).
  • Removed: parameter-type aliases and total_num_items_guarantee used by the old segmented_params framework.
  • Changed: cub::detail::batched_topk::dispatch parameter from SelectDirectionParameterT select_directions → SelectDirectionT select_direction (internal wrapping added).
  • Changed: device_segmented_topk_kernel template parameter types for d_counters and d_large_segments_ids now use ::cuda::__argument::__traits<...>::element_type.

Files Touched (high-level)

  • Major: cub/cub/detail/segmented_params.cuh (large refactor)
  • Dispatch/kernels/agent: cub/cub/device/dispatch/dispatch_batched_topk.cuh, cub/cub/device/dispatch/kernels/kernel_batched_topk.cuh, cub/cub/agent/agent_batched_topk.cuh
  • Benchmarks/tests: cub/benchmarks/bench/segmented_topk/{fixed,variable}/keys.cu, cub/test/catch2_test_device_segmented_topk_{keys,pairs}.cu
  • Minor per-file changes across batched_topk-related call sites to adopt ::cuda::__argument types.

Metrics

  • Lines changed (sum of file-level diffs in this PR): +196 / -347
  • Estimated code review effort: High — core framework changes span multiple tightly-coupled dispatch/kernel/agent files and tests.

Notes for reviewers

  • The PR intentionally leaves two CUB-specific behaviors: static dispatch over bounded option sets and the get_param indexed/uniform retrieval API. Both are localized in segmented_params.cuh and may be candidates for future harmonization with ::cuda::__argument features.
  • Pay attention to places where the new ::cuda::__argument::__traits-derived element_type / max / is_constant replace previous type-trait helpers — correctness depends on trait values matching prior semantics.
  • Discrete-value runtime-to-compile-time dispatch: ensure supported_options and the matching/folding logic correctly assert and dispatch for all expected dynamic values in existing call sites.

important:

Walkthrough

Refactors batched top-K parameter handling to use cuda::__argument wrapper types and ::cuda::__argument::__traits with a unified detail::params::get_param API; dispatch now accepts unwrapped select_direction and num_segments, kernels and agent code read parameters via get_param, and tests/benchmarks updated accordingly.

Changes

CUB Batched Top-K Integration

Layer / File(s) Summary
Unified parameter access API
cub/cub/detail/segmented_params.cuh
Adds detail::params::get_param overloads for cuda::__argument wrapper types and supported_options/discrete-dispatch utilities; removes older static/uniform/per-segment param mixins and type-trait helpers.
Dispatch layer refactoring
cub/cub/device/dispatch/dispatch_batched_topk.cuh
Changes dispatch signature to accept raw select_direction and num_segments, adds wrap_select_direction overloads, switches bounds/element-type queries to ::cuda::__argument::__traits, and routes runtime reads through params::get_param; tightens num_segments compile-time constraint.
Agent worker parameter extraction
cub/cub/agent/agent_batched_topk.cuh
Derives element types and compile-time properties via ::cuda::__argument::__traits, computes is_full_tile from traits, and resolves num_segments, segment_sizes, and per-segment k with params::get_param in Process().
Kernel integration
cub/cub/device/dispatch/kernels/kernel_batched_topk.cuh
Updates includes to use cuda/_argument, derives max segment size and pointer element types from ::cuda::__argument::__traits, and updates static_asserts to use trait-derived bounds.
Test and benchmark updates
cub/benchmarks/bench/segmented_topk/{fixed,variable}/keys.cu, cub/test/catch2_test_device_segmented_topk_{keys,pairs}.cu
Converts dispatch argument construction to ::cuda::__argument wrappers (__immediate, __immediate_sequence, __bounds, __constant), removes total_num_items_guarantee usage, and passes select_direction enums directly.

Possibly related PRs

  • NVIDIA/cccl#8875: Introduces the cuda::__argument wrappers, traits, and bounds framework used by these changes.

Suggested reviewers

  • shwina
  • pauleonix
  • ericniebler
  • miscco

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (3)
cub/cub/detail/segmented_params.cuh (1)

31-43: 💤 Low value

suggestion: Missing [[nodiscard]] on get_param overloads. Per coding guidelines, most functions with non-void return should have this attribute.

 _CCCL_TEMPLATE(class _Tp)
 _CCCL_REQUIRES((!::cuda::argument::__is_wrapper_v<::cuda::std::remove_cv_t<::cuda::std::remove_reference_t<_Tp>>>) )
-_CCCL_HOST_DEVICE constexpr auto get_param(_Tp&& __arg, [[maybe_unused]] size_t __index) noexcept
+[[nodiscard]] _CCCL_HOST_DEVICE constexpr auto get_param(_Tp&& __arg, [[maybe_unused]] size_t __index) noexcept

Same applies to the other get_param overloads on lines 46-47, 53-54, 67-68, 74-75. As per coding guidelines, most functions with a non-void return type should use [[nodiscard]].

cub/cub/device/dispatch/dispatch_batched_topk.cuh (1)

51-66: 💤 Low value

suggestion: Both wrap_select_direction overloads return non-void and should have [[nodiscard]].

-_CCCL_HOST_DEVICE inline auto wrap_select_direction(detail::topk::select dir)
+[[nodiscard]] _CCCL_HOST_DEVICE inline auto wrap_select_direction(detail::topk::select dir)
-_CCCL_HOST_DEVICE auto wrap_select_direction(IteratorT iter)
+[[nodiscard]] _CCCL_HOST_DEVICE auto wrap_select_direction(IteratorT iter)
libcudacxx/include/cuda/__argument/argument_bounds.h (1)

103-113: ⚡ Quick win

suggestion: Complete Doxygen tags for the documented __bounds overloads. The documented non-void factory functions currently only provide //! @brief; add `//! `@param for each parameter and //! @return`` for both overloads to satisfy header documentation requirements.

As per coding guidelines: "When a function is documented with Doxygen, it must include: //! @brief, `//! `@param`[in/out/in,out]` for every parameter, and `//! `@return for non-void functions."


ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: a53ca9f6-f66d-4f20-a942-2e8bd23c2c84

📥 Commits

Reviewing files that changed from the base of the PR and between 459e81a and 8a3b299.

📒 Files selected for processing (20)
  • cub/benchmarks/bench/segmented_topk/fixed/keys.cu
  • cub/benchmarks/bench/segmented_topk/variable/keys.cu
  • cub/cub/agent/agent_batched_topk.cuh
  • cub/cub/detail/segmented_params.cuh
  • cub/cub/device/dispatch/dispatch_batched_topk.cuh
  • cub/cub/device/dispatch/kernels/kernel_batched_topk.cuh
  • cub/test/catch2_test_device_segmented_topk_keys.cu
  • cub/test/catch2_test_device_segmented_topk_pairs.cu
  • libcudacxx/include/cuda/__argument/argument.h
  • libcudacxx/include/cuda/__argument/argument_bounds.h
  • libcudacxx/include/cuda/argument
  • libcudacxx/include/cuda/std/__internal/namespaces.h
  • libcudacxx/test/libcudacxx/cuda/argument/argument_bounds.pass.cpp
  • libcudacxx/test/libcudacxx/cuda/argument/argument_traits.pass.cpp
  • libcudacxx/test/libcudacxx/cuda/argument/deferred_argument.pass.cpp
  • libcudacxx/test/libcudacxx/cuda/argument/dynamic_argument.pass.cpp
  • libcudacxx/test/libcudacxx/cuda/argument/static_argument.pass.cpp
  • libcudacxx/test/libcudacxx/cuda/argument/static_bounds_conversion.fail.cpp
  • libcudacxx/test/libcudacxx/cuda/argument/usage_example.pass.cpp
  • libcudacxx/test/support/test_macros.h

Comment on lines +274 to +280
template <auto _Lowest, auto _Max>
_CCCL_API constexpr __immediate(_Arg __arg, __static_bounds<_Lowest, _Max>) noexcept
: arg{::cuda::std::move(__arg)}
{
__validate_bounds();
__validate_value();
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

important: These constructors accept a __static_bounds<_Lowest, _Max> argument but validate against the class template parameter _StaticBounds, so explicitly-instantiated types can silently ignore the bounds token passed at construction. Add a compile-time constraint (for example, static_assert(::cuda::std::is_same_v<_StaticBounds, __static_bounds<_Lowest, _Max>>) or a requires clause) so construction fails when the token and _StaticBounds disagree.

Also applies to: 294-302

Comment on lines +340 to +345
template <auto _Lowest, auto _Max>
_CCCL_API constexpr __deferred_base(_Arg __arg, __static_bounds<_Lowest, _Max>) noexcept
: arg{::cuda::std::move(__arg)}
{
__validate_bounds_intersection<__element_type, _StaticBounds>(__runtime_bounds_);
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

important: __deferred_base has the same bounds-token mismatch risk: constructor parameters carry __static_bounds<_Lowest, _Max> but all checks use _StaticBounds. This can make user-provided static bounds inert for explicitly-typed wrappers. Enforce _StaticBounds == __static_bounds<_Lowest, _Max> at compile time (or remove the redundant bounds parameter in favor of _StaticBounds-typed overloads).

Also applies to: 357-366

@github-actions

This comment has been minimized.

@pciolkosz pciolkosz force-pushed the replace_cub_parameter_framework branch from 8a3b299 to 06977ac Compare June 3, 2026 01:23
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (3)
cub/cub/detail/segmented_params.cuh (3)

16-19: ⚡ Quick win

suggestion: Use a direct cstddef include and a qualified size type here. These new signatures and supported_options::count introduce bare size_t, but this header does not include the defining header directly. Switching to ::cuda::std::size_t and adding the precise cstddef header removes the transitive-include dependency and matches CCCL header rules. As per coding guidelines, "Type names must be fully qualified" and "Files must include all headers related to the symbols that they are using. No transitive header inclusions are allowed."

Also applies to: 34-70, 83-83


152-152: ⚡ Quick win

suggestion: Qualify the free-function call at Line 152. dispatch_impl(...) is invoked unqualified, but CCCL headers require free-function calls to start from the global namespace even inside the same namespace. As per coding guidelines, "All calls to free functions must be fully qualified starting from the global namespace."


29-31: ⚡ Quick win

suggestion: Complete the Doxygen tags on the documented functions. The new //! blocks for get_param, dispatch_impl, and dispatch_discrete omit @param and @return, which this repo requires whenever a function is documented. As per coding guidelines, "When a function is documented with Doxygen, it must include: //! @brief, `//! `@param`[in/out/in,out]` for every parameter, and `//! `@return for non-void functions."

Also applies to: 135-146


ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 7627989d-73f6-4559-875e-d2820ef16b49

📥 Commits

Reviewing files that changed from the base of the PR and between 8a3b299 and 06977ac.

📒 Files selected for processing (8)
  • cub/benchmarks/bench/segmented_topk/fixed/keys.cu
  • cub/benchmarks/bench/segmented_topk/variable/keys.cu
  • cub/cub/agent/agent_batched_topk.cuh
  • cub/cub/detail/segmented_params.cuh
  • cub/cub/device/dispatch/dispatch_batched_topk.cuh
  • cub/cub/device/dispatch/kernels/kernel_batched_topk.cuh
  • cub/test/catch2_test_device_segmented_topk_keys.cu
  • cub/test/catch2_test_device_segmented_topk_pairs.cu
💤 Files with no reviewable changes (2)
  • cub/test/catch2_test_device_segmented_topk_pairs.cu
  • cub/test/catch2_test_device_segmented_topk_keys.cu
🚧 Files skipped from review as they are similar to previous changes (5)
  • cub/cub/agent/agent_batched_topk.cuh
  • cub/benchmarks/bench/segmented_topk/fixed/keys.cu
  • cub/benchmarks/bench/segmented_topk/variable/keys.cu
  • cub/cub/device/dispatch/kernels/kernel_batched_topk.cuh
  • cub/cub/device/dispatch/dispatch_batched_topk.cuh

Comment thread cub/cub/detail/segmented_params.cuh Outdated
@github-actions

This comment has been minimized.

Copy link
Copy Markdown
Contributor

@elstehle elstehle left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this looks much cleaner with the argument annotation framework replacements. Thanks for working on this!

Comment thread cub/benchmarks/bench/segmented_topk/fixed/keys.cu Outdated
const auto num_segments = ::cuda::std::max<std::size_t>(1, (max_elements / segment_size));
const auto elements = num_segments * segment_size;
const auto total_num_items = total_num_items_guarantee_t{static_cast<::cuda::std::int64_t>(elements)};
const auto total_num_items = ::cuda::__argument::__immediate{static_cast<::cuda::std::int64_t>(elements)};
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note to myself: I think this should become part of the guarantees API, as, in the device interface, there is no concrete argument that could be annotated.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR: #9278

Comment on lines +76 to +77
using segment_size_val_t = typename ::cuda::__argument::__traits<SegmentSizeParameterT>::element_type;
using num_segments_val_t = typename ::cuda::__argument::__traits<NumSegmentsParameterT>::element_type;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to myself: I think we want to narrower the index/offset/size types to the more narrow of the two: the static upper bound or the element type.


constexpr bool is_full_tile = params::has_single_static_value_v<SegmentSizeParameterT>
&& params::static_min_value_v<SegmentSizeParameterT> == tile_size;
constexpr bool is_full_tile = ::cuda::__argument::__traits<SegmentSizeParameterT>::is_constant
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question: For something like an __immediate where we have sharp bounds, i.e., lowest=max, we wouldn't hit the full_tile branch, right? Do you think this is something we should somehow cover or do you think we can expect users to be always be using __constant? I guess there could be a scenario where users themselves do have something like bounds template parameters and would not check for narrow bounds before instantiating something with __immediate?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we could make __immediate smarter and report is_constant when it has static bounds where lower == higher. I will take a look into it in a follow-up PR.

Comment thread cub/cub/device/dispatch/dispatch_batched_topk.cuh
Comment thread cub/benchmarks/bench/segmented_topk/variable/keys.cu Outdated
Comment thread cub/cub/detail/segmented_params.cuh Outdated
@github-actions

This comment has been minimized.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Jun 3, 2026

🥳 CI Workflow Results

🟩 Finished in 1h 12m: Pass: 100%/284 | Total: 2d 20h | Max: 42m 30s | Hits: 89%/216258

See results here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Archived in project

Development

Successfully merging this pull request may close these issues.

2 participants