Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
2 changes: 1 addition & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[submodule "submodules/mlx"]
path = Source/Cmlx/mlx
path = Source/Cxxmlx/mlx
url = https://github.com/ml-explore/mlx
[submodule "submodules/mlx-c"]
path = Source/Cmlx/mlx-c
Expand Down
6 changes: 3 additions & 3 deletions MAINTENANCE.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,14 +135,15 @@ This is important because `mlx` has some build-time source generation
pre-generating the source when updating the `mlx` version.

1. Update the `mlx` and `mlx-c` submodules via `git pull` or `git checkout ...`
- `Source/Cmlx/mlx`
- `Source/Cxxmlx/mlx`
- `Source/Cmlx/mlx-c`
2. Add any vendored dependencies as needed in `/vendor`

3. Regenerate any build-time source: `./tools/update-mlx.sh`
- this updates headers in Source/Cmlx/include
- this updates headers in Source/Cmlx/include-framework
- this generates various files in Source/Cmlx/mlx-generated
- this updates headers in Source/Cxxmlx/include
- this generates various files in Source/Cxxmlx/mlx-generated

4. Fix any build issues with SwiftPM build (opening Package.swift)
5. Fix any build issues with xcodeproj build (opening xcode/MLX.codeproj), see also [README.xcodeproj.md]
Expand Down Expand Up @@ -181,4 +182,3 @@ Settings, including header search paths are in xcode/xcconfig.
### MLX, etc.

These are just normal frameworks that link to Cmlx and others as needed. The source files are all swift and there are no special settings needed.

75 changes: 47 additions & 28 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -79,18 +79,17 @@ let noCudaCmlxExcludes = [
]

#if os(Linux)
let platformExcludes: [String]
let cxxPlatformExcludes: [String]
let cmlxPlatformExcludes: [String]
let cxxSettings: [CXXSetting]
let linkerSettings: [LinkerSetting]
let mlxSwiftExcludes: [String]

