-
Notifications
You must be signed in to change notification settings - Fork 401
run to run scan warpspeed impl sm100+ #9263
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,6 +24,7 @@ | |
| #include <cuda/__memory/is_aligned.h> | ||
| #include <cuda/__ptx/instructions/get_sreg.h> | ||
| #include <cuda/__type_traits/is_trivially_copyable.h> | ||
| #include <cuda/std/__algorithm/min.h> | ||
| #include <cuda/std/__bit/popcount.h> | ||
| #include <cuda/std/__type_traits/underlying_type.h> | ||
|
|
||
|
|
@@ -261,6 +262,88 @@ template <int numTileStatesPerThread, typename AccumT, typename ScanOpT> | |
| return aggrExclusiveCtaCur; // must only be valid in lane_0 | ||
| } | ||
|
|
||
| // Deterministic version of warpIncrementalLookahead that returns the same aggrExclusiveCta. The difference is that it | ||
| // always starts the lookahead from a tile index that is a multiple of 32: it shifts the left pointer (idxTilePrev) down | ||
| // to the nearest multiple of 32 and reduces from there. Because every reduction begins at the same fixed tiles, no | ||
| // matter which tiles happened to finish first, the order in which values are summed is always the same and the result | ||
| // is identical on every run. idxTilePrev/aggrExclusiveCtaPrev are updated by reference to the last multiple of 32. | ||
| template <int numTileStatesPerThread, typename AccumT, typename ScanOpT> | ||
| [[nodiscard]] _CCCL_DEVICE_API _CCCL_FORCEINLINE AccumT warpIncrementalLookaheadStable( | ||
| SpecialRegisters specialRegisters, | ||
| tile_state_t<AccumT>* ptrTileStates, | ||
| int& idxTilePrev, | ||
| AccumT& aggrExclusiveCtaPrev, | ||
| const int idxTileNext, | ||
| ScanOpT& scan_op) | ||
| { | ||
| const int laneIdx = specialRegisters.laneIdx; | ||
| const ::cuda::std::uint32_t lanemaskEq = ::cuda::ptx::get_sreg_lanemask_eq(); | ||
|
|
||
| // Adjust the left pointer down to the nearest 32-multiple so we do batched sums | ||
| int idxTileCur = (idxTilePrev / 32) * 32; | ||
| AccumT aggrExclusiveCtaCur = aggrExclusiveCtaPrev; | ||
|
|
||
| using warp_reduce_t = WarpReduce<AccumT>; | ||
| static_assert(sizeof(typename warp_reduce_t::TempStorage) <= 4, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why 4? I assume this is
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because the |
||
| "WarpReduce with non-trivial temporary storage is not supported yet in this kernel."); | ||
| [[maybe_unused]] typename warp_reduce_t::TempStorage temp_storage; | ||
|
|
||
| using warp_reduce_or_t = WarpReduce<::cuda::std::uint32_t>; | ||
| typename warp_reduce_or_t::TempStorage temp_storage_or; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: |
||
| warp_reduce_or_t warp_reduce_or{temp_storage_or}; | ||
| constexpr ::cuda::std::bit_or<::cuda::std::uint32_t> or_op{}; | ||
|
|
||
| while (idxTileCur < idxTileNext) | ||
| { | ||
| tile_state_t<AccumT> regTmpStates[numTileStatesPerThread]; | ||
| warpLoadLookahead(laneIdx, regTmpStates, ptrTileStates, idxTileCur, idxTileNext); | ||
|
|
||
| for (int idx = 0; idx < numTileStatesPerThread; ++idx) | ||
| { | ||
| // Bitmask with a 1 bit in the position of the current lane if current lane has a tile aggregate | ||
| const ::cuda::std::uint32_t lane_has_aggregate = | ||
| lanemaskEq * (regTmpStates[idx].state == scan_state::tile_aggregate); | ||
|
|
||
| // Bitmask with 1 bits indicating which lane has a tile aggregate | ||
| const ::cuda::std::uint32_t warp_has_aggregate_mask = warp_reduce_or.Reduce(lane_has_aggregate, or_op); | ||
|
|
||
| // Bitmask with 1 bits for all rightmost lanes having a tile aggregate | ||
| const ::cuda::std::uint32_t warp_right_aggregates_mask = warp_has_aggregate_mask & (~warp_has_aggregate_mask - 1); | ||
|
|
||
| const ::cuda::std::uint32_t warp_right_aggregates_count = ::cuda::std::popcount(warp_right_aggregates_mask); | ||
|
|
||
| // Only reduce once a fixed number of contiguous tile aggregates are available, so the reduction order is fixed. | ||
| const ::cuda::std::uint32_t expected_count = | ||
| static_cast<::cuda::std::uint32_t>(::cuda::std::min(32, idxTileNext - idxTileCur)); | ||
| if (warp_right_aggregates_count < expected_count) | ||
| { | ||
| break; | ||
| } | ||
|
|
||
| const bool use_value = lanemaskEq & warp_right_aggregates_mask; | ||
| const AccumT value = use_value ? regTmpStates[idx].value : cuda::identity_element<ScanOpT, AccumT>(); | ||
| const AccumT local_aggr = warp_reduce_t{temp_storage}.Reduce(value, scan_op); | ||
|
|
||
| if (expected_count == 32) | ||
| { | ||
| aggrExclusiveCtaCur = idxTileCur == 0 ? local_aggr : scan_op(aggrExclusiveCtaCur, local_aggr); | ||
| idxTileCur += 32; | ||
| } | ||
| else | ||
| { | ||
| const AccumT full_aggr = idxTileCur == 0 ? local_aggr : scan_op(aggrExclusiveCtaCur, local_aggr); | ||
| idxTilePrev = idxTileCur; | ||
| aggrExclusiveCtaPrev = aggrExclusiveCtaCur; | ||
| return full_aggr; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| idxTilePrev = idxTileNext; | ||
| aggrExclusiveCtaPrev = aggrExclusiveCtaCur; | ||
| return aggrExclusiveCtaCur; // must only be valid in lane_0 | ||
| } | ||
|
|
||
| #endif // __cccl_ptx_isa >= 860 | ||
| } // namespace detail::warpspeed | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggestion: Use
cuda::round_down.