diff --git a/c/parallel/include/cccl/c/radix_sort.h b/c/parallel/include/cccl/c/radix_sort.h index ec095a515ae..560bf6229f7 100644 --- a/c/parallel/include/cccl/c/radix_sort.h +++ b/c/parallel/include/cccl/c/radix_sort.h @@ -40,6 +40,8 @@ typedef struct cccl_device_radix_sort_build_result_t CUkernel alt_downsweep_kernel; CUkernel histogram_kernel; CUkernel exclusive_sum_kernel; + CUkernel init_bins_and_counters_kernel; + CUkernel init_lookback_kernel; CUkernel onesweep_kernel; cccl_sort_order_t order; void* runtime_policy; diff --git a/c/parallel/src/radix_sort.cu b/c/parallel/src/radix_sort.cu index 2e4ae462581..b17937fb66e 100644 --- a/c/parallel/src/radix_sort.cu +++ b/c/parallel/src/radix_sort.cu @@ -105,6 +105,22 @@ std::string get_exclusive_sum_kernel_name(std::string_view chained_policy_t, std return std::format("cub::detail::radix_sort::DeviceRadixSortExclusiveSumKernel<{0}, {1}>", chained_policy_t, offset_t); } +std::string get_init_kernel_name( + std::string_view chained_policy_t, + bool trigger_at_start, + bool fence_before_end_trigger, + std::string_view init_t0, + std::string_view init_t1) +{ + return std::format( + "cub::detail::radix_sort::DeviceRadixSortInitKernel<{0}, {1}, {2}, {3}, {4}>", + chained_policy_t, + trigger_at_start ? "true" : "false", + fence_before_end_trigger ? "true" : "false", + init_t0, + init_t1); +} + std::string get_onesweep_kernel_name( std::string_view chained_policy_t, cccl_sort_order_t sort_order, @@ -166,6 +182,16 @@ struct radix_sort_kernel_source return build.exclusive_sum_kernel; } + CUkernel RadixSortInitBinsAndCountersKernel() const + { + return build.init_bins_and_counters_kernel; + } + + CUkernel RadixSortInitLookbackKernel() const + { + return build.init_lookback_kernel; + } + CUkernel RadixSortOnesweepKernel() const { return build.onesweep_kernel; @@ -289,6 +315,9 @@ static_assert(device_radix_sort_policy()(current_tuning_cc()) == {6}, "Host gene std::string histogram_kernel_name = radix_sort::get_histogram_kernel_name(chained_policy_t, sort_order, key_cpp, offset_t); std::string exclusive_sum_kernel_name = radix_sort::get_exclusive_sum_kernel_name(chained_policy_t, offset_t); + std::string init_bins_and_counters_kernel_name = + radix_sort::get_init_kernel_name(chained_policy_t, true, false, "int", offset_t); + std::string init_lookback_kernel_name = radix_sort::get_init_kernel_name(chained_policy_t, false, true, "int", "int"); std::string onesweep_kernel_name = radix_sort::get_onesweep_kernel_name(chained_policy_t, sort_order, key_cpp, value_cpp, offset_t); std::string single_tile_kernel_lowered_name; @@ -299,6 +328,8 @@ static_assert(device_radix_sort_policy()(current_tuning_cc()) == {6}, "Host gene std::string alt_downsweep_kernel_lowered_name; std::string histogram_kernel_lowered_name; std::string exclusive_sum_kernel_lowered_name; + std::string init_bins_and_counters_kernel_lowered_name; + std::string init_lookback_kernel_lowered_name; std::string onesweep_kernel_lowered_name; const std::string arch = std::format("-arch=sm_{0}{1}", cc_major, cc_minor); @@ -336,6 +367,8 @@ static_assert(device_radix_sort_policy()(current_tuning_cc()) == {6}, "Host gene ->add_expression({alt_downsweep_kernel_name}) ->add_expression({histogram_kernel_name}) ->add_expression({exclusive_sum_kernel_name}) + ->add_expression({init_bins_and_counters_kernel_name}) + ->add_expression({init_lookback_kernel_name}) ->add_expression({onesweep_kernel_name}) ->compile_program({args.data(), args.size()}) ->get_name({single_tile_kernel_name, single_tile_kernel_lowered_name}) @@ -346,6 +379,8 @@ static_assert(device_radix_sort_policy()(current_tuning_cc()) == {6}, "Host gene ->get_name({alt_downsweep_kernel_name, alt_downsweep_kernel_lowered_name}) ->get_name({histogram_kernel_name, histogram_kernel_lowered_name}) ->get_name({exclusive_sum_kernel_name, exclusive_sum_kernel_lowered_name}) + ->get_name({init_bins_and_counters_kernel_name, init_bins_and_counters_kernel_lowered_name}) + ->get_name({init_lookback_kernel_name, init_lookback_kernel_lowered_name}) ->get_name({onesweep_kernel_name, onesweep_kernel_lowered_name}) ->link_program() ->add_link_list(linkable_list) @@ -364,6 +399,10 @@ static_assert(device_radix_sort_policy()(current_tuning_cc()) == {6}, "Host gene check(cuLibraryGetKernel(&build_ptr->histogram_kernel, build_ptr->library, histogram_kernel_lowered_name.c_str())); check(cuLibraryGetKernel( &build_ptr->exclusive_sum_kernel, build_ptr->library, exclusive_sum_kernel_lowered_name.c_str())); + check(cuLibraryGetKernel( + &build_ptr->init_bins_and_counters_kernel, build_ptr->library, init_bins_and_counters_kernel_lowered_name.c_str())); + check(cuLibraryGetKernel( + &build_ptr->init_lookback_kernel, build_ptr->library, init_lookback_kernel_lowered_name.c_str())); check(cuLibraryGetKernel(&build_ptr->onesweep_kernel, build_ptr->library, onesweep_kernel_lowered_name.c_str())); build_ptr->cc = cc_major * 10 + cc_minor; diff --git a/cub/cub/agent/agent_radix_sort_histogram.cuh b/cub/cub/agent/agent_radix_sort_histogram.cuh index 7e54e31c036..08a10850105 100644 --- a/cub/cub/agent/agent_radix_sort_histogram.cuh +++ b/cub/cub/agent/agent_radix_sort_histogram.cuh @@ -257,6 +257,7 @@ struct AgentRadixSortHistogram __syncthreads(); // Accumulate the result in global memory. + _CCCL_PDL_GRID_DEPENDENCY_SYNC(); AccumulateGlobalHistograms(); __syncthreads(); } diff --git a/cub/cub/agent/agent_radix_sort_onesweep.cuh b/cub/cub/agent/agent_radix_sort_onesweep.cuh index 334026c24b3..009e55ecebd 100644 --- a/cub/cub/agent/agent_radix_sort_onesweep.cuh +++ b/cub/cub/agent/agent_radix_sort_onesweep.cuh @@ -241,6 +241,8 @@ struct AgentRadixSortOnesweep { bins[u] = other_bins[u]; } + + _CCCL_PDL_GRID_DEPENDENCY_SYNC(); agent.LookbackPartial(bins); agent.TryShortCircuit(keys, bins); diff --git a/cub/cub/device/dispatch/dispatch_radix_sort.cuh b/cub/cub/device/dispatch/dispatch_radix_sort.cuh index c6eaf3ca33c..ba560c901ac 100644 --- a/cub/cub/device/dispatch/dispatch_radix_sort.cuh +++ b/cub/cub/device/dispatch/dispatch_radix_sort.cuh @@ -79,6 +79,12 @@ struct DeviceRadixSortKernelSource CUB_DEFINE_KERNEL_GETTER(RadixSortExclusiveSumKernel, DeviceRadixSortExclusiveSumKernel); + CUB_DEFINE_KERNEL_GETTER(RadixSortInitBinsAndCountersKernel, + DeviceRadixSortInitKernel); + + CUB_DEFINE_KERNEL_GETTER(RadixSortInitLookbackKernel, + DeviceRadixSortInitKernel); + CUB_DEFINE_KERNEL_GETTER( RadixSortOnesweepKernel, DeviceRadixSortOnesweepKernel); @@ -591,20 +597,13 @@ private: ValueT* d_values_tmp2 = (ValueT*) allocations[3]; AtomicOffsetT* d_ctrs = (AtomicOffsetT*) allocations[4]; - // initialization - if (const auto error = - CubDebug(cudaMemsetAsync(d_ctrs, 0, num_portions * num_passes * sizeof(AtomicOffsetT), stream))) - { - return error; - } + constexpr OffsetT PDL_MAX_ITEMS = static_cast(1) << 20; + const bool use_pdl = num_items <= PDL_MAX_ITEMS; - // compute num_passes histograms with RADIX_DIGITS bins each - if (const auto error = CubDebug(cudaMemsetAsync(d_bins, 0, num_passes * RADIX_DIGITS * sizeof(OffsetT), stream))) - { - return error; - } - int device = -1; - int num_sms = 0; + const size_t num_counter_items = static_cast(num_portions) * num_passes; + const size_t num_bin_items = static_cast(num_passes) * RADIX_DIGITS; + int device = -1; + int num_sms = 0; if (const auto error = CubDebug(cudaGetDevice(&device))) { @@ -638,18 +637,6 @@ private: policy.histogram.radix_bits); #endif - if (const auto error = CubDebug( - launcher_factory(histo_blocks_per_sm * num_sms, HISTO_BLOCK_THREADS, 0, stream) - .doit(histogram_kernel, d_bins, d_keys.Current(), num_items, begin_bit, end_bit, decomposer))) - { - return error; - } - - if (const auto error = CubDebug(detail::DebugSyncStream(stream))) - { - return error; - } - // exclusive sums to determine starts const int SCAN_BLOCK_THREADS = policy.exclusive_sum.threads_per_block; @@ -662,7 +649,60 @@ private: policy.exclusive_sum.radix_bits); #endif - if (const auto error = CubDebug(launcher_factory(num_passes, SCAN_BLOCK_THREADS, 0, stream) + // Initialization is intentionally adjacent to the histogram launch. For the PDL path, this avoids consuming the + // short init kernel's runtime in host-side launch setup work before the dependent histogram is submitted. + if (use_pdl) + { + constexpr int INIT_STARTUP_THREADS = 256; + const size_t num_init_items = ::cuda::std::max(num_counter_items, num_bin_items); + const int init_startup_blocks = + static_cast(::cuda::ceil_div(num_init_items, static_cast(INIT_STARTUP_THREADS))); + +#ifdef CUB_DEBUG_LOG + _CubLog("Invoking init_bins_and_counters_kernel<<<%d, %d, 0, %lld>>>()\n", + init_startup_blocks, + INIT_STARTUP_THREADS, + reinterpret_cast(stream)); +#endif + + if (const auto error = CubDebug( + launcher_factory(init_startup_blocks, INIT_STARTUP_THREADS, 0, stream, use_pdl) + .doit( + kernel_source.RadixSortInitBinsAndCountersKernel(), d_ctrs, num_counter_items, d_bins, num_bin_items))) + { + return error; + } + } + else + { + if (const auto error = CubDebug(cudaMemsetAsync(d_ctrs, 0, num_counter_items * sizeof(AtomicOffsetT), stream))) + { + return error; + } + + // compute num_passes histograms with RADIX_DIGITS bins each + if (const auto error = CubDebug(cudaMemsetAsync(d_bins, 0, num_bin_items * sizeof(OffsetT), stream))) + { + return error; + } + } + + if (const auto error = CubDebug( + launcher_factory(histo_blocks_per_sm * num_sms, HISTO_BLOCK_THREADS, 0, stream, use_pdl) + .doit(histogram_kernel, d_bins, d_keys.Current(), num_items, begin_bit, end_bit, decomposer))) + { + return error; + } + + if (!use_pdl) + { + if (const auto error = CubDebug(detail::DebugSyncStream(stream))) + { + return error; + } + } + + if (const auto error = CubDebug(launcher_factory(num_passes, SCAN_BLOCK_THREADS, 0, stream, use_pdl) .doit(kernel_source.RadixSortExclusiveSumKernel(), d_bins))) { return error; @@ -672,6 +712,7 @@ private: { return error; } + // use the other buffer if no overwrite is allowed KeyT* d_keys_tmp = d_keys.Alternate(); ValueT* d_values_tmp = d_values.Alternate(); @@ -689,12 +730,30 @@ private: PortionOffsetT portion_num_items = static_cast( ::cuda::std::min(num_items - portion * PORTION_SIZE, static_cast(PORTION_SIZE))); - PortionOffsetT num_blocks = ::cuda::ceil_div(portion_num_items, ONESWEEP_TILE_ITEMS); + PortionOffsetT num_blocks = ::cuda::ceil_div(portion_num_items, ONESWEEP_TILE_ITEMS); + const size_t num_lookback_items = static_cast(num_blocks) * RADIX_DIGITS; - if (const auto error = - CubDebug(cudaMemsetAsync(d_lookback, 0, num_blocks * RADIX_DIGITS * sizeof(AtomicOffsetT), stream))) + if (use_pdl) { - return error; + constexpr int INIT_LOOKBACK_THREADS = 256; + const int init_lookback_blocks = + static_cast(::cuda::ceil_div(num_lookback_items, static_cast(INIT_LOOKBACK_THREADS))); + + if (const auto error = CubDebug( + launcher_factory(init_lookback_blocks, INIT_LOOKBACK_THREADS, 0, stream, use_pdl) + .doit( + kernel_source.RadixSortInitLookbackKernel(), d_lookback, num_lookback_items, d_lookback, size_t{0}))) + { + return error; + } + } + else + { + if (const auto error = + CubDebug(cudaMemsetAsync(d_lookback, 0, num_lookback_items * sizeof(AtomicOffsetT), stream))) + { + return error; + } } // log onesweep_kernel configuration @@ -714,7 +773,7 @@ private: auto onesweep_kernel = kernel_source.RadixSortOnesweepKernel(); if (const auto error = CubDebug( - launcher_factory(num_blocks, ONESWEEP_BLOCK_THREADS, 0, stream) + launcher_factory(num_blocks, ONESWEEP_BLOCK_THREADS, 0, stream, use_pdl) .doit( onesweep_kernel, d_lookback, diff --git a/cub/cub/device/dispatch/kernels/kernel_radix_sort.cuh b/cub/cub/device/dispatch/kernels/kernel_radix_sort.cuh index 5ae4231b9ea..85d9d110461 100644 --- a/cub/cub/device/dispatch/kernels/kernel_radix_sort.cuh +++ b/cub/cub/device/dispatch/kernels/kernel_radix_sort.cuh @@ -477,6 +477,46 @@ __launch_bounds__(current_policy().histogram.threads_per_block) agent.Process(); } +template +_CCCL_KERNEL_ATTRIBUTES void DeviceRadixSortInitKernel( + _CCCL_GRID_CONSTANT InitT0* const d_items0, + _CCCL_GRID_CONSTANT const size_t num_items0, + _CCCL_GRID_CONSTANT InitT1* const d_items1, + _CCCL_GRID_CONSTANT const size_t num_items1) +{ + if constexpr (TRIGGER_AT_START) + { + _CCCL_PDL_GRID_DEPENDENCY_SYNC(); + _CCCL_PDL_TRIGGER_NEXT_LAUNCH(); + } + + const size_t stride = static_cast(blockDim.x) * gridDim.x; + for (size_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + idx < ::cuda::std::max(num_items0, num_items1); + idx += stride) + { + if (idx < num_items0) + { + d_items0[idx] = 0; + } + if (idx < num_items1) + { + d_items1[idx] = 0; + } + } + + if constexpr (FENCE_BEFORE_END_TRIGGER) + { + __threadfence(); + } + + if constexpr (!TRIGGER_AT_START) + { + _CCCL_PDL_GRID_DEPENDENCY_SYNC(); + _CCCL_PDL_TRIGGER_NEXT_LAUNCH(); + } +} + template ; __shared__ typename BlockScan::TempStorage temp_storage; + _CCCL_PDL_GRID_DEPENDENCY_SYNC(); + // load the bins OffsetT bins[BINS_PER_THREAD]; int bin_start = blockIdx.x * RADIX_DIGITS;