Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions cub/cub/detail/warpspeed/look_ahead.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <cuda/__memory/is_aligned.h>
#include <cuda/__ptx/instructions/get_sreg.h>
#include <cuda/__type_traits/is_trivially_copyable.h>
#include <cuda/std/__algorithm/min.h>
#include <cuda/std/__bit/popcount.h>
#include <cuda/std/__type_traits/underlying_type.h>

Expand Down Expand Up @@ -261,6 +262,88 @@ template <int numTileStatesPerThread, typename AccumT, typename ScanOpT>
return aggrExclusiveCtaCur; // must only be valid in lane_0
}

// Deterministic version of warpIncrementalLookahead that returns the same aggrExclusiveCta. The difference is that it
// always starts the lookahead from a tile index that is a multiple of 32: it shifts the left pointer (idxTilePrev) down
// to the nearest multiple of 32 and reduces from there. Because every reduction begins at the same fixed tiles, no
// matter which tiles happened to finish first, the order in which values are summed is always the same and the result
// is identical on every run. idxTilePrev/aggrExclusiveCtaPrev are updated by reference to the last multiple of 32.
template <int numTileStatesPerThread, typename AccumT, typename ScanOpT>
[[nodiscard]] _CCCL_DEVICE_API _CCCL_FORCEINLINE AccumT warpIncrementalLookaheadStable(
SpecialRegisters specialRegisters,
tile_state_t<AccumT>* ptrTileStates,
int& idxTilePrev,
AccumT& aggrExclusiveCtaPrev,
const int idxTileNext,
ScanOpT& scan_op)
{
const int laneIdx = specialRegisters.laneIdx;
const ::cuda::std::uint32_t lanemaskEq = ::cuda::ptx::get_sreg_lanemask_eq();

// Adjust the left pointer down to the nearest 32-multiple so we do batched sums
int idxTileCur = (idxTilePrev / 32) * 32;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: Use cuda::round_down.

AccumT aggrExclusiveCtaCur = aggrExclusiveCtaPrev;

using warp_reduce_t = WarpReduce<AccumT>;
static_assert(sizeof(typename warp_reduce_t::TempStorage) <= 4,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why 4? I assume this is sizeof(uint32_t)? If so, best to say sizeof(uint32_t) instead (or better yet, refer to an actual type/value so that when that size is changed, the check automatically is as well).

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because the TempStorage is a struct with further nested types that have no value, but because there are data members it has a size of 1. For some reason @elstehle chose 4 here, but the check is basically that no temporary storage is required. Btw, is_empty also does not work here.

"WarpReduce with non-trivial temporary storage is not supported yet in this kernel.");
[[maybe_unused]] typename warp_reduce_t::TempStorage temp_storage;

using warp_reduce_or_t = WarpReduce<::cuda::std::uint32_t>;
typename warp_reduce_or_t::TempStorage temp_storage_or;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: typename is not needed here I think. WarpReduce<uint32_t> is not dependent on any of your template params.

warp_reduce_or_t warp_reduce_or{temp_storage_or};
constexpr ::cuda::std::bit_or<::cuda::std::uint32_t> or_op{};

while (idxTileCur < idxTileNext)
{
tile_state_t<AccumT> regTmpStates[numTileStatesPerThread];
warpLoadLookahead(laneIdx, regTmpStates, ptrTileStates, idxTileCur, idxTileNext);

for (int idx = 0; idx < numTileStatesPerThread; ++idx)
{
// Bitmask with a 1 bit in the position of the current lane if current lane has a tile aggregate
const ::cuda::std::uint32_t lane_has_aggregate =
lanemaskEq * (regTmpStates[idx].state == scan_state::tile_aggregate);

// Bitmask with 1 bits indicating which lane has a tile aggregate
const ::cuda::std::uint32_t warp_has_aggregate_mask = warp_reduce_or.Reduce(lane_has_aggregate, or_op);

// Bitmask with 1 bits for all rightmost lanes having a tile aggregate
const ::cuda::std::uint32_t warp_right_aggregates_mask = warp_has_aggregate_mask & (~warp_has_aggregate_mask - 1);

const ::cuda::std::uint32_t warp_right_aggregates_count = ::cuda::std::popcount(warp_right_aggregates_mask);

// Only reduce once a fixed number of contiguous tile aggregates are available, so the reduction order is fixed.
const ::cuda::std::uint32_t expected_count =
static_cast<::cuda::std::uint32_t>(::cuda::std::min(32, idxTileNext - idxTileCur));
if (warp_right_aggregates_count < expected_count)
{
break;
}

const bool use_value = lanemaskEq & warp_right_aggregates_mask;
const AccumT value = use_value ? regTmpStates[idx].value : cuda::identity_element<ScanOpT, AccumT>();
const AccumT local_aggr = warp_reduce_t{temp_storage}.Reduce(value, scan_op);

if (expected_count == 32)
{
aggrExclusiveCtaCur = idxTileCur == 0 ? local_aggr : scan_op(aggrExclusiveCtaCur, local_aggr);
idxTileCur += 32;
}
else
{
const AccumT full_aggr = idxTileCur == 0 ? local_aggr : scan_op(aggrExclusiveCtaCur, local_aggr);
idxTilePrev = idxTileCur;
aggrExclusiveCtaPrev = aggrExclusiveCtaCur;
return full_aggr;
}
}
}

idxTilePrev = idxTileNext;
aggrExclusiveCtaPrev = aggrExclusiveCtaCur;
return aggrExclusiveCtaCur; // must only be valid in lane_0
}

