Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions c/parallel/include/cccl/c/radix_sort.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ typedef struct cccl_device_radix_sort_build_result_t
CUkernel alt_downsweep_kernel;
CUkernel histogram_kernel;
CUkernel exclusive_sum_kernel;
CUkernel init_bins_and_counters_kernel;
CUkernel init_lookback_kernel;
CUkernel onesweep_kernel;
cccl_sort_order_t order;
void* runtime_policy;
Expand Down
39 changes: 39 additions & 0 deletions c/parallel/src/radix_sort.cu
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,22 @@ std::string get_exclusive_sum_kernel_name(std::string_view chained_policy_t, std
return std::format("cub::detail::radix_sort::DeviceRadixSortExclusiveSumKernel<{0}, {1}>", chained_policy_t, offset_t);
}

std::string get_init_kernel_name(
std::string_view chained_policy_t,
bool trigger_at_start,
bool fence_before_end_trigger,
std::string_view init_t0,
std::string_view init_t1)
{
return std::format(
"cub::detail::radix_sort::DeviceRadixSortInitKernel<{0}, {1}, {2}, {3}, {4}>",
chained_policy_t,
trigger_at_start ? "true" : "false",
fence_before_end_trigger ? "true" : "false",
init_t0,
init_t1);
}

std::string get_onesweep_kernel_name(
std::string_view chained_policy_t,
cccl_sort_order_t sort_order,
Expand Down Expand Up @@ -166,6 +182,16 @@ struct radix_sort_kernel_source
return build.exclusive_sum_kernel;
}

CUkernel RadixSortInitBinsAndCountersKernel() const
{
return build.init_bins_and_counters_kernel;
}

CUkernel RadixSortInitLookbackKernel() const
{
return build.init_lookback_kernel;
}

