Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions cudax/include/cuda/experimental/__group/group.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class group

using _ParentMappingResult = typename _ParentGroup::__mapping_result_type;
using _MappingResult = decltype(::cuda::std::declval<const _Mapping&>().map(
::cuda::std::declval<const _Unit&>(),
::cuda::std::declval<const _ParentGroup&>(),
__get_initial_mapping_result(::cuda::std::declval<const _ParentGroup&>())));
using _SynchronizerInstance =
Expand All @@ -91,9 +92,9 @@ class group
_SynchronizerInstance __synchronizer_instance_;

[[nodiscard]] _CCCL_DEVICE_API static _MappingResult
__do_mapping(const _Mapping& __mapping, const _ParentGroup& __parent) noexcept
__do_mapping(const _Unit& __unit, const _Mapping& __mapping, const _ParentGroup& __parent) noexcept
{
const auto __mapping_result = __mapping.map(__parent, __get_initial_mapping_result(__parent));
const auto __mapping_result = __mapping.map(__unit, __parent, __get_initial_mapping_result(__parent));
if (__mapping_result.is_valid())
{
_CCCL_ASSERT(__mapping_result.group_rank() < __mapping_result.group_count(), "invalid group rank");
Expand All @@ -108,6 +109,7 @@ class group
}

[[nodiscard]] _CCCL_DEVICE_API static _SynchronizerInstance __make_synchronizer_instance(
const _Unit& __unit,
const _Synchronizer& __synchronizer,
const _ParentGroup& __parent,
const _Mapping& __mapping,
Expand All @@ -123,7 +125,7 @@ class group
return _SynchronizerInstance::invalid();
}
}
return __synchronizer.make_instance(_Unit{}, __parent, __mapping, __mapping_result);
return __synchronizer.make_instance(__unit, __parent, __mapping, __mapping_result);
}

public:
Expand All @@ -141,9 +143,10 @@ public:
const _Synchronizer& __synchronizer) noexcept
: __hier_{__parent.hierarchy()}
, __mapping_{__mapping}
, __mapping_result_{__do_mapping(__mapping_, __parent)}
, __mapping_result_{__do_mapping(__unit, __mapping_, __parent)}
, __synchronizer_{__synchronizer}
, __synchronizer_instance_{__make_synchronizer_instance(__synchronizer_, __parent, __mapping_, __mapping_result_)}
, __synchronizer_instance_{
__make_synchronizer_instance(__unit, __synchronizer_, __parent, __mapping_, __mapping_result_)}
{}

[[nodiscard]] _CCCL_DEVICE_API const hierarchy_type& hierarchy() const noexcept
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,12 @@ public:
: __fn_(::cuda::std::move(__fn))
{}

