Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .ci/docker/ci_commit_pins/optimum-executorch.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
d03e90c2cd9048e6d9a75285c0355f033cd016fc
de4f3c4978b4d36cc0bb8f87c6877a4a040d7ae7
2 changes: 2 additions & 0 deletions backends/aoti/aoti_delegate_handle.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include <executorch/runtime/core/error.h>
#include <executorch/runtime/core/evalue.h>
#include <string>

namespace executorch {
namespace backends {
Expand Down Expand Up @@ -85,6 +86,7 @@ struct AOTIDelegateHandle {
AOTInductorModelContainerHandle container_handle;
void* cuda_stream; // cudaStream_t stored as void* to avoid CUDA header
// dependency
std::string method_name;

// Function pointers specific to this handle's shared library
AOTInductorModelContainerCreateWithDeviceFunc create_with_device;
Expand Down
106 changes: 94 additions & 12 deletions backends/cuda/runtime/cuda_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
#include <cstdio>

#include <array>
#include <filesystem>
#include <fstream>
#include <mutex>
#include <string>
#include <vector>

Expand All @@ -35,20 +37,55 @@ using executorch::runtime::ArrayRef;
using executorch::runtime::Backend;
using executorch::runtime::BackendExecutionContext;
using executorch::runtime::BackendInitContext;
using executorch::runtime::BackendOption;
using executorch::runtime::BackendOptionContext;
using executorch::runtime::CompileSpec;
using executorch::runtime::DelegateHandle;
using executorch::runtime::Error;
using executorch::runtime::EValue;
using executorch::runtime::FreeableBuffer;
using executorch::runtime::kMaxOptionValueLength;
using executorch::runtime::MemoryAllocator;
using executorch::runtime::NamedDataMap;
using executorch::runtime::Result;
using executorch::runtime::Span;
using executorch::runtime::etensor::Tensor;

namespace {
constexpr char kSkipCopyOutputToCpuForMethod[] =
"skip_copy_output_to_cpu_for_method";
}

class ET_EXPERIMENTAL CudaBackend final
: public ::executorch::runtime::BackendInterface {
private:

void set_skip_copy_method(
const std::array<char, kMaxOptionValueLength>& raw) {
std::lock_guard<std::mutex> guard(skip_copy_method_mutex_);
skip_copy_method_ = std::string(raw.data());
}

std::array<char, kMaxOptionValueLength> get_skip_copy_method_as_option()
const {
std::array<char, kMaxOptionValueLength> out{};
std::string value;
{
std::lock_guard<std::mutex> guard(skip_copy_method_mutex_);
value = skip_copy_method_;
}
std::snprintf(out.data(), out.size(), "%s", value.c_str());
return out;
}

bool should_skip_copy_for_method(const std::string& method_name) const {
if (method_name.empty()) {
return false;
}
std::lock_guard<std::mutex> guard(skip_copy_method_mutex_);
return method_name == skip_copy_method_;
}

Error load_function_pointers_into_handle(
void* so_handle,
AOTIDelegateHandle* handle) const {
Expand Down Expand Up @@ -91,6 +128,38 @@ class ET_EXPERIMENTAL CudaBackend final
return 1;
}

Error set_option(
ET_UNUSED BackendOptionContext& context,
const executorch::runtime::Span<BackendOption>& backend_options)
override {
for (const auto& option : backend_options) {
if (std::strcmp(option.key, kSkipCopyOutputToCpuForMethod) == 0) {
if (auto* val = std::get_if<std::array<char, kMaxOptionValueLength>>(
&option.value)) {
set_skip_copy_method(*val);
} else {
ET_LOG(
Error,
"Option %s must be a method name string.",
kSkipCopyOutputToCpuForMethod);
return Error::InvalidArgument;
}
}
}
return Error::Ok;
}

Error get_option(
ET_UNUSED BackendOptionContext& context,
executorch::runtime::Span<BackendOption>& backend_options) override {
for (auto& option : backend_options) {
if (std::strcmp(option.key, kSkipCopyOutputToCpuForMethod) == 0) {
option.value = get_skip_copy_method_as_option();
}
}
return Error::Ok;
}

// Once per loaded binary blob
Result<DelegateHandle*> init(
BackendInitContext& context,
Expand Down Expand Up @@ -159,6 +228,7 @@ class ET_EXPERIMENTAL CudaBackend final
AOTIDelegateHandle* handle = new AOTIDelegateHandle();
handle->so_handle = lib_handle;
handle->so_path = so_path.string();
handle->method_name = method_name;

// Load function pointers specific to this handle's shared library
ET_CHECK_OK_OR_RETURN_ERROR(
Expand Down Expand Up @@ -303,18 +373,26 @@ class ET_EXPERIMENTAL CudaBackend final
"AOTInductorModelContainerRun failed with error code %d",
error);

// Copy GPU output results back to CPU output tensors
for (int i = 0; i < n_outputs; i++) {
auto cpu_output_tensor = &(args[i + n_inputs]->toTensor());
// For DYNAMIC_BOUND tensors we try to resize
ET_CHECK_OK_OR_RETURN_ERROR(
resize_tensor(*cpu_output_tensor, gpu_outputs[i]->sizes()),
"Error resizing tensor at output index %d",
i);
ET_CHECK_OK_OR_RETURN_ERROR(
aoti_torch_copy_(cpu_output_tensor, gpu_outputs[i], 0),
"Failed to copy GPU output %d back to CPU",
i);
const bool copy_outputs = !should_skip_copy_for_method(handle->method_name);

if (copy_outputs) {
// Copy GPU output results back to CPU output tensors
for (int i = 0; i < n_outputs; i++) {
auto cpu_output_tensor = &(args[i + n_inputs]->toTensor());
// For DYNAMIC_BOUND tensors we try to resize
ET_CHECK_OK_OR_RETURN_ERROR(
resize_tensor(*cpu_output_tensor, gpu_outputs[i]->sizes()),
"Error resizing tensor at output index %d",
i);
ET_CHECK_OK_OR_RETURN_ERROR(
aoti_torch_copy_(cpu_output_tensor, gpu_outputs[i], 0),
"Failed to copy GPU output %d back to CPU",
i);
}
} else {
for (int i = 0; i < n_outputs; i++) {
args[i + n_inputs]->toTensor() = *gpu_outputs[i];
}
}

return Error::Ok;
Expand Down Expand Up @@ -365,6 +443,10 @@ class ET_EXPERIMENTAL CudaBackend final
delete handle;
clear_all_tensors();
}

private:
mutable std::mutex skip_copy_method_mutex_;
std::string skip_copy_method_;
};

} // namespace executorch::backends::cuda
Expand Down
16 changes: 16 additions & 0 deletions extension/asr/runner/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,22 @@ set_target_properties(
extension_asr_runner PROPERTIES POSITION_INDEPENDENT_CODE ON
)

# If the project is configured to build with CUDA support, try to find a CUDA
# runtime (prefer the CUDAToolkit package). If found, expose a compile-time
# macro so sources can conditionally compile CUDA-aware code.
if(EXECUTORCH_BUILD_CUDA)
find_package(CUDAToolkit QUIET)
if(CUDAToolkit_FOUND)
target_compile_definitions(extension_asr_runner PUBLIC CUDA_AVAILABLE)
message(STATUS "CUDAToolkit found; defining CUDA_AVAILABLE for ASR runner")
else()
message(
STATUS
"CUDA requested (EXECUTORCH_BUILD_CUDA=ON) but no CUDA runtime found"
)
endif()
endif()

install(
TARGETS extension_asr_runner
EXPORT ExecuTorchTargets
Expand Down
19 changes: 18 additions & 1 deletion extension/asr/runner/runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
#include <executorch/extension/llm/runner/util.h>
#include <executorch/extension/llm/sampler/util.h>
#include <executorch/extension/tensor/tensor_ptr_maker.h>
#include <executorch/runtime/backend/interface.h>
#include <executorch/runtime/backend/options.h>
#include <executorch/runtime/core/evalue.h>
#include <executorch/runtime/platform/assert.h>
#include <executorch/runtime/platform/log.h>
Expand Down Expand Up @@ -107,7 +109,22 @@ Error AsrRunner::load() {

ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kDecoderMethodName));
decoder_method_loaded_ = true;

#ifdef CUDA_AVAILABLE
executorch::runtime::BackendOptions<1> backend_options;
// For decoder still copy output from GPU to CPU for sampling.
// TODO: change sampler to use a CUDA kernel to sample and then skip copying
// decoder output as well
ET_CHECK_OK_OR_RETURN_ERROR(backend_options.set_option(
"skip_copy_output_to_cpu_for_method", kEncoderMethodName));
const auto opt_err =
executorch::runtime::set_option("CudaBackend", backend_options.view());
if (opt_err != ::executorch::runtime::Error::Ok) {
ET_LOG(
Error,
"Failed to set CUDA backend options: %d",
static_cast<int>(opt_err));
}
#endif
ET_CHECK_OK_OR_RETURN_ERROR(load_tokenizer());
auto eos_ids = get_eos_ids(tokenizer_.get(), module_.get());
if (!eos_ids.empty()) {
Expand Down