#endif // __cccl_ptx_isa >= 860
} // namespace detail::warpspeed

Expand Down
3 changes: 2 additions & 1 deletion cub/cub/device/dispatch/kernels/kernel_scan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,8 @@ __launch_bounds__(device_scan_launch_bounds<PolicySelector>, 1) _CCCL_KERNEL_ATT
NV_PROVIDES_SM_100, ({
auto scan_params = scanKernelParams<it_value_t<InputIteratorT>, it_value_t<OutputIteratorT>, AccumT>{
d_in, d_out, tile_state.warpspeed, num_items, num_stages};
device_scan_warpspeed_body<PolicySelector, ForceInclusive, RealInitValueT>(scan_params, scan_op, init_value);
device_scan_warpspeed_body<PolicySelector, ForceInclusive, RealInitValueT, StableReductionOrder>(
scan_params, scan_op, init_value);
}));
#else
static_assert(sizeof(d_in) == 0,
Expand Down
40 changes: 31 additions & 9 deletions cub/cub/device/dispatch/kernels/kernel_scan_warpspeed.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,8 @@ template <typename PolicySelector,
typename AccumT,
typename ScanOpT,
typename RealInitValueT,
bool ForceInclusive>
bool ForceInclusive,
bool StableReductionOrder = false>
struct warpspeed_scan_closure
{
static constexpr scan_warpspeed_policy policy = current_policy<PolicySelector>().warpspeed;
Expand Down Expand Up @@ -327,14 +328,27 @@ struct warpspeed_scan_closure

if (!is_first_tile)
{
AccumT regAggrExclusiveCta = warpspeed::warpIncrementalLookahead<look_ahead_items_per_thread>(
specialRegisters, params.ptrTileStates, idxTilePrev, AggrExclusiveCtaPrev, idxTile, scan_op);
if (squad.isLeaderThread())
if constexpr (StableReductionOrder)
{
// The stable-order version updates idxTilePrev/AggrExclusiveCtaPrev itself
AccumT regAggrExclusiveCta = warpspeed::warpIncrementalLookaheadStable<look_ahead_items_per_thread>(
specialRegisters, params.ptrTileStates, idxTilePrev, AggrExclusiveCtaPrev, idxTile, scan_op);
if (squad.isLeaderThread())
{
refAggrExclusiveCtaW.data() = regAggrExclusiveCta;
}
}
else
{
refAggrExclusiveCtaW.data() = regAggrExclusiveCta;
AccumT regAggrExclusiveCta = warpspeed::warpIncrementalLookahead<look_ahead_items_per_thread>(
specialRegisters, params.ptrTileStates, idxTilePrev, AggrExclusiveCtaPrev, idxTile, scan_op);
if (squad.isLeaderThread())
{
refAggrExclusiveCtaW.data() = regAggrExclusiveCta;
}
AggrExclusiveCtaPrev = regAggrExclusiveCta;
idxTilePrev = idxTile;
}
AggrExclusiveCtaPrev = regAggrExclusiveCta;
idxTilePrev = idxTile;
}
}

Expand Down Expand Up @@ -825,6 +839,7 @@ struct warpspeed_scan_closure
template <typename PolicySelector,
bool ForceInclusive,
typename RealInitValueT,
bool StableReductionOrder,
typename InputT,
typename OutputT,
typename AccumT,
Expand All @@ -849,8 +864,15 @@ _CCCL_DEVICE_API _CCCL_FORCEINLINE void device_scan_warpspeed_body(
}();

// Dispatch each warp to its respective squad
using closure_t =
warpspeed_scan_closure<PolicySelector, InputT, OutputT, AccumT, ScanOpT, RealInitValueT, ForceInclusive>;
using closure_t = warpspeed_scan_closure<
PolicySelector,
InputT,
OutputT,
AccumT,
ScanOpT,
RealInitValueT,
ForceInclusive,
StableReductionOrder>;
warpspeed::squadDispatch(
specialRegisters, closure_t::scanSquads, [&](warpspeed::Squad squad) _CCCL_FORCEINLINE_LAMBDA {
// we load the initial value after the squad dispatch, so only the squads needing it emit an LDG
Expand Down
4 changes: 3 additions & 1 deletion cub/cub/device/dispatch/tuning/tuning_scan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1035,7 +1035,9 @@ struct policy_selector
[[nodiscard]] _CCCL_HOST_DEVICE_API constexpr auto operator()(::cuda::compute_capability cc) const -> scan_policy
{
// we first try to get the valid warpspeed implementation. if we can't run it, fall back to the old scan impl.
if (!require_stable_reduction_order)
// For stable reduction order (fp + plus), warpspeed can only be used on sm_100+, Older arches fall back to classic
// lookback stable reduction order verison below.
if (!require_stable_reduction_order || cc >= ::cuda::compute_capability{10, 0})
{
const auto warpspeed_policy_opt = get_warpspeed_policy(cc);
if (warpspeed_policy_opt && can_use_warpspeed(cc, *warpspeed_policy_opt))
Expand Down
Loading