Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 56 additions & 7 deletions backends/cuda/cuda_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
# LICENSE file in the root directory of this source tree.

import os
import shutil
import typing
from importlib import resources
from typing import Any, Dict, final, List
from typing import Any, Dict, final, List, Optional

import torch
from executorch.backends.aoti.aoti_backend import AotiBackend
Expand Down Expand Up @@ -36,6 +37,57 @@ class CudaBackend(AotiBackend, BackendDetails):
def get_device_name(cls) -> str:
return "cuda"

@staticmethod
def _find_ptxas_for_version(cuda_version: str) -> Optional[str]: # noqa: C901
"""
Find ptxas binary that matches the expected CUDA version.
Returns the path to ptxas if found and version matches, None otherwise.
"""
expected_version_marker = f"/cuda-{cuda_version}/"

def _validate_ptxas_version(path: str) -> bool:
"""Check if ptxas at given path matches expected CUDA version."""
if not os.path.exists(path):
return False
resolved = os.path.realpath(path)
return expected_version_marker in resolved

# 1. Try PyTorch's CUDA_HOME
try:
from torch.utils.cpp_extension import CUDA_HOME

if CUDA_HOME:
ptxas_path = os.path.join(CUDA_HOME, "bin", "ptxas")
if _validate_ptxas_version(ptxas_path):
return ptxas_path
except ImportError:
pass

# 2. Try CUDA_HOME / CUDA_PATH environment variables
for env_var in ("CUDA_HOME", "CUDA_PATH", "CUDA_ROOT"):
cuda_home = os.environ.get(env_var)
if cuda_home:
ptxas_path = os.path.join(cuda_home, "bin", "ptxas")
if _validate_ptxas_version(ptxas_path):
return ptxas_path

# 3. Try versioned path directly
versioned_path = f"/usr/local/cuda-{cuda_version}/bin/ptxas"
if os.path.exists(versioned_path):
return versioned_path

# 4. Try system PATH via shutil.which
ptxas_in_path = shutil.which("ptxas")
if ptxas_in_path and _validate_ptxas_version(ptxas_in_path):
return ptxas_in_path

# 5. Try default symlink path as last resort
default_path = "/usr/local/cuda/bin/ptxas"
if _validate_ptxas_version(default_path):
return default_path

return None

@staticmethod
def _setup_cuda_environment_for_fatbin() -> bool:
"""
Expand All @@ -57,12 +109,9 @@ def _setup_cuda_environment_for_fatbin() -> bool:

# Set TRITON_PTXAS_PATH for CUDA 12.6+
if major == 12 and minor >= 6:
# Try versioned path first, fallback to symlinked path
ptxas_path = f"/usr/local/cuda-{cuda_version}/bin/ptxas"
if not os.path.exists(ptxas_path):
ptxas_path = "/usr/local/cuda/bin/ptxas"
if not os.path.exists(ptxas_path):
return False
ptxas_path = CudaBackend._find_ptxas_for_version(cuda_version)
if ptxas_path is None:
return False
os.environ["TRITON_PTXAS_PATH"] = ptxas_path

# Get compute capability of current CUDA device
Expand Down
Loading