template <class _ParentGroup, class _PrevMappingResult>
template <class _Unit, class _ParentGroup, class _PrevMappingResult>
[[nodiscard]] _CCCL_DEVICE_API auto
map(const _ParentGroup& __parent, const _PrevMappingResult& __prev_mapping_result) noexcept(
map(const _Unit&, const _ParentGroup& __parent, const _PrevMappingResult& __prev_mapping_result) noexcept(
::cuda::std::is_nothrow_invocable_v<_Fn, const _PrevMappingResult&>)
{
static_assert(::cuda::std::is_same_v<_Unit, thread_level>, "binary_partition can only group threads");
static_assert(::cuda::std::is_same_v<typename _ParentGroup::level_type, warp_level>,
"binary_partition can be only used within warp_level");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,14 @@ class composite_mapping
{
::cuda::std::tuple<_Mappings...> __mappings_;

template <::cuda::std::size_t _Ip = 0, class _ParentGroup, class _PrevMappingResult>
[[nodiscard]] _CCCL_DEVICE_API auto
__map_impl(const _ParentGroup& __parent, const _PrevMappingResult& __prev_mapping_result) const noexcept
template <::cuda::std::size_t _Ip = 0, class _Unit, class _ParentGroup, class _PrevMappingResult>
[[nodiscard]] _CCCL_DEVICE_API auto __map_impl(
const _Unit& __unit, const _ParentGroup& __parent, const _PrevMappingResult& __prev_mapping_result) const noexcept
{
const auto __result = ::cuda::std::get<_Ip>(__mappings_).map(__parent, __prev_mapping_result);
const auto __result = ::cuda::std::get<_Ip>(__mappings_).map(__unit, __parent, __prev_mapping_result);
if constexpr (_Ip + 1 < sizeof...(_Mappings))
{
return __map_impl<_Ip + 1>(__parent, __result);
return __map_impl<_Ip + 1>(__unit, __parent, __result);
}
else
{
Expand All @@ -68,11 +68,11 @@ public:
return __mappings_;
}

template <class _ParentGroup, class _PrevMappingResult>
template <class _Unit, class _ParentGroup, class _PrevMappingResult>
[[nodiscard]] _CCCL_DEVICE_API auto
map(const _ParentGroup& __parent, const _PrevMappingResult& __prev_mapping_result) const noexcept
map(const _Unit& __unit, const _ParentGroup& __parent, const _PrevMappingResult& __prev_mapping_result) const noexcept
{
return __map_impl(__parent, __prev_mapping_result);
return __map_impl(__unit, __parent, __prev_mapping_result);
}
};

Expand Down
20 changes: 12 additions & 8 deletions cudax/include/cuda/experimental/__group/mapping/group_as.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,9 @@ public:
return static_cast<unsigned>(static_count(__i));
}

template <class _ParentGroup, class _PrevMappingResult>
template <class _Unit, class _ParentGroup, class _PrevMappingResult>
[[nodiscard]] _CCCL_DEVICE_API auto
map(const _ParentGroup&, const _PrevMappingResult& __prev_mapping_result) const noexcept
map(const _Unit&, const _ParentGroup&, const _PrevMappingResult& __prev_mapping_result) const noexcept
{
constexpr auto __static_prev_ngroups = _PrevMappingResult::static_group_count();
constexpr auto __static_prev_nunits = _PrevMappingResult::static_count();
Expand Down Expand Up @@ -163,8 +163,10 @@ public:
const auto __n = __i_count;
const auto __rank = __prev_unit_rank - __sum;
const auto __lane_mask =
::cuda::experimental::__make_lane_mask_for_n<_PrevMappingResult::is_always_contiguous()>(
__prev_mapping_result.lane_mask(), __n, __rank);
(::cuda::std::is_same_v<_Unit, thread_level>)
? ::cuda::experimental::__make_lane_mask_for_n<_PrevMappingResult::is_always_contiguous()>(
__prev_mapping_result.lane_mask(), __n, __rank)
: __prev_mapping_result.lane_mask();
return _MappingResult{__ngroups, __group_rank, __n, __rank, __lane_mask};
}
__sum += __i_count;
Expand Down Expand Up @@ -237,9 +239,9 @@ public:
return __counts_[__i];
}

template <class _ParentGroup, class _PrevMappingResult>
template <class _Unit, class _ParentGroup, class _PrevMappingResult>
[[nodiscard]] _CCCL_DEVICE_API auto
map(const _ParentGroup&, const _PrevMappingResult& __prev_mapping_result) const noexcept
map(const _Unit&, const _ParentGroup&, const _PrevMappingResult& __prev_mapping_result) const noexcept
{
constexpr auto __static_prev_ngroups = _PrevMappingResult::static_group_count();
constexpr auto __static_prev_nunits = _PrevMappingResult::static_count();
Expand Down Expand Up @@ -287,8 +289,10 @@ public:
const auto __n = __i_count;
const auto __rank = __prev_unit_rank - __sum;
const auto __lane_mask =
::cuda::experimental::__make_lane_mask_for_n<_PrevMappingResult::is_always_contiguous()>(
__prev_mapping_result.lane_mask(), __n, __rank);
(::cuda::std::is_same_v<_Unit, thread_level>)
? ::cuda::experimental::__make_lane_mask_for_n<_PrevMappingResult::is_always_contiguous()>(
__prev_mapping_result.lane_mask(), __n, __rank)
: __prev_mapping_result.lane_mask();
return _MappingResult{__ngroups, __group_rank, __n, __rank, __lane_mask};
}
__sum += __i_count;
Expand Down
22 changes: 14 additions & 8 deletions cudax/include/cuda/experimental/__group/mapping/group_by.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,9 @@ public:
return static_cast<unsigned>(_Count);
}