CUkernel RadixSortOnesweepKernel() const
{
return build.onesweep_kernel;
Expand Down Expand Up @@ -289,6 +315,9 @@ static_assert(device_radix_sort_policy()(current_tuning_cc()) == {6}, "Host gene
std::string histogram_kernel_name =
radix_sort::get_histogram_kernel_name(chained_policy_t, sort_order, key_cpp, offset_t);
std::string exclusive_sum_kernel_name = radix_sort::get_exclusive_sum_kernel_name(chained_policy_t, offset_t);
std::string init_bins_and_counters_kernel_name =
radix_sort::get_init_kernel_name(chained_policy_t, true, false, "int", offset_t);
std::string init_lookback_kernel_name = radix_sort::get_init_kernel_name(chained_policy_t, false, true, "int", "int");
std::string onesweep_kernel_name =
radix_sort::get_onesweep_kernel_name(chained_policy_t, sort_order, key_cpp, value_cpp, offset_t);
std::string single_tile_kernel_lowered_name;
Expand All @@ -299,6 +328,8 @@ static_assert(device_radix_sort_policy()(current_tuning_cc()) == {6}, "Host gene
std::string alt_downsweep_kernel_lowered_name;
std::string histogram_kernel_lowered_name;
std::string exclusive_sum_kernel_lowered_name;
std::string init_bins_and_counters_kernel_lowered_name;
std::string init_lookback_kernel_lowered_name;
std::string onesweep_kernel_lowered_name;

const std::string arch = std::format("-arch=sm_{0}{1}", cc_major, cc_minor);
Expand Down Expand Up @@ -336,6 +367,8 @@ static_assert(device_radix_sort_policy()(current_tuning_cc()) == {6}, "Host gene
->add_expression({alt_downsweep_kernel_name})
->add_expression({histogram_kernel_name})
->add_expression({exclusive_sum_kernel_name})
->add_expression({init_bins_and_counters_kernel_name})
->add_expression({init_lookback_kernel_name})
->add_expression({onesweep_kernel_name})
->compile_program({args.data(), args.size()})
->get_name({single_tile_kernel_name, single_tile_kernel_lowered_name})
Expand All @@ -346,6 +379,8 @@ static_assert(device_radix_sort_policy()(current_tuning_cc()) == {6}, "Host gene
->get_name({alt_downsweep_kernel_name, alt_downsweep_kernel_lowered_name})
->get_name({histogram_kernel_name, histogram_kernel_lowered_name})
->get_name({exclusive_sum_kernel_name, exclusive_sum_kernel_lowered_name})
->get_name({init_bins_and_counters_kernel_name, init_bins_and_counters_kernel_lowered_name})
->get_name({init_lookback_kernel_name, init_lookback_kernel_lowered_name})
->get_name({onesweep_kernel_name, onesweep_kernel_lowered_name})
->link_program()
->add_link_list(linkable_list)
Expand All @@ -364,6 +399,10 @@ static_assert(device_radix_sort_policy()(current_tuning_cc()) == {6}, "Host gene
check(cuLibraryGetKernel(&build_ptr->histogram_kernel, build_ptr->library, histogram_kernel_lowered_name.c_str()));
check(cuLibraryGetKernel(
&build_ptr->exclusive_sum_kernel, build_ptr->library, exclusive_sum_kernel_lowered_name.c_str()));
check(cuLibraryGetKernel(
&build_ptr->init_bins_and_counters_kernel, build_ptr->library, init_bins_and_counters_kernel_lowered_name.c_str()));
check(cuLibraryGetKernel(
&build_ptr->init_lookback_kernel, build_ptr->library, init_lookback_kernel_lowered_name.c_str()));
check(cuLibraryGetKernel(&build_ptr->onesweep_kernel, build_ptr->library, onesweep_kernel_lowered_name.c_str()));

build_ptr->cc = cc_major * 10 + cc_minor;
Expand Down
1 change: 1 addition & 0 deletions cub/cub/agent/agent_radix_sort_histogram.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ struct AgentRadixSortHistogram
__syncthreads();

// Accumulate the result in global memory.
_CCCL_PDL_GRID_DEPENDENCY_SYNC();
AccumulateGlobalHistograms();
__syncthreads();
}
Expand Down
2 changes: 2 additions & 0 deletions cub/cub/agent/agent_radix_sort_onesweep.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,8 @@ struct AgentRadixSortOnesweep
{
bins[u] = other_bins[u];
}

_CCCL_PDL_GRID_DEPENDENCY_SYNC();
agent.LookbackPartial(bins);

agent.TryShortCircuit(keys, bins);
Expand Down
121 changes: 90 additions & 31 deletions cub/cub/device/dispatch/dispatch_radix_sort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@ struct DeviceRadixSortKernelSource

CUB_DEFINE_KERNEL_GETTER(RadixSortExclusiveSumKernel, DeviceRadixSortExclusiveSumKernel<PolicySelector, OffsetT>);

CUB_DEFINE_KERNEL_GETTER(RadixSortInitBinsAndCountersKernel,
DeviceRadixSortInitKernel<PolicySelector, true, false, int, OffsetT>);

CUB_DEFINE_KERNEL_GETTER(RadixSortInitLookbackKernel,
DeviceRadixSortInitKernel<PolicySelector, false, true, int, int>);

CUB_DEFINE_KERNEL_GETTER(
RadixSortOnesweepKernel,
DeviceRadixSortOnesweepKernel<PolicySelector, Order, KeyT, ValueT, OffsetT, int, int, DecomposerT>);
Expand Down Expand Up @@ -591,20 +597,13 @@ private:
ValueT* d_values_tmp2 = (ValueT*) allocations[3];
AtomicOffsetT* d_ctrs = (AtomicOffsetT*) allocations[4];

// initialization
if (const auto error =
CubDebug(cudaMemsetAsync(d_ctrs, 0, num_portions * num_passes * sizeof(AtomicOffsetT), stream)))
{
return error;
}
constexpr OffsetT PDL_MAX_ITEMS = static_cast<OffsetT>(1) << 20;
const bool use_pdl = num_items <= PDL_MAX_ITEMS;

// compute num_passes histograms with RADIX_DIGITS bins each
if (const auto error = CubDebug(cudaMemsetAsync(d_bins, 0, num_passes * RADIX_DIGITS * sizeof(OffsetT), stream)))
{
return error;
}
int device = -1;
int num_sms = 0;
const size_t num_counter_items = static_cast<size_t>(num_portions) * num_passes;
const size_t num_bin_items = static_cast<size_t>(num_passes) * RADIX_DIGITS;
int device = -1;
int num_sms = 0;

if (const auto error = CubDebug(cudaGetDevice(&device)))
{
Expand Down Expand Up @@ -638,18 +637,6 @@ private:
policy.histogram.radix_bits);
#endif

if (const auto error = CubDebug(
launcher_factory(histo_blocks_per_sm * num_sms, HISTO_BLOCK_THREADS, 0, stream)
.doit(histogram_kernel, d_bins, d_keys.Current(), num_items, begin_bit, end_bit, decomposer)))
{
return error;
}

if (const auto error = CubDebug(detail::DebugSyncStream(stream)))
{
return error;
}

// exclusive sums to determine starts
const int SCAN_BLOCK_THREADS = policy.exclusive_sum.threads_per_block;

Expand All @@ -662,7 +649,60 @@ private:
policy.exclusive_sum.radix_bits);
#endif

if (const auto error = CubDebug(launcher_factory(num_passes, SCAN_BLOCK_THREADS, 0, stream)
// Initialization is intentionally adjacent to the histogram launch. For the PDL path, this avoids consuming the
// short init kernel's runtime in host-side launch setup work before the dependent histogram is submitted.
if (use_pdl)
{
constexpr int INIT_STARTUP_THREADS = 256;
const size_t num_init_items = ::cuda::std::max(num_counter_items, num_bin_items);
const int init_startup_blocks =
static_cast<int>(::cuda::ceil_div(num_init_items, static_cast<size_t>(INIT_STARTUP_THREADS)));

#ifdef CUB_DEBUG_LOG
_CubLog("Invoking init_bins_and_counters_kernel<<<%d, %d, 0, %lld>>>()\n",
init_startup_blocks,
INIT_STARTUP_THREADS,
reinterpret_cast<long long>(stream));
#endif

if (const auto error = CubDebug(
launcher_factory(init_startup_blocks, INIT_STARTUP_THREADS, 0, stream, use_pdl)
.doit(
kernel_source.RadixSortInitBinsAndCountersKernel(), d_ctrs, num_counter_items, d_bins, num_bin_items)))
{
return error;
}
}
else
{
if (const auto error = CubDebug(cudaMemsetAsync(d_ctrs, 0, num_counter_items * sizeof(AtomicOffsetT), stream)))
{
return error;
}

// compute num_passes histograms with RADIX_DIGITS bins each
if (const auto error = CubDebug(cudaMemsetAsync(d_bins, 0, num_bin_items * sizeof(OffsetT), stream)))
{
return error;
}
}

if (const auto error = CubDebug(
launcher_factory(histo_blocks_per_sm * num_sms, HISTO_BLOCK_THREADS, 0, stream, use_pdl)
.doit(histogram_kernel, d_bins, d_keys.Current(), num_items, begin_bit, end_bit, decomposer)))
{
return error;
}

if (!use_pdl)
{
if (const auto error = CubDebug(detail::DebugSyncStream(stream)))
{
return error;
}
}

if (const auto error = CubDebug(launcher_factory(num_passes, SCAN_BLOCK_THREADS, 0, stream, use_pdl)
.doit(kernel_source.RadixSortExclusiveSumKernel(), d_bins)))
{
return error;
Expand All @@ -672,6 +712,7 @@ private:
{
return error;
}

// use the other buffer if no overwrite is allowed
KeyT* d_keys_tmp = d_keys.Alternate();
ValueT* d_values_tmp = d_values.Alternate();
Expand All @@ -689,12 +730,30 @@ private:
PortionOffsetT portion_num_items = static_cast<PortionOffsetT>(
::cuda::std::min(num_items - portion * PORTION_SIZE, static_cast<OffsetT>(PORTION_SIZE)));

PortionOffsetT num_blocks = ::cuda::ceil_div(portion_num_items, ONESWEEP_TILE_ITEMS);
PortionOffsetT num_blocks = ::cuda::ceil_div(portion_num_items, ONESWEEP_TILE_ITEMS);
const size_t num_lookback_items = static_cast<size_t>(num_blocks) * RADIX_DIGITS;

if (const auto error =
CubDebug(cudaMemsetAsync(d_lookback, 0, num_blocks * RADIX_DIGITS * sizeof(AtomicOffsetT), stream)))
if (use_pdl)
{
return error;
constexpr int INIT_LOOKBACK_THREADS = 256;
const int init_lookback_blocks =
static_cast<int>(::cuda::ceil_div(num_lookback_items, static_cast<size_t>(INIT_LOOKBACK_THREADS)));

if (const auto error = CubDebug(
launcher_factory(init_lookback_blocks, INIT_LOOKBACK_THREADS, 0, stream, use_pdl)
.doit(
kernel_source.RadixSortInitLookbackKernel(), d_lookback, num_lookback_items, d_lookback, size_t{0})))
{
return error;
}
}
else
{
if (const auto error =
CubDebug(cudaMemsetAsync(d_lookback, 0, num_lookback_items * sizeof(AtomicOffsetT), stream)))
{
return error;
}
}

// log onesweep_kernel configuration
Expand All @@ -714,7 +773,7 @@ private:
auto onesweep_kernel = kernel_source.RadixSortOnesweepKernel();

if (const auto error = CubDebug(
launcher_factory(num_blocks, ONESWEEP_BLOCK_THREADS, 0, stream)
launcher_factory(num_blocks, ONESWEEP_BLOCK_THREADS, 0, stream, use_pdl)
.doit(
onesweep_kernel,
d_lookback,
Expand Down
42 changes: 42 additions & 0 deletions cub/cub/device/dispatch/kernels/kernel_radix_sort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,46 @@ __launch_bounds__(current_policy<PolicySelector>().histogram.threads_per_block)
agent.Process();
}

template <typename PolicySelector, bool TRIGGER_AT_START, bool FENCE_BEFORE_END_TRIGGER, typename InitT0, typename InitT1>
_CCCL_KERNEL_ATTRIBUTES void DeviceRadixSortInitKernel(
_CCCL_GRID_CONSTANT InitT0* const d_items0,
_CCCL_GRID_CONSTANT const size_t num_items0,
_CCCL_GRID_CONSTANT InitT1* const d_items1,
_CCCL_GRID_CONSTANT const size_t num_items1)
{
if constexpr (TRIGGER_AT_START)
{
_CCCL_PDL_GRID_DEPENDENCY_SYNC();
_CCCL_PDL_TRIGGER_NEXT_LAUNCH();
}

const size_t stride = static_cast<size_t>(blockDim.x) * gridDim.x;
for (size_t idx = static_cast<size_t>(blockIdx.x) * blockDim.x + threadIdx.x;
idx < ::cuda::std::max(num_items0, num_items1);
idx += stride)
{
if (idx < num_items0)
{
d_items0[idx] = 0;
}
if (idx < num_items1)
{
d_items1[idx] = 0;
}
}

if constexpr (FENCE_BEFORE_END_TRIGGER)
{
__threadfence();
}

Comment on lines +508 to +512
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Therefore, this can be removed as well.

if constexpr (!TRIGGER_AT_START)
{
_CCCL_PDL_GRID_DEPENDENCY_SYNC();
_CCCL_PDL_TRIGGER_NEXT_LAUNCH();
}
Comment on lines +513 to +517
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Neither a grid dependency sync nor a trigger make sense at the end of the kernel. I think this can just be removed.

}

template <typename PolicySelector,
SortOrder Order,
typename KeyT,
Expand Down Expand Up @@ -553,6 +593,8 @@ _CCCL_KERNEL_ATTRIBUTES void DeviceRadixSortExclusiveSumKernel(_CCCL_GRID_CONSTA
using BlockScan = cub::BlockScan<OffsetT, BLOCK_THREADS>;
__shared__ typename BlockScan::TempStorage temp_storage;

_CCCL_PDL_GRID_DEPENDENCY_SYNC();

// load the bins
OffsetT bins[BINS_PER_THREAD];
int bin_start = blockIdx.x * RADIX_DIGITS;
Expand Down