Skip to content

Support a custom comparison operator in DeviceReduce::ArgMin|Max#8285

Merged
bernhardmgruber merged 12 commits intoNVIDIA:mainfrom
bernhardmgruber:ref_argmin
Apr 8, 2026
Merged

Support a custom comparison operator in DeviceReduce::ArgMin|Max#8285
bernhardmgruber merged 12 commits intoNVIDIA:mainfrom
bernhardmgruber:ref_argmin

Conversation

@bernhardmgruber
Copy link
Copy Markdown
Contributor

Fixes: #6123

@bernhardmgruber bernhardmgruber requested review from a team as code owners April 2, 2026 23:09
@github-project-automation github-project-automation bot moved this to Todo in CCCL Apr 2, 2026
@cccl-authenticator-app cccl-authenticator-app bot moved this from Todo to In Review in CCCL Apr 2, 2026
@bernhardmgruber bernhardmgruber changed the title Support a custom comparison predicate in DeviceReduce::ArgMin Support a custom comparison operator in DeviceReduce::ArgMin Apr 2, 2026
Comment on lines +1041 to +1043
// TODO(bgruber): this constraint is not accurate, since the implementation will compare the value types of
// ExtremumOutIteratorT, which is wrong IMO
::cuda::std::enable_if_t<::cuda::std::indirectly_comparable<InputIteratorT, InputIteratorT, CompareOpT>, int> = 0>
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of InputIteratorT we should use non_void_value_t<ExtremumOutIteratorT, it_value_t<InputIteratorT>>, but that just "feels" wrong here. But this is what the implementation does. What do the reviewers think?

I think the implementation should actually be changed to compare the input values, not the converted ones.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not follow why that constraint is wrong? We want to ensure that the input sequence is comparable with the passed operator. Why should we compare the ExtremumOutputIteratorT

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reduction implementation does not call compare_op(d_in[i], d_in[j]), it calls something like:

  using input_value_t     = it_value_t<InputIteratorT>;
  using accum_t = non_void_value_t<ExtremumOutIteratorT, input_value_t>;
  accum_t a = d_in[i];
  accum_t b = d_in[j];
  compare_op(a, b);

So it performs a conversion of the input value to the output iterator's value_type before comparing. That can be a totally different type.

I think this is a bug itself, but outside the scope of this PR.

@github-actions

This comment has been minimized.

Copy link
Copy Markdown
Contributor

@NaderAlAwar NaderAlAwar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: the issue being closed mentions ArgMax as well in the title, but this PR only appears to add public custom-comparator overloads and
test coverage for ArgMin. The internal refactor is more general, but DeviceReduce::ArgMax still seems to expose only the old no-comparator API. I would either create a separate issue for ArgMax or expose the custom comparator overload as well.

@bernhardmgruber
Copy link
Copy Markdown
Contributor Author

Suggestion: the issue being closed mentions ArgMax as well in the title, but this PR only appears to add public custom-comparator overloads and test coverage for ArgMin.

So I temporarily added a new overload for ArgMax as well, but then I noticed that the implementation is actually identical to ArgMin(..., std::not_fn(compare_op), ...) and then I wondered whether a simple negation of the predicate deserves another public API overload.

Also, while ArgMin without a comparison operator defaults to std::less, what should ArgMax default to? std::max_element defaults to std::less to find the maximum, but maybe that's irritating to some users on a first glance. I thought not adding ArgMax would just avoid the confusion.

Finally, I considered naming the new overload not ArgMin but something like ArgReduce, but that doesn't make any sense either, since the user does not specify the reduction, but the comparison predicate. Maybe we should just call the overload ArgExtremum to distinguish it from ArgMin. It generalizes both ArgMin and ArgMax. @NaderAlAwar let me know what you think!

I would either create a separate issue for ArgMax or expose the custom comparator overload as well.

As my last paragraph points out, the new overload actually generalizes over both, ArgMin and ArgMax, so no more work should be necessary. An ArgMax with a custom comparison is essentially calling ArgMin with that operator.

@github-actions

This comment has been minimized.

@NaderAlAwar
Copy link
Copy Markdown
Contributor

@bernhardmgruber those are good points, I hadn't considered that. Looking into this some more, since the standard library and Thrust already expose comparator overloads for both min_element and max_element, I think matching that symmetry here would be less surprising to users than only exposing a comparator overload on ArgMin. Since the public CUB API already has both ArgMin and ArgMax, I would expect custom-comparator
support to be available on both as well.

My worry about ArgExtremum is that it's name may be less familiar to users which could lead to them avoiding using it.

std::max_element defaults to std::less to find the maximum, but maybe that's irritating to some users on a first glance.

I do agree that this is a little confusing. I don't feel too strongly about this either way since I have not used this extensively but my intuition would be to stay consistent with existing standards unless we believe they are broken in some way.

@github-actions

This comment has been minimized.