template <class _ParentGroup, class _PrevMappingResult>
template <class _Unit, class _ParentGroup, class _PrevMappingResult>
[[nodiscard]] _CCCL_DEVICE_API auto
map(const _ParentGroup& __parent, const _PrevMappingResult& __prev_mapping_result) const noexcept
map(const _Unit&, const _ParentGroup& __parent, const _PrevMappingResult& __prev_mapping_result) const noexcept
{
constexpr auto __static_prev_ngroups = _PrevMappingResult::static_group_count();
constexpr auto __static_prev_nunits = _PrevMappingResult::static_count();
Expand Down Expand Up @@ -134,8 +134,11 @@ public:
const auto __group_rank = __prev_mapping_result.group_rank() * __curr_ngroups + __curr_group_rank;
const auto __n = count();
const auto __rank = __prev_unit_rank % __n;
const auto __lane_mask = ::cuda::experimental::__make_lane_mask_for_n<_PrevMappingResult::is_always_contiguous()>(
__prev_mapping_result.lane_mask(), __n, __rank);
const auto __lane_mask =
(::cuda::std::is_same_v<_Unit, thread_level>)
? ::cuda::experimental::__make_lane_mask_for_n<_PrevMappingResult::is_always_contiguous()>(
__prev_mapping_result.lane_mask(), __n, __rank)
: __prev_mapping_result.lane_mask();
return _MappingResult{__ngroups, __group_rank, __n, __rank, __lane_mask};
}
};
Expand Down Expand Up @@ -175,9 +178,9 @@ public:
return __count_;
}

