Skip to content
28 changes: 14 additions & 14 deletions cub/benchmarks/bench/reduce/arg_extrema.cu
Original file line number Diff line number Diff line change
Expand Up @@ -45,37 +45,37 @@ void arg_reduce(nvbench::state& state, nvbench::type_list<T, OpT>)
// Iterator providing the values being reduced
using values_it_t = T*;

// Type used for the final result
using output_tuple_t = cub::KeyValuePair<global_offset_t, T>;

auto const init = ::cuda::std::is_same_v<OpT, cub::ArgMin>
auto const init = ::cuda::std::is_same_v<OpT, cub::detail::arg_min>
? ::cuda::std::numeric_limits<T>::max()
: ::cuda::std::numeric_limits<T>::lowest();

// Retrieve axis parameters
const auto elements = static_cast<std::size_t>(state.get_int64("Elements{io}"));
thrust::device_vector<T> in = generate(elements);
thrust::device_vector<output_tuple_t> out(1);
thrust::device_vector<global_offset_t> out_index(1);
thrust::device_vector<T> out_extremum(1);

values_it_t d_in = thrust::raw_pointer_cast(in.data());
output_tuple_t* d_out = thrust::raw_pointer_cast(out.data());
auto const num_items = static_cast<global_offset_t>(elements);
values_it_t d_in = thrust::raw_pointer_cast(in.data());
global_offset_t* d_out_index = thrust::raw_pointer_cast(out_index.data());
T* d_out_extremum = thrust::raw_pointer_cast(out_extremum.data());
auto const num_items = static_cast<global_offset_t>(elements);

// Enable throughput calculations and add "Size" column to results.
state.add_element_count(elements);
state.add_global_memory_reads<T>(elements, "Size");
state.add_global_memory_writes<output_tuple_t>(1);
state.add_global_memory_writes<global_offset_t>(1);
state.add_global_memory_writes<T>(1);

// Allocate temporary storage
std::size_t temp_size;
cub::detail::reduce::dispatch_streaming_arg_reduce<per_partition_offset_t>(
nullptr,
temp_size,
d_in,
d_out,
d_out_index,
d_out_extremum,
num_items,
OpT{},
init,
0 /* stream */
#if !TUNE_BASE
,
Expand All @@ -91,10 +91,10 @@ void arg_reduce(nvbench::state& state, nvbench::type_list<T, OpT>)
temp_storage,
temp_size,
d_in,
d_out,
d_out_index,
d_out_extremum,
num_items,
OpT{},
init,
launch.get_stream()
#if !TUNE_BASE
,
Expand All @@ -104,7 +104,7 @@ void arg_reduce(nvbench::state& state, nvbench::type_list<T, OpT>)
});
}

using op_types = nvbench::type_list<cub::ArgMin, cub::ArgMax>;
using op_types = nvbench::type_list<cub::detail::arg_min, cub::detail::arg_max>;

NVBENCH_BENCH_TYPES(arg_reduce, NVBENCH_TYPE_AXES(fundamental_types, op_types))
.set_name("base")
Expand Down
Loading
Loading