Skip to content
Merged
4 changes: 4 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ if(NOT MLX_BUILD_METAL)
${CMAKE_CURRENT_LIST_DIR}/Source/MLX/MLXArray+Metal.swift)
endif()

if(NOT MLX_BUILD_CUDA)
list(REMOVE_ITEM MLX-src ${CMAKE_CURRENT_LIST_DIR}/Source/MLX/GPU+CUDA.swift)
endif()

add_library(MLX STATIC ${MLX-src})
target_include_directories(MLX
PUBLIC ${CMAKE_CURRENT_LIST_DIR}/Source/Cmlx/include)
Expand Down
323 changes: 209 additions & 114 deletions Package.swift
Original file line number Diff line number Diff line change
@@ -1,89 +1,197 @@
// swift-tools-version: 5.12
// swift-tools-version: 6.3;(experimentalCGen)
// The swift-tools-version declares the minimum version of Swift required to build this package.
// Copyright © 2024 Apple Inc.

import PackageDescription

#if os(Linux)
let platformExcludes: [String] = [
// Linux specific excludes
"framework",
"include-framework",
"metal-cpp",
// Exclude Metal backend files on Linux, but keep no_metal.cpp for stubs
"mlx/mlx/backend/metal/allocator.cpp",
"mlx/mlx/backend/metal/binary.cpp",
"mlx/mlx/backend/metal/compiled.cpp",
"mlx/mlx/backend/metal/conv.cpp",
"mlx/mlx/backend/metal/copy.cpp",
"mlx/mlx/backend/metal/custom_kernel.cpp",
"mlx/mlx/backend/metal/device.cpp",
"mlx/mlx/backend/metal/device_info.cpp",
"mlx/mlx/backend/metal/distributed.cpp",
"mlx/mlx/backend/metal/eval.cpp",
"mlx/mlx/backend/metal/event.cpp",
"mlx/mlx/backend/metal/fence.cpp",
"mlx/mlx/backend/metal/fft.cpp",
"mlx/mlx/backend/metal/hadamard.cpp",
"mlx/mlx/backend/metal/indexing.cpp",
"mlx/mlx/backend/metal/jit_kernels.cpp",
"mlx/mlx/backend/metal/logsumexp.cpp",
"mlx/mlx/backend/metal/matmul.cpp",
"mlx/mlx/backend/metal/metal.cpp",
"mlx/mlx/backend/metal/normalization.cpp",
"mlx/mlx/backend/metal/primitives.cpp",
"mlx/mlx/backend/metal/quantized.cpp",
"mlx/mlx/backend/metal/reduce.cpp",
"mlx/mlx/backend/metal/resident.cpp",
"mlx/mlx/backend/metal/rope.cpp",
"mlx/mlx/backend/metal/scaled_dot_product_attention.cpp",
"mlx/mlx/backend/metal/scan.cpp",
"mlx/mlx/backend/metal/slicing.cpp",
"mlx/mlx/backend/metal/softmax.cpp",
"mlx/mlx/backend/metal/sort.cpp",
"mlx/mlx/backend/metal/ternary.cpp",
"mlx/mlx/backend/metal/unary.cpp",
"mlx/mlx/backend/metal/utils.cpp",
"mlx/mlx/backend/metal/kernels", // Exclude kernels directory
"mlx/mlx/backend/metal/jit", // Exclude jit directory

"mlx/mlx/backend/gpu", // Exclude GPU backend on Linux, use no_gpu instead
"mlx/mlx/backend/no_cpu", // Exclude no_cpu backend on Linux, use cpu instead
"mlx/mlx/backend/cpu/gemms/bnns.cpp", // macOS Accelerate version
"mlx-conditional",
"mlx-c/mlx/c/metal.cpp",

"mlx-c/mlx/c/fast.cpp", // Exclude on Linux - calls metal_kernel unconditionally
]

let cxxSettings: [CXXSetting] = []