template <class _ParentGroup, class _PrevMappingResult>
template <class _Unit, class _ParentGroup, class _PrevMappingResult>
[[nodiscard]] _CCCL_DEVICE_API auto
map(const _ParentGroup& __parent, const _PrevMappingResult& __prev_mapping_result) const noexcept
map(const _Unit&, const _ParentGroup& __parent, const _PrevMappingResult& __prev_mapping_result) const noexcept
{
using _MappingResult =
__mapping_result<::cuda::std::dynamic_extent,
Expand Down Expand Up @@ -212,8 +215,11 @@ public:
const auto __group_rank = __prev_mapping_result.group_rank() * __curr_ngroups + __curr_group_rank;
const auto __n = __count_;
const auto __rank = __prev_unit_rank % __count_;
const auto __lane_mask = ::cuda::experimental::__make_lane_mask_for_n<_PrevMappingResult::is_always_contiguous()>(
__prev_mapping_result.lane_mask(), __n, __rank);
const auto __lane_mask =
(::cuda::std::is_same_v<_Unit, thread_level>)
? ::cuda::experimental::__make_lane_mask_for_n<_PrevMappingResult::is_always_contiguous()>(
__prev_mapping_result.lane_mask(), __n, __rank)
: __prev_mapping_result.lane_mask();
return _MappingResult{__ngroups, __group_rank, __n, __rank, __lane_mask};
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ class identity_mapping
public:
_CCCL_HIDE_FROM_ABI explicit identity_mapping() = default;

template <class _ParentGroup, class _PrevMappingResult>
template <class _Unit, class _ParentGroup, class _PrevMappingResult>
[[nodiscard]] _CCCL_DEVICE_API auto
map(const _ParentGroup&, const _PrevMappingResult& __prev_mapping_result) const noexcept
map(const _Unit&, const _ParentGroup&, const _PrevMappingResult& __prev_mapping_result) const noexcept
{
return __prev_mapping_result;
}
Expand Down
6 changes: 3 additions & 3 deletions cudax/include/cuda/experimental/__group/traits.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@
namespace cuda::experimental
{
template <class _Mapping, class _Unit, class _ParentGroup>
using __group_mapping_result_t =
decltype(::cuda::std::declval<_Mapping>().map(_Unit{}, ::cuda::std::declval<const _ParentGroup&>()));
using __group_mapping_result_t = decltype(::cuda::std::declval<_Mapping>().map(
::cuda::std::declval<_Unit>(), ::cuda::std::declval<const _ParentGroup&>()));

template <class _Synchronizer, class _Unit, class _ParentGroup, class _Mapping, class _MappingResult>
using __group_synchronizer_instance_t = decltype(::cuda::std::declval<_Synchronizer>().make_instance(
_Unit{},
::cuda::std::declval<const _Unit&>(),
::cuda::std::declval<const _ParentGroup&>(),
::cuda::std::declval<const _Mapping&>(),
::cuda::std::declval<const _MappingResult&>()));
Expand Down
24 changes: 12 additions & 12 deletions cudax/test/group/mapping/binary_partition.cu
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,12 @@ __device__ void test_binary_partition(Config config)
const cudax::this_warp parent_group{config};
const ThreadsInWarpMappingResult prev_mapping_result;

static_assert(
cudax::__group_mapping_result<decltype(cuda::std::declval<Mapping>().map(parent_group, prev_mapping_result))>);
static_assert(!noexcept(cuda::std::declval<Mapping>().map(parent_group, prev_mapping_result)));
static_assert(cudax::__group_mapping_result<decltype(cuda::std::declval<Mapping>().map(
cuda::gpu_thread, parent_group, prev_mapping_result))>);
static_assert(!noexcept(cuda::std::declval<Mapping>().map(cuda::gpu_thread, parent_group, prev_mapping_result)));

Mapping mapping{Pred{}};
auto result = mapping.map(parent_group, prev_mapping_result);
auto result = mapping.map(cuda::gpu_thread, parent_group, prev_mapping_result);
using Result = decltype(result);

static_assert(Result::static_group_count() == 2);
Expand Down Expand Up @@ -108,12 +108,12 @@ __device__ void test_binary_partition(Config config)
const cudax::this_warp parent_group{config};
const ThreadsInWarpMappingResult prev_mapping_result;

static_assert(
cudax::__group_mapping_result<decltype(cuda::std::declval<Mapping>().map(parent_group, prev_mapping_result))>);
static_assert(noexcept(cuda::std::declval<Mapping>().map(parent_group, prev_mapping_result)));
static_assert(cudax::__group_mapping_result<decltype(cuda::std::declval<Mapping>().map(
cuda::gpu_thread, parent_group, prev_mapping_result))>);
static_assert(noexcept(cuda::std::declval<Mapping>().map(cuda::gpu_thread, parent_group, prev_mapping_result)));

Mapping mapping{Pred{}};
auto result = mapping.map(parent_group, prev_mapping_result);
auto result = mapping.map(cuda::gpu_thread, parent_group, prev_mapping_result);
using Result = decltype(result);

static_assert(Result::static_group_count() == 2);
Expand Down Expand Up @@ -148,12 +148,12 @@ __device__ void test_binary_partition(Config config)
const cudax::this_warp parent_group{config};
const ThreadsInWarpMappingResult prev_mapping_result;

static_assert(
cudax::__group_mapping_result<decltype(cuda::std::declval<Mapping>().map(parent_group, prev_mapping_result))>);
static_assert(!noexcept(cuda::std::declval<Mapping>().map(parent_group, prev_mapping_result)));
static_assert(cudax::__group_mapping_result<decltype(cuda::std::declval<Mapping>().map(
cuda::gpu_thread, parent_group, prev_mapping_result))>);
static_assert(!noexcept(cuda::std::declval<Mapping>().map(cuda::gpu_thread, parent_group, prev_mapping_result)));

Mapping mapping{Pred{}};
auto result = mapping.map(parent_group, prev_mapping_result);
auto result = mapping.map(cuda::gpu_thread, parent_group, prev_mapping_result);
using Result = decltype(result);

static_assert(Result::static_group_count() == 2);
Expand Down
5 changes: 3 additions & 2 deletions cudax/test/group/mapping/composite_mapping.cu
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,10 @@ __device__ void test_composite_mapping(const Mapping1& mapping1, const Mapping2&
const ThreadsInWarpMappingResult prev_mapping_result;
const cudax::composite_mapping mapping{mapping1, mapping2};

static_assert(cudax::__group_mapping_result<decltype(mapping.map(parent_group, prev_mapping_result))>);
static_assert(
cudax::__group_mapping_result<decltype(mapping.map(cuda::gpu_thread, parent_group, prev_mapping_result))>);

auto result = mapping.map(parent_group, prev_mapping_result);
auto result = mapping.map(cuda::gpu_thread, parent_group, prev_mapping_result);
using Result = decltype(result);

const auto rank_in_warp = cuda::gpu_thread.rank_as<unsigned>(parent_group);
Expand Down
28 changes: 16 additions & 12 deletions cudax/test/group/mapping/group_as.cu
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,12 @@ __device__ void test_group_as(Config config)
const ThreadsInWarpMappingResult prev_mapping_result;

static_assert(cudax::__group_mapping_result<decltype(cuda::std::declval<const Mapping>().map(
parent_group, prev_mapping_result))>);
static_assert(noexcept(cuda::std::declval<const Mapping>().map(parent_group, prev_mapping_result)));
cuda::gpu_thread, parent_group, prev_mapping_result))>);
static_assert(
noexcept(cuda::std::declval<const Mapping>().map(cuda::gpu_thread, parent_group, prev_mapping_result)));