if Context.environment["SPM_CUDA"] != "0" {
// Linux with CUDA

platformExcludes =
cxxPlatformExcludes =
[
"framework",
"include-framework",
"metal-cpp",

"mlx/mlx/backend/no_gpu",
Expand All @@ -100,7 +99,6 @@ let noCudaCmlxExcludes = [
"mlx/mlx/backend/no_cpu", // Exclude no_cpu backend on Linux, use cpu instead
"mlx/mlx/backend/cpu/gemms/bnns.cpp", // macOS Accelerate version
"mlx-conditional",
"mlx-c/mlx/c/metal.cpp",

"mlx/mlx/backend/cuda/delayload.cpp", // For Windows
"mlx/mlx/backend/cuda/quantized/qmm/qmm_impl_sm90_m128_n16_m1.cu",
Expand All @@ -115,6 +113,12 @@ let noCudaCmlxExcludes = [
"mlx/mlx/backend/cuda/quantized/qmm/fp_qmv.cu",
] + noMetalCmlxExcludes

cmlxPlatformExcludes = [
"framework",
"include-framework",
"mlx-c/mlx/c/metal.cpp",
]

cxxSettings = [
.unsafeFlags(["-I/usr/local/cuda/include"]),
.unsafeFlags(["-I/usr/local/cuda/include/cccl"]),
Expand Down Expand Up @@ -143,22 +147,24 @@ let noCudaCmlxExcludes = [
} else {
// Linux without CUDA (CPU only)

platformExcludes =
cxxPlatformExcludes =
[
"framework",
"include-framework",
"metal-cpp",

"mlx/mlx/backend/gpu", // Exclude GPU backend on Linux, use no_gpu instead
"mlx/mlx/backend/no_cpu", // Exclude no_cpu backend on Linux, use cpu instead
"mlx/mlx/backend/cpu/gemms/bnns.cpp", // macOS Accelerate version
"mlx-conditional",
"mlx-c/mlx/c/metal.cpp",

"mlx-c/mlx/c/fast.cpp", // Exclude on Linux - calls metal_kernel unconditionally

] + noMetalCmlxExcludes + noCudaCmlxExcludes

cmlxPlatformExcludes = [
"framework",
"include-framework",
"mlx-c/mlx/c/metal.cpp",
"mlx-c/mlx/c/fast.cpp", // Exclude on Linux - calls metal_kernel unconditionally
]

cxxSettings = []

linkerSettings = [
Expand All @@ -179,7 +185,7 @@ let noCudaCmlxExcludes = [
#else
// Apple's platforms with Metal

let platformExcludes: [String] =
let cxxPlatformExcludes: [String] =
[
"mlx/mlx/backend/cpu/compiled.cpp",

Expand All @@ -193,13 +199,15 @@ let noCudaCmlxExcludes = [
"mlx/mlx/backend/cpu/gemms/simd_bf16.cpp",
] + noCudaCmlxExcludes

let cmlxPlatformExcludes: [String] = []

let cxxSettings: [CXXSetting] = [
.headerSearchPath("metal-cpp"),

.define("MLX_USE_ACCELERATE"),
.define("ACCELERATE_NEW_LAPACK"),
.define("_METAL_"),
.define("SWIFTPM_BUNDLE", to: "\"mlx-swift_Cmlx\""),
.define("SWIFTPM_BUNDLE", to: "\"mlx-swift_Cxxmlx\""),
.define("METAL_PATH", to: "\"default.metallib\""),
]

Expand All @@ -214,18 +222,13 @@ let noCudaCmlxExcludes = [
]
#endif

let cmlx = Target.target(
name: "Cmlx",
path: "Source/Cmlx",
exclude: platformExcludes + [
let cxxmlx = Target.target(
name: "Cxxmlx",
path: "Source/Cxxmlx",
exclude: cxxPlatformExcludes + [
// vendor docs
"vendor-README.md",

// example code + mlx-c distributed
"mlx-c/examples",
"mlx-c/mlx/c/distributed.cpp",
"mlx-c/mlx/c/distributed_group.cpp",

// vendored library, include header only
"json",

Expand Down Expand Up @@ -276,14 +279,10 @@ let cmlx = Target.target(
"mlx/mlx/distributed/jaccl/ring.cpp",
"mlx/mlx/distributed/jaccl/utils.cpp",
],
cSettings: [
.headerSearchPath("mlx"),
.headerSearchPath("mlx-c"),
.headerSearchPath("mlx-generated/cuda"),
],
publicHeadersPath: "include",
cxxSettings: cxxSettings + [
.headerSearchPath("mlx"),
.headerSearchPath("mlx-c"),
.headerSearchPath("mlx-generated/cuda"),
.headerSearchPath("json/single_include/nlohmann"),
.headerSearchPath("fmt/include"),
.define("MLX_VERSION", to: "\"0.31.1\""),
Expand All @@ -294,6 +293,24 @@ let cmlx = Target.target(
],
)

let cmlx = Target.target(
name: "Cmlx",
dependencies: ["Cxxmlx"],
path: "Source/Cmlx",
exclude: cmlxPlatformExcludes + [
// example code + mlx-c distributed
"mlx-c/examples",
"mlx-c/mlx/c/distributed.cpp",
"mlx-c/mlx/c/distributed_group.cpp",
],
cSettings: [
.headerSearchPath("mlx-c")
],
cxxSettings: [
.headerSearchPath("mlx-c")
]
)

let package = Package(
name: "mlx-swift",

Expand All @@ -313,13 +330,15 @@ let package = Package(
.library(name: "MLXFFT", targets: ["MLXFFT"]),
.library(name: "MLXLinalg", targets: ["MLXLinalg"]),
.library(name: "MLXFast", targets: ["MLXFast"]),
.library(name: "Cxxmlx", targets: ["Cxxmlx"]),
],
dependencies: [
// for Complex type
.package(url: "https://github.com/apple/swift-numerics", from: "1.0.0"),
.package(url: "https://github.com/apple/swift-argument-parser", from: "1.0.0"),
],
targets: [
cxxmlx,
cmlx,
.testTarget(
name: "CmlxTests",
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ dependencies: [.product(name: "MLX", package: "mlx-swift"),

**Update the submodules**

The directories `Source/Cmlx/mlx` and `Source/Cmlx/mlx-c` are sourced as submodules.
The directories `Source/Cxxmlx/mlx` and `Source/Cmlx/mlx-c` are sourced as submodules.
Before you attempt to build the project locally, pull down the updates for those submodules:

```shell
Expand Down
3 changes: 1 addition & 2 deletions Source/Cmlx/CudaBuild.json → Source/Cxxmlx/CudaBuild.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"headerSearchPaths": [
"mlx",
"mlx-c",
"mlx-generated/cuda",
"json/single_include/nlohmann",
"fmt/include"
],
Expand All @@ -13,7 +13,6 @@
"mlx/mlx/backend/metal/jit",
"mlx/mlx/backend/no_cpu",
"mlx-conditional",
"mlx-c/examples",
"json",
"fmt/test",
"fmt/doc",
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
22 changes: 22 additions & 0 deletions Source/Cxxmlx/include/Cxxmlx.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#pragma once

#include <mlx/array.h>
#include <mlx/compile.h>
#include <mlx/device.h>
#include <mlx/dtype.h>
#include <mlx/einsum.h>
#include <mlx/fast.h>
#include <mlx/fft.h>
#include <mlx/io.h>
#include <mlx/linalg.h>
#include <mlx/memory.h>
#include <mlx/ops.h>
#include <mlx/random.h>
#include <mlx/stream.h>
#include <mlx/transforms.h>
#include <mlx/utils.h>
#include <mlx/version.h>

#if defined(__APPLE__)
#include <mlx/backend/metal/metal.h>
#endif
Loading