diff --git a/backends/cuda/cuda_backend.py b/backends/cuda/cuda_backend.py index 01044d85f5f..dbbd79f4881 100644 --- a/backends/cuda/cuda_backend.py +++ b/backends/cuda/cuda_backend.py @@ -5,9 +5,10 @@ # LICENSE file in the root directory of this source tree. import os +import shutil import typing from importlib import resources -from typing import Any, Dict, final, List +from typing import Any, Dict, final, List, Optional import torch from executorch.backends.aoti.aoti_backend import AotiBackend @@ -36,6 +37,57 @@ class CudaBackend(AotiBackend, BackendDetails): def get_device_name(cls) -> str: return "cuda" + @staticmethod + def _find_ptxas_for_version(cuda_version: str) -> Optional[str]: # noqa: C901 + """ + Find ptxas binary that matches the expected CUDA version. + Returns the path to ptxas if found and version matches, None otherwise. + """ + expected_version_marker = f"/cuda-{cuda_version}/" + + def _validate_ptxas_version(path: str) -> bool: + """Check if ptxas at given path matches expected CUDA version.""" + if not os.path.exists(path): + return False + resolved = os.path.realpath(path) + return expected_version_marker in resolved + + # 1. Try PyTorch's CUDA_HOME + try: + from torch.utils.cpp_extension import CUDA_HOME + + if CUDA_HOME: + ptxas_path = os.path.join(CUDA_HOME, "bin", "ptxas") + if _validate_ptxas_version(ptxas_path): + return ptxas_path + except ImportError: + pass + + # 2. Try CUDA_HOME / CUDA_PATH environment variables + for env_var in ("CUDA_HOME", "CUDA_PATH", "CUDA_ROOT"): + cuda_home = os.environ.get(env_var) + if cuda_home: + ptxas_path = os.path.join(cuda_home, "bin", "ptxas") + if _validate_ptxas_version(ptxas_path): + return ptxas_path + + # 3. Try versioned path directly + versioned_path = f"/usr/local/cuda-{cuda_version}/bin/ptxas" + if os.path.exists(versioned_path): + return versioned_path + + # 4. Try system PATH via shutil.which + ptxas_in_path = shutil.which("ptxas") + if ptxas_in_path and _validate_ptxas_version(ptxas_in_path): + return ptxas_in_path + + # 5. Try default symlink path as last resort + default_path = "/usr/local/cuda/bin/ptxas" + if _validate_ptxas_version(default_path): + return default_path + + return None + @staticmethod def _setup_cuda_environment_for_fatbin() -> bool: """ @@ -57,12 +109,9 @@ def _setup_cuda_environment_for_fatbin() -> bool: # Set TRITON_PTXAS_PATH for CUDA 12.6+ if major == 12 and minor >= 6: - # Try versioned path first, fallback to symlinked path - ptxas_path = f"/usr/local/cuda-{cuda_version}/bin/ptxas" - if not os.path.exists(ptxas_path): - ptxas_path = "/usr/local/cuda/bin/ptxas" - if not os.path.exists(ptxas_path): - return False + ptxas_path = CudaBackend._find_ptxas_for_version(cuda_version) + if ptxas_path is None: + return False os.environ["TRITON_PTXAS_PATH"] = ptxas_path # Get compute capability of current CUDA device