const Mapping mapping;
auto result = mapping.map(parent_group, prev_mapping_result);
auto result = mapping.map(cuda::gpu_thread, parent_group, prev_mapping_result);
using Result = decltype(result);

const auto rank_in_warp = cuda::gpu_thread.rank(parent_group);
Expand Down Expand Up @@ -206,11 +207,12 @@ __device__ void test_group_as(Config config)
const ThreadsInWarpMappingResult prev_mapping_result;

static_assert(cudax::__group_mapping_result<decltype(cuda::std::declval<const Mapping>().map(
parent_group, prev_mapping_result))>);
static_assert(noexcept(cuda::std::declval<const Mapping>().map(parent_group, prev_mapping_result)));
cuda::gpu_thread, parent_group, prev_mapping_result))>);
static_assert(
noexcept(cuda::std::declval<const Mapping>().map(cuda::gpu_thread, parent_group, prev_mapping_result)));

const Mapping mapping{ns};
auto result = mapping.map(parent_group, prev_mapping_result);
auto result = mapping.map(cuda::gpu_thread, parent_group, prev_mapping_result);
using Result = decltype(result);

const auto rank_in_warp = cuda::gpu_thread.rank_as<unsigned>(parent_group);
Expand Down Expand Up @@ -329,11 +331,12 @@ __device__ void test_group_as_non_exhaustive(Config config)
const ThreadsInWarpMappingResult prev_mapping_result;

static_assert(cudax::__group_mapping_result<decltype(cuda::std::declval<const Mapping>().map(
parent_group, prev_mapping_result))>);
static_assert(noexcept(cuda::std::declval<const Mapping>().map(parent_group, prev_mapping_result)));
cuda::gpu_thread, parent_group, prev_mapping_result))>);
static_assert(
noexcept(cuda::std::declval<const Mapping>().map(cuda::gpu_thread, parent_group, prev_mapping_result)));

const Mapping mapping;
auto result = mapping.map(parent_group, prev_mapping_result);
auto result = mapping.map(cuda::gpu_thread, parent_group, prev_mapping_result);
using Result = decltype(result);

const auto rank_in_warp = cuda::gpu_thread.rank(parent_group);
Expand Down Expand Up @@ -435,11 +438,12 @@ __device__ void test_group_as_non_exhaustive(Config config)
const ThreadsInWarpMappingResult prev_mapping_result;

static_assert(cudax::__group_mapping_result<decltype(cuda::std::declval<const Mapping>().map(
parent_group, prev_mapping_result))>);
static_assert(noexcept(cuda::std::declval<const Mapping>().map(parent_group, prev_mapping_result)));
cuda::gpu_thread, parent_group, prev_mapping_result))>);
static_assert(
noexcept(cuda::std::declval<const Mapping>().map(cuda::gpu_thread, parent_group, prev_mapping_result)));

const Mapping mapping{ns, cudax::non_exhaustive};
auto result = mapping.map(parent_group, prev_mapping_result);
auto result = mapping.map(cuda::gpu_thread, parent_group, prev_mapping_result);
using Result = decltype(result);

const auto rank_in_warp = cuda::gpu_thread.rank(parent_group);
Expand Down
Loading
Loading