let linkerSettings: [LinkerSetting] = [
.linkedLibrary("gfortran", .when(platforms: [.linux])),
.linkedLibrary("blas", .when(platforms: [.linux])),
.linkedLibrary("lapack", .when(platforms: [.linux])),
.linkedLibrary("openblas", .when(platforms: [.linux])),
]
let noMetalCmlxExcludes = [
// Exclude Metal backend files, but keep no_metal.cpp for stubs
// "mlx/mlx/backend/metal/no_metal.cpp",
"mlx/mlx/backend/metal/allocator.cpp",
"mlx/mlx/backend/metal/binary.cpp",
"mlx/mlx/backend/metal/compiled.cpp",
"mlx/mlx/backend/metal/conv.cpp",
"mlx/mlx/backend/metal/copy.cpp",
"mlx/mlx/backend/metal/custom_kernel.cpp",
"mlx/mlx/backend/metal/device.cpp",
"mlx/mlx/backend/metal/device_info.cpp",
"mlx/mlx/backend/metal/distributed.cpp",
"mlx/mlx/backend/metal/eval.cpp",
"mlx/mlx/backend/metal/event.cpp",
"mlx/mlx/backend/metal/fence.cpp",
"mlx/mlx/backend/metal/fft.cpp",
"mlx/mlx/backend/metal/hadamard.cpp",
"mlx/mlx/backend/metal/indexing.cpp",
"mlx/mlx/backend/metal/jit_kernels.cpp",
"mlx/mlx/backend/metal/logsumexp.cpp",
"mlx/mlx/backend/metal/matmul.cpp",
"mlx/mlx/backend/metal/metal.cpp",
"mlx/mlx/backend/metal/normalization.cpp",
"mlx/mlx/backend/metal/primitives.cpp",
"mlx/mlx/backend/metal/quantized.cpp",
"mlx/mlx/backend/metal/reduce.cpp",
"mlx/mlx/backend/metal/resident.cpp",
"mlx/mlx/backend/metal/rope.cpp",
"mlx/mlx/backend/metal/scaled_dot_product_attention.cpp",
"mlx/mlx/backend/metal/scan.cpp",
"mlx/mlx/backend/metal/slicing.cpp",
"mlx/mlx/backend/metal/softmax.cpp",
"mlx/mlx/backend/metal/sort.cpp",
"mlx/mlx/backend/metal/ternary.cpp",
"mlx/mlx/backend/metal/unary.cpp",
"mlx/mlx/backend/metal/utils.cpp",
"mlx/mlx/backend/metal/kernels", // Exclude kernels directory
"mlx/mlx/backend/metal/jit", // Exclude jit directory
]

let noCudaCmlxExcludes = [
// Exclude CUDA backend files, but keep no_cuda.cpp for stubs
// mlx/mlx/backend/cuda/no_cuda.cpp
"mlx/mlx/backend/cuda/allocator.cpp",
"mlx/mlx/backend/cuda/compiled.cpp",
"mlx/mlx/backend/cuda/conv.cpp",
"mlx/mlx/backend/cuda/cublas_utils.cpp",
"mlx/mlx/backend/cuda/cudnn_utils.cpp",
"mlx/mlx/backend/cuda/custom_kernel.cpp",
"mlx/mlx/backend/cuda/delayload.cpp",
"mlx/mlx/backend/cuda/device_info.cpp",
"mlx/mlx/backend/cuda/device.cpp",
"mlx/mlx/backend/cuda/eval.cpp",
"mlx/mlx/backend/cuda/fence.cpp",
"mlx/mlx/backend/cuda/indexing.cpp",
"mlx/mlx/backend/cuda/jit_module.cpp",
"mlx/mlx/backend/cuda/load.cpp",
"mlx/mlx/backend/cuda/matmul.cpp",
"mlx/mlx/backend/cuda/primitives.cpp",
"mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp",
"mlx/mlx/backend/cuda/slicing.cpp",
"mlx/mlx/backend/cuda/utils.cpp",
"mlx/mlx/backend/cuda/worker.cpp",
"mlx/mlx/backend/cuda/binary",
"mlx/mlx/backend/cuda/conv",
"mlx/mlx/backend/cuda/copy",
"mlx/mlx/backend/cuda/device",
"mlx/mlx/backend/cuda/gemms",
"mlx/mlx/backend/cuda/quantized",
"mlx/mlx/backend/cuda/reduce",
"mlx/mlx/backend/cuda/steel",
"mlx/mlx/backend/cuda/unary",
]

let mlxSwiftExcludes: [String] = [
"GPU+Metal.swift",
"MLXArray+Metal.swift",
"MLXFast.swift",
"MLXFastKernel.swift",
]
#if os(Linux)
let platformExcludes: [String]
let cxxSettings: [CXXSetting]
let linkerSettings: [LinkerSetting]
let mlxSwiftExcludes: [String]