Comment on lines +232 to +247
// Initial value for empty problems, according to documented contract
const auto empty_problem_extremum = static_cast<output_extremum_t>([] {
if constexpr (::cuda::std::is_same_v<ReductionOpT, arg_min>)
{
return ::cuda::std::numeric_limits<input_value_t>::max();
}
else if constexpr (::cuda::std::is_same_v<ReductionOpT, arg_max>)
{
return ::cuda::std::numeric_limits<input_value_t>::lowest();
}
else
{
return input_value_t{};
}
}());
auto initial_value = empty_problem_init_t<per_partition_accum_t>{{PerPartitionOffsetT{1}, empty_problem_extremum}};
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am really unhappy that we actually need an initial value

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's only needed for the case where the user passes num_items == 0.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would love for us to change the implementation so that in the legacy API without a comparison operator we do the return value thing and for the new API we only return indices, which in that case can just be 0

Comment on lines +1041 to +1043
// TODO(bgruber): this constraint is not accurate, since the implementation will compare the value types of
// ExtremumOutIteratorT, which is wrong IMO
::cuda::std::enable_if_t<::cuda::std::indirectly_comparable<InputIteratorT, InputIteratorT, CompareOpT>, int> = 0>
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not follow why that constraint is wrong? We want to ensure that the input sequence is comparable with the passed operator. Why should we compare the ExtremumOutputIteratorT

cudaStream_t stream = 0)
{
return ArgMax(
d_temp_storage, temp_storage_bytes, d_in, d_max_out, d_index_out, num_items, ::cuda::std::less{}, stream);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto use typed less

Comment on lines +109 to 124
// Less-than comparator for an index/value pair that compares values first, and indices when the values are equal
template <typename ValueLessThen = ::cuda::std::less<>>
struct arg_less : ValueLessThen
{
/// Boolean max operator, preferring the item having the smaller offset in
/// case of ties
template <typename T, typename OffsetT>
_CCCL_HOST_DEVICE _CCCL_FORCEINLINE ::cuda::std::pair<OffsetT, T>
operator()(const ::cuda::std::pair<OffsetT, T>& a, const ::cuda::std::pair<OffsetT, T>& b) const
{
if ((b.second > a.second) || ((a.second == b.second) && (b.first < a.first)))
const auto& less = static_cast<const ValueLessThen&>(*this);
if (less(b.second, a.second) || (!less(a.second, b.second) && b.first < a.first))
{
return b;
}

return a;
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Important: Inheritance is almost always worse than making it a member.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But with a member, we are not getting EBCO. But maybe that's not so important here.

Copy link
Copy Markdown
Contributor Author

@bernhardmgruber bernhardmgruber Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I think if I move to a data member, aggregate init would no longer work with the deduction guide in C++17. This can be worked around of course. Do you insist on this change, or can I save myself 43s of typing?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can keep it as is


//! @brief Binary functor swapping the arguments to ``operator()`` before forwarding to an inner functor
template <typename Predicate>
struct swap_args : Predicate
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: Why arent we just using not_fn

Copy link
Copy Markdown
Contributor Author

@bernhardmgruber bernhardmgruber Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because non_fun(less{}) is not the same as greater{}, it's greater_equal{}. It should actually not matter, since we are returning the first element that matches the predicate. But I felt swapping arguments is more true.

@bernhardmgruber bernhardmgruber changed the title Support a custom comparison operator in DeviceReduce::ArgMin Support a custom comparison operator in DeviceReduce::ArgMin|Max Apr 7, 2026
@bernhardmgruber bernhardmgruber enabled auto-merge (squash) April 7, 2026 11:36
@bernhardmgruber
Copy link
Copy Markdown
Contributor Author

          Start  65: cub.test.device.reduce.lid_0.types_0
   59/177 Test  #65: cub.test.device.reduce.lid_0.types_0 ...........................   Passed  3285.96 sec

Seems a bit excessive: https://github.com/NVIDIA/cccl/actions/runs/24074951429/job/70228571613?pr=8285

@github-actions

This comment has been minimized.

@bernhardmgruber
Copy link
Copy Markdown
Contributor Author

@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Apr 8, 2026

🥳 CI Workflow Results

🟩 Finished in 1h 42m: Pass: 100%/269 | Total: 9d 10h | Max: 1h 37m | Hits: 62%/176731

See results here.

@bernhardmgruber bernhardmgruber merged commit a003464 into NVIDIA:main Apr 8, 2026
287 of 289 checks passed
@github-project-automation github-project-automation bot moved this from In Review to Done in CCCL Apr 8, 2026
@bernhardmgruber bernhardmgruber deleted the ref_argmin branch April 8, 2026 13:22
gonidelis pushed a commit to gonidelis/cccl that referenced this pull request Apr 8, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Archived in project

Development

Successfully merging this pull request may close these issues.

Support a custom comparison predicate in cub::DeviceReduce::ArgMin|Max.

3 participants