-
Notifications
You must be signed in to change notification settings - Fork 78
Add kernel based alltoallv and cuda backend for MoE dispatch and combine #5863
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
samnordmann
wants to merge
21
commits into
main
Choose a base branch
from
dispatch_combine/stub_for_kernel
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+696
−83
Open
Changes from all commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
cf77bdb
first working dispatch and combine primitive for k=1
samnordmann 66e7811
add comments and cleanup
samnordmann afd948d
review
samnordmann dda9aa7
add kernel based a2av and cuda backend for d/c
samnordmann ba6612d
minor comments
samnordmann 4693c53
minor review
samnordmann 8041c46
Merge branch 'main' of github.com:NVIDIA/Fuser into dispatch_combine/…
samnordmann f1ce74c
Merge branch 'main' of github.com:NVIDIA/Fuser into dispatch_combine/…
samnordmann a81a514
renaming
samnordmann a0de605
add back topk_weights
samnordmann 74d18d1
harden tests
samnordmann 6b994ba
assume continuous expert-to-rank mapping and simplify API and impleme…
samnordmann 47d710f
simplify by enforcing 2D shapes
samnordmann f39daf2
lint
samnordmann da52220
remove combined_topk_weights
samnordmann c089049
minor simplification
samnordmann 490200f
remove (in|out|send)_src_rank
samnordmann f148137
Merge branch 'main' of github.com:NVIDIA/Fuser into dispatch_combine/…
samnordmann 6f56706
Merge branch 'dispatch_combine/stub' into dispatch_combine/stub_for_k…
samnordmann ea5ad45
Merge branch 'main' of github.com:NVIDIA/Fuser into dispatch_combine/…
samnordmann 3828247
lint
samnordmann File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,36 @@ | ||
| // clang-format off | ||
| /* | ||
| * SPDX-FileCopyrightText: Copyright (c) 2026-present NVIDIA CORPORATION & AFFILIATES. | ||
| * All rights reserved. | ||
| * SPDX-License-Identifier: BSD-3-Clause | ||
| */ | ||
| // clang-format on | ||
|
|
||
| extern "C" __global__ void alltoallv_kernel( | ||
| const unsigned char* send, | ||
| const unsigned long long* recv_ptrs, | ||
| const long long* send_offsets, | ||
| const long long* send_sizes, | ||
| const long long* recv_offsets, | ||
| long long world_size, | ||
| long long elem_size, | ||
| long long max_send_bytes) { | ||
| const long long peer = static_cast<long long>(blockIdx.y); | ||
| if (peer >= world_size) { | ||
| return; | ||
| } | ||
| const long long bytes = send_sizes[peer] * elem_size; | ||
| if (bytes == 0) { | ||
| return; | ||
| } | ||
| const long long idx = | ||
| static_cast<long long>(blockIdx.x) * blockDim.x + threadIdx.x; | ||
| if (idx >= bytes) { | ||
| return; | ||
| } | ||
| const long long send_byte_offset = send_offsets[peer] * elem_size + idx; | ||
| const long long recv_byte_offset = recv_offsets[peer] * elem_size + idx; | ||
| auto* dst = reinterpret_cast<unsigned char*>( | ||
| static_cast<unsigned long long>(recv_ptrs[peer])); | ||
| dst[recv_byte_offset] = send[send_byte_offset]; | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -6,6 +6,7 @@ | |||||||||||||
| */ | ||||||||||||||
| // clang-format on | ||||||||||||||
| #include "multidevice/cuda_p2p.h" | ||||||||||||||
| #include "nvfuser_resources/alltoallv.h" | ||||||||||||||
| #include "nvfuser_resources/multicast.h" | ||||||||||||||
|
|
||||||||||||||
| #include "cuda_utils.h" | ||||||||||||||
|
|
@@ -34,6 +35,143 @@ P2pProtocol getP2pProtocol() { | |||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| namespace { | ||||||||||||||
| void launchAlltoallvKernel( | ||||||||||||||
| const void* send, | ||||||||||||||
| const uint64_t* recv_ptrs, | ||||||||||||||
| const int64_t* send_offsets, | ||||||||||||||
| const int64_t* send_sizes, | ||||||||||||||
| const int64_t* recv_offsets, | ||||||||||||||
| int64_t world_size, | ||||||||||||||
| int64_t elem_size, | ||||||||||||||
| int64_t max_send_bytes, | ||||||||||||||
| CUstream stream) { | ||||||||||||||
| static CUmodule module = nullptr; | ||||||||||||||
| static CUfunction kernel = nullptr; | ||||||||||||||
|
|
||||||||||||||
| if (module == nullptr) { | ||||||||||||||
| nvrtcProgram prog; | ||||||||||||||
| NVFUSER_NVRTC_SAFE_CALL(nvrtcCreateProgram( | ||||||||||||||
| &prog, | ||||||||||||||
| nvfuser_resources::alltoallv_cu, | ||||||||||||||
| "alltoallv.cu", | ||||||||||||||
| 0, | ||||||||||||||
| nullptr, | ||||||||||||||
| nullptr)); | ||||||||||||||
|
|
||||||||||||||
| int major = 0; | ||||||||||||||
| int minor = 0; | ||||||||||||||
| int device = 0; | ||||||||||||||
| NVFUSER_CUDA_RT_SAFE_CALL(cudaGetDevice(&device)); | ||||||||||||||
| cudaDeviceProp prop; | ||||||||||||||
| NVFUSER_CUDA_RT_SAFE_CALL(cudaGetDeviceProperties(&prop, device)); | ||||||||||||||
| major = prop.major; | ||||||||||||||
| minor = prop.minor; | ||||||||||||||
|
|
||||||||||||||
| std::string arch_arg = "--gpu-architecture=compute_" + | ||||||||||||||
| std::to_string(major) + std::to_string(minor); | ||||||||||||||
| std::vector<const char*> opts = {arch_arg.c_str(), "--std=c++17"}; | ||||||||||||||
| // NVRTC needs CUDA headers to compile alltoallv.cu. | ||||||||||||||
| opts.push_back("-I/usr/local/cuda/include"); | ||||||||||||||
| opts.push_back("-I/usr/local/cuda/include/cccl"); | ||||||||||||||
|
Comment on lines
+74
to
+75
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hardcoded CUDA include paths may break on non-standard installations
Suggested change
|
||||||||||||||
|
|
||||||||||||||
| nvrtcResult res = nvrtcCompileProgram(prog, (int)opts.size(), opts.data()); | ||||||||||||||
| if (res != NVRTC_SUCCESS) { | ||||||||||||||
| size_t logSize; | ||||||||||||||
| NVFUSER_NVRTC_SAFE_CALL(nvrtcGetProgramLogSize(prog, &logSize)); | ||||||||||||||
| std::vector<char> log(logSize); | ||||||||||||||
| NVFUSER_NVRTC_SAFE_CALL(nvrtcGetProgramLog(prog, log.data())); | ||||||||||||||
| NVF_ERROR(false, "Alltoallv kernel compilation failed:\n", log.data()); | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| size_t ptxSize; | ||||||||||||||
| NVFUSER_NVRTC_SAFE_CALL(nvrtcGetPTXSize(prog, &ptxSize)); | ||||||||||||||
| std::vector<char> ptx(ptxSize); | ||||||||||||||
| NVFUSER_NVRTC_SAFE_CALL(nvrtcGetPTX(prog, ptx.data())); | ||||||||||||||
| NVFUSER_NVRTC_SAFE_CALL(nvrtcDestroyProgram(&prog)); | ||||||||||||||
|
|
||||||||||||||
| CUresult load_result = cuModuleLoadData(&module, ptx.data()); | ||||||||||||||
| if (load_result != CUDA_SUCCESS) { | ||||||||||||||
| constexpr size_t kLogSize = 8192; | ||||||||||||||
| char error_log[kLogSize]; | ||||||||||||||
| char info_log[kLogSize]; | ||||||||||||||
| CUjit_option options[] = { | ||||||||||||||
| CU_JIT_ERROR_LOG_BUFFER, | ||||||||||||||
| CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, | ||||||||||||||
| CU_JIT_INFO_LOG_BUFFER, | ||||||||||||||
| CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES, | ||||||||||||||
| CU_JIT_LOG_VERBOSE}; | ||||||||||||||
| void* option_values[] = { | ||||||||||||||
| (void*)error_log, | ||||||||||||||
| (void*)kLogSize, | ||||||||||||||
| (void*)info_log, | ||||||||||||||
| (void*)kLogSize, | ||||||||||||||
| (void*)1}; | ||||||||||||||
| cuModuleLoadDataEx(&module, ptx.data(), 5, options, option_values); | ||||||||||||||
| NVF_ERROR( | ||||||||||||||
| false, | ||||||||||||||
| "Alltoallv kernel module load failed with error: ", | ||||||||||||||
| load_result, | ||||||||||||||
| "\nInfo Log:\n", | ||||||||||||||
| info_log, | ||||||||||||||
| "\nError Log:\n", | ||||||||||||||
| error_log); | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| NVFUSER_CUDA_SAFE_CALL( | ||||||||||||||
| cuModuleGetFunction(&kernel, module, "alltoallv_kernel")); | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| if (max_send_bytes == 0) { | ||||||||||||||
| return; | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| constexpr int kThreads = 256; | ||||||||||||||
| const int64_t blocks_x = (max_send_bytes + kThreads - 1) / kThreads; | ||||||||||||||
| void* args_kernel[] = { | ||||||||||||||
| const_cast<void*>(static_cast<const void*>(&send)), | ||||||||||||||
| const_cast<void*>(static_cast<const void*>(&recv_ptrs)), | ||||||||||||||
| const_cast<void*>(static_cast<const void*>(&send_offsets)), | ||||||||||||||
| const_cast<void*>(static_cast<const void*>(&send_sizes)), | ||||||||||||||
| const_cast<void*>(static_cast<const void*>(&recv_offsets)), | ||||||||||||||
| &world_size, | ||||||||||||||
| &elem_size, | ||||||||||||||
| &max_send_bytes}; | ||||||||||||||
| NVFUSER_CUDA_SAFE_CALL(cuLaunchKernel( | ||||||||||||||
| kernel, | ||||||||||||||
| blocks_x, | ||||||||||||||
| static_cast<unsigned int>(world_size), | ||||||||||||||
| 1, | ||||||||||||||
| kThreads, | ||||||||||||||
| 1, | ||||||||||||||
| 1, | ||||||||||||||
| 0, | ||||||||||||||
| stream, | ||||||||||||||
| args_kernel, | ||||||||||||||
| nullptr)); | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| std::vector<uint8_t> serializeInt64Vector(const std::vector<int64_t>& values) { | ||||||||||||||
| std::vector<uint8_t> bytes(values.size() * sizeof(int64_t)); | ||||||||||||||
| std::memcpy(bytes.data(), values.data(), bytes.size()); | ||||||||||||||
| return bytes; | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| std::vector<int64_t> deserializeInt64Vector(const std::vector<uint8_t>& bytes) { | ||||||||||||||
| NVF_CHECK( | ||||||||||||||
| bytes.size() % sizeof(int64_t) == 0, "Invalid int64 byte buffer size."); | ||||||||||||||
| const size_t count = bytes.size() / sizeof(int64_t); | ||||||||||||||
| std::vector<int64_t> values(count); | ||||||||||||||
| std::memcpy(values.data(), bytes.data(), bytes.size()); | ||||||||||||||
| return values; | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| std::string alltoallvCountsKey(const std::string& tag, int64_t rank) { | ||||||||||||||
| return "nvfuser_alltoallv_counts_" + tag + "_" + std::to_string(rank); | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| std::string alltoallvBarrierKey(const std::string& tag, int64_t rank) { | ||||||||||||||
| return "nvfuser_alltoallv_barrier_" + tag + "_" + std::to_string(rank); | ||||||||||||||
| } | ||||||||||||||
|
Comment on lines
+172
to
+174
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. unused function - |
||||||||||||||
|
|
||||||||||||||
| void launchMulticastKernel( | ||||||||||||||
| void* dst, | ||||||||||||||
|
|
@@ -710,4 +848,181 @@ void waitWithCudaBackend( | |||||||||||||
| } | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| AlltoallvMetadata prepareAlltoallvMetadata( | ||||||||||||||
| const at::Tensor& send_counts, | ||||||||||||||
| const std::string& tag) { | ||||||||||||||
| Communicator& comm = Communicator::getInstance(); | ||||||||||||||
| const int64_t world_size = comm.size(); | ||||||||||||||
| const int64_t my_rank = comm.deviceId(); | ||||||||||||||
| NVF_CHECK( | ||||||||||||||
| send_counts.is_cuda(), "alltoallv send_counts must be CUDA tensor."); | ||||||||||||||
| NVF_CHECK( | ||||||||||||||
| send_counts.dim() == 1 && send_counts.numel() == world_size, | ||||||||||||||
| "alltoallv send_counts must be 1D [R]."); | ||||||||||||||
|
|
||||||||||||||
| auto store = comm.getTcpStore(); | ||||||||||||||
| auto send_counts_cpu = send_counts.to(at::kCPU); | ||||||||||||||
| auto* send_ptr = send_counts_cpu.data_ptr<int64_t>(); | ||||||||||||||
| std::vector<int64_t> send_counts_vec(send_ptr, send_ptr + world_size); | ||||||||||||||
|
|
||||||||||||||
| store->set( | ||||||||||||||
| alltoallvCountsKey(tag, my_rank), serializeInt64Vector(send_counts_vec)); | ||||||||||||||
|
|
||||||||||||||
| std::vector<std::vector<int64_t>> counts_matrix(world_size); | ||||||||||||||
| for (int64_t rank = 0; rank < world_size; ++rank) { | ||||||||||||||
| auto bytes = store->get(alltoallvCountsKey(tag, rank)); | ||||||||||||||
| counts_matrix[rank] = deserializeInt64Vector(bytes); | ||||||||||||||
| NVF_CHECK( | ||||||||||||||
| (int64_t)counts_matrix[rank].size() == world_size, | ||||||||||||||
| "Invalid alltoallv counts size."); | ||||||||||||||
| } | ||||||||||||||
| comm.barrier(); | ||||||||||||||
| for (int64_t rank = 0; rank < world_size; ++rank) { | ||||||||||||||
| store->deleteKey(alltoallvCountsKey(tag, rank)); | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| std::vector<int64_t> recv_counts_vec(world_size, 0); | ||||||||||||||
| for (int64_t sender = 0; sender < world_size; ++sender) { | ||||||||||||||
| recv_counts_vec[sender] = counts_matrix[sender][my_rank]; | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| std::vector<int64_t> send_offsets_vec(world_size, 0); | ||||||||||||||
| int64_t prefix = 0; | ||||||||||||||
| for (int64_t rank = 0; rank < world_size; ++rank) { | ||||||||||||||
| send_offsets_vec[rank] = prefix; | ||||||||||||||
| prefix += send_counts_vec[rank]; | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| std::vector<int64_t> recv_offsets_vec(world_size, 0); | ||||||||||||||
| for (int64_t peer = 0; peer < world_size; ++peer) { | ||||||||||||||
| int64_t offset = 0; | ||||||||||||||
| for (int64_t sender = 0; sender < my_rank; ++sender) { | ||||||||||||||
| offset += counts_matrix[sender][peer]; | ||||||||||||||
| } | ||||||||||||||
| recv_offsets_vec[peer] = offset; | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| int64_t total_recv = 0; | ||||||||||||||
| for (auto value : recv_counts_vec) { | ||||||||||||||
| total_recv += value; | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| int64_t max_recv = 0; | ||||||||||||||
| int64_t max_send_total = 0; | ||||||||||||||
| for (int64_t rank = 0; rank < world_size; ++rank) { | ||||||||||||||
| int64_t total = 0; | ||||||||||||||
| for (int64_t sender = 0; sender < world_size; ++sender) { | ||||||||||||||
| total += counts_matrix[sender][rank]; | ||||||||||||||
| } | ||||||||||||||
| if (total > max_recv) { | ||||||||||||||
| max_recv = total; | ||||||||||||||
| } | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| for (int64_t rank = 0; rank < world_size; ++rank) { | ||||||||||||||
| int64_t total = 0; | ||||||||||||||
| for (int64_t dest = 0; dest < world_size; ++dest) { | ||||||||||||||
| total += counts_matrix[rank][dest]; | ||||||||||||||
| } | ||||||||||||||
| if (total > max_send_total) { | ||||||||||||||
| max_send_total = total; | ||||||||||||||
| } | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| int64_t max_send = 0; | ||||||||||||||
| for (auto value : send_counts_vec) { | ||||||||||||||
| if (value > max_send) { | ||||||||||||||
| max_send = value; | ||||||||||||||
| } | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| auto cpu_options = at::TensorOptions().dtype(at::kLong).device(at::kCPU); | ||||||||||||||
| auto send_offsets_cpu = at::empty({world_size}, cpu_options); | ||||||||||||||
| std::memcpy( | ||||||||||||||
| send_offsets_cpu.data_ptr<int64_t>(), | ||||||||||||||
| send_offsets_vec.data(), | ||||||||||||||
| world_size * sizeof(int64_t)); | ||||||||||||||
| auto recv_offsets_cpu = at::empty({world_size}, cpu_options); | ||||||||||||||
| std::memcpy( | ||||||||||||||
| recv_offsets_cpu.data_ptr<int64_t>(), | ||||||||||||||
| recv_offsets_vec.data(), | ||||||||||||||
| world_size * sizeof(int64_t)); | ||||||||||||||
| auto recv_counts_cpu = at::empty({world_size}, cpu_options); | ||||||||||||||
| std::memcpy( | ||||||||||||||
| recv_counts_cpu.data_ptr<int64_t>(), | ||||||||||||||
| recv_counts_vec.data(), | ||||||||||||||
| world_size * sizeof(int64_t)); | ||||||||||||||
|
|
||||||||||||||
| AlltoallvMetadata metadata; | ||||||||||||||
| metadata.send_counts = send_counts; | ||||||||||||||
| metadata.recv_counts = recv_counts_cpu.to(send_counts.device()); | ||||||||||||||
| metadata.send_offsets = send_offsets_cpu.to(send_counts.device()); | ||||||||||||||
| metadata.recv_offsets = recv_offsets_cpu.to(send_counts.device()); | ||||||||||||||
| metadata.total_recv = total_recv; | ||||||||||||||
| metadata.max_recv = max_recv; | ||||||||||||||
| metadata.max_send_total = max_send_total; | ||||||||||||||
| metadata.max_send_bytes = max_send; | ||||||||||||||
| metadata.world_size = world_size; | ||||||||||||||
| return metadata; | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| void alltoallvWithCudaBackend( | ||||||||||||||
| const at::Tensor& send, | ||||||||||||||
| const at::Tensor& recv, | ||||||||||||||
| const AlltoallvMetadata& metadata, | ||||||||||||||
| const std::vector<void*>& recv_ptrs, | ||||||||||||||
| CUstream stream) { | ||||||||||||||
| NVF_CHECK(send.is_cuda(), "alltoallv send must be CUDA."); | ||||||||||||||
| NVF_CHECK(recv.is_cuda(), "alltoallv recv must be CUDA."); | ||||||||||||||
| NVF_CHECK( | ||||||||||||||
| (int64_t)recv_ptrs.size() == metadata.world_size, | ||||||||||||||
| "recv_ptrs size must match world size."); | ||||||||||||||
|
|
||||||||||||||
| auto cpu_options = at::TensorOptions().dtype(at::kLong).device(at::kCPU); | ||||||||||||||
| auto recv_ptrs_cpu = at::empty({metadata.world_size}, cpu_options); | ||||||||||||||
| auto* ptrs = recv_ptrs_cpu.data_ptr<int64_t>(); | ||||||||||||||
| for (int64_t rank = 0; rank < metadata.world_size; ++rank) { | ||||||||||||||
| ptrs[rank] = | ||||||||||||||
| static_cast<int64_t>(reinterpret_cast<uintptr_t>(recv_ptrs[rank])); | ||||||||||||||
| } | ||||||||||||||
| auto recv_ptrs_cuda = recv_ptrs_cpu.to(send.device()); | ||||||||||||||
|
|
||||||||||||||
| const int64_t elem_stride = | ||||||||||||||
| metadata.max_send_total > 0 ? send.numel() / metadata.max_send_total : 1; | ||||||||||||||
| NVF_CHECK( | ||||||||||||||
| metadata.max_send_total == 0 || | ||||||||||||||
| send.numel() % metadata.max_send_total == 0, | ||||||||||||||
| "alltoallv send numel must be divisible by max_send_total."); | ||||||||||||||
| NVF_CHECK( | ||||||||||||||
| metadata.max_recv == 0 || recv.numel() % metadata.max_recv == 0, | ||||||||||||||
| "alltoallv recv numel must be divisible by max_recv."); | ||||||||||||||
|
|
||||||||||||||
| auto send_offsets = metadata.send_offsets; | ||||||||||||||
| auto send_counts = metadata.send_counts; | ||||||||||||||
| auto recv_offsets = metadata.recv_offsets; | ||||||||||||||
| int64_t max_send_bytes = metadata.max_send_bytes; | ||||||||||||||
| if (elem_stride > 1) { | ||||||||||||||
| send_offsets = metadata.send_offsets * elem_stride; | ||||||||||||||
| send_counts = metadata.send_counts * elem_stride; | ||||||||||||||
| recv_offsets = metadata.recv_offsets * elem_stride; | ||||||||||||||
| max_send_bytes = metadata.max_send_bytes * elem_stride; | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| launchAlltoallvKernel( | ||||||||||||||
| send.data_ptr(), | ||||||||||||||
| reinterpret_cast<const uint64_t*>(recv_ptrs_cuda.data_ptr<int64_t>()), | ||||||||||||||
| send_offsets.data_ptr<int64_t>(), | ||||||||||||||
| send_counts.data_ptr<int64_t>(), | ||||||||||||||
| recv_offsets.data_ptr<int64_t>(), | ||||||||||||||
| metadata.world_size, | ||||||||||||||
| send.element_size(), | ||||||||||||||
| max_send_bytes * send.element_size(), | ||||||||||||||
| stream); | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| void alltoallvBarrier(const std::string& tag) { | ||||||||||||||
| Communicator& comm = Communicator::getInstance(); | ||||||||||||||
| comm.barrier(); | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| } // namespace nvfuser | ||||||||||||||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why nvrtc? Can't we simply
alltoallv_kernel<<<...>>>?