if Context.environment["SPM_CUDA"] != "0" {
// Linux with CUDA

platformExcludes =
[
"framework",
"include-framework",
"metal-cpp",

"mlx/mlx/backend/no_gpu",
"mlx/mlx/backend/cuda/no_cuda.cpp",
"mlx/mlx/backend/cuda/quantized/no_qqmm_impl.cpp",
"mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp",
"mlx/mlx/backend/no_cpu", // Exclude no_cpu backend on Linux, use cpu instead
"mlx/mlx/backend/cpu/gemms/bnns.cpp", // macOS Accelerate version
"mlx-conditional",
"mlx-c/mlx/c/metal.cpp",

"mlx/mlx/backend/cuda/delayload.cpp", // For Windows
"mlx/mlx/backend/cuda/quantized/qmm/qmm_impl_sm90_m128_n16_m1.cu",
"mlx/mlx/backend/cuda/quantized/qmm/qmv.cu",
"mlx/mlx/backend/cuda/quantized/qmm/qmm_impl_sm90_m128_n32_m1.cu",
"mlx/mlx/backend/cuda/quantized/qmm/qmm_impl_sm90.cuh",
"mlx/mlx/backend/cuda/quantized/qmm/qmm_impl_sm90_m128_n64_m2.cu",
"mlx/mlx/backend/cuda/quantized/qmm/qmm.h",
"mlx/mlx/backend/cuda/quantized/qmm/qmm_impl_sm90_m128_n256_m2.cu",
"mlx/mlx/backend/cuda/quantized/qmm/qmm.cu",
"mlx/mlx/backend/cuda/quantized/qmm/qmm_impl_sm90_m128_n128_m2.cu",
"mlx/mlx/backend/cuda/quantized/qmm/fp_qmv.cu",
] + noMetalCmlxExcludes

cxxSettings = [
.unsafeFlags(["-I/usr/local/cuda/include"]),
.unsafeFlags(["-I/usr/local/cuda/include/cccl"]),
.define("MLX_CCCL_DIR", to: "\"/usr/local/cuda/include/cccl\""),
]

linkerSettings = [
.linkedLibrary("gfortran", .when(platforms: [.linux])),
.linkedLibrary("blas", .when(platforms: [.linux])),
.linkedLibrary("lapack", .when(platforms: [.linux])),
.linkedLibrary("openblas", .when(platforms: [.linux])),
.unsafeFlags(["-L/usr/local/cuda/lib64"]),
.unsafeFlags(["-L/usr/local/cuda/lib64/stubs"]),
.linkedLibrary("cudnn"),
.linkedLibrary("cublas"),
.linkedLibrary("cublasLt"),
.linkedLibrary("nvrtc"),
.linkedLibrary("cudart"),
.linkedLibrary("cuda"),
]

mlxSwiftExcludes = [
"GPU+Metal.swift",
"MLXArray+Metal.swift",
]
} else {
// Linux without CUDA (CPU only)

platformExcludes =
[
"framework",
"include-framework",
"metal-cpp",

"mlx/mlx/backend/gpu", // Exclude GPU backend on Linux, use no_gpu instead
"mlx/mlx/backend/no_cpu", // Exclude no_cpu backend on Linux, use cpu instead
"mlx/mlx/backend/cpu/gemms/bnns.cpp", // macOS Accelerate version
"mlx-conditional",
"mlx-c/mlx/c/metal.cpp",

"mlx-c/mlx/c/fast.cpp", // Exclude on Linux - calls metal_kernel unconditionally

] + noMetalCmlxExcludes + noCudaCmlxExcludes

cxxSettings = []

linkerSettings = [
.linkedLibrary("gfortran", .when(platforms: [.linux])),
.linkedLibrary("blas", .when(platforms: [.linux])),
.linkedLibrary("lapack", .when(platforms: [.linux])),
.linkedLibrary("openblas", .when(platforms: [.linux])),
]

mlxSwiftExcludes = [
"GPU+Metal.swift",
"GPU+CUDA.swift",
"MLXArray+Metal.swift",
"MLXFast.swift",
"MLXFastKernel.swift",
]
}
#else
let platformExcludes: [String] = [
"mlx/mlx/backend/cpu/compiled.cpp",
// Apple's platforms with Metal

// opt-out of these backends (using metal)
"mlx/mlx/backend/no_gpu",
"mlx/mlx/backend/no_cpu",
"mlx/mlx/backend/metal/no_metal.cpp",
let platformExcludes: [String] =
[
"mlx/mlx/backend/cpu/compiled.cpp",

// bnns instead of simd (accelerate)
"mlx/mlx/backend/cpu/gemms/simd_fp16.cpp",
"mlx/mlx/backend/cpu/gemms/simd_bf16.cpp",
]
// opt-out of these backends (using metal)
"mlx/mlx/backend/no_gpu",
"mlx/mlx/backend/no_cpu",
"mlx/mlx/backend/metal/no_metal.cpp",

// bnns instead of simd (accelerate)
"mlx/mlx/backend/cpu/gemms/simd_fp16.cpp",
"mlx/mlx/backend/cpu/gemms/simd_bf16.cpp",
] + noCudaCmlxExcludes

