-
Notifications
You must be signed in to change notification settings - Fork 401
[cub] Replace cub parameter framework with cuda::argument #9074
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,6 +23,7 @@ | |
| #include <cub/device/dispatch/tuning/tuning_batched_topk.cuh> | ||
| #include <cub/util_type.cuh> | ||
|
|
||
| #include <cuda/__argument_> | ||
| #include <cuda/__cmath/ceil_div.h> | ||
|
|
||
| CUB_NAMESPACE_BEGIN | ||
|
|
@@ -72,8 +73,8 @@ struct agent_batched_topk_worker_per_segment | |
| using key_t = it_value_t<key_it_t>; | ||
| using value_t = it_value_t<value_it_t>; | ||
|
|
||
| using segment_size_val_t = typename SegmentSizeParameterT::value_type; | ||
| using num_segments_val_t = typename NumSegmentsParameterT::value_type; | ||
| using segment_size_val_t = typename ::cuda::__argument::__traits<SegmentSizeParameterT>::element_type; | ||
| using num_segments_val_t = typename ::cuda::__argument::__traits<NumSegmentsParameterT>::element_type; | ||
|
Comment on lines
+76
to
+77
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note to myself: I think we want to narrower the index/offset/size types to the more narrow of the two: the static upper bound or the element type. |
||
| using counters_t = batched_topk_counters<num_segments_val_t>; | ||
|
|
||
| static constexpr auto policy = PolicyGetter{}(); | ||
|
|
@@ -94,7 +95,7 @@ struct agent_batched_topk_worker_per_segment | |
| multi_worker_per_segment_policy.threads_per_block * multi_worker_per_segment_policy.items_per_thread; | ||
|
|
||
| // Check if there could be large segments present | ||
| static constexpr bool only_small_segments = params::static_max_value_v<SegmentSizeParameterT> <= tile_size; | ||
| static constexpr bool only_small_segments = ::cuda::__argument::__traits<SegmentSizeParameterT>::max <= tile_size; | ||
|
|
||
| // Check if we are dealing with keys-only or key-value pairs | ||
| static constexpr bool is_keys_only = ::cuda::std::is_same_v<value_t, cub::NullType>; | ||
|
|
@@ -190,16 +191,16 @@ struct agent_batched_topk_worker_per_segment | |
|
|
||
| // Boundary check | ||
| // TODO (elstehle): consider skipping boundary check if we can safely assume the right grid dimensions | ||
| if (segment_id >= num_segments.get_param(0)) | ||
| if (segment_id >= params::get_param(num_segments, 0)) | ||
| { | ||
| return; | ||
| } | ||
|
|
||
| constexpr bool is_full_tile = params::has_single_static_value_v<SegmentSizeParameterT> | ||
| && params::static_min_value_v<SegmentSizeParameterT> == tile_size; | ||
| constexpr bool is_full_tile = ::cuda::__argument::__traits<SegmentSizeParameterT>::is_constant | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. question: For something like an
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we could make |
||
| && ::cuda::__argument::__traits<SegmentSizeParameterT>::lowest == tile_size; | ||
|
|
||
| // Resolve Segment Parameters | ||
| const auto segment_size = segment_sizes.get_param(segment_id); | ||
| const auto segment_size = params::get_param(segment_sizes, segment_id); | ||
| if (!only_small_segments && segment_size > tile_size) | ||
| { | ||
| // Enqueue large segment | ||
|
|
@@ -215,8 +216,8 @@ struct agent_batched_topk_worker_per_segment | |
| else | ||
| { | ||
| // Process small segment | ||
| const auto k = (::cuda::std::min) (k_param.get_param(segment_id), | ||
| static_cast<decltype(k_param.get_param(segment_id))>(segment_size)); | ||
| const auto k = (::cuda::std::min) (params::get_param(k_param, segment_id), | ||
| static_cast<decltype(params::get_param(k_param, segment_id))>(segment_size)); | ||
| const auto direction = select_directions.get_param(segment_id); | ||
|
|
||
| // Determine padding key based on direction | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note to myself: I think this should become part of the guarantees API, as, in the device interface, there is no concrete argument that could be annotated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR: #9278