let cxxSettings: [CXXSetting] = [
.headerSearchPath("metal-cpp"),
Expand All @@ -101,7 +209,9 @@ import PackageDescription
.linkedFramework("Accelerate"),
]

let mlxSwiftExcludes: [String] = []
let mlxSwiftExcludes: [String] = [
"GPU+CUDA.swift"
]
#endif

let cmlx = Target.target(
Expand Down Expand Up @@ -147,40 +257,6 @@ let cmlx = Target.target(
"mlx/setup.py",
"mlx/tests",

// special handling for cuda -- we need to keep one file:
// mlx/mlx/backend/cuda/no_cuda.cpp

"mlx/mlx/backend/cuda/allocator.cpp",
"mlx/mlx/backend/cuda/compiled.cpp",
"mlx/mlx/backend/cuda/conv.cpp",
"mlx/mlx/backend/cuda/cublas_utils.cpp",
"mlx/mlx/backend/cuda/cudnn_utils.cpp",
"mlx/mlx/backend/cuda/custom_kernel.cpp",
"mlx/mlx/backend/cuda/delayload.cpp",
"mlx/mlx/backend/cuda/device_info.cpp",
"mlx/mlx/backend/cuda/device.cpp",
"mlx/mlx/backend/cuda/eval.cpp",
"mlx/mlx/backend/cuda/fence.cpp",
"mlx/mlx/backend/cuda/indexing.cpp",
"mlx/mlx/backend/cuda/jit_module.cpp",
"mlx/mlx/backend/cuda/load.cpp",
"mlx/mlx/backend/cuda/matmul.cpp",
"mlx/mlx/backend/cuda/primitives.cpp",
"mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp",
"mlx/mlx/backend/cuda/slicing.cpp",
"mlx/mlx/backend/cuda/utils.cpp",
"mlx/mlx/backend/cuda/worker.cpp",

"mlx/mlx/backend/cuda/binary",
"mlx/mlx/backend/cuda/conv",
"mlx/mlx/backend/cuda/copy",
"mlx/mlx/backend/cuda/device",
"mlx/mlx/backend/cuda/gemms",
"mlx/mlx/backend/cuda/quantized",
"mlx/mlx/backend/cuda/reduce",
"mlx/mlx/backend/cuda/steel",
"mlx/mlx/backend/cuda/unary",

// build variants (we are opting _out_ of these)
"mlx/mlx/io/no_safetensors.cpp",
"mlx/mlx/io/gguf.cpp",
Expand All @@ -203,6 +279,7 @@ let cmlx = Target.target(
cSettings: [
.headerSearchPath("mlx"),
.headerSearchPath("mlx-c"),
.headerSearchPath("mlx-generated/cuda"),
],
cxxSettings: cxxSettings + [
.headerSearchPath("mlx"),
Expand All @@ -211,7 +288,10 @@ let cmlx = Target.target(
.headerSearchPath("fmt/include"),
.define("MLX_VERSION", to: "\"0.31.1\""),
],
linkerSettings: linkerSettings
linkerSettings: linkerSettings,
plugins: [
.plugin(name: "CudaBuild")
],
)

let package = Package(
Expand All @@ -236,7 +316,8 @@ let package = Package(
],
dependencies: [
// for Complex type
.package(url: "https://github.com/apple/swift-numerics", from: "1.0.0")
.package(url: "https://github.com/apple/swift-numerics", from: "1.0.0"),
.package(url: "https://github.com/apple/swift-argument-parser", from: "1.0.0"),
],
targets: [
cmlx,
Expand Down Expand Up @@ -333,6 +414,20 @@ let package = Package(
path: "Source/Examples",
sources: ["CustomFunctionExampleSimple.swift"]
),
.executableTarget(
name: "encuda",
dependencies: [
.product(name: "ArgumentParser", package: "swift-argument-parser")
],
path: "Source/Encuda",
),
.plugin(
name: "CudaBuild",
capability: .buildTool(),
dependencies: [
.target(name: "encuda")
],
),
],
cxxLanguageStandard: .gnucxx20
)
Expand Down
Loading
Loading