diff --git a/.gitmodules b/.gitmodules index 4b9b6084..9428bb77 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,5 +1,5 @@ [submodule "submodules/mlx"] - path = Source/Cmlx/mlx + path = Source/Cxxmlx/mlx url = https://github.com/ml-explore/mlx [submodule "submodules/mlx-c"] path = Source/Cmlx/mlx-c diff --git a/MAINTENANCE.md b/MAINTENANCE.md index f71c5b26..90b8b27e 100644 --- a/MAINTENANCE.md +++ b/MAINTENANCE.md @@ -135,14 +135,15 @@ This is important because `mlx` has some build-time source generation pre-generating the source when updating the `mlx` version. 1. Update the `mlx` and `mlx-c` submodules via `git pull` or `git checkout ...` - - `Source/Cmlx/mlx` + - `Source/Cxxmlx/mlx` - `Source/Cmlx/mlx-c` 2. Add any vendored dependencies as needed in `/vendor` 3. Regenerate any build-time source: `./tools/update-mlx.sh` - this updates headers in Source/Cmlx/include - this updates headers in Source/Cmlx/include-framework - - this generates various files in Source/Cmlx/mlx-generated + - this updates headers in Source/Cxxmlx/include + - this generates various files in Source/Cxxmlx/mlx-generated 4. Fix any build issues with SwiftPM build (opening Package.swift) 5. Fix any build issues with xcodeproj build (opening xcode/MLX.codeproj), see also [README.xcodeproj.md] @@ -181,4 +182,3 @@ Settings, including header search paths are in xcode/xcconfig. ### MLX, etc. These are just normal frameworks that link to Cmlx and others as needed. The source files are all swift and there are no special settings needed. - diff --git a/Package.swift b/Package.swift index e36e0d86..1ef9797d 100644 --- a/Package.swift +++ b/Package.swift @@ -79,7 +79,8 @@ let noCudaCmlxExcludes = [ ] #if os(Linux) - let platformExcludes: [String] + let cxxPlatformExcludes: [String] + let cmlxPlatformExcludes: [String] let cxxSettings: [CXXSetting] let linkerSettings: [LinkerSetting] let mlxSwiftExcludes: [String] @@ -87,10 +88,8 @@ let noCudaCmlxExcludes = [ if Context.environment["SPM_CUDA"] != "0" { // Linux with CUDA - platformExcludes = + cxxPlatformExcludes = [ - "framework", - "include-framework", "metal-cpp", "mlx/mlx/backend/no_gpu", @@ -100,7 +99,6 @@ let noCudaCmlxExcludes = [ "mlx/mlx/backend/no_cpu", // Exclude no_cpu backend on Linux, use cpu instead "mlx/mlx/backend/cpu/gemms/bnns.cpp", // macOS Accelerate version "mlx-conditional", - "mlx-c/mlx/c/metal.cpp", "mlx/mlx/backend/cuda/delayload.cpp", // For Windows "mlx/mlx/backend/cuda/quantized/qmm/qmm_impl_sm90_m128_n16_m1.cu", @@ -115,6 +113,12 @@ let noCudaCmlxExcludes = [ "mlx/mlx/backend/cuda/quantized/qmm/fp_qmv.cu", ] + noMetalCmlxExcludes + cmlxPlatformExcludes = [ + "framework", + "include-framework", + "mlx-c/mlx/c/metal.cpp", + ] + cxxSettings = [ .unsafeFlags(["-I/usr/local/cuda/include"]), .unsafeFlags(["-I/usr/local/cuda/include/cccl"]), @@ -143,22 +147,24 @@ let noCudaCmlxExcludes = [ } else { // Linux without CUDA (CPU only) - platformExcludes = + cxxPlatformExcludes = [ - "framework", - "include-framework", "metal-cpp", "mlx/mlx/backend/gpu", // Exclude GPU backend on Linux, use no_gpu instead "mlx/mlx/backend/no_cpu", // Exclude no_cpu backend on Linux, use cpu instead "mlx/mlx/backend/cpu/gemms/bnns.cpp", // macOS Accelerate version "mlx-conditional", - "mlx-c/mlx/c/metal.cpp", - - "mlx-c/mlx/c/fast.cpp", // Exclude on Linux - calls metal_kernel unconditionally ] + noMetalCmlxExcludes + noCudaCmlxExcludes + cmlxPlatformExcludes = [ + "framework", + "include-framework", + "mlx-c/mlx/c/metal.cpp", + "mlx-c/mlx/c/fast.cpp", // Exclude on Linux - calls metal_kernel unconditionally + ] + cxxSettings = [] linkerSettings = [ @@ -179,7 +185,7 @@ let noCudaCmlxExcludes = [ #else // Apple's platforms with Metal - let platformExcludes: [String] = + let cxxPlatformExcludes: [String] = [ "mlx/mlx/backend/cpu/compiled.cpp", @@ -193,13 +199,15 @@ let noCudaCmlxExcludes = [ "mlx/mlx/backend/cpu/gemms/simd_bf16.cpp", ] + noCudaCmlxExcludes + let cmlxPlatformExcludes: [String] = [] + let cxxSettings: [CXXSetting] = [ .headerSearchPath("metal-cpp"), .define("MLX_USE_ACCELERATE"), .define("ACCELERATE_NEW_LAPACK"), .define("_METAL_"), - .define("SWIFTPM_BUNDLE", to: "\"mlx-swift_Cmlx\""), + .define("SWIFTPM_BUNDLE", to: "\"mlx-swift_Cxxmlx\""), .define("METAL_PATH", to: "\"default.metallib\""), ] @@ -214,18 +222,13 @@ let noCudaCmlxExcludes = [ ] #endif -let cmlx = Target.target( - name: "Cmlx", - path: "Source/Cmlx", - exclude: platformExcludes + [ +let cxxmlx = Target.target( + name: "Cxxmlx", + path: "Source/Cxxmlx", + exclude: cxxPlatformExcludes + [ // vendor docs "vendor-README.md", - // example code + mlx-c distributed - "mlx-c/examples", - "mlx-c/mlx/c/distributed.cpp", - "mlx-c/mlx/c/distributed_group.cpp", - // vendored library, include header only "json", @@ -276,14 +279,10 @@ let cmlx = Target.target( "mlx/mlx/distributed/jaccl/ring.cpp", "mlx/mlx/distributed/jaccl/utils.cpp", ], - cSettings: [ - .headerSearchPath("mlx"), - .headerSearchPath("mlx-c"), - .headerSearchPath("mlx-generated/cuda"), - ], + publicHeadersPath: "include", cxxSettings: cxxSettings + [ .headerSearchPath("mlx"), - .headerSearchPath("mlx-c"), + .headerSearchPath("mlx-generated/cuda"), .headerSearchPath("json/single_include/nlohmann"), .headerSearchPath("fmt/include"), .define("MLX_VERSION", to: "\"0.31.1\""), @@ -294,6 +293,24 @@ let cmlx = Target.target( ], ) +let cmlx = Target.target( + name: "Cmlx", + dependencies: ["Cxxmlx"], + path: "Source/Cmlx", + exclude: cmlxPlatformExcludes + [ + // example code + mlx-c distributed + "mlx-c/examples", + "mlx-c/mlx/c/distributed.cpp", + "mlx-c/mlx/c/distributed_group.cpp", + ], + cSettings: [ + .headerSearchPath("mlx-c") + ], + cxxSettings: [ + .headerSearchPath("mlx-c") + ] +) + let package = Package( name: "mlx-swift", @@ -313,6 +330,7 @@ let package = Package( .library(name: "MLXFFT", targets: ["MLXFFT"]), .library(name: "MLXLinalg", targets: ["MLXLinalg"]), .library(name: "MLXFast", targets: ["MLXFast"]), + .library(name: "Cxxmlx", targets: ["Cxxmlx"]), ], dependencies: [ // for Complex type @@ -320,6 +338,7 @@ let package = Package( .package(url: "https://github.com/apple/swift-argument-parser", from: "1.0.0"), ], targets: [ + cxxmlx, cmlx, .testTarget( name: "CmlxTests", diff --git a/README.md b/README.md index 540607ec..d8748366 100644 --- a/README.md +++ b/README.md @@ -91,7 +91,7 @@ dependencies: [.product(name: "MLX", package: "mlx-swift"), **Update the submodules** -The directories `Source/Cmlx/mlx` and `Source/Cmlx/mlx-c` are sourced as submodules. +The directories `Source/Cxxmlx/mlx` and `Source/Cmlx/mlx-c` are sourced as submodules. Before you attempt to build the project locally, pull down the updates for those submodules: ```shell diff --git a/Source/Cmlx/CudaBuild.json b/Source/Cxxmlx/CudaBuild.json similarity index 93% rename from Source/Cmlx/CudaBuild.json rename to Source/Cxxmlx/CudaBuild.json index 3ba88fc9..385cdfce 100644 --- a/Source/Cmlx/CudaBuild.json +++ b/Source/Cxxmlx/CudaBuild.json @@ -1,7 +1,7 @@ { "headerSearchPaths": [ "mlx", - "mlx-c", + "mlx-generated/cuda", "json/single_include/nlohmann", "fmt/include" ], @@ -13,7 +13,6 @@ "mlx/mlx/backend/metal/jit", "mlx/mlx/backend/no_cpu", "mlx-conditional", - "mlx-c/examples", "json", "fmt/test", "fmt/doc", diff --git a/Source/Cmlx/fmt/.clang-format b/Source/Cxxmlx/fmt/.clang-format similarity index 100% rename from Source/Cmlx/fmt/.clang-format rename to Source/Cxxmlx/fmt/.clang-format diff --git a/Source/Cmlx/fmt/.gitignore b/Source/Cxxmlx/fmt/.gitignore similarity index 100% rename from Source/Cmlx/fmt/.gitignore rename to Source/Cxxmlx/fmt/.gitignore diff --git a/Source/Cmlx/fmt/CMakeLists.txt b/Source/Cxxmlx/fmt/CMakeLists.txt similarity index 100% rename from Source/Cmlx/fmt/CMakeLists.txt rename to Source/Cxxmlx/fmt/CMakeLists.txt diff --git a/Source/Cmlx/fmt/CONTRIBUTING.md b/Source/Cxxmlx/fmt/CONTRIBUTING.md similarity index 100% rename from Source/Cmlx/fmt/CONTRIBUTING.md rename to Source/Cxxmlx/fmt/CONTRIBUTING.md diff --git a/Source/Cmlx/fmt/ChangeLog.md b/Source/Cxxmlx/fmt/ChangeLog.md similarity index 100% rename from Source/Cmlx/fmt/ChangeLog.md rename to Source/Cxxmlx/fmt/ChangeLog.md diff --git a/Source/Cmlx/fmt/LICENSE b/Source/Cxxmlx/fmt/LICENSE similarity index 100% rename from Source/Cmlx/fmt/LICENSE rename to Source/Cxxmlx/fmt/LICENSE diff --git a/Source/Cmlx/fmt/README.md b/Source/Cxxmlx/fmt/README.md similarity index 100% rename from Source/Cmlx/fmt/README.md rename to Source/Cxxmlx/fmt/README.md diff --git a/Source/Cmlx/fmt/doc/ChangeLog-old.md b/Source/Cxxmlx/fmt/doc/ChangeLog-old.md similarity index 100% rename from Source/Cmlx/fmt/doc/ChangeLog-old.md rename to Source/Cxxmlx/fmt/doc/ChangeLog-old.md diff --git a/Source/Cmlx/fmt/doc/api.md b/Source/Cxxmlx/fmt/doc/api.md similarity index 100% rename from Source/Cmlx/fmt/doc/api.md rename to Source/Cxxmlx/fmt/doc/api.md diff --git a/Source/Cmlx/fmt/doc/fmt.css b/Source/Cxxmlx/fmt/doc/fmt.css similarity index 100% rename from Source/Cmlx/fmt/doc/fmt.css rename to Source/Cxxmlx/fmt/doc/fmt.css diff --git a/Source/Cmlx/fmt/doc/fmt.js b/Source/Cxxmlx/fmt/doc/fmt.js similarity index 100% rename from Source/Cmlx/fmt/doc/fmt.js rename to Source/Cxxmlx/fmt/doc/fmt.js diff --git a/Source/Cmlx/fmt/doc/get-started.md b/Source/Cxxmlx/fmt/doc/get-started.md similarity index 100% rename from Source/Cmlx/fmt/doc/get-started.md rename to Source/Cxxmlx/fmt/doc/get-started.md diff --git a/Source/Cmlx/fmt/doc/index.md b/Source/Cxxmlx/fmt/doc/index.md similarity index 100% rename from Source/Cmlx/fmt/doc/index.md rename to Source/Cxxmlx/fmt/doc/index.md diff --git a/Source/Cmlx/fmt/doc/perf.svg b/Source/Cxxmlx/fmt/doc/perf.svg similarity index 100% rename from Source/Cmlx/fmt/doc/perf.svg rename to Source/Cxxmlx/fmt/doc/perf.svg diff --git a/Source/Cmlx/fmt/doc/python-license.txt b/Source/Cxxmlx/fmt/doc/python-license.txt similarity index 100% rename from Source/Cmlx/fmt/doc/python-license.txt rename to Source/Cxxmlx/fmt/doc/python-license.txt diff --git a/Source/Cmlx/fmt/doc/syntax.md b/Source/Cxxmlx/fmt/doc/syntax.md similarity index 100% rename from Source/Cmlx/fmt/doc/syntax.md rename to Source/Cxxmlx/fmt/doc/syntax.md diff --git a/Source/Cmlx/fmt/include/fmt/args.h b/Source/Cxxmlx/fmt/include/fmt/args.h similarity index 100% rename from Source/Cmlx/fmt/include/fmt/args.h rename to Source/Cxxmlx/fmt/include/fmt/args.h diff --git a/Source/Cmlx/fmt/include/fmt/base.h b/Source/Cxxmlx/fmt/include/fmt/base.h similarity index 100% rename from Source/Cmlx/fmt/include/fmt/base.h rename to Source/Cxxmlx/fmt/include/fmt/base.h diff --git a/Source/Cmlx/fmt/include/fmt/chrono.h b/Source/Cxxmlx/fmt/include/fmt/chrono.h similarity index 100% rename from Source/Cmlx/fmt/include/fmt/chrono.h rename to Source/Cxxmlx/fmt/include/fmt/chrono.h diff --git a/Source/Cmlx/fmt/include/fmt/color.h b/Source/Cxxmlx/fmt/include/fmt/color.h similarity index 100% rename from Source/Cmlx/fmt/include/fmt/color.h rename to Source/Cxxmlx/fmt/include/fmt/color.h diff --git a/Source/Cmlx/fmt/include/fmt/compile.h b/Source/Cxxmlx/fmt/include/fmt/compile.h similarity index 100% rename from Source/Cmlx/fmt/include/fmt/compile.h rename to Source/Cxxmlx/fmt/include/fmt/compile.h diff --git a/Source/Cmlx/fmt/include/fmt/core.h b/Source/Cxxmlx/fmt/include/fmt/core.h similarity index 100% rename from Source/Cmlx/fmt/include/fmt/core.h rename to Source/Cxxmlx/fmt/include/fmt/core.h diff --git a/Source/Cmlx/fmt/include/fmt/format-inl.h b/Source/Cxxmlx/fmt/include/fmt/format-inl.h similarity index 100% rename from Source/Cmlx/fmt/include/fmt/format-inl.h rename to Source/Cxxmlx/fmt/include/fmt/format-inl.h diff --git a/Source/Cmlx/fmt/include/fmt/format.h b/Source/Cxxmlx/fmt/include/fmt/format.h similarity index 100% rename from Source/Cmlx/fmt/include/fmt/format.h rename to Source/Cxxmlx/fmt/include/fmt/format.h diff --git a/Source/Cmlx/fmt/include/fmt/os.h b/Source/Cxxmlx/fmt/include/fmt/os.h similarity index 100% rename from Source/Cmlx/fmt/include/fmt/os.h rename to Source/Cxxmlx/fmt/include/fmt/os.h diff --git a/Source/Cmlx/fmt/include/fmt/ostream.h b/Source/Cxxmlx/fmt/include/fmt/ostream.h similarity index 100% rename from Source/Cmlx/fmt/include/fmt/ostream.h rename to Source/Cxxmlx/fmt/include/fmt/ostream.h diff --git a/Source/Cmlx/fmt/include/fmt/printf.h b/Source/Cxxmlx/fmt/include/fmt/printf.h similarity index 100% rename from Source/Cmlx/fmt/include/fmt/printf.h rename to Source/Cxxmlx/fmt/include/fmt/printf.h diff --git a/Source/Cmlx/fmt/include/fmt/ranges.h b/Source/Cxxmlx/fmt/include/fmt/ranges.h similarity index 100% rename from Source/Cmlx/fmt/include/fmt/ranges.h rename to Source/Cxxmlx/fmt/include/fmt/ranges.h diff --git a/Source/Cmlx/fmt/include/fmt/std.h b/Source/Cxxmlx/fmt/include/fmt/std.h similarity index 100% rename from Source/Cmlx/fmt/include/fmt/std.h rename to Source/Cxxmlx/fmt/include/fmt/std.h diff --git a/Source/Cmlx/fmt/include/fmt/xchar.h b/Source/Cxxmlx/fmt/include/fmt/xchar.h similarity index 100% rename from Source/Cmlx/fmt/include/fmt/xchar.h rename to Source/Cxxmlx/fmt/include/fmt/xchar.h diff --git a/Source/Cmlx/fmt/src/fmt.cc b/Source/Cxxmlx/fmt/src/fmt.cc similarity index 100% rename from Source/Cmlx/fmt/src/fmt.cc rename to Source/Cxxmlx/fmt/src/fmt.cc diff --git a/Source/Cmlx/fmt/src/format.cc b/Source/Cxxmlx/fmt/src/format.cc similarity index 100% rename from Source/Cmlx/fmt/src/format.cc rename to Source/Cxxmlx/fmt/src/format.cc diff --git a/Source/Cmlx/fmt/src/os.cc b/Source/Cxxmlx/fmt/src/os.cc similarity index 100% rename from Source/Cmlx/fmt/src/os.cc rename to Source/Cxxmlx/fmt/src/os.cc diff --git a/Source/Cmlx/fmt/support/Android.mk b/Source/Cxxmlx/fmt/support/Android.mk similarity index 100% rename from Source/Cmlx/fmt/support/Android.mk rename to Source/Cxxmlx/fmt/support/Android.mk diff --git a/Source/Cmlx/fmt/support/AndroidManifest.xml b/Source/Cxxmlx/fmt/support/AndroidManifest.xml similarity index 100% rename from Source/Cmlx/fmt/support/AndroidManifest.xml rename to Source/Cxxmlx/fmt/support/AndroidManifest.xml diff --git a/Source/Cmlx/fmt/support/C++.sublime-syntax b/Source/Cxxmlx/fmt/support/C++.sublime-syntax similarity index 100% rename from Source/Cmlx/fmt/support/C++.sublime-syntax rename to Source/Cxxmlx/fmt/support/C++.sublime-syntax diff --git a/Source/Cmlx/fmt/support/README b/Source/Cxxmlx/fmt/support/README similarity index 100% rename from Source/Cmlx/fmt/support/README rename to Source/Cxxmlx/fmt/support/README diff --git a/Source/Cmlx/fmt/support/Vagrantfile b/Source/Cxxmlx/fmt/support/Vagrantfile similarity index 100% rename from Source/Cmlx/fmt/support/Vagrantfile rename to Source/Cxxmlx/fmt/support/Vagrantfile diff --git a/Source/Cmlx/fmt/support/build.gradle b/Source/Cxxmlx/fmt/support/build.gradle similarity index 100% rename from Source/Cmlx/fmt/support/build.gradle rename to Source/Cxxmlx/fmt/support/build.gradle diff --git a/Source/Cmlx/fmt/support/check-commits b/Source/Cxxmlx/fmt/support/check-commits similarity index 100% rename from Source/Cmlx/fmt/support/check-commits rename to Source/Cxxmlx/fmt/support/check-commits diff --git a/Source/Cmlx/fmt/support/cmake/FindSetEnv.cmake b/Source/Cxxmlx/fmt/support/cmake/FindSetEnv.cmake similarity index 100% rename from Source/Cmlx/fmt/support/cmake/FindSetEnv.cmake rename to Source/Cxxmlx/fmt/support/cmake/FindSetEnv.cmake diff --git a/Source/Cmlx/fmt/support/cmake/JoinPaths.cmake b/Source/Cxxmlx/fmt/support/cmake/JoinPaths.cmake similarity index 100% rename from Source/Cmlx/fmt/support/cmake/JoinPaths.cmake rename to Source/Cxxmlx/fmt/support/cmake/JoinPaths.cmake diff --git a/Source/Cmlx/fmt/support/cmake/fmt-config.cmake.in b/Source/Cxxmlx/fmt/support/cmake/fmt-config.cmake.in similarity index 100% rename from Source/Cmlx/fmt/support/cmake/fmt-config.cmake.in rename to Source/Cxxmlx/fmt/support/cmake/fmt-config.cmake.in diff --git a/Source/Cmlx/fmt/support/cmake/fmt.pc.in b/Source/Cxxmlx/fmt/support/cmake/fmt.pc.in similarity index 100% rename from Source/Cmlx/fmt/support/cmake/fmt.pc.in rename to Source/Cxxmlx/fmt/support/cmake/fmt.pc.in diff --git a/Source/Cmlx/fmt/support/docopt.py b/Source/Cxxmlx/fmt/support/docopt.py similarity index 100% rename from Source/Cmlx/fmt/support/docopt.py rename to Source/Cxxmlx/fmt/support/docopt.py diff --git a/Source/Cmlx/fmt/support/mkdocs b/Source/Cxxmlx/fmt/support/mkdocs similarity index 100% rename from Source/Cmlx/fmt/support/mkdocs rename to Source/Cxxmlx/fmt/support/mkdocs diff --git a/Source/Cmlx/fmt/support/mkdocs.yml b/Source/Cxxmlx/fmt/support/mkdocs.yml similarity index 100% rename from Source/Cmlx/fmt/support/mkdocs.yml rename to Source/Cxxmlx/fmt/support/mkdocs.yml diff --git a/Source/Cmlx/fmt/support/printable.py b/Source/Cxxmlx/fmt/support/printable.py similarity index 100% rename from Source/Cmlx/fmt/support/printable.py rename to Source/Cxxmlx/fmt/support/printable.py diff --git a/Source/Cmlx/fmt/support/python/mkdocstrings_handlers/cxx/__init__.py b/Source/Cxxmlx/fmt/support/python/mkdocstrings_handlers/cxx/__init__.py similarity index 100% rename from Source/Cmlx/fmt/support/python/mkdocstrings_handlers/cxx/__init__.py rename to Source/Cxxmlx/fmt/support/python/mkdocstrings_handlers/cxx/__init__.py diff --git a/Source/Cmlx/fmt/support/python/mkdocstrings_handlers/cxx/templates/README b/Source/Cxxmlx/fmt/support/python/mkdocstrings_handlers/cxx/templates/README similarity index 100% rename from Source/Cmlx/fmt/support/python/mkdocstrings_handlers/cxx/templates/README rename to Source/Cxxmlx/fmt/support/python/mkdocstrings_handlers/cxx/templates/README diff --git a/Source/Cmlx/fmt/support/release.py b/Source/Cxxmlx/fmt/support/release.py similarity index 100% rename from Source/Cmlx/fmt/support/release.py rename to Source/Cxxmlx/fmt/support/release.py diff --git a/Source/Cmlx/fmt/test/CMakeLists.txt b/Source/Cxxmlx/fmt/test/CMakeLists.txt similarity index 100% rename from Source/Cmlx/fmt/test/CMakeLists.txt rename to Source/Cxxmlx/fmt/test/CMakeLists.txt diff --git a/Source/Cmlx/fmt/test/add-subdirectory-test/CMakeLists.txt b/Source/Cxxmlx/fmt/test/add-subdirectory-test/CMakeLists.txt similarity index 100% rename from Source/Cmlx/fmt/test/add-subdirectory-test/CMakeLists.txt rename to Source/Cxxmlx/fmt/test/add-subdirectory-test/CMakeLists.txt diff --git a/Source/Cmlx/fmt/test/add-subdirectory-test/main.cc b/Source/Cxxmlx/fmt/test/add-subdirectory-test/main.cc similarity index 100% rename from Source/Cmlx/fmt/test/add-subdirectory-test/main.cc rename to Source/Cxxmlx/fmt/test/add-subdirectory-test/main.cc diff --git a/Source/Cmlx/fmt/test/args-test.cc b/Source/Cxxmlx/fmt/test/args-test.cc similarity index 100% rename from Source/Cmlx/fmt/test/args-test.cc rename to Source/Cxxmlx/fmt/test/args-test.cc diff --git a/Source/Cmlx/fmt/test/assert-test.cc b/Source/Cxxmlx/fmt/test/assert-test.cc similarity index 100% rename from Source/Cmlx/fmt/test/assert-test.cc rename to Source/Cxxmlx/fmt/test/assert-test.cc diff --git a/Source/Cmlx/fmt/test/base-test.cc b/Source/Cxxmlx/fmt/test/base-test.cc similarity index 100% rename from Source/Cmlx/fmt/test/base-test.cc rename to Source/Cxxmlx/fmt/test/base-test.cc diff --git a/Source/Cmlx/fmt/test/chrono-test.cc b/Source/Cxxmlx/fmt/test/chrono-test.cc similarity index 100% rename from Source/Cmlx/fmt/test/chrono-test.cc rename to Source/Cxxmlx/fmt/test/chrono-test.cc diff --git a/Source/Cmlx/fmt/test/color-test.cc b/Source/Cxxmlx/fmt/test/color-test.cc similarity index 100% rename from Source/Cmlx/fmt/test/color-test.cc rename to Source/Cxxmlx/fmt/test/color-test.cc diff --git a/Source/Cmlx/fmt/test/compile-error-test/CMakeLists.txt b/Source/Cxxmlx/fmt/test/compile-error-test/CMakeLists.txt similarity index 100% rename from Source/Cmlx/fmt/test/compile-error-test/CMakeLists.txt rename to Source/Cxxmlx/fmt/test/compile-error-test/CMakeLists.txt diff --git a/Source/Cmlx/fmt/test/compile-fp-test.cc b/Source/Cxxmlx/fmt/test/compile-fp-test.cc similarity index 100% rename from Source/Cmlx/fmt/test/compile-fp-test.cc rename to Source/Cxxmlx/fmt/test/compile-fp-test.cc diff --git a/Source/Cmlx/fmt/test/compile-test.cc b/Source/Cxxmlx/fmt/test/compile-test.cc similarity index 100% rename from Source/Cmlx/fmt/test/compile-test.cc rename to Source/Cxxmlx/fmt/test/compile-test.cc diff --git a/Source/Cmlx/fmt/test/cuda-test/CMakeLists.txt b/Source/Cxxmlx/fmt/test/cuda-test/CMakeLists.txt similarity index 100% rename from Source/Cmlx/fmt/test/cuda-test/CMakeLists.txt rename to Source/Cxxmlx/fmt/test/cuda-test/CMakeLists.txt diff --git a/Source/Cmlx/fmt/test/cuda-test/cpp14.cc b/Source/Cxxmlx/fmt/test/cuda-test/cpp14.cc similarity index 100% rename from Source/Cmlx/fmt/test/cuda-test/cpp14.cc rename to Source/Cxxmlx/fmt/test/cuda-test/cpp14.cc diff --git a/Source/Cmlx/fmt/test/cuda-test/cuda-cpp14.cu b/Source/Cxxmlx/fmt/test/cuda-test/cuda-cpp14.cu similarity index 100% rename from Source/Cmlx/fmt/test/cuda-test/cuda-cpp14.cu rename to Source/Cxxmlx/fmt/test/cuda-test/cuda-cpp14.cu diff --git a/Source/Cmlx/fmt/test/detect-stdfs.cc b/Source/Cxxmlx/fmt/test/detect-stdfs.cc similarity index 100% rename from Source/Cmlx/fmt/test/detect-stdfs.cc rename to Source/Cxxmlx/fmt/test/detect-stdfs.cc diff --git a/Source/Cmlx/fmt/test/enforce-checks-test.cc b/Source/Cxxmlx/fmt/test/enforce-checks-test.cc similarity index 100% rename from Source/Cmlx/fmt/test/enforce-checks-test.cc rename to Source/Cxxmlx/fmt/test/enforce-checks-test.cc diff --git a/Source/Cmlx/fmt/test/find-package-test/CMakeLists.txt b/Source/Cxxmlx/fmt/test/find-package-test/CMakeLists.txt similarity index 100% rename from Source/Cmlx/fmt/test/find-package-test/CMakeLists.txt rename to Source/Cxxmlx/fmt/test/find-package-test/CMakeLists.txt diff --git a/Source/Cmlx/fmt/test/find-package-test/main.cc b/Source/Cxxmlx/fmt/test/find-package-test/main.cc similarity index 100% rename from Source/Cmlx/fmt/test/find-package-test/main.cc rename to Source/Cxxmlx/fmt/test/find-package-test/main.cc diff --git a/Source/Cmlx/fmt/test/format-impl-test.cc b/Source/Cxxmlx/fmt/test/format-impl-test.cc similarity index 100% rename from Source/Cmlx/fmt/test/format-impl-test.cc rename to Source/Cxxmlx/fmt/test/format-impl-test.cc diff --git a/Source/Cmlx/fmt/test/format-test.cc b/Source/Cxxmlx/fmt/test/format-test.cc similarity index 100% rename from Source/Cmlx/fmt/test/format-test.cc rename to Source/Cxxmlx/fmt/test/format-test.cc diff --git a/Source/Cmlx/fmt/test/fuzzing/.gitignore b/Source/Cxxmlx/fmt/test/fuzzing/.gitignore similarity index 100% rename from Source/Cmlx/fmt/test/fuzzing/.gitignore rename to Source/Cxxmlx/fmt/test/fuzzing/.gitignore diff --git a/Source/Cmlx/fmt/test/fuzzing/CMakeLists.txt b/Source/Cxxmlx/fmt/test/fuzzing/CMakeLists.txt similarity index 100% rename from Source/Cmlx/fmt/test/fuzzing/CMakeLists.txt rename to Source/Cxxmlx/fmt/test/fuzzing/CMakeLists.txt diff --git a/Source/Cmlx/fmt/test/fuzzing/README.md b/Source/Cxxmlx/fmt/test/fuzzing/README.md similarity index 100% rename from Source/Cmlx/fmt/test/fuzzing/README.md rename to Source/Cxxmlx/fmt/test/fuzzing/README.md diff --git a/Source/Cmlx/fmt/test/fuzzing/build.sh b/Source/Cxxmlx/fmt/test/fuzzing/build.sh similarity index 100% rename from Source/Cmlx/fmt/test/fuzzing/build.sh rename to Source/Cxxmlx/fmt/test/fuzzing/build.sh diff --git a/Source/Cmlx/fmt/test/fuzzing/chrono-duration.cc b/Source/Cxxmlx/fmt/test/fuzzing/chrono-duration.cc similarity index 100% rename from Source/Cmlx/fmt/test/fuzzing/chrono-duration.cc rename to Source/Cxxmlx/fmt/test/fuzzing/chrono-duration.cc diff --git a/Source/Cmlx/fmt/test/fuzzing/chrono-timepoint.cc b/Source/Cxxmlx/fmt/test/fuzzing/chrono-timepoint.cc similarity index 100% rename from Source/Cmlx/fmt/test/fuzzing/chrono-timepoint.cc rename to Source/Cxxmlx/fmt/test/fuzzing/chrono-timepoint.cc diff --git a/Source/Cmlx/fmt/test/fuzzing/float.cc b/Source/Cxxmlx/fmt/test/fuzzing/float.cc similarity index 100% rename from Source/Cmlx/fmt/test/fuzzing/float.cc rename to Source/Cxxmlx/fmt/test/fuzzing/float.cc diff --git a/Source/Cmlx/fmt/test/fuzzing/fuzzer-common.h b/Source/Cxxmlx/fmt/test/fuzzing/fuzzer-common.h similarity index 100% rename from Source/Cmlx/fmt/test/fuzzing/fuzzer-common.h rename to Source/Cxxmlx/fmt/test/fuzzing/fuzzer-common.h diff --git a/Source/Cmlx/fmt/test/fuzzing/main.cc b/Source/Cxxmlx/fmt/test/fuzzing/main.cc similarity index 100% rename from Source/Cmlx/fmt/test/fuzzing/main.cc rename to Source/Cxxmlx/fmt/test/fuzzing/main.cc diff --git a/Source/Cmlx/fmt/test/fuzzing/named-arg.cc b/Source/Cxxmlx/fmt/test/fuzzing/named-arg.cc similarity index 100% rename from Source/Cmlx/fmt/test/fuzzing/named-arg.cc rename to Source/Cxxmlx/fmt/test/fuzzing/named-arg.cc diff --git a/Source/Cmlx/fmt/test/fuzzing/one-arg.cc b/Source/Cxxmlx/fmt/test/fuzzing/one-arg.cc similarity index 100% rename from Source/Cmlx/fmt/test/fuzzing/one-arg.cc rename to Source/Cxxmlx/fmt/test/fuzzing/one-arg.cc diff --git a/Source/Cmlx/fmt/test/fuzzing/two-args.cc b/Source/Cxxmlx/fmt/test/fuzzing/two-args.cc similarity index 100% rename from Source/Cmlx/fmt/test/fuzzing/two-args.cc rename to Source/Cxxmlx/fmt/test/fuzzing/two-args.cc diff --git a/Source/Cmlx/fmt/test/gtest-extra-test.cc b/Source/Cxxmlx/fmt/test/gtest-extra-test.cc similarity index 100% rename from Source/Cmlx/fmt/test/gtest-extra-test.cc rename to Source/Cxxmlx/fmt/test/gtest-extra-test.cc diff --git a/Source/Cmlx/fmt/test/gtest-extra.cc b/Source/Cxxmlx/fmt/test/gtest-extra.cc similarity index 100% rename from Source/Cmlx/fmt/test/gtest-extra.cc rename to Source/Cxxmlx/fmt/test/gtest-extra.cc diff --git a/Source/Cmlx/fmt/test/gtest-extra.h b/Source/Cxxmlx/fmt/test/gtest-extra.h similarity index 100% rename from Source/Cmlx/fmt/test/gtest-extra.h rename to Source/Cxxmlx/fmt/test/gtest-extra.h diff --git a/Source/Cmlx/fmt/test/gtest/.clang-format b/Source/Cxxmlx/fmt/test/gtest/.clang-format similarity index 100% rename from Source/Cmlx/fmt/test/gtest/.clang-format rename to Source/Cxxmlx/fmt/test/gtest/.clang-format diff --git a/Source/Cmlx/fmt/test/gtest/CMakeLists.txt b/Source/Cxxmlx/fmt/test/gtest/CMakeLists.txt similarity index 100% rename from Source/Cmlx/fmt/test/gtest/CMakeLists.txt rename to Source/Cxxmlx/fmt/test/gtest/CMakeLists.txt diff --git a/Source/Cmlx/fmt/test/gtest/gmock-gtest-all.cc b/Source/Cxxmlx/fmt/test/gtest/gmock-gtest-all.cc similarity index 100% rename from Source/Cmlx/fmt/test/gtest/gmock-gtest-all.cc rename to Source/Cxxmlx/fmt/test/gtest/gmock-gtest-all.cc diff --git a/Source/Cmlx/fmt/test/gtest/gmock/gmock.h b/Source/Cxxmlx/fmt/test/gtest/gmock/gmock.h similarity index 100% rename from Source/Cmlx/fmt/test/gtest/gmock/gmock.h rename to Source/Cxxmlx/fmt/test/gtest/gmock/gmock.h diff --git a/Source/Cmlx/fmt/test/gtest/gtest/gtest-spi.h b/Source/Cxxmlx/fmt/test/gtest/gtest/gtest-spi.h similarity index 100% rename from Source/Cmlx/fmt/test/gtest/gtest/gtest-spi.h rename to Source/Cxxmlx/fmt/test/gtest/gtest/gtest-spi.h diff --git a/Source/Cmlx/fmt/test/gtest/gtest/gtest.h b/Source/Cxxmlx/fmt/test/gtest/gtest/gtest.h similarity index 100% rename from Source/Cmlx/fmt/test/gtest/gtest/gtest.h rename to Source/Cxxmlx/fmt/test/gtest/gtest/gtest.h diff --git a/Source/Cmlx/fmt/test/header-only-test.cc b/Source/Cxxmlx/fmt/test/header-only-test.cc similarity index 100% rename from Source/Cmlx/fmt/test/header-only-test.cc rename to Source/Cxxmlx/fmt/test/header-only-test.cc diff --git a/Source/Cmlx/fmt/test/mock-allocator.h b/Source/Cxxmlx/fmt/test/mock-allocator.h similarity index 100% rename from Source/Cmlx/fmt/test/mock-allocator.h rename to Source/Cxxmlx/fmt/test/mock-allocator.h diff --git a/Source/Cmlx/fmt/test/module-test.cc b/Source/Cxxmlx/fmt/test/module-test.cc similarity index 100% rename from Source/Cmlx/fmt/test/module-test.cc rename to Source/Cxxmlx/fmt/test/module-test.cc diff --git a/Source/Cmlx/fmt/test/no-builtin-types-test.cc b/Source/Cxxmlx/fmt/test/no-builtin-types-test.cc similarity index 100% rename from Source/Cmlx/fmt/test/no-builtin-types-test.cc rename to Source/Cxxmlx/fmt/test/no-builtin-types-test.cc diff --git a/Source/Cmlx/fmt/test/noexception-test.cc b/Source/Cxxmlx/fmt/test/noexception-test.cc similarity index 100% rename from Source/Cmlx/fmt/test/noexception-test.cc rename to Source/Cxxmlx/fmt/test/noexception-test.cc diff --git a/Source/Cmlx/fmt/test/os-test.cc b/Source/Cxxmlx/fmt/test/os-test.cc similarity index 100% rename from Source/Cmlx/fmt/test/os-test.cc rename to Source/Cxxmlx/fmt/test/os-test.cc diff --git a/Source/Cmlx/fmt/test/ostream-test.cc b/Source/Cxxmlx/fmt/test/ostream-test.cc similarity index 100% rename from Source/Cmlx/fmt/test/ostream-test.cc rename to Source/Cxxmlx/fmt/test/ostream-test.cc diff --git a/Source/Cmlx/fmt/test/perf-sanity.cc b/Source/Cxxmlx/fmt/test/perf-sanity.cc similarity index 100% rename from Source/Cmlx/fmt/test/perf-sanity.cc rename to Source/Cxxmlx/fmt/test/perf-sanity.cc diff --git a/Source/Cmlx/fmt/test/posix-mock-test.cc b/Source/Cxxmlx/fmt/test/posix-mock-test.cc similarity index 100% rename from Source/Cmlx/fmt/test/posix-mock-test.cc rename to Source/Cxxmlx/fmt/test/posix-mock-test.cc diff --git a/Source/Cmlx/fmt/test/posix-mock.h b/Source/Cxxmlx/fmt/test/posix-mock.h similarity index 100% rename from Source/Cmlx/fmt/test/posix-mock.h rename to Source/Cxxmlx/fmt/test/posix-mock.h diff --git a/Source/Cmlx/fmt/test/printf-test.cc b/Source/Cxxmlx/fmt/test/printf-test.cc similarity index 100% rename from Source/Cmlx/fmt/test/printf-test.cc rename to Source/Cxxmlx/fmt/test/printf-test.cc diff --git a/Source/Cmlx/fmt/test/ranges-odr-test.cc b/Source/Cxxmlx/fmt/test/ranges-odr-test.cc similarity index 100% rename from Source/Cmlx/fmt/test/ranges-odr-test.cc rename to Source/Cxxmlx/fmt/test/ranges-odr-test.cc diff --git a/Source/Cmlx/fmt/test/ranges-test.cc b/Source/Cxxmlx/fmt/test/ranges-test.cc similarity index 100% rename from Source/Cmlx/fmt/test/ranges-test.cc rename to Source/Cxxmlx/fmt/test/ranges-test.cc diff --git a/Source/Cmlx/fmt/test/scan-test.cc b/Source/Cxxmlx/fmt/test/scan-test.cc similarity index 100% rename from Source/Cmlx/fmt/test/scan-test.cc rename to Source/Cxxmlx/fmt/test/scan-test.cc diff --git a/Source/Cmlx/fmt/test/scan.h b/Source/Cxxmlx/fmt/test/scan.h similarity index 100% rename from Source/Cmlx/fmt/test/scan.h rename to Source/Cxxmlx/fmt/test/scan.h diff --git a/Source/Cmlx/fmt/test/static-export-test/CMakeLists.txt b/Source/Cxxmlx/fmt/test/static-export-test/CMakeLists.txt similarity index 100% rename from Source/Cmlx/fmt/test/static-export-test/CMakeLists.txt rename to Source/Cxxmlx/fmt/test/static-export-test/CMakeLists.txt diff --git a/Source/Cmlx/fmt/test/static-export-test/library.cc b/Source/Cxxmlx/fmt/test/static-export-test/library.cc similarity index 100% rename from Source/Cmlx/fmt/test/static-export-test/library.cc rename to Source/Cxxmlx/fmt/test/static-export-test/library.cc diff --git a/Source/Cmlx/fmt/test/static-export-test/main.cc b/Source/Cxxmlx/fmt/test/static-export-test/main.cc similarity index 100% rename from Source/Cmlx/fmt/test/static-export-test/main.cc rename to Source/Cxxmlx/fmt/test/static-export-test/main.cc diff --git a/Source/Cmlx/fmt/test/std-test.cc b/Source/Cxxmlx/fmt/test/std-test.cc similarity index 100% rename from Source/Cmlx/fmt/test/std-test.cc rename to Source/Cxxmlx/fmt/test/std-test.cc diff --git a/Source/Cmlx/fmt/test/test-assert.h b/Source/Cxxmlx/fmt/test/test-assert.h similarity index 100% rename from Source/Cmlx/fmt/test/test-assert.h rename to Source/Cxxmlx/fmt/test/test-assert.h diff --git a/Source/Cmlx/fmt/test/test-main.cc b/Source/Cxxmlx/fmt/test/test-main.cc similarity index 100% rename from Source/Cmlx/fmt/test/test-main.cc rename to Source/Cxxmlx/fmt/test/test-main.cc diff --git a/Source/Cmlx/fmt/test/unicode-test.cc b/Source/Cxxmlx/fmt/test/unicode-test.cc similarity index 100% rename from Source/Cmlx/fmt/test/unicode-test.cc rename to Source/Cxxmlx/fmt/test/unicode-test.cc diff --git a/Source/Cmlx/fmt/test/util.cc b/Source/Cxxmlx/fmt/test/util.cc similarity index 100% rename from Source/Cmlx/fmt/test/util.cc rename to Source/Cxxmlx/fmt/test/util.cc diff --git a/Source/Cmlx/fmt/test/util.h b/Source/Cxxmlx/fmt/test/util.h similarity index 100% rename from Source/Cmlx/fmt/test/util.h rename to Source/Cxxmlx/fmt/test/util.h diff --git a/Source/Cmlx/fmt/test/xchar-test.cc b/Source/Cxxmlx/fmt/test/xchar-test.cc similarity index 100% rename from Source/Cmlx/fmt/test/xchar-test.cc rename to Source/Cxxmlx/fmt/test/xchar-test.cc diff --git a/Source/Cxxmlx/include/Cxxmlx.h b/Source/Cxxmlx/include/Cxxmlx.h new file mode 100644 index 00000000..601056ff --- /dev/null +++ b/Source/Cxxmlx/include/Cxxmlx.h @@ -0,0 +1,22 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(__APPLE__) +#include +#endif diff --git a/Source/Cmlx/metal-cpp/Foundation/Foundation.hpp b/Source/Cxxmlx/include/Foundation/Foundation.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Foundation/Foundation.hpp rename to Source/Cxxmlx/include/Foundation/Foundation.hpp diff --git a/Source/Cmlx/metal-cpp/Foundation/NSArray.hpp b/Source/Cxxmlx/include/Foundation/NSArray.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Foundation/NSArray.hpp rename to Source/Cxxmlx/include/Foundation/NSArray.hpp diff --git a/Source/Cmlx/metal-cpp/Foundation/NSAutoreleasePool.hpp b/Source/Cxxmlx/include/Foundation/NSAutoreleasePool.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Foundation/NSAutoreleasePool.hpp rename to Source/Cxxmlx/include/Foundation/NSAutoreleasePool.hpp diff --git a/Source/Cmlx/metal-cpp/Foundation/NSBundle.hpp b/Source/Cxxmlx/include/Foundation/NSBundle.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Foundation/NSBundle.hpp rename to Source/Cxxmlx/include/Foundation/NSBundle.hpp diff --git a/Source/Cmlx/metal-cpp/Foundation/NSData.hpp b/Source/Cxxmlx/include/Foundation/NSData.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Foundation/NSData.hpp rename to Source/Cxxmlx/include/Foundation/NSData.hpp diff --git a/Source/Cmlx/metal-cpp/Foundation/NSDate.hpp b/Source/Cxxmlx/include/Foundation/NSDate.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Foundation/NSDate.hpp rename to Source/Cxxmlx/include/Foundation/NSDate.hpp diff --git a/Source/Cmlx/metal-cpp/Foundation/NSDefines.hpp b/Source/Cxxmlx/include/Foundation/NSDefines.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Foundation/NSDefines.hpp rename to Source/Cxxmlx/include/Foundation/NSDefines.hpp diff --git a/Source/Cmlx/metal-cpp/Foundation/NSDictionary.hpp b/Source/Cxxmlx/include/Foundation/NSDictionary.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Foundation/NSDictionary.hpp rename to Source/Cxxmlx/include/Foundation/NSDictionary.hpp diff --git a/Source/Cmlx/metal-cpp/Foundation/NSEnumerator.hpp b/Source/Cxxmlx/include/Foundation/NSEnumerator.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Foundation/NSEnumerator.hpp rename to Source/Cxxmlx/include/Foundation/NSEnumerator.hpp diff --git a/Source/Cmlx/metal-cpp/Foundation/NSError.hpp b/Source/Cxxmlx/include/Foundation/NSError.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Foundation/NSError.hpp rename to Source/Cxxmlx/include/Foundation/NSError.hpp diff --git a/Source/Cmlx/metal-cpp/Foundation/NSLock.hpp b/Source/Cxxmlx/include/Foundation/NSLock.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Foundation/NSLock.hpp rename to Source/Cxxmlx/include/Foundation/NSLock.hpp diff --git a/Source/Cmlx/metal-cpp/Foundation/NSNotification.hpp b/Source/Cxxmlx/include/Foundation/NSNotification.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Foundation/NSNotification.hpp rename to Source/Cxxmlx/include/Foundation/NSNotification.hpp diff --git a/Source/Cmlx/metal-cpp/Foundation/NSNumber.hpp b/Source/Cxxmlx/include/Foundation/NSNumber.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Foundation/NSNumber.hpp rename to Source/Cxxmlx/include/Foundation/NSNumber.hpp diff --git a/Source/Cmlx/metal-cpp/Foundation/NSObjCRuntime.hpp b/Source/Cxxmlx/include/Foundation/NSObjCRuntime.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Foundation/NSObjCRuntime.hpp rename to Source/Cxxmlx/include/Foundation/NSObjCRuntime.hpp diff --git a/Source/Cmlx/metal-cpp/Foundation/NSObject.hpp b/Source/Cxxmlx/include/Foundation/NSObject.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Foundation/NSObject.hpp rename to Source/Cxxmlx/include/Foundation/NSObject.hpp diff --git a/Source/Cmlx/metal-cpp/Foundation/NSPrivate.hpp b/Source/Cxxmlx/include/Foundation/NSPrivate.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Foundation/NSPrivate.hpp rename to Source/Cxxmlx/include/Foundation/NSPrivate.hpp diff --git a/Source/Cmlx/metal-cpp/Foundation/NSProcessInfo.hpp b/Source/Cxxmlx/include/Foundation/NSProcessInfo.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Foundation/NSProcessInfo.hpp rename to Source/Cxxmlx/include/Foundation/NSProcessInfo.hpp diff --git a/Source/Cmlx/metal-cpp/Foundation/NSRange.hpp b/Source/Cxxmlx/include/Foundation/NSRange.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Foundation/NSRange.hpp rename to Source/Cxxmlx/include/Foundation/NSRange.hpp diff --git a/Source/Cmlx/metal-cpp/Foundation/NSSet.hpp b/Source/Cxxmlx/include/Foundation/NSSet.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Foundation/NSSet.hpp rename to Source/Cxxmlx/include/Foundation/NSSet.hpp diff --git a/Source/Cmlx/metal-cpp/Foundation/NSSharedPtr.hpp b/Source/Cxxmlx/include/Foundation/NSSharedPtr.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Foundation/NSSharedPtr.hpp rename to Source/Cxxmlx/include/Foundation/NSSharedPtr.hpp diff --git a/Source/Cmlx/metal-cpp/Foundation/NSString.hpp b/Source/Cxxmlx/include/Foundation/NSString.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Foundation/NSString.hpp rename to Source/Cxxmlx/include/Foundation/NSString.hpp diff --git a/Source/Cmlx/metal-cpp/Foundation/NSTypes.hpp b/Source/Cxxmlx/include/Foundation/NSTypes.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Foundation/NSTypes.hpp rename to Source/Cxxmlx/include/Foundation/NSTypes.hpp diff --git a/Source/Cmlx/metal-cpp/Foundation/NSURL.hpp b/Source/Cxxmlx/include/Foundation/NSURL.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Foundation/NSURL.hpp rename to Source/Cxxmlx/include/Foundation/NSURL.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTL4AccelerationStructure.hpp b/Source/Cxxmlx/include/Metal/MTL4AccelerationStructure.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTL4AccelerationStructure.hpp rename to Source/Cxxmlx/include/Metal/MTL4AccelerationStructure.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTL4Archive.hpp b/Source/Cxxmlx/include/Metal/MTL4Archive.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTL4Archive.hpp rename to Source/Cxxmlx/include/Metal/MTL4Archive.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTL4ArgumentTable.hpp b/Source/Cxxmlx/include/Metal/MTL4ArgumentTable.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTL4ArgumentTable.hpp rename to Source/Cxxmlx/include/Metal/MTL4ArgumentTable.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTL4BinaryFunction.hpp b/Source/Cxxmlx/include/Metal/MTL4BinaryFunction.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTL4BinaryFunction.hpp rename to Source/Cxxmlx/include/Metal/MTL4BinaryFunction.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTL4BinaryFunctionDescriptor.hpp b/Source/Cxxmlx/include/Metal/MTL4BinaryFunctionDescriptor.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTL4BinaryFunctionDescriptor.hpp rename to Source/Cxxmlx/include/Metal/MTL4BinaryFunctionDescriptor.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTL4CommandAllocator.hpp b/Source/Cxxmlx/include/Metal/MTL4CommandAllocator.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTL4CommandAllocator.hpp rename to Source/Cxxmlx/include/Metal/MTL4CommandAllocator.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTL4CommandBuffer.hpp b/Source/Cxxmlx/include/Metal/MTL4CommandBuffer.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTL4CommandBuffer.hpp rename to Source/Cxxmlx/include/Metal/MTL4CommandBuffer.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTL4CommandEncoder.hpp b/Source/Cxxmlx/include/Metal/MTL4CommandEncoder.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTL4CommandEncoder.hpp rename to Source/Cxxmlx/include/Metal/MTL4CommandEncoder.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTL4CommandQueue.hpp b/Source/Cxxmlx/include/Metal/MTL4CommandQueue.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTL4CommandQueue.hpp rename to Source/Cxxmlx/include/Metal/MTL4CommandQueue.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTL4CommitFeedback.hpp b/Source/Cxxmlx/include/Metal/MTL4CommitFeedback.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTL4CommitFeedback.hpp rename to Source/Cxxmlx/include/Metal/MTL4CommitFeedback.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTL4Compiler.hpp b/Source/Cxxmlx/include/Metal/MTL4Compiler.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTL4Compiler.hpp rename to Source/Cxxmlx/include/Metal/MTL4Compiler.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTL4CompilerTask.hpp b/Source/Cxxmlx/include/Metal/MTL4CompilerTask.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTL4CompilerTask.hpp rename to Source/Cxxmlx/include/Metal/MTL4CompilerTask.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTL4ComputeCommandEncoder.hpp b/Source/Cxxmlx/include/Metal/MTL4ComputeCommandEncoder.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTL4ComputeCommandEncoder.hpp rename to Source/Cxxmlx/include/Metal/MTL4ComputeCommandEncoder.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTL4ComputePipeline.hpp b/Source/Cxxmlx/include/Metal/MTL4ComputePipeline.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTL4ComputePipeline.hpp rename to Source/Cxxmlx/include/Metal/MTL4ComputePipeline.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTL4Counters.hpp b/Source/Cxxmlx/include/Metal/MTL4Counters.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTL4Counters.hpp rename to Source/Cxxmlx/include/Metal/MTL4Counters.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTL4FunctionDescriptor.hpp b/Source/Cxxmlx/include/Metal/MTL4FunctionDescriptor.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTL4FunctionDescriptor.hpp rename to Source/Cxxmlx/include/Metal/MTL4FunctionDescriptor.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTL4LibraryDescriptor.hpp b/Source/Cxxmlx/include/Metal/MTL4LibraryDescriptor.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTL4LibraryDescriptor.hpp rename to Source/Cxxmlx/include/Metal/MTL4LibraryDescriptor.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTL4LibraryFunctionDescriptor.hpp b/Source/Cxxmlx/include/Metal/MTL4LibraryFunctionDescriptor.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTL4LibraryFunctionDescriptor.hpp rename to Source/Cxxmlx/include/Metal/MTL4LibraryFunctionDescriptor.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTL4LinkingDescriptor.hpp b/Source/Cxxmlx/include/Metal/MTL4LinkingDescriptor.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTL4LinkingDescriptor.hpp rename to Source/Cxxmlx/include/Metal/MTL4LinkingDescriptor.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTL4MachineLearningCommandEncoder.hpp b/Source/Cxxmlx/include/Metal/MTL4MachineLearningCommandEncoder.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTL4MachineLearningCommandEncoder.hpp rename to Source/Cxxmlx/include/Metal/MTL4MachineLearningCommandEncoder.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTL4MachineLearningPipeline.hpp b/Source/Cxxmlx/include/Metal/MTL4MachineLearningPipeline.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTL4MachineLearningPipeline.hpp rename to Source/Cxxmlx/include/Metal/MTL4MachineLearningPipeline.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTL4MeshRenderPipeline.hpp b/Source/Cxxmlx/include/Metal/MTL4MeshRenderPipeline.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTL4MeshRenderPipeline.hpp rename to Source/Cxxmlx/include/Metal/MTL4MeshRenderPipeline.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTL4PipelineDataSetSerializer.hpp b/Source/Cxxmlx/include/Metal/MTL4PipelineDataSetSerializer.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTL4PipelineDataSetSerializer.hpp rename to Source/Cxxmlx/include/Metal/MTL4PipelineDataSetSerializer.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTL4PipelineState.hpp b/Source/Cxxmlx/include/Metal/MTL4PipelineState.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTL4PipelineState.hpp rename to Source/Cxxmlx/include/Metal/MTL4PipelineState.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTL4RenderCommandEncoder.hpp b/Source/Cxxmlx/include/Metal/MTL4RenderCommandEncoder.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTL4RenderCommandEncoder.hpp rename to Source/Cxxmlx/include/Metal/MTL4RenderCommandEncoder.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTL4RenderPass.hpp b/Source/Cxxmlx/include/Metal/MTL4RenderPass.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTL4RenderPass.hpp rename to Source/Cxxmlx/include/Metal/MTL4RenderPass.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTL4RenderPipeline.hpp b/Source/Cxxmlx/include/Metal/MTL4RenderPipeline.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTL4RenderPipeline.hpp rename to Source/Cxxmlx/include/Metal/MTL4RenderPipeline.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTL4SpecializedFunctionDescriptor.hpp b/Source/Cxxmlx/include/Metal/MTL4SpecializedFunctionDescriptor.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTL4SpecializedFunctionDescriptor.hpp rename to Source/Cxxmlx/include/Metal/MTL4SpecializedFunctionDescriptor.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTL4StitchedFunctionDescriptor.hpp b/Source/Cxxmlx/include/Metal/MTL4StitchedFunctionDescriptor.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTL4StitchedFunctionDescriptor.hpp rename to Source/Cxxmlx/include/Metal/MTL4StitchedFunctionDescriptor.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTL4TileRenderPipeline.hpp b/Source/Cxxmlx/include/Metal/MTL4TileRenderPipeline.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTL4TileRenderPipeline.hpp rename to Source/Cxxmlx/include/Metal/MTL4TileRenderPipeline.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLAccelerationStructure.hpp b/Source/Cxxmlx/include/Metal/MTLAccelerationStructure.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLAccelerationStructure.hpp rename to Source/Cxxmlx/include/Metal/MTLAccelerationStructure.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLAccelerationStructureCommandEncoder.hpp b/Source/Cxxmlx/include/Metal/MTLAccelerationStructureCommandEncoder.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLAccelerationStructureCommandEncoder.hpp rename to Source/Cxxmlx/include/Metal/MTLAccelerationStructureCommandEncoder.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLAccelerationStructureTypes.hpp b/Source/Cxxmlx/include/Metal/MTLAccelerationStructureTypes.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLAccelerationStructureTypes.hpp rename to Source/Cxxmlx/include/Metal/MTLAccelerationStructureTypes.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLAllocation.hpp b/Source/Cxxmlx/include/Metal/MTLAllocation.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLAllocation.hpp rename to Source/Cxxmlx/include/Metal/MTLAllocation.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLArgument.hpp b/Source/Cxxmlx/include/Metal/MTLArgument.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLArgument.hpp rename to Source/Cxxmlx/include/Metal/MTLArgument.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLArgumentEncoder.hpp b/Source/Cxxmlx/include/Metal/MTLArgumentEncoder.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLArgumentEncoder.hpp rename to Source/Cxxmlx/include/Metal/MTLArgumentEncoder.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLBinaryArchive.hpp b/Source/Cxxmlx/include/Metal/MTLBinaryArchive.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLBinaryArchive.hpp rename to Source/Cxxmlx/include/Metal/MTLBinaryArchive.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLBlitCommandEncoder.hpp b/Source/Cxxmlx/include/Metal/MTLBlitCommandEncoder.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLBlitCommandEncoder.hpp rename to Source/Cxxmlx/include/Metal/MTLBlitCommandEncoder.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLBlitPass.hpp b/Source/Cxxmlx/include/Metal/MTLBlitPass.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLBlitPass.hpp rename to Source/Cxxmlx/include/Metal/MTLBlitPass.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLBuffer.hpp b/Source/Cxxmlx/include/Metal/MTLBuffer.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLBuffer.hpp rename to Source/Cxxmlx/include/Metal/MTLBuffer.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLCaptureManager.hpp b/Source/Cxxmlx/include/Metal/MTLCaptureManager.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLCaptureManager.hpp rename to Source/Cxxmlx/include/Metal/MTLCaptureManager.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLCaptureScope.hpp b/Source/Cxxmlx/include/Metal/MTLCaptureScope.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLCaptureScope.hpp rename to Source/Cxxmlx/include/Metal/MTLCaptureScope.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLCommandBuffer.hpp b/Source/Cxxmlx/include/Metal/MTLCommandBuffer.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLCommandBuffer.hpp rename to Source/Cxxmlx/include/Metal/MTLCommandBuffer.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLCommandEncoder.hpp b/Source/Cxxmlx/include/Metal/MTLCommandEncoder.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLCommandEncoder.hpp rename to Source/Cxxmlx/include/Metal/MTLCommandEncoder.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLCommandQueue.hpp b/Source/Cxxmlx/include/Metal/MTLCommandQueue.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLCommandQueue.hpp rename to Source/Cxxmlx/include/Metal/MTLCommandQueue.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLComputeCommandEncoder.hpp b/Source/Cxxmlx/include/Metal/MTLComputeCommandEncoder.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLComputeCommandEncoder.hpp rename to Source/Cxxmlx/include/Metal/MTLComputeCommandEncoder.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLComputePass.hpp b/Source/Cxxmlx/include/Metal/MTLComputePass.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLComputePass.hpp rename to Source/Cxxmlx/include/Metal/MTLComputePass.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLComputePipeline.hpp b/Source/Cxxmlx/include/Metal/MTLComputePipeline.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLComputePipeline.hpp rename to Source/Cxxmlx/include/Metal/MTLComputePipeline.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLCounters.hpp b/Source/Cxxmlx/include/Metal/MTLCounters.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLCounters.hpp rename to Source/Cxxmlx/include/Metal/MTLCounters.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLDataType.hpp b/Source/Cxxmlx/include/Metal/MTLDataType.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLDataType.hpp rename to Source/Cxxmlx/include/Metal/MTLDataType.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLDefines.hpp b/Source/Cxxmlx/include/Metal/MTLDefines.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLDefines.hpp rename to Source/Cxxmlx/include/Metal/MTLDefines.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLDepthStencil.hpp b/Source/Cxxmlx/include/Metal/MTLDepthStencil.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLDepthStencil.hpp rename to Source/Cxxmlx/include/Metal/MTLDepthStencil.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLDevice.hpp b/Source/Cxxmlx/include/Metal/MTLDevice.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLDevice.hpp rename to Source/Cxxmlx/include/Metal/MTLDevice.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLDrawable.hpp b/Source/Cxxmlx/include/Metal/MTLDrawable.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLDrawable.hpp rename to Source/Cxxmlx/include/Metal/MTLDrawable.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLDynamicLibrary.hpp b/Source/Cxxmlx/include/Metal/MTLDynamicLibrary.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLDynamicLibrary.hpp rename to Source/Cxxmlx/include/Metal/MTLDynamicLibrary.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLEvent.hpp b/Source/Cxxmlx/include/Metal/MTLEvent.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLEvent.hpp rename to Source/Cxxmlx/include/Metal/MTLEvent.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLFence.hpp b/Source/Cxxmlx/include/Metal/MTLFence.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLFence.hpp rename to Source/Cxxmlx/include/Metal/MTLFence.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLFunctionConstantValues.hpp b/Source/Cxxmlx/include/Metal/MTLFunctionConstantValues.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLFunctionConstantValues.hpp rename to Source/Cxxmlx/include/Metal/MTLFunctionConstantValues.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLFunctionDescriptor.hpp b/Source/Cxxmlx/include/Metal/MTLFunctionDescriptor.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLFunctionDescriptor.hpp rename to Source/Cxxmlx/include/Metal/MTLFunctionDescriptor.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLFunctionHandle.hpp b/Source/Cxxmlx/include/Metal/MTLFunctionHandle.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLFunctionHandle.hpp rename to Source/Cxxmlx/include/Metal/MTLFunctionHandle.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLFunctionLog.hpp b/Source/Cxxmlx/include/Metal/MTLFunctionLog.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLFunctionLog.hpp rename to Source/Cxxmlx/include/Metal/MTLFunctionLog.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLFunctionStitching.hpp b/Source/Cxxmlx/include/Metal/MTLFunctionStitching.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLFunctionStitching.hpp rename to Source/Cxxmlx/include/Metal/MTLFunctionStitching.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLGPUAddress.hpp b/Source/Cxxmlx/include/Metal/MTLGPUAddress.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLGPUAddress.hpp rename to Source/Cxxmlx/include/Metal/MTLGPUAddress.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLHeaderBridge.hpp b/Source/Cxxmlx/include/Metal/MTLHeaderBridge.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLHeaderBridge.hpp rename to Source/Cxxmlx/include/Metal/MTLHeaderBridge.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLHeap.hpp b/Source/Cxxmlx/include/Metal/MTLHeap.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLHeap.hpp rename to Source/Cxxmlx/include/Metal/MTLHeap.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLIOCommandBuffer.hpp b/Source/Cxxmlx/include/Metal/MTLIOCommandBuffer.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLIOCommandBuffer.hpp rename to Source/Cxxmlx/include/Metal/MTLIOCommandBuffer.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLIOCommandQueue.hpp b/Source/Cxxmlx/include/Metal/MTLIOCommandQueue.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLIOCommandQueue.hpp rename to Source/Cxxmlx/include/Metal/MTLIOCommandQueue.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLIOCompressor.hpp b/Source/Cxxmlx/include/Metal/MTLIOCompressor.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLIOCompressor.hpp rename to Source/Cxxmlx/include/Metal/MTLIOCompressor.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLIndirectCommandBuffer.hpp b/Source/Cxxmlx/include/Metal/MTLIndirectCommandBuffer.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLIndirectCommandBuffer.hpp rename to Source/Cxxmlx/include/Metal/MTLIndirectCommandBuffer.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLIndirectCommandEncoder.hpp b/Source/Cxxmlx/include/Metal/MTLIndirectCommandEncoder.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLIndirectCommandEncoder.hpp rename to Source/Cxxmlx/include/Metal/MTLIndirectCommandEncoder.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLIntersectionFunctionTable.hpp b/Source/Cxxmlx/include/Metal/MTLIntersectionFunctionTable.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLIntersectionFunctionTable.hpp rename to Source/Cxxmlx/include/Metal/MTLIntersectionFunctionTable.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLLibrary.hpp b/Source/Cxxmlx/include/Metal/MTLLibrary.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLLibrary.hpp rename to Source/Cxxmlx/include/Metal/MTLLibrary.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLLinkedFunctions.hpp b/Source/Cxxmlx/include/Metal/MTLLinkedFunctions.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLLinkedFunctions.hpp rename to Source/Cxxmlx/include/Metal/MTLLinkedFunctions.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLLogState.hpp b/Source/Cxxmlx/include/Metal/MTLLogState.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLLogState.hpp rename to Source/Cxxmlx/include/Metal/MTLLogState.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLParallelRenderCommandEncoder.hpp b/Source/Cxxmlx/include/Metal/MTLParallelRenderCommandEncoder.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLParallelRenderCommandEncoder.hpp rename to Source/Cxxmlx/include/Metal/MTLParallelRenderCommandEncoder.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLPipeline.hpp b/Source/Cxxmlx/include/Metal/MTLPipeline.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLPipeline.hpp rename to Source/Cxxmlx/include/Metal/MTLPipeline.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLPixelFormat.hpp b/Source/Cxxmlx/include/Metal/MTLPixelFormat.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLPixelFormat.hpp rename to Source/Cxxmlx/include/Metal/MTLPixelFormat.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLPrivate.hpp b/Source/Cxxmlx/include/Metal/MTLPrivate.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLPrivate.hpp rename to Source/Cxxmlx/include/Metal/MTLPrivate.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLRasterizationRate.hpp b/Source/Cxxmlx/include/Metal/MTLRasterizationRate.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLRasterizationRate.hpp rename to Source/Cxxmlx/include/Metal/MTLRasterizationRate.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLRenderCommandEncoder.hpp b/Source/Cxxmlx/include/Metal/MTLRenderCommandEncoder.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLRenderCommandEncoder.hpp rename to Source/Cxxmlx/include/Metal/MTLRenderCommandEncoder.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLRenderPass.hpp b/Source/Cxxmlx/include/Metal/MTLRenderPass.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLRenderPass.hpp rename to Source/Cxxmlx/include/Metal/MTLRenderPass.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLRenderPipeline.hpp b/Source/Cxxmlx/include/Metal/MTLRenderPipeline.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLRenderPipeline.hpp rename to Source/Cxxmlx/include/Metal/MTLRenderPipeline.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLResidencySet.hpp b/Source/Cxxmlx/include/Metal/MTLResidencySet.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLResidencySet.hpp rename to Source/Cxxmlx/include/Metal/MTLResidencySet.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLResource.hpp b/Source/Cxxmlx/include/Metal/MTLResource.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLResource.hpp rename to Source/Cxxmlx/include/Metal/MTLResource.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLResourceStateCommandEncoder.hpp b/Source/Cxxmlx/include/Metal/MTLResourceStateCommandEncoder.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLResourceStateCommandEncoder.hpp rename to Source/Cxxmlx/include/Metal/MTLResourceStateCommandEncoder.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLResourceStatePass.hpp b/Source/Cxxmlx/include/Metal/MTLResourceStatePass.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLResourceStatePass.hpp rename to Source/Cxxmlx/include/Metal/MTLResourceStatePass.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLResourceViewPool.hpp b/Source/Cxxmlx/include/Metal/MTLResourceViewPool.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLResourceViewPool.hpp rename to Source/Cxxmlx/include/Metal/MTLResourceViewPool.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLSampler.hpp b/Source/Cxxmlx/include/Metal/MTLSampler.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLSampler.hpp rename to Source/Cxxmlx/include/Metal/MTLSampler.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLStageInputOutputDescriptor.hpp b/Source/Cxxmlx/include/Metal/MTLStageInputOutputDescriptor.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLStageInputOutputDescriptor.hpp rename to Source/Cxxmlx/include/Metal/MTLStageInputOutputDescriptor.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLTensor.hpp b/Source/Cxxmlx/include/Metal/MTLTensor.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLTensor.hpp rename to Source/Cxxmlx/include/Metal/MTLTensor.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLTexture.hpp b/Source/Cxxmlx/include/Metal/MTLTexture.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLTexture.hpp rename to Source/Cxxmlx/include/Metal/MTLTexture.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLTextureViewPool.hpp b/Source/Cxxmlx/include/Metal/MTLTextureViewPool.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLTextureViewPool.hpp rename to Source/Cxxmlx/include/Metal/MTLTextureViewPool.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLTypes.hpp b/Source/Cxxmlx/include/Metal/MTLTypes.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLTypes.hpp rename to Source/Cxxmlx/include/Metal/MTLTypes.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLVersion.hpp b/Source/Cxxmlx/include/Metal/MTLVersion.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLVersion.hpp rename to Source/Cxxmlx/include/Metal/MTLVersion.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLVertexDescriptor.hpp b/Source/Cxxmlx/include/Metal/MTLVertexDescriptor.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLVertexDescriptor.hpp rename to Source/Cxxmlx/include/Metal/MTLVertexDescriptor.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/MTLVisibleFunctionTable.hpp b/Source/Cxxmlx/include/Metal/MTLVisibleFunctionTable.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/MTLVisibleFunctionTable.hpp rename to Source/Cxxmlx/include/Metal/MTLVisibleFunctionTable.hpp diff --git a/Source/Cmlx/metal-cpp/Metal/Metal.hpp b/Source/Cxxmlx/include/Metal/Metal.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/Metal/Metal.hpp rename to Source/Cxxmlx/include/Metal/Metal.hpp diff --git a/Source/Cmlx/metal-cpp/MetalFX/MTL4FXFrameInterpolator.hpp b/Source/Cxxmlx/include/MetalFX/MTL4FXFrameInterpolator.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/MetalFX/MTL4FXFrameInterpolator.hpp rename to Source/Cxxmlx/include/MetalFX/MTL4FXFrameInterpolator.hpp diff --git a/Source/Cmlx/metal-cpp/MetalFX/MTL4FXSpatialScaler.hpp b/Source/Cxxmlx/include/MetalFX/MTL4FXSpatialScaler.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/MetalFX/MTL4FXSpatialScaler.hpp rename to Source/Cxxmlx/include/MetalFX/MTL4FXSpatialScaler.hpp diff --git a/Source/Cmlx/metal-cpp/MetalFX/MTL4FXTemporalDenoisedScaler.hpp b/Source/Cxxmlx/include/MetalFX/MTL4FXTemporalDenoisedScaler.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/MetalFX/MTL4FXTemporalDenoisedScaler.hpp rename to Source/Cxxmlx/include/MetalFX/MTL4FXTemporalDenoisedScaler.hpp diff --git a/Source/Cmlx/metal-cpp/MetalFX/MTL4FXTemporalScaler.hpp b/Source/Cxxmlx/include/MetalFX/MTL4FXTemporalScaler.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/MetalFX/MTL4FXTemporalScaler.hpp rename to Source/Cxxmlx/include/MetalFX/MTL4FXTemporalScaler.hpp diff --git a/Source/Cmlx/metal-cpp/MetalFX/MTLFXDefines.hpp b/Source/Cxxmlx/include/MetalFX/MTLFXDefines.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/MetalFX/MTLFXDefines.hpp rename to Source/Cxxmlx/include/MetalFX/MTLFXDefines.hpp diff --git a/Source/Cmlx/metal-cpp/MetalFX/MTLFXFrameInterpolator.hpp b/Source/Cxxmlx/include/MetalFX/MTLFXFrameInterpolator.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/MetalFX/MTLFXFrameInterpolator.hpp rename to Source/Cxxmlx/include/MetalFX/MTLFXFrameInterpolator.hpp diff --git a/Source/Cmlx/metal-cpp/MetalFX/MTLFXPrivate.hpp b/Source/Cxxmlx/include/MetalFX/MTLFXPrivate.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/MetalFX/MTLFXPrivate.hpp rename to Source/Cxxmlx/include/MetalFX/MTLFXPrivate.hpp diff --git a/Source/Cmlx/metal-cpp/MetalFX/MTLFXSpatialScaler.hpp b/Source/Cxxmlx/include/MetalFX/MTLFXSpatialScaler.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/MetalFX/MTLFXSpatialScaler.hpp rename to Source/Cxxmlx/include/MetalFX/MTLFXSpatialScaler.hpp diff --git a/Source/Cmlx/metal-cpp/MetalFX/MTLFXTemporalDenoisedScaler.hpp b/Source/Cxxmlx/include/MetalFX/MTLFXTemporalDenoisedScaler.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/MetalFX/MTLFXTemporalDenoisedScaler.hpp rename to Source/Cxxmlx/include/MetalFX/MTLFXTemporalDenoisedScaler.hpp diff --git a/Source/Cmlx/metal-cpp/MetalFX/MTLFXTemporalScaler.hpp b/Source/Cxxmlx/include/MetalFX/MTLFXTemporalScaler.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/MetalFX/MTLFXTemporalScaler.hpp rename to Source/Cxxmlx/include/MetalFX/MTLFXTemporalScaler.hpp diff --git a/Source/Cmlx/metal-cpp/MetalFX/MetalFX.hpp b/Source/Cxxmlx/include/MetalFX/MetalFX.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/MetalFX/MetalFX.hpp rename to Source/Cxxmlx/include/MetalFX/MetalFX.hpp diff --git a/Source/Cmlx/metal-cpp/QuartzCore/CADefines.hpp b/Source/Cxxmlx/include/QuartzCore/CADefines.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/QuartzCore/CADefines.hpp rename to Source/Cxxmlx/include/QuartzCore/CADefines.hpp diff --git a/Source/Cmlx/metal-cpp/QuartzCore/CAMetalDrawable.hpp b/Source/Cxxmlx/include/QuartzCore/CAMetalDrawable.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/QuartzCore/CAMetalDrawable.hpp rename to Source/Cxxmlx/include/QuartzCore/CAMetalDrawable.hpp diff --git a/Source/Cmlx/metal-cpp/QuartzCore/CAMetalLayer.hpp b/Source/Cxxmlx/include/QuartzCore/CAMetalLayer.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/QuartzCore/CAMetalLayer.hpp rename to Source/Cxxmlx/include/QuartzCore/CAMetalLayer.hpp diff --git a/Source/Cmlx/metal-cpp/QuartzCore/CAPrivate.hpp b/Source/Cxxmlx/include/QuartzCore/CAPrivate.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/QuartzCore/CAPrivate.hpp rename to Source/Cxxmlx/include/QuartzCore/CAPrivate.hpp diff --git a/Source/Cmlx/metal-cpp/QuartzCore/QuartzCore.hpp b/Source/Cxxmlx/include/QuartzCore/QuartzCore.hpp similarity index 100% rename from Source/Cmlx/metal-cpp/QuartzCore/QuartzCore.hpp rename to Source/Cxxmlx/include/QuartzCore/QuartzCore.hpp diff --git a/Source/Cxxmlx/include/mlx/3rdparty/pocketfft.h b/Source/Cxxmlx/include/mlx/3rdparty/pocketfft.h new file mode 100644 index 00000000..03a45897 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/3rdparty/pocketfft.h @@ -0,0 +1,3581 @@ +/* +This file is part of pocketfft. + +Copyright (C) 2010-2022 Max-Planck-Society +Copyright (C) 2019-2020 Peter Bell + +For the odd-sized DCT-IV transforms: + Copyright (C) 2003, 2007-14 Matteo Frigo + Copyright (C) 2003, 2007-14 Massachusetts Institute of Technology + +Authors: Martin Reinecke, Peter Bell + +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. +* Redistributions in binary form must reproduce the above copyright notice, this + list of conditions and the following disclaimer in the documentation and/or + other materials provided with the distribution. +* Neither the name of the copyright holder nor the names of its contributors may + be used to endorse or promote products derived from this software without + specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ + +#ifndef POCKETFFT_HDRONLY_H +#define POCKETFFT_HDRONLY_H + +#ifndef __cplusplus +#error This file is C++ and requires a C++ compiler. +#endif + +#if !(__cplusplus >= 201103L || _MSVC_LANG+0L >= 201103L) +#error This file requires at least C++11 support. +#endif + +#ifndef POCKETFFT_CACHE_SIZE +#define POCKETFFT_CACHE_SIZE 0 +#endif + +#include +#include +#include +#include +#include +#include +#include +#if POCKETFFT_CACHE_SIZE!=0 +#include +#include +#endif + +#ifndef POCKETFFT_NO_MULTITHREADING +#include +#include +#include +#include +#include +#include +#include + +#ifdef POCKETFFT_PTHREADS +# include +#endif +#endif + +#if defined(__GNUC__) +#define POCKETFFT_NOINLINE __attribute__((noinline)) +#define POCKETFFT_RESTRICT __restrict__ +#elif defined(_MSC_VER) +#define POCKETFFT_NOINLINE __declspec(noinline) +#define POCKETFFT_RESTRICT __restrict +#else +#define POCKETFFT_NOINLINE +#define POCKETFFT_RESTRICT +#endif + +namespace pocketfft { + +namespace detail { +using std::size_t; +using std::ptrdiff_t; + +// Always use std:: for functions +template T cos(T) = delete; +template T sin(T) = delete; +template T sqrt(T) = delete; + +using shape_t = std::vector; +using stride_t = std::vector; + +constexpr bool FORWARD = true, + BACKWARD = false; + +// only enable vector support for gcc>=5.0 and clang>=5.0 +#ifndef POCKETFFT_NO_VECTORS +#define POCKETFFT_NO_VECTORS +#if defined(__INTEL_COMPILER) +// do nothing. This is necessary because this compiler also sets __GNUC__. +#elif defined(__clang__) +// AppleClang has their own version numbering +#ifdef __apple_build_version__ +# if (__clang_major__ > 9) || (__clang_major__ == 9 && __clang_minor__ >= 1) +# undef POCKETFFT_NO_VECTORS +# endif +#elif __clang_major__ >= 5 +# undef POCKETFFT_NO_VECTORS +#endif +#elif defined(__GNUC__) +#if __GNUC__>=5 +#undef POCKETFFT_NO_VECTORS +#endif +#endif +#endif + +template struct VLEN { static constexpr size_t val=1; }; + +#ifndef POCKETFFT_NO_VECTORS +#if (defined(__AVX512F__)) +template<> struct VLEN { static constexpr size_t val=16; }; +template<> struct VLEN { static constexpr size_t val=8; }; +#elif (defined(__AVX__)) +template<> struct VLEN { static constexpr size_t val=8; }; +template<> struct VLEN { static constexpr size_t val=4; }; +#elif (defined(__SSE2__)) +template<> struct VLEN { static constexpr size_t val=4; }; +template<> struct VLEN { static constexpr size_t val=2; }; +#elif (defined(__VSX__)) +template<> struct VLEN { static constexpr size_t val=4; }; +template<> struct VLEN { static constexpr size_t val=2; }; +#elif (defined(__ARM_NEON__) || defined(__ARM_NEON)) +template<> struct VLEN { static constexpr size_t val=4; }; +template<> struct VLEN { static constexpr size_t val=2; }; +#else +#define POCKETFFT_NO_VECTORS +#endif +#endif + +// the __MINGW32__ part in the conditional below works around the problem that +// the standard C++ library on Windows does not provide aligned_alloc() even +// though the MinGW compiler and MSVC may advertise C++17 compliance. +#if (__cplusplus >= 201703L) && (!defined(__MINGW32__)) && (!defined(_MSC_VER)) +inline void *aligned_alloc(size_t align, size_t size) + { + // aligned_alloc() requires that the requested size is a multiple of "align" + void *ptr = ::aligned_alloc(align,(size+align-1)&(~(align-1))); + if (!ptr) throw std::bad_alloc(); + return ptr; + } +inline void aligned_dealloc(void *ptr) + { free(ptr); } +#else // portable emulation +inline void *aligned_alloc(size_t align, size_t size) + { + align = std::max(align, alignof(max_align_t)); + void *ptr = malloc(size+align); + if (!ptr) throw std::bad_alloc(); + void *res = reinterpret_cast + ((reinterpret_cast(ptr) & ~(uintptr_t(align-1))) + uintptr_t(align)); + (reinterpret_cast(res))[-1] = ptr; + return res; + } +inline void aligned_dealloc(void *ptr) + { if (ptr) free((reinterpret_cast(ptr))[-1]); } +#endif + +template class arr + { + private: + T *p; + size_t sz; + +#if defined(POCKETFFT_NO_VECTORS) + static T *ralloc(size_t num) + { + if (num==0) return nullptr; + void *res = malloc(num*sizeof(T)); + if (!res) throw std::bad_alloc(); + return reinterpret_cast(res); + } + static void dealloc(T *ptr) + { free(ptr); } +#else + static T *ralloc(size_t num) + { + if (num==0) return nullptr; + void *ptr = aligned_alloc(64, num*sizeof(T)); + return static_cast(ptr); + } + static void dealloc(T *ptr) + { aligned_dealloc(ptr); } +#endif + + public: + arr() : p(0), sz(0) {} + arr(size_t n) : p(ralloc(n)), sz(n) {} + arr(arr &&other) + : p(other.p), sz(other.sz) + { other.p=nullptr; other.sz=0; } + ~arr() { dealloc(p); } + + void resize(size_t n) + { + if (n==sz) return; + dealloc(p); + p = ralloc(n); + sz = n; + } + + T &operator[](size_t idx) { return p[idx]; } + const T &operator[](size_t idx) const { return p[idx]; } + + T *data() { return p; } + const T *data() const { return p; } + + size_t size() const { return sz; } + }; + +template struct cmplx { + T r, i; + cmplx() {} + cmplx(T r_, T i_) : r(r_), i(i_) {} + void Set(T r_, T i_) { r=r_; i=i_; } + void Set(T r_) { r=r_; i=T(0); } + cmplx &operator+= (const cmplx &other) + { r+=other.r; i+=other.i; return *this; } + templatecmplx &operator*= (T2 other) + { r*=other; i*=other; return *this; } + templatecmplx &operator*= (const cmplx &other) + { + T tmp = r*other.r - i*other.i; + i = r*other.i + i*other.r; + r = tmp; + return *this; + } + templatecmplx &operator+= (const cmplx &other) + { r+=other.r; i+=other.i; return *this; } + templatecmplx &operator-= (const cmplx &other) + { r-=other.r; i-=other.i; return *this; } + template auto operator* (const T2 &other) const + -> cmplx + { return {r*other, i*other}; } + template auto operator+ (const cmplx &other) const + -> cmplx + { return {r+other.r, i+other.i}; } + template auto operator- (const cmplx &other) const + -> cmplx + { return {r-other.r, i-other.i}; } + template auto operator* (const cmplx &other) const + -> cmplx + { return {r*other.r-i*other.i, r*other.i + i*other.r}; } + template auto special_mul (const cmplx &other) const + -> cmplx + { + using Tres = cmplx; + return fwd ? Tres(r*other.r+i*other.i, i*other.r-r*other.i) + : Tres(r*other.r-i*other.i, r*other.i+i*other.r); + } +}; +template inline void PM(T &a, T &b, T c, T d) + { a=c+d; b=c-d; } +template inline void PMINPLACE(T &a, T &b) + { T t = a; a+=b; b=t-b; } +template inline void MPINPLACE(T &a, T &b) + { T t = a; a-=b; b=t+b; } +template cmplx conj(const cmplx &a) + { return {a.r, -a.i}; } +template void special_mul (const cmplx &v1, const cmplx &v2, cmplx &res) + { + res = fwd ? cmplx(v1.r*v2.r+v1.i*v2.i, v1.i*v2.r-v1.r*v2.i) + : cmplx(v1.r*v2.r-v1.i*v2.i, v1.r*v2.i+v1.i*v2.r); + } + +template void ROT90(cmplx &a) + { auto tmp_=a.r; a.r=-a.i; a.i=tmp_; } +template void ROTX90(cmplx &a) + { auto tmp_= fwd ? -a.r : a.r; a.r = fwd ? a.i : -a.i; a.i=tmp_; } + +// +// twiddle factor section +// +template class sincos_2pibyn + { + private: + using Thigh = typename std::conditional<(sizeof(T)>sizeof(double)), T, double>::type; + size_t N, mask, shift; + arr> v1, v2; + + static cmplx calc(size_t x, size_t n, Thigh ang) + { + x<<=3; + if (x<4*n) // first half + { + if (x<2*n) // first quadrant + { + if (x(std::cos(Thigh(x)*ang), std::sin(Thigh(x)*ang)); + return cmplx(std::sin(Thigh(2*n-x)*ang), std::cos(Thigh(2*n-x)*ang)); + } + else // second quadrant + { + x-=2*n; + if (x(-std::sin(Thigh(x)*ang), std::cos(Thigh(x)*ang)); + return cmplx(-std::cos(Thigh(2*n-x)*ang), std::sin(Thigh(2*n-x)*ang)); + } + } + else + { + x=8*n-x; + if (x<2*n) // third quadrant + { + if (x(std::cos(Thigh(x)*ang), -std::sin(Thigh(x)*ang)); + return cmplx(std::sin(Thigh(2*n-x)*ang), -std::cos(Thigh(2*n-x)*ang)); + } + else // fourth quadrant + { + x-=2*n; + if (x(-std::sin(Thigh(x)*ang), -std::cos(Thigh(x)*ang)); + return cmplx(-std::cos(Thigh(2*n-x)*ang), -std::sin(Thigh(2*n-x)*ang)); + } + } + } + + public: + POCKETFFT_NOINLINE sincos_2pibyn(size_t n) + : N(n) + { + constexpr auto pi = 3.141592653589793238462643383279502884197L; + Thigh ang = Thigh(0.25L*pi/n); + size_t nval = (n+2)/2; + shift = 1; + while((size_t(1)< operator[](size_t idx) const + { + if (2*idx<=N) + { + auto x1=v1[idx&mask], x2=v2[idx>>shift]; + return cmplx(T(x1.r*x2.r-x1.i*x2.i), T(x1.r*x2.i+x1.i*x2.r)); + } + idx = N-idx; + auto x1=v1[idx&mask], x2=v2[idx>>shift]; + return cmplx(T(x1.r*x2.r-x1.i*x2.i), -T(x1.r*x2.i+x1.i*x2.r)); + } + }; + +struct util // hack to avoid duplicate symbols + { + static POCKETFFT_NOINLINE size_t largest_prime_factor (size_t n) + { + size_t res=1; + while ((n&1)==0) + { res=2; n>>=1; } + for (size_t x=3; x*x<=n; x+=2) + while ((n%x)==0) + { res=x; n/=x; } + if (n>1) res=n; + return res; + } + + static POCKETFFT_NOINLINE double cost_guess (size_t n) + { + constexpr double lfp=1.1; // penalty for non-hardcoded larger factors + size_t ni=n; + double result=0.; + while ((n&1)==0) + { result+=2; n>>=1; } + for (size_t x=3; x*x<=n; x+=2) + while ((n%x)==0) + { + result+= (x<=5) ? double(x) : lfp*double(x); // penalize larger prime factors + n/=x; + } + if (n>1) result+=(n<=5) ? double(n) : lfp*double(n); + return result*double(ni); + } + + /* returns the smallest composite of 2, 3, 5, 7 and 11 which is >= n */ + static POCKETFFT_NOINLINE size_t good_size_cmplx(size_t n) + { + if (n<=12) return n; + + size_t bestfac=2*n; + for (size_t f11=1; f11n) + { + if (x>=1; + } + else + return n; + } + } + return bestfac; + } + + /* returns the smallest composite of 2, 3, 5 which is >= n */ + static POCKETFFT_NOINLINE size_t good_size_real(size_t n) + { + if (n<=6) return n; + + size_t bestfac=2*n; + for (size_t f5=1; f5n) + { + if (x>=1; + } + else + return n; + } + } + return bestfac; + } + + static size_t prod(const shape_t &shape) + { + size_t res=1; + for (auto sz: shape) + res*=sz; + return res; + } + + static POCKETFFT_NOINLINE void sanity_check(const shape_t &shape, + const stride_t &stride_in, const stride_t &stride_out, bool inplace) + { + auto ndim = shape.size(); + if (ndim<1) throw std::runtime_error("ndim must be >= 1"); + if ((stride_in.size()!=ndim) || (stride_out.size()!=ndim)) + throw std::runtime_error("stride dimension mismatch"); + if (inplace && (stride_in!=stride_out)) + throw std::runtime_error("stride mismatch"); + } + + static POCKETFFT_NOINLINE void sanity_check(const shape_t &shape, + const stride_t &stride_in, const stride_t &stride_out, bool inplace, + const shape_t &axes) + { + sanity_check(shape, stride_in, stride_out, inplace); + auto ndim = shape.size(); + shape_t tmp(ndim,0); + for (auto ax : axes) + { + if (ax>=ndim) throw std::invalid_argument("bad axis number"); + if (++tmp[ax]>1) throw std::invalid_argument("axis specified repeatedly"); + } + } + + static POCKETFFT_NOINLINE void sanity_check(const shape_t &shape, + const stride_t &stride_in, const stride_t &stride_out, bool inplace, + size_t axis) + { + sanity_check(shape, stride_in, stride_out, inplace); + if (axis>=shape.size()) throw std::invalid_argument("bad axis number"); + } + +#ifdef POCKETFFT_NO_MULTITHREADING + static size_t thread_count (size_t /*nthreads*/, const shape_t &/*shape*/, + size_t /*axis*/, size_t /*vlen*/) + { return 1; } +#else + static size_t thread_count (size_t nthreads, const shape_t &shape, + size_t axis, size_t vlen) + { + if (nthreads==1) return 1; + size_t size = prod(shape); + size_t parallel = size / (shape[axis] * vlen); + if (shape[axis] < 1000) + parallel /= 4; + size_t max_threads = nthreads == 0 ? + std::thread::hardware_concurrency() : nthreads; + return std::max(size_t(1), std::min(parallel, max_threads)); + } +#endif + }; + +namespace threading { + +#ifdef POCKETFFT_NO_MULTITHREADING + +constexpr inline size_t thread_id() { return 0; } +constexpr inline size_t num_threads() { return 1; } + +template +void thread_map(size_t /* nthreads */, Func f) + { f(); } + +#else + +inline size_t &thread_id() + { + static thread_local size_t thread_id_=0; + return thread_id_; + } +inline size_t &num_threads() + { + static thread_local size_t num_threads_=1; + return num_threads_; + } +static const size_t max_threads = std::max(1u, std::thread::hardware_concurrency()); + +class latch + { + std::atomic num_left_; + std::mutex mut_; + std::condition_variable completed_; + using lock_t = std::unique_lock; + + public: + latch(size_t n): num_left_(n) {} + + void count_down() + { + lock_t lock(mut_); + if (--num_left_) + return; + completed_.notify_all(); + } + + void wait() + { + lock_t lock(mut_); + completed_.wait(lock, [this]{ return is_ready(); }); + } + bool is_ready() { return num_left_ == 0; } + }; + +template class concurrent_queue + { + std::queue q_; + std::mutex mut_; + std::atomic size_; + using lock_t = std::lock_guard; + + public: + + void push(T val) + { + lock_t lock(mut_); + ++size_; + q_.push(std::move(val)); + } + + bool try_pop(T &val) + { + if (size_ == 0) return false; + lock_t lock(mut_); + // Queue might have been emptied while we acquired the lock + if (q_.empty()) return false; + + val = std::move(q_.front()); + --size_; + q_.pop(); + return true; + } + + bool empty() const { return size_==0; } + }; + +// C++ allocator with support for over-aligned types +template struct aligned_allocator + { + using value_type = T; + template + aligned_allocator(const aligned_allocator&) {} + aligned_allocator() = default; + + T *allocate(size_t n) + { + void* mem = aligned_alloc(alignof(T), n*sizeof(T)); + return static_cast(mem); + } + + void deallocate(T *p, size_t /*n*/) + { aligned_dealloc(p); } + }; + +class thread_pool + { + // A reasonable guess, probably close enough for most hardware + static constexpr size_t cache_line_size = 64; + struct alignas(cache_line_size) worker + { + std::thread thread; + std::condition_variable work_ready; + std::mutex mut; + std::atomic_flag busy_flag = ATOMIC_FLAG_INIT; + std::function work; + + void worker_main( + std::atomic &shutdown_flag, + std::atomic &unscheduled_tasks, + concurrent_queue> &overflow_work) + { + using lock_t = std::unique_lock; + bool expect_work = true; + while (!shutdown_flag || expect_work) + { + std::function local_work; + if (expect_work || unscheduled_tasks == 0) + { + lock_t lock(mut); + // Wait until there is work to be executed + work_ready.wait(lock, [&]{ return (work || shutdown_flag); }); + local_work.swap(work); + expect_work = false; + } + + bool marked_busy = false; + if (local_work) + { + marked_busy = true; + local_work(); + } + + if (!overflow_work.empty()) + { + if (!marked_busy && busy_flag.test_and_set()) + { + expect_work = true; + continue; + } + marked_busy = true; + + while (overflow_work.try_pop(local_work)) + { + --unscheduled_tasks; + local_work(); + } + } + + if (marked_busy) busy_flag.clear(); + } + } + }; + + concurrent_queue> overflow_work_; + std::mutex mut_; + std::vector> workers_; + std::atomic shutdown_; + std::atomic unscheduled_tasks_; + using lock_t = std::lock_guard; + + void create_threads() + { + lock_t lock(mut_); + size_t nthreads=workers_.size(); + for (size_t i=0; ibusy_flag.clear(); + worker->work = nullptr; + worker->thread = std::thread([worker, this] + { + worker->worker_main(shutdown_, unscheduled_tasks_, overflow_work_); + }); + } + catch (...) + { + shutdown_locked(); + throw; + } + } + } + + void shutdown_locked() + { + shutdown_ = true; + for (auto &worker : workers_) + worker.work_ready.notify_all(); + + for (auto &worker : workers_) + if (worker.thread.joinable()) + worker.thread.join(); + } + + public: + explicit thread_pool(size_t nthreads): + workers_(nthreads) + { create_threads(); } + + thread_pool(): thread_pool(max_threads) {} + + ~thread_pool() { shutdown(); } + + void submit(std::function work) + { + lock_t lock(mut_); + if (shutdown_) + throw std::runtime_error("Work item submitted after shutdown"); + + ++unscheduled_tasks_; + + // First check for any idle workers and wake those + for (auto &worker : workers_) + if (!worker.busy_flag.test_and_set()) + { + --unscheduled_tasks_; + { + lock_t lock(worker.mut); + worker.work = std::move(work); + } + worker.work_ready.notify_one(); + return; + } + + // If no workers were idle, push onto the overflow queue for later + overflow_work_.push(std::move(work)); + } + + void shutdown() + { + lock_t lock(mut_); + shutdown_locked(); + } + + void restart() + { + shutdown_ = false; + create_threads(); + } + }; + +inline thread_pool & get_pool() + { + static thread_pool pool; +#ifdef POCKETFFT_PTHREADS + static std::once_flag f; + std::call_once(f, + []{ + pthread_atfork( + +[]{ get_pool().shutdown(); }, // prepare + +[]{ get_pool().restart(); }, // parent + +[]{ get_pool().restart(); } // child + ); + }); +#endif + + return pool; + } + +/** Map a function f over nthreads */ +template +void thread_map(size_t nthreads, Func f) + { + if (nthreads == 0) + nthreads = max_threads; + + if (nthreads == 1) + { f(); return; } + + auto & pool = get_pool(); + latch counter(nthreads); + std::exception_ptr ex; + std::mutex ex_mut; + for (size_t i=0; i lock(ex_mut); + ex = std::current_exception(); + } + counter.count_down(); + }); + } + counter.wait(); + if (ex) + std::rethrow_exception(ex); + } + +#endif + +} + +// +// complex FFTPACK transforms +// + +template class cfftp + { + private: + struct fctdata + { + size_t fct; + cmplx *tw, *tws; + }; + + size_t length; + arr> mem; + std::vector fact; + + void add_factor(size_t factor) + { fact.push_back({factor, nullptr, nullptr}); } + +template void pass2 (size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const cmplx * POCKETFFT_RESTRICT wa) const + { + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+2*c)]; }; + auto WA = [wa, ido](size_t x, size_t i) + { return wa[i-1+x*(ido-1)]; }; + + if (ido==1) + for (size_t k=0; k(CC(i,0,k)-CC(i,1,k),WA(0,i),CH(i,k,1)); + } + } + } + +#define POCKETFFT_PREP3(idx) \ + T t0 = CC(idx,0,k), t1, t2; \ + PM (t1,t2,CC(idx,1,k),CC(idx,2,k)); \ + CH(idx,k,0)=t0+t1; +#define POCKETFFT_PARTSTEP3a(u1,u2,twr,twi) \ + { \ + T ca=t0+t1*twr; \ + T cb{-t2.i*twi, t2.r*twi}; \ + PM(CH(0,k,u1),CH(0,k,u2),ca,cb) ;\ + } +#define POCKETFFT_PARTSTEP3b(u1,u2,twr,twi) \ + { \ + T ca=t0+t1*twr; \ + T cb{-t2.i*twi, t2.r*twi}; \ + special_mul(ca+cb,WA(u1-1,i),CH(i,k,u1)); \ + special_mul(ca-cb,WA(u2-1,i),CH(i,k,u2)); \ + } +template void pass3 (size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const cmplx * POCKETFFT_RESTRICT wa) const + { + constexpr T0 tw1r=-0.5, + tw1i= (fwd ? -1: 1) * T0(0.8660254037844386467637231707529362L); + + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+3*c)]; }; + auto WA = [wa, ido](size_t x, size_t i) + { return wa[i-1+x*(ido-1)]; }; + + if (ido==1) + for (size_t k=0; k void pass4 (size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const cmplx * POCKETFFT_RESTRICT wa) const + { + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+4*c)]; }; + auto WA = [wa, ido](size_t x, size_t i) + { return wa[i-1+x*(ido-1)]; }; + + if (ido==1) + for (size_t k=0; k(t4); + PM(CH(0,k,0),CH(0,k,2),t2,t3); + PM(CH(0,k,1),CH(0,k,3),t1,t4); + } + else + for (size_t k=0; k(t4); + PM(CH(0,k,0),CH(0,k,2),t2,t3); + PM(CH(0,k,1),CH(0,k,3),t1,t4); + } + for (size_t i=1; i(t4); + CH(i,k,0) = t2+t3; + special_mul(t1+t4,WA(0,i),CH(i,k,1)); + special_mul(t2-t3,WA(1,i),CH(i,k,2)); + special_mul(t1-t4,WA(2,i),CH(i,k,3)); + } + } + } + +#define POCKETFFT_PREP5(idx) \ + T t0 = CC(idx,0,k), t1, t2, t3, t4; \ + PM (t1,t4,CC(idx,1,k),CC(idx,4,k)); \ + PM (t2,t3,CC(idx,2,k),CC(idx,3,k)); \ + CH(idx,k,0).r=t0.r+t1.r+t2.r; \ + CH(idx,k,0).i=t0.i+t1.i+t2.i; + +#define POCKETFFT_PARTSTEP5a(u1,u2,twar,twbr,twai,twbi) \ + { \ + T ca,cb; \ + ca.r=t0.r+twar*t1.r+twbr*t2.r; \ + ca.i=t0.i+twar*t1.i+twbr*t2.i; \ + cb.i=twai*t4.r twbi*t3.r; \ + cb.r=-(twai*t4.i twbi*t3.i); \ + PM(CH(0,k,u1),CH(0,k,u2),ca,cb); \ + } + +#define POCKETFFT_PARTSTEP5b(u1,u2,twar,twbr,twai,twbi) \ + { \ + T ca,cb,da,db; \ + ca.r=t0.r+twar*t1.r+twbr*t2.r; \ + ca.i=t0.i+twar*t1.i+twbr*t2.i; \ + cb.i=twai*t4.r twbi*t3.r; \ + cb.r=-(twai*t4.i twbi*t3.i); \ + special_mul(ca+cb,WA(u1-1,i),CH(i,k,u1)); \ + special_mul(ca-cb,WA(u2-1,i),CH(i,k,u2)); \ + } +template void pass5 (size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const cmplx * POCKETFFT_RESTRICT wa) const + { + constexpr T0 tw1r= T0(0.3090169943749474241022934171828191L), + tw1i= (fwd ? -1: 1) * T0(0.9510565162951535721164393333793821L), + tw2r= T0(-0.8090169943749474241022934171828191L), + tw2i= (fwd ? -1: 1) * T0(0.5877852522924731291687059546390728L); + + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+5*c)]; }; + auto WA = [wa, ido](size_t x, size_t i) + { return wa[i-1+x*(ido-1)]; }; + + if (ido==1) + for (size_t k=0; k(da,WA(u1-1,i),CH(i,k,u1)); \ + special_mul(db,WA(u2-1,i),CH(i,k,u2)); \ + } + +template void pass7(size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const cmplx * POCKETFFT_RESTRICT wa) const + { + constexpr T0 tw1r= T0(0.6234898018587335305250048840042398L), + tw1i= (fwd ? -1 : 1) * T0(0.7818314824680298087084445266740578L), + tw2r= T0(-0.2225209339563144042889025644967948L), + tw2i= (fwd ? -1 : 1) * T0(0.9749279121818236070181316829939312L), + tw3r= T0(-0.9009688679024191262361023195074451L), + tw3i= (fwd ? -1 : 1) * T0(0.433883739117558120475768332848359L); + + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+7*c)]; }; + auto WA = [wa, ido](size_t x, size_t i) + { return wa[i-1+x*(ido-1)]; }; + + if (ido==1) + for (size_t k=0; k void ROTX45(T &a) const + { + constexpr T0 hsqt2=T0(0.707106781186547524400844362104849L); + if (fwd) + { auto tmp_=a.r; a.r=hsqt2*(a.r+a.i); a.i=hsqt2*(a.i-tmp_); } + else + { auto tmp_=a.r; a.r=hsqt2*(a.r-a.i); a.i=hsqt2*(a.i+tmp_); } + } +template void ROTX135(T &a) const + { + constexpr T0 hsqt2=T0(0.707106781186547524400844362104849L); + if (fwd) + { auto tmp_=a.r; a.r=hsqt2*(a.i-a.r); a.i=hsqt2*(-tmp_-a.i); } + else + { auto tmp_=a.r; a.r=hsqt2*(-a.r-a.i); a.i=hsqt2*(tmp_-a.i); } + } + +template void pass8 (size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const cmplx * POCKETFFT_RESTRICT wa) const + { + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+8*c)]; }; + auto WA = [wa, ido](size_t x, size_t i) + { return wa[i-1+x*(ido-1)]; }; + + if (ido==1) + for (size_t k=0; k(a3); + + ROTX90(a7); + PMINPLACE(a5,a7); + ROTX45(a5); + ROTX135(a7); + + PM(a0,a4,CC(0,0,k),CC(0,4,k)); + PM(a2,a6,CC(0,2,k),CC(0,6,k)); + PM(CH(0,k,0),CH(0,k,4),a0+a2,a1); + PM(CH(0,k,2),CH(0,k,6),a0-a2,a3); + ROTX90(a6); + PM(CH(0,k,1),CH(0,k,5),a4+a6,a5); + PM(CH(0,k,3),CH(0,k,7),a4-a6,a7); + } + else + for (size_t k=0; k(a3); + + ROTX90(a7); + PMINPLACE(a5,a7); + ROTX45(a5); + ROTX135(a7); + + PM(a0,a4,CC(0,0,k),CC(0,4,k)); + PM(a2,a6,CC(0,2,k),CC(0,6,k)); + PM(CH(0,k,0),CH(0,k,4),a0+a2,a1); + PM(CH(0,k,2),CH(0,k,6),a0-a2,a3); + ROTX90(a6); + PM(CH(0,k,1),CH(0,k,5),a4+a6,a5); + PM(CH(0,k,3),CH(0,k,7),a4-a6,a7); + } + for (size_t i=1; i(a7); + PMINPLACE(a1,a3); + ROTX90(a3); + PMINPLACE(a5,a7); + ROTX45(a5); + ROTX135(a7); + PM(a0,a4,CC(i,0,k),CC(i,4,k)); + PM(a2,a6,CC(i,2,k),CC(i,6,k)); + PMINPLACE(a0,a2); + CH(i,k,0) = a0+a1; + special_mul(a0-a1,WA(3,i),CH(i,k,4)); + special_mul(a2+a3,WA(1,i),CH(i,k,2)); + special_mul(a2-a3,WA(5,i),CH(i,k,6)); + ROTX90(a6); + PMINPLACE(a4,a6); + special_mul(a4+a5,WA(0,i),CH(i,k,1)); + special_mul(a4-a5,WA(4,i),CH(i,k,5)); + special_mul(a6+a7,WA(2,i),CH(i,k,3)); + special_mul(a6-a7,WA(6,i),CH(i,k,7)); + } + } + } + + +#define POCKETFFT_PREP11(idx) \ + T t1 = CC(idx,0,k), t2, t3, t4, t5, t6, t7, t8, t9, t10, t11; \ + PM (t2,t11,CC(idx,1,k),CC(idx,10,k)); \ + PM (t3,t10,CC(idx,2,k),CC(idx, 9,k)); \ + PM (t4,t9 ,CC(idx,3,k),CC(idx, 8,k)); \ + PM (t5,t8 ,CC(idx,4,k),CC(idx, 7,k)); \ + PM (t6,t7 ,CC(idx,5,k),CC(idx, 6,k)); \ + CH(idx,k,0).r=t1.r+t2.r+t3.r+t4.r+t5.r+t6.r; \ + CH(idx,k,0).i=t1.i+t2.i+t3.i+t4.i+t5.i+t6.i; + +#define POCKETFFT_PARTSTEP11a0(u1,u2,x1,x2,x3,x4,x5,y1,y2,y3,y4,y5,out1,out2) \ + { \ + T ca = t1 + t2*x1 + t3*x2 + t4*x3 + t5*x4 +t6*x5, \ + cb; \ + cb.i=y1*t11.r y2*t10.r y3*t9.r y4*t8.r y5*t7.r; \ + cb.r=-(y1*t11.i y2*t10.i y3*t9.i y4*t8.i y5*t7.i ); \ + PM(out1,out2,ca,cb); \ + } +#define POCKETFFT_PARTSTEP11a(u1,u2,x1,x2,x3,x4,x5,y1,y2,y3,y4,y5) \ + POCKETFFT_PARTSTEP11a0(u1,u2,x1,x2,x3,x4,x5,y1,y2,y3,y4,y5,CH(0,k,u1),CH(0,k,u2)) +#define POCKETFFT_PARTSTEP11(u1,u2,x1,x2,x3,x4,x5,y1,y2,y3,y4,y5) \ + { \ + T da,db; \ + POCKETFFT_PARTSTEP11a0(u1,u2,x1,x2,x3,x4,x5,y1,y2,y3,y4,y5,da,db) \ + special_mul(da,WA(u1-1,i),CH(i,k,u1)); \ + special_mul(db,WA(u2-1,i),CH(i,k,u2)); \ + } + +template void pass11 (size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const cmplx * POCKETFFT_RESTRICT wa) const + { + constexpr T0 tw1r= T0(0.8412535328311811688618116489193677L), + tw1i= (fwd ? -1 : 1) * T0(0.5406408174555975821076359543186917L), + tw2r= T0(0.4154150130018864255292741492296232L), + tw2i= (fwd ? -1 : 1) * T0(0.9096319953545183714117153830790285L), + tw3r= T0(-0.1423148382732851404437926686163697L), + tw3i= (fwd ? -1 : 1) * T0(0.9898214418809327323760920377767188L), + tw4r= T0(-0.6548607339452850640569250724662936L), + tw4i= (fwd ? -1 : 1) * T0(0.7557495743542582837740358439723444L), + tw5r= T0(-0.9594929736144973898903680570663277L), + tw5i= (fwd ? -1 : 1) * T0(0.2817325568414296977114179153466169L); + + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+11*c)]; }; + auto WA = [wa, ido](size_t x, size_t i) + { return wa[i-1+x*(ido-1)]; }; + + if (ido==1) + for (size_t k=0; k void passg (size_t ido, size_t ip, + size_t l1, T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const cmplx * POCKETFFT_RESTRICT wa, + const cmplx * POCKETFFT_RESTRICT csarr) const + { + const size_t cdim=ip; + size_t ipph = (ip+1)/2; + size_t idl1 = ido*l1; + + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + auto CC = [cc,ido,cdim](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+cdim*c)]; }; + auto CX = [cc, ido, l1](size_t a, size_t b, size_t c) -> T& + { return cc[a+ido*(b+l1*c)]; }; + auto CX2 = [cc, idl1](size_t a, size_t b) -> T& + { return cc[a+idl1*b]; }; + auto CH2 = [ch, idl1](size_t a, size_t b) -> const T& + { return ch[a+idl1*b]; }; + + arr> wal(ip); + wal[0] = cmplx(1., 0.); + for (size_t i=1; i(csarr[i].r,fwd ? -csarr[i].i : csarr[i].i); + + for (size_t k=0; kip) iwal-=ip; + cmplx xwal=wal[iwal]; + iwal+=l; if (iwal>ip) iwal-=ip; + cmplx xwal2=wal[iwal]; + for (size_t ik=0; ikip) iwal-=ip; + cmplx xwal=wal[iwal]; + for (size_t ik=0; ik(x1,wa[idij],CX(i,k,j)); + idij=(jc-1)*(ido-1)+i-1; + special_mul(x2,wa[idij],CX(i,k,jc)); + } + } + } + } + +template void pass_all(T c[], T0 fct) const + { + if (length==1) { c[0]*=fct; return; } + size_t l1=1; + arr ch(length); + T *p1=c, *p2=ch.data(); + + for(size_t k1=0; k1 (ido, l1, p1, p2, fact[k1].tw); + else if(ip==8) + pass8(ido, l1, p1, p2, fact[k1].tw); + else if(ip==2) + pass2(ido, l1, p1, p2, fact[k1].tw); + else if(ip==3) + pass3 (ido, l1, p1, p2, fact[k1].tw); + else if(ip==5) + pass5 (ido, l1, p1, p2, fact[k1].tw); + else if(ip==7) + pass7 (ido, l1, p1, p2, fact[k1].tw); + else if(ip==11) + pass11 (ido, l1, p1, p2, fact[k1].tw); + else + { + passg(ido, ip, l1, p1, p2, fact[k1].tw, fact[k1].tws); + std::swap(p1,p2); + } + std::swap(p1,p2); + l1=l2; + } + if (p1!=c) + { + if (fct!=1.) + for (size_t i=0; i void exec(T c[], T0 fct, bool fwd) const + { fwd ? pass_all(c, fct) : pass_all(c, fct); } + + private: + POCKETFFT_NOINLINE void factorize() + { + size_t len=length; + while ((len&7)==0) + { add_factor(8); len>>=3; } + while ((len&3)==0) + { add_factor(4); len>>=2; } + if ((len&1)==0) + { + len>>=1; + // factor 2 should be at the front of the factor list + add_factor(2); + std::swap(fact[0].fct, fact.back().fct); + } + for (size_t divisor=3; divisor*divisor<=len; divisor+=2) + while ((len%divisor)==0) + { + add_factor(divisor); + len/=divisor; + } + if (len>1) add_factor(len); + } + + size_t twsize() const + { + size_t twsize=0, l1=1; + for (size_t k=0; k11) + twsize+=ip; + l1*=ip; + } + return twsize; + } + + void comp_twiddle() + { + sincos_2pibyn twiddle(length); + size_t l1=1; + size_t memofs=0; + for (size_t k=0; k11) + { + fact[k].tws=mem.data()+memofs; + memofs+=ip; + for (size_t j=0; j class rfftp + { + private: + struct fctdata + { + size_t fct; + T0 *tw, *tws; + }; + + size_t length; + arr mem; + std::vector fact; + + void add_factor(size_t factor) + { fact.push_back({factor, nullptr, nullptr}); } + +/* (a+ib) = conj(c+id) * (e+if) */ +template inline void MULPM + (T1 &a, T1 &b, T2 c, T2 d, T3 e, T3 f) const + { a=c*e+d*f; b=c*f-d*e; } + +template void radf2 (size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const T0 * POCKETFFT_RESTRICT wa) const + { + auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; }; + auto CC = [cc,ido,l1](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+l1*c)]; }; + auto CH = [ch,ido](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+2*c)]; }; + + for (size_t k=0; k void radf3(size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const T0 * POCKETFFT_RESTRICT wa) const + { + constexpr T0 taur=-0.5, taui=T0(0.8660254037844386467637231707529362L); + + auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; }; + auto CC = [cc,ido,l1](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+l1*c)]; }; + auto CH = [ch,ido](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+3*c)]; }; + + for (size_t k=0; k void radf4(size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const T0 * POCKETFFT_RESTRICT wa) const + { + constexpr T0 hsqt2=T0(0.707106781186547524400844362104849L); + + auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; }; + auto CC = [cc,ido,l1](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+l1*c)]; }; + auto CH = [ch,ido](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+4*c)]; }; + + for (size_t k=0; k void radf5(size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const T0 * POCKETFFT_RESTRICT wa) const + { + constexpr T0 tr11= T0(0.3090169943749474241022934171828191L), + ti11= T0(0.9510565162951535721164393333793821L), + tr12= T0(-0.8090169943749474241022934171828191L), + ti12= T0(0.5877852522924731291687059546390728L); + + auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; }; + auto CC = [cc,ido,l1](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+l1*c)]; }; + auto CH = [ch,ido](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+5*c)]; }; + + for (size_t k=0; k void radfg(size_t ido, size_t ip, size_t l1, + T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const T0 * POCKETFFT_RESTRICT wa, const T0 * POCKETFFT_RESTRICT csarr) const + { + const size_t cdim=ip; + size_t ipph=(ip+1)/2; + size_t idl1 = ido*l1; + + auto CC = [cc,ido,cdim](size_t a, size_t b, size_t c) -> T& + { return cc[a+ido*(b+cdim*c)]; }; + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> const T& + { return ch[a+ido*(b+l1*c)]; }; + auto C1 = [cc,ido,l1] (size_t a, size_t b, size_t c) -> T& + { return cc[a+ido*(b+l1*c)]; }; + auto C2 = [cc,idl1] (size_t a, size_t b) -> T& + { return cc[a+idl1*b]; }; + auto CH2 = [ch,idl1] (size_t a, size_t b) -> T& + { return ch[a+idl1*b]; }; + + if (ido>1) + { + for (size_t j=1, jc=ip-1; j=ip) iang-=ip; + T0 ar1=csarr[2*iang], ai1=csarr[2*iang+1]; + iang+=l; if (iang>=ip) iang-=ip; + T0 ar2=csarr[2*iang], ai2=csarr[2*iang+1]; + iang+=l; if (iang>=ip) iang-=ip; + T0 ar3=csarr[2*iang], ai3=csarr[2*iang+1]; + iang+=l; if (iang>=ip) iang-=ip; + T0 ar4=csarr[2*iang], ai4=csarr[2*iang+1]; + for (size_t ik=0; ik=ip) iang-=ip; + T0 ar1=csarr[2*iang], ai1=csarr[2*iang+1]; + iang+=l; if (iang>=ip) iang-=ip; + T0 ar2=csarr[2*iang], ai2=csarr[2*iang+1]; + for (size_t ik=0; ik=ip) iang-=ip; + T0 ar=csarr[2*iang], ai=csarr[2*iang+1]; + for (size_t ik=0; ik void radb2(size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const T0 * POCKETFFT_RESTRICT wa) const + { + auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; }; + auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+2*c)]; }; + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + + for (size_t k=0; k void radb3(size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const T0 * POCKETFFT_RESTRICT wa) const + { + constexpr T0 taur=-0.5, taui=T0(0.8660254037844386467637231707529362L); + + auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; }; + auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+3*c)]; }; + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + + for (size_t k=0; k void radb4(size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const T0 * POCKETFFT_RESTRICT wa) const + { + constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); + + auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; }; + auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+4*c)]; }; + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + + for (size_t k=0; k void radb5(size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const T0 * POCKETFFT_RESTRICT wa) const + { + constexpr T0 tr11= T0(0.3090169943749474241022934171828191L), + ti11= T0(0.9510565162951535721164393333793821L), + tr12= T0(-0.8090169943749474241022934171828191L), + ti12= T0(0.5877852522924731291687059546390728L); + + auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; }; + auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+5*c)]; }; + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + + for (size_t k=0; k void radbg(size_t ido, size_t ip, size_t l1, + T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const T0 * POCKETFFT_RESTRICT wa, const T0 * POCKETFFT_RESTRICT csarr) const + { + const size_t cdim=ip; + size_t ipph=(ip+1)/ 2; + size_t idl1 = ido*l1; + + auto CC = [cc,ido,cdim](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+cdim*c)]; }; + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + auto C1 = [cc,ido,l1](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+l1*c)]; }; + auto C2 = [cc,idl1](size_t a, size_t b) -> T& + { return cc[a+idl1*b]; }; + auto CH2 = [ch,idl1](size_t a, size_t b) -> T& + { return ch[a+idl1*b]; }; + + for (size_t k=0; kip) iang-=ip; + T0 ar1=csarr[2*iang], ai1=csarr[2*iang+1]; + iang+=l; if(iang>ip) iang-=ip; + T0 ar2=csarr[2*iang], ai2=csarr[2*iang+1]; + iang+=l; if(iang>ip) iang-=ip; + T0 ar3=csarr[2*iang], ai3=csarr[2*iang+1]; + iang+=l; if(iang>ip) iang-=ip; + T0 ar4=csarr[2*iang], ai4=csarr[2*iang+1]; + for (size_t ik=0; ikip) iang-=ip; + T0 ar1=csarr[2*iang], ai1=csarr[2*iang+1]; + iang+=l; if(iang>ip) iang-=ip; + T0 ar2=csarr[2*iang], ai2=csarr[2*iang+1]; + for (size_t ik=0; ikip) iang-=ip; + T0 war=csarr[2*iang], wai=csarr[2*iang+1]; + for (size_t ik=0; ik void copy_and_norm(T *c, T *p1, T0 fct) const + { + if (p1!=c) + { + if (fct!=1.) + for (size_t i=0; i void exec(T c[], T0 fct, bool r2hc) const + { + if (length==1) { c[0]*=fct; return; } + size_t nf=fact.size(); + arr ch(length); + T *p1=c, *p2=ch.data(); + + if (r2hc) + for(size_t k1=0, l1=length; k1>=2; } + if ((len%2)==0) + { + len>>=1; + // factor 2 should be at the front of the factor list + add_factor(2); + std::swap(fact[0].fct, fact.back().fct); + } + for (size_t divisor=3; divisor*divisor<=len; divisor+=2) + while ((len%divisor)==0) + { + add_factor(divisor); + len/=divisor; + } + if (len>1) add_factor(len); + } + + size_t twsize() const + { + size_t twsz=0, l1=1; + for (size_t k=0; k5) twsz+=2*ip; + l1*=ip; + } + return twsz; + } + + void comp_twiddle() + { + sincos_2pibyn twid(length); + size_t l1=1; + T0 *ptr=mem.data(); + for (size_t k=0; k5) // special factors required by *g functions + { + fact[k].tws=ptr; ptr+=2*ip; + fact[k].tws[0] = 1.; + fact[k].tws[1] = 0.; + for (size_t i=2, ic=2*ip-2; i<=ic; i+=2, ic-=2) + { + fact[k].tws[i ] = twid[i/2*(length/ip)].r; + fact[k].tws[i+1] = twid[i/2*(length/ip)].i; + fact[k].tws[ic] = twid[i/2*(length/ip)].r; + fact[k].tws[ic+1] = -twid[i/2*(length/ip)].i; + } + } + l1*=ip; + } + } + + public: + POCKETFFT_NOINLINE rfftp(size_t length_) + : length(length_) + { + if (length==0) throw std::runtime_error("zero-length FFT requested"); + if (length==1) return; + factorize(); + mem.resize(twsize()); + comp_twiddle(); + } +}; + +// +// complex Bluestein transforms +// + +template class fftblue + { + private: + size_t n, n2; + cfftp plan; + arr> mem; + cmplx *bk, *bkf; + + template void fft(cmplx c[], T0 fct) const + { + arr> akf(n2); + + /* initialize a_k and FFT it */ + for (size_t m=0; m(c[m],bk[m],akf[m]); + auto zero = akf[0]*T0(0); + for (size_t m=n; m(bkf[0]); + for (size_t m=1; m<(n2+1)/2; ++m) + { + akf[m] = akf[m].template special_mul(bkf[m]); + akf[n2-m] = akf[n2-m].template special_mul(bkf[m]); + } + if ((n2&1)==0) + akf[n2/2] = akf[n2/2].template special_mul(bkf[n2/2]); + + /* inverse FFT */ + plan.exec (akf.data(),1.,false); + + /* multiply by b_k */ + for (size_t m=0; m(bk[m])*fct; + } + + public: + POCKETFFT_NOINLINE fftblue(size_t length) + : n(length), n2(util::good_size_cmplx(n*2-1)), plan(n2), mem(n+n2/2+1), + bk(mem.data()), bkf(mem.data()+n) + { + /* initialize b_k */ + sincos_2pibyn tmp(2*n); + bk[0].Set(1, 0); + + size_t coeff=0; + for (size_t m=1; m=2*n) coeff-=2*n; + bk[m] = tmp[coeff]; + } + + /* initialize the zero-padded, Fourier transformed b_k. Add normalisation. */ + arr> tbkf(n2); + T0 xn2 = T0(1)/T0(n2); + tbkf[0] = bk[0]*xn2; + for (size_t m=1; m void exec(cmplx c[], T0 fct, bool fwd) const + { fwd ? fft(c,fct) : fft(c,fct); } + + template void exec_r(T c[], T0 fct, bool fwd) + { + arr> tmp(n); + if (fwd) + { + auto zero = T0(0)*c[0]; + for (size_t m=0; m(tmp.data(),fct); + c[0] = tmp[0].r; + std::copy_n (&tmp[1].r, n-1, &c[1]); + } + else + { + tmp[0].Set(c[0],c[0]*0); + std::copy_n (c+1, n-1, &tmp[1].r); + if ((n&1)==0) tmp[n/2].i=T0(0)*c[0]; + for (size_t m=1; 2*m(tmp.data(),fct); + for (size_t m=0; m class pocketfft_c + { + private: + std::unique_ptr> packplan; + std::unique_ptr> blueplan; + size_t len; + + public: + POCKETFFT_NOINLINE pocketfft_c(size_t length) + : len(length) + { + if (length==0) throw std::runtime_error("zero-length FFT requested"); + size_t tmp = (length<50) ? 0 : util::largest_prime_factor(length); + if (tmp*tmp <= length) + { + packplan=std::unique_ptr>(new cfftp(length)); + return; + } + double comp1 = util::cost_guess(length); + double comp2 = 2*util::cost_guess(util::good_size_cmplx(2*length-1)); + comp2*=1.5; /* fudge factor that appears to give good overall performance */ + if (comp2>(new fftblue(length)); + else + packplan=std::unique_ptr>(new cfftp(length)); + } + + template POCKETFFT_NOINLINE void exec(cmplx c[], T0 fct, bool fwd) const + { packplan ? packplan->exec(c,fct,fwd) : blueplan->exec(c,fct,fwd); } + + size_t length() const { return len; } + }; + +// +// flexible (FFTPACK/Bluestein) real-valued 1D transform +// + +template class pocketfft_r + { + private: + std::unique_ptr> packplan; + std::unique_ptr> blueplan; + size_t len; + + public: + POCKETFFT_NOINLINE pocketfft_r(size_t length) + : len(length) + { + if (length==0) throw std::runtime_error("zero-length FFT requested"); + size_t tmp = (length<50) ? 0 : util::largest_prime_factor(length); + if (tmp*tmp <= length) + { + packplan=std::unique_ptr>(new rfftp(length)); + return; + } + double comp1 = 0.5*util::cost_guess(length); + double comp2 = 2*util::cost_guess(util::good_size_cmplx(2*length-1)); + comp2*=1.5; /* fudge factor that appears to give good overall performance */ + if (comp2>(new fftblue(length)); + else + packplan=std::unique_ptr>(new rfftp(length)); + } + + template POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool fwd) const + { packplan ? packplan->exec(c,fct,fwd) : blueplan->exec_r(c,fct,fwd); } + + size_t length() const { return len; } + }; + + +// +// sine/cosine transforms +// + +template class T_dct1 + { + private: + pocketfft_r fftplan; + + public: + POCKETFFT_NOINLINE T_dct1(size_t length) + : fftplan(2*(length-1)) {} + + template POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho, + int /*type*/, bool /*cosine*/) const + { + constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); + size_t N=fftplan.length(), n=N/2+1; + if (ortho) + { c[0]*=sqrt2; c[n-1]*=sqrt2; } + arr tmp(N); + tmp[0] = c[0]; + for (size_t i=1; i class T_dst1 + { + private: + pocketfft_r fftplan; + + public: + POCKETFFT_NOINLINE T_dst1(size_t length) + : fftplan(2*(length+1)) {} + + template POCKETFFT_NOINLINE void exec(T c[], T0 fct, + bool /*ortho*/, int /*type*/, bool /*cosine*/) const + { + size_t N=fftplan.length(), n=N/2-1; + arr tmp(N); + tmp[0] = tmp[n+1] = c[0]*0; + for (size_t i=0; i class T_dcst23 + { + private: + pocketfft_r fftplan; + std::vector twiddle; + + public: + POCKETFFT_NOINLINE T_dcst23(size_t length) + : fftplan(length), twiddle(length) + { + sincos_2pibyn tw(4*length); + for (size_t i=0; i POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho, + int type, bool cosine) const + { + constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); + size_t N=length(); + size_t NS2 = (N+1)/2; + if (type==2) + { + if (!cosine) + for (size_t k=1; k class T_dcst4 + { + private: + size_t N; + std::unique_ptr> fft; + std::unique_ptr> rfft; + arr> C2; + + public: + POCKETFFT_NOINLINE T_dcst4(size_t length) + : N(length), + fft((N&1) ? nullptr : new pocketfft_c(N/2)), + rfft((N&1)? new pocketfft_r(N) : nullptr), + C2((N&1) ? 0 : N/2) + { + if ((N&1)==0) + { + sincos_2pibyn tw(16*N); + for (size_t i=0; i POCKETFFT_NOINLINE void exec(T c[], T0 fct, + bool /*ortho*/, int /*type*/, bool cosine) const + { + size_t n2 = N/2; + if (!cosine) + for (size_t k=0, kc=N-1; k y(N); + { + size_t i=0, m=n2; + for (; mexec(y.data(), fct, true); + { + auto SGN = [](size_t i) + { + constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); + return (i&2) ? -sqrt2 : sqrt2; + }; + c[n2] = y[0]*SGN(n2+1); + size_t i=0, i1=1, k=1; + for (; k> y(n2); + for(size_t i=0; iexec(y.data(), fct, true); + for(size_t i=0, ic=n2-1; i std::shared_ptr get_plan(size_t length) + { +#if POCKETFFT_CACHE_SIZE==0 + return std::make_shared(length); +#else + constexpr size_t nmax=POCKETFFT_CACHE_SIZE; + static std::array, nmax> cache; + static std::array last_access{{0}}; + static size_t access_counter = 0; + static std::mutex mut; + + auto find_in_cache = [&]() -> std::shared_ptr + { + for (size_t i=0; ilength()==length)) + { + // no need to update if this is already the most recent entry + if (last_access[i]!=access_counter) + { + last_access[i] = ++access_counter; + // Guard against overflow + if (access_counter == 0) + last_access.fill(0); + } + return cache[i]; + } + + return nullptr; + }; + + { + std::lock_guard lock(mut); + auto p = find_in_cache(); + if (p) return p; + } + auto plan = std::make_shared(length); + { + std::lock_guard lock(mut); + auto p = find_in_cache(); + if (p) return p; + + size_t lru = 0; + for (size_t i=1; i class cndarr: public arr_info + { + protected: + const char *d; + + public: + cndarr(const void *data_, const shape_t &shape_, const stride_t &stride_) + : arr_info(shape_, stride_), + d(reinterpret_cast(data_)) {} + const T &operator[](ptrdiff_t ofs) const + { return *reinterpret_cast(d+ofs); } + }; + +template class ndarr: public cndarr + { + public: + ndarr(void *data_, const shape_t &shape_, const stride_t &stride_) + : cndarr::cndarr(const_cast(data_), shape_, stride_) + {} + T &operator[](ptrdiff_t ofs) + { return *reinterpret_cast(const_cast(cndarr::d+ofs)); } + }; + +template class multi_iter + { + private: + shape_t pos; + const arr_info &iarr, &oarr; + ptrdiff_t p_ii, p_i[N], str_i, p_oi, p_o[N], str_o; + size_t idim, rem; + + void advance_i() + { + for (int i_=int(pos.size())-1; i_>=0; --i_) + { + auto i = size_t(i_); + if (i==idim) continue; + p_ii += iarr.stride(i); + p_oi += oarr.stride(i); + if (++pos[i] < iarr.shape(i)) + return; + pos[i] = 0; + p_ii -= ptrdiff_t(iarr.shape(i))*iarr.stride(i); + p_oi -= ptrdiff_t(oarr.shape(i))*oarr.stride(i); + } + } + + public: + multi_iter(const arr_info &iarr_, const arr_info &oarr_, size_t idim_) + : pos(iarr_.ndim(), 0), iarr(iarr_), oarr(oarr_), p_ii(0), + str_i(iarr.stride(idim_)), p_oi(0), str_o(oarr.stride(idim_)), + idim(idim_), rem(iarr.size()/iarr.shape(idim)) + { + auto nshares = threading::num_threads(); + if (nshares==1) return; + if (nshares==0) throw std::runtime_error("can't run with zero threads"); + auto myshare = threading::thread_id(); + if (myshare>=nshares) throw std::runtime_error("impossible share requested"); + size_t nbase = rem/nshares; + size_t additional = rem%nshares; + size_t lo = myshare*nbase + ((myshare=0; --i_) + { + auto i = size_t(i_); + p += arr.stride(i); + if (++pos[i] < arr.shape(i)) + return; + pos[i] = 0; + p -= ptrdiff_t(arr.shape(i))*arr.stride(i); + } + } + ptrdiff_t ofs() const { return p; } + size_t remaining() const { return rem; } + }; + +class rev_iter + { + private: + shape_t pos; + const arr_info &arr; + std::vector rev_axis; + std::vector rev_jump; + size_t last_axis, last_size; + shape_t shp; + ptrdiff_t p, rp; + size_t rem; + + public: + rev_iter(const arr_info &arr_, const shape_t &axes) + : pos(arr_.ndim(), 0), arr(arr_), rev_axis(arr_.ndim(), 0), + rev_jump(arr_.ndim(), 1), p(0), rp(0) + { + for (auto ax: axes) + rev_axis[ax]=1; + last_axis = axes.back(); + last_size = arr.shape(last_axis)/2 + 1; + shp = arr.shape(); + shp[last_axis] = last_size; + rem=1; + for (auto i: shp) + rem *= i; + } + void advance() + { + --rem; + for (int i_=int(pos.size())-1; i_>=0; --i_) + { + auto i = size_t(i_); + p += arr.stride(i); + if (!rev_axis[i]) + rp += arr.stride(i); + else + { + rp -= arr.stride(i); + if (rev_jump[i]) + { + rp += ptrdiff_t(arr.shape(i))*arr.stride(i); + rev_jump[i] = 0; + } + } + if (++pos[i] < shp[i]) + return; + pos[i] = 0; + p -= ptrdiff_t(shp[i])*arr.stride(i); + if (rev_axis[i]) + { + rp -= ptrdiff_t(arr.shape(i)-shp[i])*arr.stride(i); + rev_jump[i] = 1; + } + else + rp -= ptrdiff_t(shp[i])*arr.stride(i); + } + } + ptrdiff_t ofs() const { return p; } + ptrdiff_t rev_ofs() const { return rp; } + size_t remaining() const { return rem; } + }; + +template struct VTYPE {}; +template using vtype_t = typename VTYPE::type; + +#ifndef POCKETFFT_NO_VECTORS +template<> struct VTYPE + { + using type = float __attribute__ ((vector_size (VLEN::val*sizeof(float)))); + }; +template<> struct VTYPE + { + using type = double __attribute__ ((vector_size (VLEN::val*sizeof(double)))); + }; +template<> struct VTYPE + { + using type = long double __attribute__ ((vector_size (VLEN::val*sizeof(long double)))); + }; +#endif + +template arr alloc_tmp(const shape_t &shape, + size_t axsize, size_t elemsize) + { + auto othersize = util::prod(shape)/axsize; + auto tmpsize = axsize*((othersize>=VLEN::val) ? VLEN::val : 1); + return arr(tmpsize*elemsize); + } +template arr alloc_tmp(const shape_t &shape, + const shape_t &axes, size_t elemsize) + { + size_t fullsize=util::prod(shape); + size_t tmpsize=0; + for (size_t i=0; i=VLEN::val) ? VLEN::val : 1); + if (sz>tmpsize) tmpsize=sz; + } + return arr(tmpsize*elemsize); + } + +template void copy_input(const multi_iter &it, + const cndarr> &src, cmplx> *POCKETFFT_RESTRICT dst) + { + for (size_t i=0; i void copy_input(const multi_iter &it, + const cndarr &src, vtype_t *POCKETFFT_RESTRICT dst) + { + for (size_t i=0; i void copy_input(const multi_iter &it, + const cndarr &src, T *POCKETFFT_RESTRICT dst) + { + if (dst == &src[it.iofs(0)]) return; // in-place + for (size_t i=0; i void copy_output(const multi_iter &it, + const cmplx> *POCKETFFT_RESTRICT src, ndarr> &dst) + { + for (size_t i=0; i void copy_output(const multi_iter &it, + const vtype_t *POCKETFFT_RESTRICT src, ndarr &dst) + { + for (size_t i=0; i void copy_output(const multi_iter &it, + const T *POCKETFFT_RESTRICT src, ndarr &dst) + { + if (src == &dst[it.oofs(0)]) return; // in-place + for (size_t i=0; i struct add_vec { using type = vtype_t; }; +template struct add_vec> + { using type = cmplx>; }; +template using add_vec_t = typename add_vec::type; + +template +POCKETFFT_NOINLINE void general_nd(const cndarr &in, ndarr &out, + const shape_t &axes, T0 fct, size_t nthreads, const Exec & exec, + const bool allow_inplace=true) + { + std::shared_ptr plan; + + for (size_t iax=0; iaxlength())) + plan = get_plan(len); + + threading::thread_map( + util::thread_count(nthreads, in.shape(), axes[iax], VLEN::val), + [&] { + constexpr auto vlen = VLEN::val; + auto storage = alloc_tmp(in.shape(), len, sizeof(T)); + const auto &tin(iax==0? in : out); + multi_iter it(tin, out, axes[iax]); +#ifndef POCKETFFT_NO_VECTORS + if (vlen>1) + while (it.remaining()>=vlen) + { + it.advance(vlen); + auto tdatav = reinterpret_cast *>(storage.data()); + exec(it, tin, out, tdatav, *plan, fct); + } +#endif + while (it.remaining()>0) + { + it.advance(1); + auto buf = allow_inplace && it.stride_out() == sizeof(T) ? + &out[it.oofs(0)] : reinterpret_cast(storage.data()); + exec(it, tin, out, buf, *plan, fct); + } + }); // end of parallel region + fct = T0(1); // factor has been applied, use 1 for remaining axes + } + } + +struct ExecC2C + { + bool forward; + + template void operator () ( + const multi_iter &it, const cndarr> &in, + ndarr> &out, T * buf, const pocketfft_c &plan, T0 fct) const + { + copy_input(it, in, buf); + plan.exec(buf, fct, forward); + copy_output(it, buf, out); + } + }; + +template void copy_hartley(const multi_iter &it, + const vtype_t *POCKETFFT_RESTRICT src, ndarr &dst) + { + for (size_t j=0; j void copy_hartley(const multi_iter &it, + const T *POCKETFFT_RESTRICT src, ndarr &dst) + { + dst[it.oofs(0)] = src[0]; + size_t i=1, i1=1, i2=it.length_out()-1; + for (i=1; i void operator () ( + const multi_iter &it, const cndarr &in, ndarr &out, + T * buf, const pocketfft_r &plan, T0 fct) const + { + copy_input(it, in, buf); + plan.exec(buf, fct, true); + copy_hartley(it, buf, out); + } + }; + +struct ExecDcst + { + bool ortho; + int type; + bool cosine; + + template + void operator () (const multi_iter &it, const cndarr &in, + ndarr &out, T * buf, const Tplan &plan, T0 fct) const + { + copy_input(it, in, buf); + plan.exec(buf, fct, ortho, type, cosine); + copy_output(it, buf, out); + } + }; + +template POCKETFFT_NOINLINE void general_r2c( + const cndarr &in, ndarr> &out, size_t axis, bool forward, T fct, + size_t nthreads) + { + auto plan = get_plan>(in.shape(axis)); + size_t len=in.shape(axis); + threading::thread_map( + util::thread_count(nthreads, in.shape(), axis, VLEN::val), + [&] { + constexpr auto vlen = VLEN::val; + auto storage = alloc_tmp(in.shape(), len, sizeof(T)); + multi_iter it(in, out, axis); +#ifndef POCKETFFT_NO_VECTORS + if (vlen>1) + while (it.remaining()>=vlen) + { + it.advance(vlen); + auto tdatav = reinterpret_cast *>(storage.data()); + copy_input(it, in, tdatav); + plan->exec(tdatav, fct, true); + for (size_t j=0; j0) + { + it.advance(1); + auto tdata = reinterpret_cast(storage.data()); + copy_input(it, in, tdata); + plan->exec(tdata, fct, true); + out[it.oofs(0)].Set(tdata[0]); + size_t i=1, ii=1; + if (forward) + for (; i POCKETFFT_NOINLINE void general_c2r( + const cndarr> &in, ndarr &out, size_t axis, bool forward, T fct, + size_t nthreads) + { + auto plan = get_plan>(out.shape(axis)); + size_t len=out.shape(axis); + threading::thread_map( + util::thread_count(nthreads, in.shape(), axis, VLEN::val), + [&] { + constexpr auto vlen = VLEN::val; + auto storage = alloc_tmp(out.shape(), len, sizeof(T)); + multi_iter it(in, out, axis); +#ifndef POCKETFFT_NO_VECTORS + if (vlen>1) + while (it.remaining()>=vlen) + { + it.advance(vlen); + auto tdatav = reinterpret_cast *>(storage.data()); + for (size_t j=0; jexec(tdatav, fct, false); + copy_output(it, tdatav, out); + } +#endif + while (it.remaining()>0) + { + it.advance(1); + auto tdata = reinterpret_cast(storage.data()); + tdata[0]=in[it.iofs(0)].r; + { + size_t i=1, ii=1; + if (forward) + for (; iexec(tdata, fct, false); + copy_output(it, tdata, out); + } + }); // end of parallel region + } + +struct ExecR2R + { + bool r2h, forward; + + template void operator () ( + const multi_iter &it, const cndarr &in, ndarr &out, T * buf, + const pocketfft_r &plan, T0 fct) const + { + copy_input(it, in, buf); + if ((!r2h) && forward) + for (size_t i=2; i void c2c(const shape_t &shape, const stride_t &stride_in, + const stride_t &stride_out, const shape_t &axes, bool forward, + const std::complex *data_in, std::complex *data_out, T fct, + size_t nthreads=1) + { + if (util::prod(shape)==0) return; + util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes); + cndarr> ain(data_in, shape, stride_in); + ndarr> aout(data_out, shape, stride_out); + general_nd>(ain, aout, axes, fct, nthreads, ExecC2C{forward}); + } + +template void dct(const shape_t &shape, + const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, + int type, const T *data_in, T *data_out, T fct, bool ortho, size_t nthreads=1) + { + if ((type<1) || (type>4)) throw std::invalid_argument("invalid DCT type"); + if (util::prod(shape)==0) return; + util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes); + cndarr ain(data_in, shape, stride_in); + ndarr aout(data_out, shape, stride_out); + const ExecDcst exec{ortho, type, true}; + if (type==1) + general_nd>(ain, aout, axes, fct, nthreads, exec); + else if (type==4) + general_nd>(ain, aout, axes, fct, nthreads, exec); + else + general_nd>(ain, aout, axes, fct, nthreads, exec); + } + +template void dst(const shape_t &shape, + const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, + int type, const T *data_in, T *data_out, T fct, bool ortho, size_t nthreads=1) + { + if ((type<1) || (type>4)) throw std::invalid_argument("invalid DST type"); + if (util::prod(shape)==0) return; + util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes); + cndarr ain(data_in, shape, stride_in); + ndarr aout(data_out, shape, stride_out); + const ExecDcst exec{ortho, type, false}; + if (type==1) + general_nd>(ain, aout, axes, fct, nthreads, exec); + else if (type==4) + general_nd>(ain, aout, axes, fct, nthreads, exec); + else + general_nd>(ain, aout, axes, fct, nthreads, exec); + } + +template void r2c(const shape_t &shape_in, + const stride_t &stride_in, const stride_t &stride_out, size_t axis, + bool forward, const T *data_in, std::complex *data_out, T fct, + size_t nthreads=1) + { + if (util::prod(shape_in)==0) return; + util::sanity_check(shape_in, stride_in, stride_out, false, axis); + cndarr ain(data_in, shape_in, stride_in); + shape_t shape_out(shape_in); + shape_out[axis] = shape_in[axis]/2 + 1; + ndarr> aout(data_out, shape_out, stride_out); + general_r2c(ain, aout, axis, forward, fct, nthreads); + } + +template void r2c(const shape_t &shape_in, + const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, + bool forward, const T *data_in, std::complex *data_out, T fct, + size_t nthreads=1) + { + if (util::prod(shape_in)==0) return; + util::sanity_check(shape_in, stride_in, stride_out, false, axes); + r2c(shape_in, stride_in, stride_out, axes.back(), forward, data_in, data_out, + fct, nthreads); + if (axes.size()==1) return; + + shape_t shape_out(shape_in); + shape_out[axes.back()] = shape_in[axes.back()]/2 + 1; + auto newaxes = shape_t{axes.begin(), --axes.end()}; + c2c(shape_out, stride_out, stride_out, newaxes, forward, data_out, data_out, + T(1), nthreads); + } + +template void c2r(const shape_t &shape_out, + const stride_t &stride_in, const stride_t &stride_out, size_t axis, + bool forward, const std::complex *data_in, T *data_out, T fct, + size_t nthreads=1) + { + if (util::prod(shape_out)==0) return; + util::sanity_check(shape_out, stride_in, stride_out, false, axis); + shape_t shape_in(shape_out); + shape_in[axis] = shape_out[axis]/2 + 1; + cndarr> ain(data_in, shape_in, stride_in); + ndarr aout(data_out, shape_out, stride_out); + general_c2r(ain, aout, axis, forward, fct, nthreads); + } + +template void c2r(const shape_t &shape_out, + const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, + bool forward, const std::complex *data_in, T *data_out, T fct, + size_t nthreads=1) + { + if (util::prod(shape_out)==0) return; + if (axes.size()==1) + return c2r(shape_out, stride_in, stride_out, axes[0], forward, + data_in, data_out, fct, nthreads); + util::sanity_check(shape_out, stride_in, stride_out, false, axes); + auto shape_in = shape_out; + shape_in[axes.back()] = shape_out[axes.back()]/2 + 1; + auto nval = util::prod(shape_in); + stride_t stride_inter(shape_in.size()); + stride_inter.back() = sizeof(cmplx); + for (int i=int(shape_in.size())-2; i>=0; --i) + stride_inter[size_t(i)] = + stride_inter[size_t(i+1)]*ptrdiff_t(shape_in[size_t(i+1)]); + arr> tmp(nval); + auto newaxes = shape_t{axes.begin(), --axes.end()}; + c2c(shape_in, stride_in, stride_inter, newaxes, forward, data_in, tmp.data(), + T(1), nthreads); + c2r(shape_out, stride_inter, stride_out, axes.back(), forward, + tmp.data(), data_out, fct, nthreads); + } + +template void r2r_fftpack(const shape_t &shape, + const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, + bool real2hermitian, bool forward, const T *data_in, T *data_out, T fct, + size_t nthreads=1) + { + if (util::prod(shape)==0) return; + util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes); + cndarr ain(data_in, shape, stride_in); + ndarr aout(data_out, shape, stride_out); + general_nd>(ain, aout, axes, fct, nthreads, + ExecR2R{real2hermitian, forward}); + } + +template void r2r_separable_hartley(const shape_t &shape, + const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, + const T *data_in, T *data_out, T fct, size_t nthreads=1) + { + if (util::prod(shape)==0) return; + util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes); + cndarr ain(data_in, shape, stride_in); + ndarr aout(data_out, shape, stride_out); + general_nd>(ain, aout, axes, fct, nthreads, ExecHartley{}, + false); + } + +template void r2r_genuine_hartley(const shape_t &shape, + const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, + const T *data_in, T *data_out, T fct, size_t nthreads=1) + { + if (util::prod(shape)==0) return; + if (axes.size()==1) + return r2r_separable_hartley(shape, stride_in, stride_out, axes, data_in, + data_out, fct, nthreads); + util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes); + shape_t tshp(shape); + tshp[axes.back()] = tshp[axes.back()]/2+1; + arr> tdata(util::prod(tshp)); + stride_t tstride(shape.size()); + tstride.back()=sizeof(std::complex); + for (size_t i=tstride.size()-1; i>0; --i) + tstride[i-1]=tstride[i]*ptrdiff_t(tshp[i]); + r2c(shape, stride_in, tstride, axes, true, data_in, tdata.data(), fct, nthreads); + cndarr> atmp(tdata.data(), tshp, tstride); + ndarr aout(data_out, shape, stride_out); + simple_iter iin(atmp); + rev_iter iout(aout, axes); + while(iin.remaining()>0) + { + auto v = atmp[iin.ofs()]; + aout[iout.ofs()] = v.r+v.i; + aout[iout.rev_ofs()] = v.r-v.i; + iin.advance(); iout.advance(); + } + } + +} // namespace detail + +using detail::FORWARD; +using detail::BACKWARD; +using detail::shape_t; +using detail::stride_t; +using detail::c2c; +using detail::c2r; +using detail::r2c; +using detail::r2r_fftpack; +using detail::r2r_separable_hartley; +using detail::r2r_genuine_hartley; +using detail::dct; +using detail::dst; + +} // namespace pocketfft + +#undef POCKETFFT_NOINLINE +#undef POCKETFFT_RESTRICT + +#endif // POCKETFFT_HDRONLY_H diff --git a/Source/Cxxmlx/include/mlx/allocator.h b/Source/Cxxmlx/include/mlx/allocator.h new file mode 100644 index 00000000..824deac2 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/allocator.h @@ -0,0 +1,75 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include + +#include "mlx/api.h" + +namespace mlx::core::allocator { + +// Simple wrapper around buffer pointers +// WARNING: Only Buffer objects constructed from and those that wrap +// raw pointers from mlx::allocator are supported. +class MLX_API Buffer { + private: + void* ptr_; + + public: + explicit Buffer(void* ptr) : ptr_(ptr) {}; + + // Get the raw data pointer from the buffer + void* raw_ptr(); + + // Get the buffer pointer from the buffer + const void* ptr() const { + return ptr_; + }; + void* ptr() { + return ptr_; + }; +}; + +class MLX_API Allocator { + /** Abstract base class for a memory allocator. */ + public: + virtual Buffer malloc(size_t size) = 0; + virtual void free(Buffer buffer) = 0; + virtual size_t size(Buffer buffer) const = 0; + virtual Buffer make_buffer(void* ptr, size_t size) { + return Buffer{nullptr}; + }; + virtual void release(Buffer buffer) {} + + Allocator() = default; + Allocator(const Allocator& other) = delete; + Allocator(Allocator&& other) = delete; + Allocator& operator=(const Allocator& other) = delete; + Allocator& operator=(Allocator&& other) = delete; + virtual ~Allocator() = default; +}; + +MLX_API Allocator& allocator(); + +inline Buffer malloc(size_t size) { + return allocator().malloc(size); +} + +inline void free(Buffer buffer) { + allocator().free(buffer); +} + +// Make a Buffer from a raw pointer of the given size without a copy. If a +// no-copy conversion is not possible then the returned buffer.ptr() will be +// nullptr. Any buffer created with this function must be released with +// release(buffer) +inline Buffer make_buffer(void* ptr, size_t size) { + return allocator().make_buffer(ptr, size); +}; + +// Release a buffer from the allocator made with make_buffer +inline void release(Buffer buffer) { + allocator().release(buffer); +} + +} // namespace mlx::core::allocator diff --git a/Source/Cxxmlx/include/mlx/api.h b/Source/Cxxmlx/include/mlx/api.h new file mode 100644 index 00000000..8aed0910 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/api.h @@ -0,0 +1,29 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +// MLX_API macro for controlling symbol visibility, must add for public APIs. +// +// Usage: +// MLX_API void some_function(...); +// class MLX_API SomeClass { ... }; + +#if defined(MLX_STATIC) + +// Static library build - no import/export decorations needed +#define MLX_API + +#else + +// Shared library build. +#if defined(_WIN32) +#if defined(MLX_EXPORT) +#define MLX_API __declspec(dllexport) +#else +#define MLX_API __declspec(dllimport) +#endif // defined(MLX_EXPORT) +#else +#define MLX_API __attribute__((visibility("default"))) +#endif // defined(_WIN32) + +#endif // defined(MLX_STATIC) diff --git a/Source/Cxxmlx/include/mlx/array.h b/Source/Cxxmlx/include/mlx/array.h new file mode 100644 index 00000000..60d5e50b --- /dev/null +++ b/Source/Cxxmlx/include/mlx/array.h @@ -0,0 +1,647 @@ +// Copyright © 2023 Apple Inc. +#pragma once + +#include +#include +#include +#include +#include + +#include "mlx/allocator.h" +#include "mlx/api.h" +#include "mlx/dtype.h" +#include "mlx/event.h" +#include "mlx/small_vector.h" + +namespace mlx::core { + +// Forward declaration +class Primitive; + +using Deleter = std::function; +using ShapeElem = int32_t; +using Shape = SmallVector; +using Strides = SmallVector; + +class MLX_API array { + /* An array is really a node in a graph. It contains a shared ArrayDesc + * object */ + + public: + /** Construct a scalar array with zero dimensions. */ + template + explicit array(T val, Dtype dtype = TypeToDtype()); + + /* Special case since std::complex can't be implicitly converted to other + * types. */ + explicit array(const std::complex& val, Dtype dtype = complex64); + + template + explicit array( + It data, + Shape shape, + Dtype dtype = + TypeToDtype::value_type>()); + + template + explicit array(std::initializer_list data, Dtype dtype = TypeToDtype()); + + /* Special case so empty lists default to float32. */ + explicit array(std::initializer_list data); + + /* Special case so array({}, type) is an empty array. */ + explicit array(std::initializer_list data, Dtype dtype); + + template + explicit array( + std::initializer_list data, + Shape shape, + Dtype dtype = TypeToDtype()); + + /* Build an array from a raw pointer. The constructor will attempt to use the + * input data without a copy. The deleter will be called when the array no + * longer needs the underlying memory - after the array is destroyed in the + * no-copy case and after the copy otherwise. */ + explicit array( + void* data, + Shape shape, + Dtype dtype, + const std::function& deleter); + + /* Build an array from a buffer */ + explicit array( + allocator::Buffer data, + Shape shape, + Dtype dtype, + Deleter deleter = allocator::free); + + /** Assignment to rvalue does not compile. */ + array& operator=(const array& other) && = delete; + array& operator=(array&& other) && = delete; + + /** Default copy and move constructors otherwise. */ + array& operator=(array&& other) & = default; + array(const array& other) = default; + array(array&& other) = default; + + array& operator=(const array& other) & { + if (this->id() != other.id()) { + this->array_desc_ = other.array_desc_; + } + return *this; + } + + /** The size of the array's datatype in bytes. */ + size_t itemsize() const { + return size_of(dtype()); + } + + /** The number of elements in the array. */ + size_t size() const { + return array_desc_->size; + } + + /** The number of bytes in the array. */ + size_t nbytes() const { + return size() * itemsize(); + } + + /** The number of dimensions of the array. */ + size_t ndim() const { + return array_desc_->shape.size(); + } + + /** The shape of the array as a vector of integers. */ + const Shape& shape() const { + return array_desc_->shape; + } + + /** + * Get the size of the corresponding dimension. + * + * This function supports negative indexing and provides + * bounds checking. */ + auto shape(int dim) const { + return shape().at(dim < 0 ? dim + static_cast(ndim()) : dim); + } + + /** The strides of the array. */ + const Strides& strides() const { + return array_desc_->strides; + } + + /** + * Get the stride of the corresponding dimension. + * + * This function supports negative indexing and provides + * bounds checking. */ + auto strides(int dim) const { + return strides().at(dim < 0 ? dim + static_cast(ndim()) : dim); + } + + /** Get the arrays data type. */ + Dtype dtype() const { + return array_desc_->dtype; + } + + /** Evaluate the array. */ + void eval(); + + /** Get the value from a scalar array. */ + template + T item(); + + template + T item() const; + + struct MLX_API ArrayIterator { + using iterator_category = std::random_access_iterator_tag; + using difference_type = size_t; + using value_type = const array; + using reference = value_type; + + explicit ArrayIterator(const array& arr, int idx = 0); + + reference operator*() const; + + ArrayIterator& operator+(difference_type diff) { + idx += diff; + return *this; + } + + ArrayIterator& operator++() { + idx++; + return *this; + } + + friend bool operator==(const ArrayIterator& a, const ArrayIterator& b) { + return a.arr.id() == b.arr.id() && a.idx == b.idx; + } + friend bool operator!=(const ArrayIterator& a, const ArrayIterator& b) { + return !(a == b); + } + + private: + const array& arr; + int idx; + }; + + ArrayIterator begin() const { + return ArrayIterator(*this); + } + ArrayIterator end() const { + return ArrayIterator(*this, shape(0)); + } + + /** + * The following methods should be used with caution. + * They are intended for use by the backend implementation and the + * API may change. + */ + + array( + Shape shape, + Dtype dtype, + std::shared_ptr primitive, + std::vector inputs); + + static std::vector make_arrays( + std::vector shapes, + const std::vector& dtypes, + const std::shared_ptr& primitive, + const std::vector& inputs); + + /** + * Get a new array that refers to the same data as the input but with a + * non-owning pointer to it. Note the array is detached from the graph and has + * no inputs, siblings or primitive. + */ + static array unsafe_weak_copy(const array& other); + + /** A unique identifier for an array. */ + std::uintptr_t id() const { + return reinterpret_cast(array_desc_.get()); + } + + /** A unique identifier for an arrays primitive. */ + std::uintptr_t primitive_id() const { + return reinterpret_cast(array_desc_->primitive.get()); + } + + struct Data { + allocator::Buffer buffer; + Deleter d; + Data(allocator::Buffer buffer, Deleter d = allocator::free) + : buffer(buffer), d(d) {} + // Not copyable + Data(const Data& d) = delete; + Data& operator=(const Data& d) = delete; + Data(Data&& o) : buffer(o.buffer), d(o.d) { + o.buffer = allocator::Buffer(nullptr); + o.d = [](allocator::Buffer) {}; + } + ~Data() { + d(buffer); + } + }; + + struct Flags { + // True iff there are no gaps in the underlying data. Each item + // in the underlying data buffer belongs to at least one index. + // + // True iff: + // prod(shape[i] for i in range(ndim) if strides[i] > 0) == data_size() + bool contiguous : 1; + + // True iff: + // strides[-1] == 1 and + // all(strides[i] == (shape[i+1]*strides[i+1]) or shape[i] == 1 for i in + // range(ndim - 1)) + bool row_contiguous : 1; + + // True iff: + // strides[0] == 1 and + // all(strides[i] == (shape[i-1]*strides[i-1]) or shape[i] == 1 for i in + // range(1, ndim)) + bool col_contiguous : 1; + }; + + /** The array's primitive. */ + Primitive& primitive() const { + return *(array_desc_->primitive); + } + + /** A shared pointer to the array's primitive. */ + std::shared_ptr& primitive_ptr() const { + return array_desc_->primitive; + } + + /** Check if the array has an attached primitive or is a leaf node. */ + bool has_primitive() const { + return array_desc_->primitive != nullptr; + } + + /** The array's inputs. */ + const std::vector& inputs() const { + return array_desc_->inputs; + } + + std::vector& inputs() { + return array_desc_->inputs; + } + + /** True indicates the arrays buffer is safe to reuse */ + bool is_donatable() const { + return array_desc_.use_count() == 1 && (array_desc_->data.use_count() == 1); + } + + /** The array's siblings. */ + const std::vector& siblings() const { + return array_desc_->siblings; + } + + /** The array's siblings. */ + std::vector& siblings() { + return array_desc_->siblings; + } + + /** The array's position in the sibling list. */ + int sibling_position() const { + return array_desc_->position; + } + + void set_siblings(std::vector siblings, uint16_t position) { + array_desc_->siblings = std::move(siblings); + array_desc_->position = position; + } + + /** The outputs of the array's primitive (i.e. this array and + * its siblings) in the order the primitive expects. */ + std::vector outputs() const { + auto idx = array_desc_->position; + std::vector outputs; + outputs.reserve(siblings().size() + 1); + outputs.insert(outputs.end(), siblings().begin(), siblings().begin() + idx); + outputs.push_back(*this); + outputs.insert(outputs.end(), siblings().begin() + idx, siblings().end()); + return outputs; + } + + /** Detach the array from the graph. */ + void detach(); + + /** Get the Flags bit-field. */ + const Flags& flags() const { + return array_desc_->flags; + } + + /** The size (in elements) of the underlying buffer the array points to. + * + * This can be different than the actual size of the array if the array has + * been broadcast or irregularly strided. If ``first`` is the offset into + * the data buffer of the first element of the array (i.e. the offset + * corresponding to ``arr[0, 0, ...]``) and last is the offset into the + * data buffer of the last element of the array (i.e. the offset + * corresponding to ``arr[-1, -1, ...]``) then ``data_size = last - first``. + * Note, ``data_size`` is in units of ``item_size`` (not bytes). + **/ + size_t data_size() const { + return array_desc_->data_size; + } + + allocator::Buffer& buffer() { + return array_desc_->data->buffer; + } + const allocator::Buffer& buffer() const { + return array_desc_->data->buffer; + } + + size_t buffer_size() const { + return allocator::allocator().size(buffer()); + } + + // Return the shared pointer to the array::Data struct + const std::shared_ptr& data_shared_ptr() const { + return array_desc_->data; + } + + // Return a raw pointer to the arrays data. This function may do a copy if + // the underlying buffer is not accessible on the CPU. When accessing the + // data for GPU kernels, be sure to use the correct method / function for the + // given backend to access the GPU pointer. + template + T* data() { + return reinterpret_cast( + (static_cast(buffer().raw_ptr()) + array_desc_->offset)); + } + + template + const T* data() const { + return const_cast(*this).data(); + } + + int64_t offset() const { + return array_desc_->offset; + } + + enum Status { + // The output of a computation which has not been scheduled. + // For example, the status of `x` in `auto x = a + b`. + unscheduled, + + // The array's `eval_*` function has been run, but the computation is not + // necessarily complete. The array will have memory allocated and if it is + // not a tracer then it will be detached from the graph. + evaluated, + + // If the array is the output of a computation then the computation + // is complete. Constant arrays are always available (e.g. `array({1, 2, + // 3})`) + available + }; + + // Check if the array is safe to read. + bool is_available() const; + + // Wait on the array to be available. After this `is_available` returns + // `true`. + void wait(); + + Status status() const { + return array_desc_->status; + } + + void set_status(Status s) const { + array_desc_->status = s; + } + + // Get the array's shared event + Event& event() const { + return array_desc_->event; + } + + // Attach an event to a not yet evaluated array + void attach_event(Event e) const { + array_desc_->event = std::move(e); + } + + void detach_event() const { + array_desc_->event = Event{}; + } + + // Mark the array as a tracer array (true) or not. + void set_tracer(bool is_tracer) { + array_desc_->is_tracer = is_tracer; + } + // Check if the array is a tracer array + bool is_tracer() const; + + void set_data(allocator::Buffer buffer, Deleter d = allocator::free); + + void set_data( + allocator::Buffer buffer, + size_t data_size, + Strides strides, + Flags flags, + Deleter d = allocator::free); + + void copy_shared_buffer( + const array& other, + const Strides& strides, + Flags flags, + size_t data_size, + int64_t offset = 0); + + void copy_shared_buffer(const array& other); + + void overwrite_descriptor(const array& other) { + array_desc_ = other.array_desc_; + } + + ~array(); + + private: + // Initialize the arrays data + template + void init(const It src); + + struct MLX_API ArrayDesc { + Shape shape; + Strides strides; + size_t size; + Dtype dtype; + std::shared_ptr primitive; + + Status status; + + // An event on the array used for synchronization + Event event; + + // Indicates an array is being used in a graph transform + // and should not be detached from the graph + bool is_tracer{false}; + + // This is a shared pointer so that *different* arrays + // can share the underlying data buffer. + std::shared_ptr data; + + // Offset from beginning of data pointer + int64_t offset{0}; + + // The size in elements of the data buffer the array accesses + size_t data_size{0}; + + // Contains useful meta data about the array + Flags flags{true, true, true}; + + std::vector inputs; + // An array to keep track of the siblings from a multi-output + // primitive. + std::vector siblings; + // The arrays position in the output list + uint32_t position{0}; + + explicit ArrayDesc(Shape shape, Dtype dtype); + + explicit ArrayDesc( + Shape shape, + Dtype dtype, + std::shared_ptr primitive, + std::vector inputs); + + ~ArrayDesc(); + + private: + // Initialize size, strides, and other metadata + void init(); + }; + + // The ArrayDesc contains the details of the materialized array including the + // shape, strides, the data type. It also includes + // the primitive which knows how to compute the array's data from its inputs + // and the list of array's inputs for the primitive. + std::shared_ptr array_desc_; +}; + +template +array::array(T val, Dtype dtype /* = TypeToDtype() */) + : array_desc_(std::make_shared(Shape{}, dtype)) { + init(&val); +} + +template +array::array( + It data, + Shape shape, + Dtype dtype /* = TypeToDtype::value_type>() */) : + array_desc_(std::make_shared(std::move(shape), dtype)) { + init(data); +} + +template +array::array( + std::initializer_list data, + Dtype dtype /* = TypeToDtype() */) + : array_desc_( + std::make_shared( + Shape{static_cast(data.size())}, + dtype)) { + init(data.begin()); +} + +template +array::array( + std::initializer_list data, + Shape shape, + Dtype dtype /* = TypeToDtype() */) + : array_desc_(std::make_shared(std::move(shape), dtype)) { + if (data.size() != size()) { + throw std::invalid_argument( + "Data size and provided shape mismatch in array construction."); + } + init(data.begin()); +} + +template +T array::item() { + if (size() != 1) { + throw std::invalid_argument("item can only be called on arrays of size 1."); + } + eval(); + return *data(); +} + +template +T array::item() const { + if (size() != 1) { + throw std::invalid_argument("item can only be called on arrays of size 1."); + } + if (status() == Status::unscheduled) { + throw std::invalid_argument( + "item() const can only be called on evaled arrays"); + } + const_cast(this)->eval(); + return *data(); +} + +template +void array::init(It src) { + set_data(allocator::malloc(size() * size_of(dtype()))); + switch (dtype()) { + case bool_: + std::copy(src, src + size(), data()); + break; + case uint8: + std::copy(src, src + size(), data()); + break; + case uint16: + std::copy(src, src + size(), data()); + break; + case uint32: + std::copy(src, src + size(), data()); + break; + case uint64: + std::copy(src, src + size(), data()); + break; + case int8: + std::copy(src, src + size(), data()); + break; + case int16: + std::copy(src, src + size(), data()); + break; + case int32: + std::copy(src, src + size(), data()); + break; + case int64: + std::copy(src, src + size(), data()); + break; + case float16: + std::copy(src, src + size(), data()); + break; + case float32: + std::copy(src, src + size(), data()); + break; + case float64: + std::copy(src, src + size(), data()); + break; + case bfloat16: + std::copy(src, src + size(), data()); + break; + case complex64: + std::copy(src, src + size(), data()); + break; + } +} + +/* Utilities for determining whether a template parameter is array. */ +template +inline constexpr bool is_array_v = + std::is_same_v>, array>; + +template +inline constexpr bool is_arrays_v = (is_array_v && ...); + +template +using enable_for_arrays_t = typename std::enable_if_t>; + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/backend/common/binary.h b/Source/Cxxmlx/include/mlx/backend/common/binary.h new file mode 100644 index 00000000..78607ef0 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/common/binary.h @@ -0,0 +1,97 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include "mlx/allocator.h" +#include "mlx/array.h" +#include "mlx/backend/common/utils.h" + +namespace mlx::core { + +enum class BinaryOpType { + ScalarScalar, + ScalarVector, + VectorScalar, + VectorVector, + General, +}; + +inline BinaryOpType get_binary_op_type(const array& a, const array& b) { + BinaryOpType bopt; + if (a.data_size() == 1 && b.data_size() == 1) { + bopt = BinaryOpType::ScalarScalar; + } else if (a.data_size() == 1 && b.flags().contiguous) { + bopt = BinaryOpType::ScalarVector; + } else if (b.data_size() == 1 && a.flags().contiguous) { + bopt = BinaryOpType::VectorScalar; + } else if ( + (a.flags().row_contiguous && b.flags().row_contiguous) || + (a.flags().col_contiguous && b.flags().col_contiguous)) { + bopt = BinaryOpType::VectorVector; + } else { + bopt = BinaryOpType::General; + } + return bopt; +} + +inline void set_binary_op_output_data( + const array& a, + const array& b, + array& out, + BinaryOpType bopt, + std::function mallocfn = allocator::malloc) { + bool b_donatable = is_donatable(b, out); + bool a_donatable = is_donatable(a, out); + switch (bopt) { + case BinaryOpType::ScalarScalar: + out.set_data(mallocfn(out.itemsize()), 1, a.strides(), a.flags()); + break; + case BinaryOpType::ScalarVector: + if (b_donatable) { + out.copy_shared_buffer(b); + } else { + out.set_data( + mallocfn(b.data_size() * out.itemsize()), + b.data_size(), + b.strides(), + b.flags()); + } + break; + case BinaryOpType::VectorScalar: + if (a_donatable) { + out.copy_shared_buffer(a); + } else { + out.set_data( + mallocfn(a.data_size() * out.itemsize()), + a.data_size(), + a.strides(), + a.flags()); + } + break; + case BinaryOpType::VectorVector: + if (a_donatable) { + out.copy_shared_buffer(a); + } else if (b_donatable) { + out.copy_shared_buffer(b); + } else { + out.set_data( + mallocfn(a.data_size() * out.itemsize()), + a.data_size(), + a.strides(), + a.flags()); + } + break; + case BinaryOpType::General: + if (a_donatable && a.flags().row_contiguous && a.size() == out.size()) { + out.copy_shared_buffer(a); + } else if ( + b_donatable && b.flags().row_contiguous && b.size() == out.size()) { + out.copy_shared_buffer(b); + } else { + out.set_data(mallocfn(out.nbytes())); + } + break; + } +} + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/backend/common/broadcasting.h b/Source/Cxxmlx/include/mlx/backend/common/broadcasting.h new file mode 100644 index 00000000..29651e90 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/common/broadcasting.h @@ -0,0 +1,11 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/array.h" + +namespace mlx::core { + +void broadcast(const array& in, array& out); + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/backend/common/buffer_cache.h b/Source/Cxxmlx/include/mlx/backend/common/buffer_cache.h new file mode 100644 index 00000000..27648b2f --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/common/buffer_cache.h @@ -0,0 +1,158 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include +#include +#include + +namespace mlx::core { + +template +class BufferCache { + public: + BufferCache( + size_t page_size, + std::function get_size, + std::function free) + : page_size_(page_size), + get_size_(std::move(get_size)), + free_(std::move(free)) {} + + ~BufferCache() { + clear(); + } + + BufferCache(const BufferCache&) = delete; + BufferCache& operator=(const BufferCache&) = delete; + + T* reuse_from_cache(size_t size) { + // Find the closest buffer in pool. + auto it = buffer_pool_.lower_bound(size); + if (it == buffer_pool_.end() || + it->first >= std::min(2 * size, size + 2 * page_size_)) { + return nullptr; + } + + // Collect from the cache. + T* buf = it->second->buf; + pool_size_ -= it->first; + + // Remove from record. + remove_from_list(it->second); + buffer_pool_.erase(it); + return buf; + } + + void recycle_to_cache(T* buf) { + assert(buf); + // Add to cache. + BufferHolder* bh = new BufferHolder(buf); + add_at_head(bh); + size_t size = get_size_(buf); + pool_size_ += size; + buffer_pool_.emplace(size, bh); + } + + int release_cached_buffers(size_t min_bytes_to_free) { + if (min_bytes_to_free >= 0.9 * pool_size_) { + return clear(); + } else { + int n_release = 0; + size_t total_bytes_freed = 0; + + while (tail_ && (total_bytes_freed < min_bytes_to_free)) { + // Release buffer. + size_t size = get_size_(tail_->buf); + total_bytes_freed += size; + free_(tail_->buf); + n_release++; + + // Remove from record. + auto its = buffer_pool_.equal_range(size); + auto it = std::find_if(its.first, its.second, [this](const auto& el) { + return el.second == tail_; + }); + assert(it != buffer_pool_.end()); + buffer_pool_.erase(it); + remove_from_list(tail_); + } + + pool_size_ -= total_bytes_freed; + return n_release; + } + } + + int clear() { + int n_release = 0; + for (auto& [size, holder] : buffer_pool_) { + free_(holder->buf); + n_release++; + delete holder; + } + buffer_pool_.clear(); + pool_size_ = 0; + head_ = nullptr; + tail_ = nullptr; + return n_release; + } + + size_t cache_size() const { + return pool_size_; + } + + size_t page_size() const { + return page_size_; + } + + private: + struct BufferHolder { + public: + explicit BufferHolder(T* buf_) : buf(buf_) {} + + BufferHolder* prev{nullptr}; + BufferHolder* next{nullptr}; + T* buf; + }; + + void add_at_head(BufferHolder* to_add) { + if (!head_) { + head_ = to_add; + tail_ = to_add; + } else { + head_->prev = to_add; + to_add->next = head_; + head_ = to_add; + } + } + + void remove_from_list(BufferHolder* to_remove) { + if (to_remove->prev && to_remove->next) { // if middle + to_remove->prev->next = to_remove->next; + to_remove->next->prev = to_remove->prev; + } else if (to_remove->prev && to_remove == tail_) { // if tail + tail_ = to_remove->prev; + tail_->next = nullptr; + } else if (to_remove == head_ && to_remove->next) { // if head + head_ = to_remove->next; + head_->prev = nullptr; + } else if (to_remove == head_ && to_remove == tail_) { // if only element + head_ = nullptr; + tail_ = nullptr; + } + + delete to_remove; + } + + std::multimap buffer_pool_; + BufferHolder* head_{nullptr}; + BufferHolder* tail_{nullptr}; + size_t pool_size_{0}; + + const size_t page_size_; + std::function get_size_; + std::function free_; +}; + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/backend/common/compiled.h b/Source/Cxxmlx/include/mlx/backend/common/compiled.h new file mode 100644 index 00000000..3be37133 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/common/compiled.h @@ -0,0 +1,77 @@ +// Copyright © 2023-2024 Apple Inc. +#pragma once + +#include +#include + +#include "mlx/array.h" +#include "mlx/primitives.h" + +namespace mlx::core { + +inline bool is_static_cast(const Primitive& p) { + return (typeid(p) == typeid(Broadcast) || typeid(p) == typeid(AsType)); +} + +std::string get_type_string(Dtype d); + +template +void print_float_constant(std::ostream& os, const array& x) { + auto old_precision = os.precision(); + if constexpr (std::is_same_v) { + os << std::setprecision(std::numeric_limits::digits10 + 1); + } else { + os << std::setprecision(std::numeric_limits::digits10 + 1); + } + os << x.item() << std::setprecision(old_precision); +} + +template +void print_int_constant(std::ostream& os, const array& x) { + os << x.item(); +} + +template +void print_complex_constant(std::ostream& os, const array& x) { + auto old_precision = os.precision(); + T constant = x.item(); + + os << get_type_string(x.dtype()) << "(" + << std::setprecision(std::numeric_limits::digits10 + 1) + << constant.real() << ", " << constant.imag() << ")" + << std::setprecision(old_precision); +} + +void print_constant(std::ostream& os, const array& x); + +inline bool is_scalar(const array& x) { + return x.ndim() == 0; +} + +// Check if we can use a contiguous operation given inputs and the output shape +bool compiled_check_contiguity( + const std::vector& inputs, + const Shape& shape); + +// Allocate space for the outputs possibly with input donation +void compiled_allocate_outputs( + const std::vector& inputs, + std::vector& outputs, + const std::function& is_constant, + bool contiguous, + const std::function& mallocfn = + allocator::malloc); + +// Collapse contiguous dims ignoring scalars and constants. +std::tuple> compiled_collapse_contiguous_dims( + const std::vector& inputs, + const array& out, + const std::function& is_constant); + +// Return whether the kernel should use large index. +bool compiled_use_large_index( + const std::vector& inputs, + const std::vector& outputs, + bool contiguous); + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/backend/common/copy.h b/Source/Cxxmlx/include/mlx/backend/common/copy.h new file mode 100644 index 00000000..859ce041 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/common/copy.h @@ -0,0 +1,50 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include "mlx/backend/common/utils.h" + +namespace mlx::core { + +enum class CopyType { + // Copy a raw scalar input into the full contiguous output + Scalar, + + // Copy the raw input buffer contiguously into a raw output buffer of the same + // size + Vector, + + // Copy the full virtual input to the full contiguous output + General, + + // Copy the full virtual input to the full virtual output. We assume the + // input and output have the same shape. + GeneralGeneral +}; + +inline bool set_copy_output_data( + const array& in, + array& out, + CopyType ctype, + std::function mallocfn = allocator::malloc) { + if (ctype == CopyType::Vector) { + // If the input is donateable, we are doing a vector copy and the types + // have the same size, then the input buffer can hold the output. + if (is_donatable(in, out)) { + out.copy_shared_buffer(in); + return true; + } else { + out.set_data( + mallocfn(in.data_size() * out.itemsize()), + in.data_size(), + in.strides(), + in.flags()); + return false; + } + } else { + out.set_data(mallocfn(out.nbytes())); + return false; + } +} + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/backend/common/hadamard.h b/Source/Cxxmlx/include/mlx/backend/common/hadamard.h new file mode 100644 index 00000000..ba5c4e41 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/common/hadamard.h @@ -0,0 +1,109 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include + +#include "mlx/utils.h" + +namespace mlx::core { + +// From http://neilsloane.com/hadamard/ +constexpr std::string_view h12 = R"( ++-++++++++++ +--+-+-+-+-+- ++++-++----++ ++---+--+-++- ++++++-++---- ++-+---+--+-+ +++--+++-++-- ++--++---+--+ +++----+++-++ ++--+-++---+- +++++----+++- ++-+--+-++--- +)"; + +constexpr std::string_view h20 = R"( ++----+----++--++-++- +-+----+---+++---+-++ +--+----+---+++-+-+-+ +---+----+---+++++-+- +----+----++--++-++-+ +-+++++-----+--+++--+ ++-+++-+---+-+--+++-- +++-++--+---+-+--+++- ++++-+---+---+-+--+++ +++++-----++--+-+--++ +--++-+-++-+-----++++ +---++-+-++-+---+-+++ ++---++-+-+--+--++-++ +++---++-+----+-+++-+ +-++---++-+----+++++- +-+--+--++-+----+---- ++-+-----++-+----+--- +-+-+-+---+--+----+-- +--+-+++------+----+- ++--+--++------+----+ +)"; + +constexpr std::string_view h28 = R"( ++------++----++-+--+-+--++-- +-+-----+++-----+-+--+-+--++- +--+-----+++---+-+-+----+--++ +---+-----+++---+-+-+-+--+--+ +----+-----+++---+-+-+++--+-- +-----+-----++++--+-+--++--+- +------++----++-+--+-+--++--+ +--++++-+-------++--+++-+--+- +---++++-+-----+-++--+-+-+--+ ++---+++--+----++-++--+-+-+-- +++---++---+----++-++--+-+-+- ++++---+----+----++-++--+-+-+ +++++--------+-+--++-++--+-+- +-++++--------+++--++--+--+-+ +-+-++-++--++--+--------++++- ++-+-++--+--++--+--------++++ +-+-+-++--+--++--+----+---+++ ++-+-+-++--+--+---+---++---++ +++-+-+-++--+------+--+++---+ +-++-+-+-++--+------+-++++--- ++-++-+---++--+------+-++++-- +-++--++-+-++-+++----++------ ++-++--++-+-++-+++-----+----- +++-++---+-+-++-+++-----+---- +-++-++-+-+-+-+--+++-----+--- +--++-++++-+-+----+++-----+-- ++--++-+-++-+-+----+++-----+- +++--++-+-++-+-+----++------+ +)"; + +inline const std::map hadamard_matrices() { + return {{12, h12}, {20, h20}, {28, h28}}; +} + +inline std::pair decompose_hadamard(int n) { + // n = m*2^k + int m = 1; + if (!is_power_of_2(n)) { + auto h_matrices = hadamard_matrices(); + for (auto [factor, _] : h_matrices) { + if (n % factor == 0) { + m = factor; + n /= factor; + break; + } + } + if (m == 1) { + throw std::invalid_argument( + "[hadamard] Only supports n = m*2^k where m in (1, 12, 20, 28)."); + } + } + if (n > (1 << 26)) { + throw std::invalid_argument( + "[hadamard] Only supports n = m*2^k where k <= 26"); + } + return {n, m}; +} + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/backend/common/matmul.h b/Source/Cxxmlx/include/mlx/backend/common/matmul.h new file mode 100644 index 00000000..2545c4fd --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/common/matmul.h @@ -0,0 +1,67 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/common/utils.h" +#include "mlx/utils.h" + +#include + +namespace mlx::core { + +inline std::tuple collapse_batches( + const array& a, + const array& b) { + if (a.ndim() == 2) { + return {Shape{1}, Strides{0}, Strides{0}}; + } + + Shape A_bshape{a.shape().begin(), a.shape().end() - 2}; + Strides A_bstride{a.strides().begin(), a.strides().end() - 2}; + Strides B_bstride{b.strides().begin(), b.strides().end() - 2}; + + auto [batch_shape, batch_strides] = + collapse_contiguous_dims(A_bshape, std::vector{A_bstride, B_bstride}); + + auto a_batch_strides = batch_strides[0]; + auto b_batch_strides = batch_strides[1]; + + if (batch_shape.empty()) { + batch_shape.push_back(1); + a_batch_strides.push_back(0); + b_batch_strides.push_back(0); + } + + return std::make_tuple(batch_shape, a_batch_strides, b_batch_strides); +} + +inline std::tuple +collapse_batches(const array& a, const array& b, const array& c) { + if (a.ndim() == 2) { + return {Shape{1}, Strides{0}, Strides{0}, Strides{0}}; + } + + Shape A_bshape{a.shape().begin(), a.shape().end() - 2}; + Strides A_bstride{a.strides().begin(), a.strides().end() - 2}; + Strides B_bstride{b.strides().begin(), b.strides().end() - 2}; + Strides C_bstride{c.strides().begin(), c.strides().end() - 2}; + + auto [batch_shape, batch_strides] = collapse_contiguous_dims( + A_bshape, std::vector{A_bstride, B_bstride, C_bstride}); + + auto A_batch_stride = batch_strides[0]; + auto B_batch_stride = batch_strides[1]; + auto C_batch_stride = batch_strides[2]; + + if (batch_shape.empty()) { + batch_shape.push_back(1); + A_batch_stride.push_back(0); + B_batch_stride.push_back(0); + C_batch_stride.push_back(0); + } + + return std::make_tuple( + batch_shape, A_batch_stride, B_batch_stride, C_batch_stride); +} + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/backend/common/quantized.h b/Source/Cxxmlx/include/mlx/backend/common/quantized.h new file mode 100644 index 00000000..7656da8b --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/common/quantized.h @@ -0,0 +1,14 @@ +// Copyright © 2026 Apple Inc. + +namespace mlx::core { + +inline constexpr short get_pack_factor(int bits, int wsize = 8) { + return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits); +} + +inline constexpr short get_bytes_per_pack(int bits, int wsize = 8) { + bool power_of_2_bits = (bits & (bits - 1)) == 0; + return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3); +} + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/backend/common/reduce.h b/Source/Cxxmlx/include/mlx/backend/common/reduce.h new file mode 100644 index 00000000..8b24f4f5 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/common/reduce.h @@ -0,0 +1,59 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include "mlx/backend/common/utils.h" + +namespace mlx::core { + +enum ReductionOpType { + // Self-explanatory. Read everything and produce 1 output. + ContiguousAllReduce, + + // The input is contiguous and the last axis is reduced + // N1xR1xN2xR2x...xNnxRn + ContiguousReduce, + + // The input is contiguous and the last axis is not reduced + // R1xN1xR2xN2x...xRnxNn + ContiguousStridedReduce, + + // The input is not contiguous but the last axis is and it is reduced so we + // need to figure out the offsets but we can call the contiguous reduce after + // that. + // N3xR1xN1xR4x...xRn + GeneralContiguousReduce, + + // The input is not contiguous but the last reduction axis and the last axis + // are so we need to figure out the offset but we can call the strided reduce + // after that. + GeneralStridedReduce, + + // The input is not contiguous after the reduction axis and it may contain + // 0-stride axes or transpositions. We could copy the strides and produce a + // transposed outcome or we can read the input out of order and write the + // output in order. + GeneralReduce +}; + +struct ReductionPlan { + ReductionOpType type; + Shape shape; + Strides strides; + + ReductionPlan(ReductionOpType type_, Shape shape_, Strides strides_) + : type(type_), shape(std::move(shape_)), strides(std::move(strides_)) {} + ReductionPlan(ReductionOpType type_) : type(type_) {} +}; + +ReductionPlan get_reduction_plan(const array& x, const std::vector& axes); + +std::pair shapes_without_reduction_axes( + const array& x, + const std::vector& axes); +std::pair shapes_without_reduction_axes( + Shape shape, + Strides strides, + const std::vector& axes); + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/backend/common/slicing.h b/Source/Cxxmlx/include/mlx/backend/common/slicing.h new file mode 100644 index 00000000..b667d261 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/common/slicing.h @@ -0,0 +1,20 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/array.h" + +namespace mlx::core { + +std::tuple prepare_slice( + const array& in, + const Shape& start_indices, + const Shape& strides); + +void slice( + const array& in, + array& out, + const Shape& start_indices, + const Shape& strides); + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/backend/common/ternary.h b/Source/Cxxmlx/include/mlx/backend/common/ternary.h new file mode 100644 index 00000000..c63a5726 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/common/ternary.h @@ -0,0 +1,85 @@ +// Copyright © 2023 Apple Inc. + +#pragma once +#include "mlx/allocator.h" +#include "mlx/array.h" +#include "mlx/backend/common/utils.h" + +namespace mlx::core { + +// TODO: Add support for more combinations of input types. +enum class TernaryOpType { + ScalarScalarScalar, + VectorVectorVector, + VectorVectorScalar, + VectorScalarVector, + General, +}; + +inline TernaryOpType +get_ternary_op_type(const array& a, const array& b, const array& c) { + TernaryOpType topt; + if (a.data_size() == 1 && b.data_size() == 1 && c.data_size() == 1) { + topt = TernaryOpType::ScalarScalarScalar; + } else if ( + (a.flags().row_contiguous && b.flags().row_contiguous && + c.flags().row_contiguous) || + (a.flags().col_contiguous && b.flags().col_contiguous && + c.flags().col_contiguous)) { + topt = TernaryOpType::VectorVectorVector; + } else if ( + b.data_size() == 1 && a.flags().row_contiguous && + c.flags().row_contiguous) { + topt = TernaryOpType::VectorScalarVector; + } else if ( + c.data_size() == 1 && a.flags().row_contiguous && + b.flags().row_contiguous) { + topt = TernaryOpType::VectorVectorScalar; + } else { + topt = TernaryOpType::General; + } + return topt; +} + +inline void set_ternary_op_output_data( + const array& a, + const array& b, + const array& c, + array& out, + TernaryOpType topt, + std::function mallocfn = allocator::malloc) { + auto maybe_donate = [&out](const array& x) { + if (is_donatable(x, out)) { + out.copy_shared_buffer(x); + return true; + } + return false; + }; + + switch (topt) { + case TernaryOpType::ScalarScalarScalar: + out.set_data(mallocfn(out.itemsize()), 1, b.strides(), b.flags()); + break; + case TernaryOpType::VectorVectorVector: + if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) { + out.set_data( + mallocfn(out.itemsize() * b.data_size()), + b.data_size(), + b.strides(), + b.flags()); + } + break; + case TernaryOpType::VectorVectorScalar: + case TernaryOpType::VectorScalarVector: + case TernaryOpType::General: + // Try to donate an input which is row_contiguous + if (!((a.flags().row_contiguous && maybe_donate(a)) || + (b.flags().row_contiguous && maybe_donate(b)) || + (c.flags().row_contiguous && maybe_donate(c)))) { + out.set_data(mallocfn(out.nbytes())); + } + break; + } +} + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/backend/common/unary.h b/Source/Cxxmlx/include/mlx/backend/common/unary.h new file mode 100644 index 00000000..b19fc98e --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/common/unary.h @@ -0,0 +1,29 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/allocator.h" +#include "mlx/backend/common/utils.h" + +namespace mlx::core { + +inline void set_unary_output_data( + const array& in, + array& out, + std::function mallocfn = allocator::malloc) { + if (in.flags().contiguous) { + if (is_donatable(in, out)) { + out.copy_shared_buffer(in); + } else { + out.set_data( + mallocfn(in.data_size() * out.itemsize()), + in.data_size(), + in.strides(), + in.flags()); + } + } else { + out.set_data(mallocfn(out.nbytes())); + } +} + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/backend/common/utils.h b/Source/Cxxmlx/include/mlx/backend/common/utils.h new file mode 100644 index 00000000..1b6902ff --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/common/utils.h @@ -0,0 +1,205 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include +#include +#include + +#include "mlx/array.h" + +namespace mlx::core { + +// Return the directory that contains current shared library. +std::filesystem::path current_binary_dir(); + +inline int64_t +elem_to_loc(int elem, const Shape& shape, const Strides& strides) { + int64_t loc = 0; + for (int i = shape.size() - 1; i >= 0; --i) { + auto q_and_r = ldiv(elem, shape[i]); + loc += q_and_r.rem * strides[i]; + elem = q_and_r.quot; + } + return loc; +} + +inline int64_t elem_to_loc(int elem, const array& a) { + if (a.flags().row_contiguous) { + return elem; + } + return elem_to_loc(elem, a.shape(), a.strides()); +} + +inline Strides make_contiguous_strides(const Shape& shape) { + Strides strides(shape.size(), 1); + for (int i = shape.size() - 1; i > 0; i--) { + strides[i - 1] = strides[i] * shape[i]; + } + return strides; +} + +// Collapse dims that are contiguous to possibly route to a better kernel +// e.g. for x = transpose(array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 2, 2}), {2, 0, 1}) +// should return {{2, 4}, {{1, 2}}}. +// +// When multiple arrays are passed they should all have the same shape. The +// collapsed axes are also the same so one shape is returned. +std::tuple> collapse_contiguous_dims( + const Shape& shape, + const std::vector& strides, + int64_t size_cap = std::numeric_limits::max()); + +inline std::tuple> collapse_contiguous_dims( + const std::vector& xs, + size_t size_cap = std::numeric_limits::max()) { + std::vector strides; + for (auto& x : xs) { + strides.emplace_back(x.strides()); + } + return collapse_contiguous_dims(xs[0].shape(), strides, size_cap); +} + +template > +inline auto collapse_contiguous_dims(Arrays&&... xs) { + return collapse_contiguous_dims( + std::vector{std::forward(xs)...}); +} + +// The single array version of the above. +std::pair collapse_contiguous_dims( + const Shape& shape, + const Strides& strides, + int64_t size_cap = std::numeric_limits::max()); +std::pair collapse_contiguous_dims( + const array& a, + int64_t size_cap = std::numeric_limits::max()); + +// Compute the thread block dimensions which fit the given +// input dimensions. +// - The thread block dimensions will be powers of two +// - The thread block size will be less than 2^pow2 +using Dims = std::tuple; +Dims get_block_dims_common(int dim0, int dim1, int dim2, int pow2 = 10); + +// Computes a 2D grid where each element is < UINT_MAX +// Assumes: +// - overall size (product of non-broadcasted dimensions) is < UINT_MAX^2 +// - shape and strides correspond to a contiguous (no holes) but +// possibly broadcasted array +Dims get_2d_grid_dims_common(const Shape& shape, const Strides& strides); + +// Same as above but we do an implicit division with divisor. +// Basically, equivalent to factorizing +// Prod(s \forall s in shape if strides[s] > 0) / divisor. +Dims get_2d_grid_dims_common( + const Shape& shape, + const Strides& strides, + size_t divisor); + +// Get both the block and a grid of blocks that covers dim0, dim1 and dim2. +std::pair get_grid_and_block_common(int dim0, int dim1, int dim2); + +struct ContiguousIterator { + inline void step() { + int dims = shape_.size(); + if (dims == 0) { + return; + } + int i = dims - 1; + while (pos_[i] == (shape_[i] - 1) && i > 0) { + pos_[i] = 0; + loc -= (shape_[i] - 1) * strides_[i]; + i--; + } + pos_[i]++; + loc += strides_[i]; + } + + void seek(int64_t n) { + loc = 0; + for (int i = shape_.size() - 1; i >= 0; --i) { + auto q_and_r = ldiv(n, shape_[i]); + loc += q_and_r.rem * strides_[i]; + pos_[i] = q_and_r.rem; + n = q_and_r.quot; + } + } + + void reset() { + loc = 0; + std::fill(pos_.begin(), pos_.end(), 0); + } + + ContiguousIterator() {}; + + explicit ContiguousIterator(const array& a) + : shape_(a.shape()), strides_(a.strides()) { + if (!shape_.empty()) { + std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_); + pos_ = Shape(shape_.size(), 0); + } + } + + explicit ContiguousIterator( + const Shape& shape, + const Strides& strides, + int dims) + : shape_(shape.begin(), shape.begin() + dims), + strides_(strides.begin(), strides.begin() + dims) { + if (!shape_.empty()) { + std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_); + pos_ = Shape(shape_.size(), 0); + } + } + + int64_t loc{0}; + + private: + Shape shape_; + Strides strides_; + Shape pos_; +}; + +inline auto check_contiguity(const Shape& shape, const Strides& strides) { + size_t no_broadcast_data_size = 1; + int64_t f_stride = 1; + int64_t b_stride = 1; + bool is_row_contiguous = true; + bool is_col_contiguous = true; + + for (int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) { + is_col_contiguous &= strides[i] == f_stride || shape[i] == 1; + is_row_contiguous &= strides[ri] == b_stride || shape[ri] == 1; + f_stride *= shape[i]; + b_stride *= shape[ri]; + if (strides[i] > 0) { + no_broadcast_data_size *= shape[i]; + } + } + + return std::make_tuple( + no_broadcast_data_size, is_row_contiguous, is_col_contiguous); +} + +inline bool is_donatable(const array& in, const array& out) { + constexpr size_t donation_extra = 16384; + + return in.is_donatable() && in.itemsize() == out.itemsize() && + in.buffer_size() <= out.nbytes() + donation_extra; +} + +std::pair prepare_reshape(const array& in, const array& out); + +void shared_buffer_reshape( + const array& in, + const Strides& out_strides, + array& out); + +template +inline SmallVector remove_index(SmallVector vec, size_t index) { + vec.erase(std::next(vec.begin(), index)); + return vec; +} + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/backend/cpu/arange.h b/Source/Cxxmlx/include/mlx/backend/cpu/arange.h new file mode 100644 index 00000000..9e9b03bd --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/cpu/arange.h @@ -0,0 +1,28 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include "mlx/array.h" +#include "mlx/backend/cpu/encoder.h" + +namespace mlx::core { + +namespace { + +template +void arange(T start, T next, array& out, size_t size, Stream stream) { + auto ptr = out.data(); + auto step_size = next - start; + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_output_array(out); + encoder.dispatch([ptr, start, step_size, size]() mutable { + for (int i = 0; i < size; ++i) { + ptr[i] = start; + start += step_size; + } + }); +} + +} // namespace + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/backend/cpu/binary.h b/Source/Cxxmlx/include/mlx/backend/cpu/binary.h new file mode 100644 index 00000000..acaca50e --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/cpu/binary.h @@ -0,0 +1,517 @@ +// Copyright © 2023 Apple Inc. + +#pragma once +#include + +#include "mlx/array.h" +#include "mlx/backend/common/binary.h" +#include "mlx/backend/common/utils.h" + +#include "mlx/backend/cpu/encoder.h" +#include "mlx/backend/cpu/simd/simd.h" + +namespace mlx::core { + +template +struct VectorScalar { + template + void operator()(const T* a, const T* b, U* dst, int size) { + T scalar = *b; + constexpr int N = simd::max_size; + while (size >= N) { + simd::store(dst, Op{}(simd::load(a), simd::Simd(scalar))); + dst += N; + a += N; + size -= N; + } + while (size-- > 0) { + *dst = Op{}(*a, scalar); + dst++; + a++; + } + } +}; + +template +struct ScalarVector { + template + void operator()(const T* a, const T* b, U* dst, int size) { + T scalar = *a; + constexpr int N = simd::max_size; + while (size >= N) { + simd::store(dst, Op{}(simd::Simd(scalar), simd::load(b))); + dst += N; + b += N; + size -= N; + } + while (size-- > 0) { + *dst = Op{}(scalar, *b); + dst++; + b++; + } + } +}; + +template +struct VectorVector { + template + void operator()(const T* a, const T* b, U* dst, int size) { + constexpr int N = simd::max_size; + while (size >= N) { + simd::store(dst, Op{}(simd::load(a), simd::load(b))); + dst += N; + a += N; + b += N; + size -= N; + } + while (size-- > 0) { + *dst = Op{}(*a, *b); + dst++; + a++; + b++; + } + } +}; + +template +void binary_op_dims( + const T* a, + const T* b, + U* out, + const Shape& shape, + const Strides& a_strides, + const Strides& b_strides, + const Strides& out_strides, + int axis) { + auto stride_a = a_strides[axis]; + auto stride_b = b_strides[axis]; + auto stride_out = out_strides[axis]; + auto N = shape[axis]; + + for (int i = 0; i < N; i++) { + if constexpr (D > 1) { + binary_op_dims( + a, b, out, shape, a_strides, b_strides, out_strides, axis + 1); + } else { + if constexpr (Strided) { + Op{}(a, b, out, stride_out); + } else { + *out = Op{}(*a, *b); + } + } + out += stride_out; + a += stride_a; + b += stride_b; + } +} + +template +void binary_op_dispatch_dims( + const T* a, + const T* b, + U* out, + int dim, + int size, + const Shape& shape, + const Strides& a_strides, + const Strides& b_strides, + const Strides& out_strides) { + switch (dim) { + case 1: + binary_op_dims( + a, b, out, shape, a_strides, b_strides, out_strides, 0); + return; + case 2: + binary_op_dims( + a, b, out, shape, a_strides, b_strides, out_strides, 0); + return; + case 3: + binary_op_dims( + a, b, out, shape, a_strides, b_strides, out_strides, 0); + return; + } + + ContiguousIterator a_it(shape, a_strides, dim - 3); + ContiguousIterator b_it(shape, b_strides, dim - 3); + auto stride = out_strides[dim - 4]; + for (int64_t elem = 0; elem < size; elem += stride) { + binary_op_dims( + a + a_it.loc, + b + b_it.loc, + out + elem, + shape, + a_strides, + b_strides, + out_strides, + dim - 3); + a_it.step(); + b_it.step(); + } +} + +template +void binary_op(const array& a, const array& b, array& out, BinaryOpType bopt) { + // The full computation is scalar scalar so call the base op once + auto a_ptr = a.data(); + auto b_ptr = b.data(); + + auto out_ptr = out.data(); + if (bopt == BinaryOpType::ScalarScalar) { + *out_ptr = Op{}(*a_ptr, *b_ptr); + return; + } + + // The full computation is scalar vector so delegate to the op + if (bopt == BinaryOpType::ScalarVector) { + ScalarVector{}(a_ptr, b_ptr, out_ptr, b.data_size()); + return; + } + + // The full computation is vector scalar so delegate to the op + if (bopt == BinaryOpType::VectorScalar) { + VectorScalar{}(a_ptr, b_ptr, out_ptr, a.data_size()); + return; + } + + // The full computation is vector vector so delegate to the op + if (bopt == BinaryOpType::VectorVector) { + VectorVector{}(a_ptr, b_ptr, out_ptr, a.size()); + return; + } + + // General computation so let's try to optimize + auto [new_shape, new_strides] = collapse_contiguous_dims( + a.shape(), {a.strides(), b.strides(), out.strides()}); + auto& a_strides = new_strides[0]; + auto& b_strides = new_strides[1]; + auto& strides = new_strides[2]; + + // Get the left-most dim such that the array is row contiguous after + auto leftmost_rc_dim = [&strides](const auto& arr_strides) { + int d = arr_strides.size() - 1; + for (; d >= 0 && arr_strides[d] == strides[d]; d--) { + } + return d + 1; + }; + auto a_rc_dim = leftmost_rc_dim(a_strides); + auto b_rc_dim = leftmost_rc_dim(b_strides); + + // Get the left-most dim such that the array is a broadcasted "scalar" after + auto leftmost_s_dim = [](const auto& arr_strides) { + int d = arr_strides.size() - 1; + for (; d >= 0 && arr_strides[d] == 0; d--) { + } + return d + 1; + }; + auto a_s_dim = leftmost_s_dim(a_strides); + auto b_s_dim = leftmost_s_dim(b_strides); + + auto ndim = new_shape.size(); + + // Case 1: LxM and FxM where L and F are broadcastable and M is row + // contiguous + int dim = ndim; + if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) { + bopt = BinaryOpType::VectorVector; + dim = d; + // Case 2: LxM and Fx1 where L and F are broadcastable and M is row + // contiguous + } else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) { + bopt = BinaryOpType::VectorScalar; + dim = d; + // Case 3: Lx1 and FxM where L and F are broadcastable and M is row + // contiguous + } else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) { + bopt = BinaryOpType::ScalarVector; + dim = d; + } + + // Can be sure dim > 0 since otherwise we would have used one of the fully + // contiguous methods above. Except for the case that the flags do not + // correspond to the underlying contiguity. + if (dim == 0 || strides[dim - 1] < 16) { + bopt = BinaryOpType::General; + dim = ndim; + } + + switch (bopt) { + case BinaryOpType::VectorVector: + binary_op_dispatch_dims>( + a_ptr, + b_ptr, + out_ptr, + dim, + a.size(), + new_shape, + a_strides, + b_strides, + strides); + break; + case BinaryOpType::VectorScalar: + binary_op_dispatch_dims>( + a_ptr, + b_ptr, + out_ptr, + dim, + a.size(), + new_shape, + a_strides, + b_strides, + strides); + break; + case BinaryOpType::ScalarVector: + binary_op_dispatch_dims>( + a_ptr, + b_ptr, + out_ptr, + dim, + a.size(), + new_shape, + a_strides, + b_strides, + strides); + break; + default: + binary_op_dispatch_dims( + a_ptr, + b_ptr, + out_ptr, + dim, + a.size(), + new_shape, + a_strides, + b_strides, + strides); + break; + } +} + +template +void binary_op(const array& a, const array& b, array& out, BinaryOpType bopt) { + binary_op(a, b, out, bopt); +} + +template +void binary_op_cpu( + const array& a, + const array& b, + array& out, + Op op, + Stream stream) { + auto bopt = get_binary_op_type(a, b); + set_binary_op_output_data(a, b, out, bopt); + + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_output_array(out); + encoder.dispatch([a = array::unsafe_weak_copy(a), + b = array::unsafe_weak_copy(b), + out = array::unsafe_weak_copy(out), + bopt]() mutable { + switch (out.dtype()) { + case bool_: + binary_op(a, b, out, bopt); + break; + case uint8: + binary_op(a, b, out, bopt); + break; + case uint16: + binary_op(a, b, out, bopt); + break; + case uint32: + binary_op(a, b, out, bopt); + break; + case uint64: + binary_op(a, b, out, bopt); + break; + case int8: + binary_op(a, b, out, bopt); + break; + case int16: + binary_op(a, b, out, bopt); + break; + case int32: + binary_op(a, b, out, bopt); + break; + case int64: + binary_op(a, b, out, bopt); + break; + case float16: + binary_op(a, b, out, bopt); + break; + case float32: + binary_op(a, b, out, bopt); + break; + case float64: + binary_op(a, b, out, bopt); + break; + case bfloat16: + binary_op(a, b, out, bopt); + break; + case complex64: + binary_op(a, b, out, bopt); + break; + } + }); +} + +template +void comparison_op_cpu( + const array& a, + const array& b, + array& out, + Op op, + Stream stream) { + auto bopt = get_binary_op_type(a, b); + set_binary_op_output_data(a, b, out, bopt); + + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_output_array(out); + encoder.dispatch([a = array::unsafe_weak_copy(a), + b = array::unsafe_weak_copy(b), + out = array::unsafe_weak_copy(out), + bopt]() mutable { + switch (a.dtype()) { + case bool_: + binary_op(a, b, out, bopt); + break; + case uint8: + binary_op(a, b, out, bopt); + break; + case uint16: + binary_op(a, b, out, bopt); + break; + case uint32: + binary_op(a, b, out, bopt); + break; + case uint64: + binary_op(a, b, out, bopt); + break; + case int8: + binary_op(a, b, out, bopt); + break; + case int16: + binary_op(a, b, out, bopt); + break; + case int32: + binary_op(a, b, out, bopt); + break; + case int64: + binary_op(a, b, out, bopt); + break; + case float16: + binary_op(a, b, out, bopt); + break; + case float32: + binary_op(a, b, out, bopt); + break; + case float64: + binary_op(a, b, out, bopt); + break; + case bfloat16: + binary_op(a, b, out, bopt); + break; + case complex64: + binary_op(a, b, out, bopt); + break; + } + }); +} + +template +void binary_float_op_cpu( + const array& a, + const array& b, + array& out, + Op op, + Stream stream) { + auto bopt = get_binary_op_type(a, b); + set_binary_op_output_data(a, b, out, bopt); + + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_output_array(out); + encoder.dispatch([a = array::unsafe_weak_copy(a), + b = array::unsafe_weak_copy(b), + out = array::unsafe_weak_copy(out), + bopt]() mutable { + switch (out.dtype()) { + case float16: + binary_op(a, b, out, bopt); + break; + case float32: + binary_op(a, b, out, bopt); + break; + case float64: + binary_op(a, b, out, bopt); + break; + case bfloat16: + binary_op(a, b, out, bopt); + break; + case complex64: + binary_op(a, b, out, bopt); + break; + default: + throw std::runtime_error( + "[binary_float] Only supports floating point types."); + } + }); +} + +template +void binary_int_op_cpu( + const array& a, + const array& b, + array& out, + Op op, + Stream stream) { + auto bopt = get_binary_op_type(a, b); + set_binary_op_output_data(a, b, out, bopt); + + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_output_array(out); + encoder.dispatch([a = array::unsafe_weak_copy(a), + b = array::unsafe_weak_copy(b), + out = array::unsafe_weak_copy(out), + bopt]() mutable { + switch (out.dtype()) { + case bool_: + binary_op(a, b, out, bopt); + case uint8: + binary_op(a, b, out, bopt); + break; + case uint16: + binary_op(a, b, out, bopt); + break; + case uint32: + binary_op(a, b, out, bopt); + break; + case uint64: + binary_op(a, b, out, bopt); + break; + case int8: + binary_op(a, b, out, bopt); + break; + case int16: + binary_op(a, b, out, bopt); + break; + case int32: + binary_op(a, b, out, bopt); + break; + case int64: + binary_op(a, b, out, bopt); + break; + default: + throw std::runtime_error("[binary_int] Type not supported"); + break; + } + }); +} + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/backend/cpu/binary_ops.h b/Source/Cxxmlx/include/mlx/backend/cpu/binary_ops.h new file mode 100644 index 00000000..d50751ce --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/cpu/binary_ops.h @@ -0,0 +1,98 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include "mlx/backend/cpu/simd/simd.h" + +namespace mlx::core::detail { + +using namespace mlx::core::simd; + +#define BINARY_SINGLE() \ + template \ + T operator()(T x, T y) { \ + return (*this)(Simd(x), Simd(y)).value; \ + } + +#define DEFAULT_BINARY_OP(Op, op) \ + struct Op { \ + template \ + Simd operator()(Simd x, Simd y) { \ + return op(x, y); \ + } \ + BINARY_SINGLE() \ + }; + +DEFAULT_BINARY_OP(Add, operator+) +DEFAULT_BINARY_OP(ArcTan2, atan2) +DEFAULT_BINARY_OP(Divide, operator/) +DEFAULT_BINARY_OP(Multiply, operator*) +DEFAULT_BINARY_OP(Subtract, operator-) +DEFAULT_BINARY_OP(LogicalAnd, operator&&) +DEFAULT_BINARY_OP(LogicalOr, operator||) +DEFAULT_BINARY_OP(BitwiseAnd, operator&) +DEFAULT_BINARY_OP(BitwiseOr, operator|) +DEFAULT_BINARY_OP(BitwiseXor, operator^) +DEFAULT_BINARY_OP(LeftShift, operator<<) +DEFAULT_BINARY_OP(RightShift, operator>>) +DEFAULT_BINARY_OP(Remainder, remainder) +DEFAULT_BINARY_OP(Maximum, maximum) +DEFAULT_BINARY_OP(Minimum, minimum) +DEFAULT_BINARY_OP(Power, pow) + +#define DEFAULT_BOOL_OP(Op, op) \ + struct Op { \ + template \ + Simd operator()(Simd x, Simd y) { \ + return op(x, y); \ + } \ + template \ + bool operator()(T x, T y) { \ + return (*this)(Simd(x), Simd(y)).value; \ + } \ + }; + +DEFAULT_BOOL_OP(Equal, operator==) +DEFAULT_BOOL_OP(Greater, operator>) +DEFAULT_BOOL_OP(GreaterEqual, operator>=) +DEFAULT_BOOL_OP(Less, operator<) +DEFAULT_BOOL_OP(LessEqual, operator<=) +DEFAULT_BOOL_OP(NotEqual, operator!=) + +struct NaNEqual { + template + Simd operator()(Simd x, Simd y) { + return x == y || (isnan(x) && isnan(y)); + } + template + bool operator()(T x, T y) { + return (*this)(Simd(x), Simd(y)).value; + } +}; + +struct LogAddExp { + template + Simd operator()(Simd x, Simd y) { + auto maxval = maximum(x, y); + auto minval = minimum(x, y); + auto mask = minval == -inf || maxval == inf; + auto out = maxval + log1p(exp(minval - maxval)); + return select(mask, Simd(maxval), Simd(out)); + } + BINARY_SINGLE() +}; + +struct Select { + template + T operator()(bool condition, T x, T y) { + return (*this)(Simd(condition), Simd(x), Simd(y)) + .value; + } + + template + Simd operator()(Simd condition, Simd x, Simd y) { + return select(condition, x, y); + } +}; + +} // namespace mlx::core::detail diff --git a/Source/Cxxmlx/include/mlx/backend/cpu/binary_two.h b/Source/Cxxmlx/include/mlx/backend/cpu/binary_two.h new file mode 100644 index 00000000..fa0ca799 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/cpu/binary_two.h @@ -0,0 +1,166 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include "mlx/backend/common/utils.h" +#include "mlx/backend/cpu/binary.h" + +namespace mlx::core { + +namespace { + +template +void binary_op_dims( + const T* a, + const T* b, + U* out_a, + U* out_b, + Op op, + const Shape& shape, + const Strides& a_strides, + const Strides& b_strides, + const Strides& out_strides, + int axis) { + auto stride_a = a_strides[axis]; + auto stride_b = b_strides[axis]; + auto stride_out = out_strides[axis]; + auto N = shape[axis]; + + for (int i = 0; i < N; i++) { + if constexpr (D > 1) { + binary_op_dims( + a, + b, + out_a, + out_b, + op, + shape, + a_strides, + b_strides, + out_strides, + axis + 1); + } else { + std::tie(*out_a, *out_b) = op(*a, *b); + } + a += stride_a; + b += stride_b; + out_a += stride_out; + out_b += stride_out; + } +} + +template +void binary_op_dispatch_dims( + const array& a, + const array& b, + array& out_a, + array& out_b, + Op op) { + auto [shape, strides] = collapse_contiguous_dims( + a.shape(), {a.strides(), b.strides(), out_a.strides()}); + const T* a_ptr = a.data(); + const T* b_ptr = b.data(); + U* out_a_ptr = out_a.data(); + U* out_b_ptr = out_b.data(); + + const auto& a_strides = strides[0]; + const auto& b_strides = strides[1]; + const auto& out_strides = strides[2]; + int ndim = shape.size(); + switch (ndim) { + case 1: + binary_op_dims( + a_ptr, + b_ptr, + out_a_ptr, + out_b_ptr, + op, + shape, + a_strides, + b_strides, + out_strides, + 0); + return; + case 2: + binary_op_dims( + a_ptr, + b_ptr, + out_a_ptr, + out_b_ptr, + op, + shape, + a_strides, + b_strides, + out_strides, + 0); + return; + } + + ContiguousIterator a_it(shape, a_strides, ndim - 2); + ContiguousIterator b_it(shape, b_strides, ndim - 2); + auto stride = out_strides[ndim - 3]; + for (size_t elem = 0; elem < a.size(); elem += stride) { + binary_op_dims( + a_ptr + a_it.loc, + b_ptr + b_it.loc, + out_a_ptr + elem, + out_b_ptr + elem, + op, + shape, + a_strides, + b_strides, + out_strides, + ndim - 2); + a_it.step(); + b_it.step(); + } +} + +template +void binary_op( + const array& a, + const array& b, + array& out_a, + array& out_b, + Op op, + BinaryOpType bopt) { + // The full computation is scalar scalar so call the base op once + if (bopt == BinaryOpType::General) { + binary_op_dispatch_dims(a, b, out_a, out_b, op); + return; + } + + auto a_ptr = a.data(); + auto b_ptr = b.data(); + auto out_a_ptr = out_a.data(); + auto out_b_ptr = out_b.data(); + if (bopt == BinaryOpType::ScalarScalar) { + std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr); + } else if (bopt == BinaryOpType::ScalarVector) { + for (size_t i = 0; i < b.data_size(); ++i) { + std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr); + out_a_ptr++; + out_b_ptr++; + b_ptr++; + } + } else if (bopt == BinaryOpType::VectorScalar) { + for (size_t i = 0; i < a.data_size(); ++i) { + std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr); + out_a_ptr++; + out_b_ptr++; + a_ptr++; + } + } else { // VectorVector + for (size_t i = 0; i < a.size(); ++i) { + std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr); + out_a_ptr++; + out_b_ptr++; + a_ptr++; + b_ptr++; + } + } +} + +} // namespace + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/backend/cpu/compiled_preamble.h b/Source/Cxxmlx/include/mlx/backend/cpu/compiled_preamble.h new file mode 100644 index 00000000..31ca1b46 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/cpu/compiled_preamble.h @@ -0,0 +1,12 @@ +// Copyright © 2023-24 Apple Inc. + +#pragma once + +// clang-format off +#include "mlx/types/half_types.h" +#include "mlx/types/complex.h" +#include "mlx/backend/cpu/unary_ops.h" +#include "mlx/backend/cpu/binary_ops.h" +// clang-format on + +const char* get_kernel_preamble(); diff --git a/Source/Cxxmlx/include/mlx/backend/cpu/copy.h b/Source/Cxxmlx/include/mlx/backend/cpu/copy.h new file mode 100644 index 00000000..00729136 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/cpu/copy.h @@ -0,0 +1,36 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include + +#include "mlx/array.h" +#include "mlx/backend/common/copy.h" +#include "mlx/backend/common/utils.h" + +namespace mlx::core { + +void copy_cpu(const array& src, array& dst, CopyType ctype, Stream stream); +void copy_cpu_inplace( + const array& src, + array& dst, + CopyType ctype, + Stream stream); + +void copy_cpu_inplace( + const array& src, + array& dst, + const Shape& data_shape, + const Strides& i_strides, + const Strides& o_strides, + int64_t i_offset, + int64_t o_offset, + CopyType ctype, + Stream stream, + const std::optional& dynamic_i_offset = std::nullopt, + const std::optional& dynamic_o_offset = std::nullopt); + +// Return a contiguous array with same shape that copies the data of |arr|. +array contiguous_copy_cpu(const array& arr, Stream stream); + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/backend/cpu/device_info.h b/Source/Cxxmlx/include/mlx/backend/cpu/device_info.h new file mode 100644 index 00000000..1e232334 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/cpu/device_info.h @@ -0,0 +1,28 @@ +// Copyright © 2026 Apple Inc. + +#pragma once + +#include +#include +#include + +namespace mlx::core::cpu { + +bool is_available(); + +/** + * Get the number of available CPU devices. + * + * For CPU, always returns 1. + */ +int device_count(); + +/** + * Get CPU device information. + * + * Returns a map with basic CPU device properties. + */ +const std::unordered_map>& +device_info(int device_index = 0); + +} // namespace mlx::core::cpu diff --git a/Source/Cxxmlx/include/mlx/backend/cpu/encoder.h b/Source/Cxxmlx/include/mlx/backend/cpu/encoder.h new file mode 100644 index 00000000..e04179e5 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/cpu/encoder.h @@ -0,0 +1,67 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include + +#include "mlx/array.h" +#include "mlx/scheduler.h" + +namespace mlx::core::cpu { + +// Number of dispatches per scheduler task +constexpr int DISPATCHES_PER_TASK = 10; + +struct MLX_API CommandEncoder { + CommandEncoder(Stream stream) : stream_(stream) {} + + CommandEncoder(const CommandEncoder&) = delete; + CommandEncoder& operator=(const CommandEncoder&) = delete; + CommandEncoder(CommandEncoder&&) = delete; + CommandEncoder& operator=(CommandEncoder&&) = delete; + + void set_input_array(const array& a) {} + void set_output_array(array& a) {} + + // Hold onto a temporary until any already scheduled tasks which use it as + // an input are complete. + void add_temporary(array arr) { + temporaries_.push_back(std::move(arr)); + } + + void add_temporaries(std::vector arrays) { + temporaries_.insert( + temporaries_.end(), + std::make_move_iterator(arrays.begin()), + std::make_move_iterator(arrays.end())); + } + + std::vector& temporaries() { + return temporaries_; + } + + template + void dispatch(F&& f, Args&&... args) { + num_ops_ = (num_ops_ + 1) % DISPATCHES_PER_TASK; + auto task = std::bind(std::forward(f), std::forward(args)...); + if (num_ops_ == 0) { + scheduler::notify_new_task(stream_); + auto task_wrap = [s = stream_, task = std::move(task)]() mutable { + task(); + scheduler::notify_task_completion(s); + }; + scheduler::enqueue(stream_, std::move(task_wrap)); + } else { + scheduler::enqueue(stream_, std::move(task)); + } + } + + private: + Stream stream_; + std::vector temporaries_; + int num_ops_{0}; +}; + +MLX_API CommandEncoder& get_command_encoder(Stream stream); + +} // namespace mlx::core::cpu diff --git a/Source/Cxxmlx/include/mlx/backend/cpu/eval.h b/Source/Cxxmlx/include/mlx/backend/cpu/eval.h new file mode 100644 index 00000000..20156d61 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/cpu/eval.h @@ -0,0 +1,12 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/array.h" +#include "mlx/stream.h" + +namespace mlx::core::cpu { + +void eval(array& arr); + +} // namespace mlx::core::cpu diff --git a/Source/Cxxmlx/include/mlx/backend/cpu/gemm.h b/Source/Cxxmlx/include/mlx/backend/cpu/gemm.h new file mode 100644 index 00000000..d665cb91 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/cpu/gemm.h @@ -0,0 +1,26 @@ +// Copyright © 2025 Apple Inc. + +#pragma once +#include "mlx/array.h" + +namespace mlx::core { + +template +void matmul( + const T* a, + const T* b, + T* out, + bool a_transposed, + bool b_transposed, + size_t lda, + size_t ldb, + size_t ldc, + float alpha, + float beta, + size_t batch_size, + const Shape& a_shape, + const Strides& a_strides, + const Shape& b_shape, + const Strides& b_strides); + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/backend/cpu/gemms/simd_gemm.h b/Source/Cxxmlx/include/mlx/backend/cpu/gemms/simd_gemm.h new file mode 100644 index 00000000..a23c7dea --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/cpu/gemms/simd_gemm.h @@ -0,0 +1,139 @@ +// Copyright © 2025 Apple Inc. +#pragma once + +#include "mlx/backend/cpu/simd/simd.h" + +namespace mlx::core { + +inline int ceildiv(int a, int b) { + return (a + b - 1) / b; +} + +template +void load_block( + const T* in, + AccT* out, + int M, + int N, + int i, + int j, + bool transpose) { + if (transpose) { + for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) { + for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) { + out[jj * block_size + ii] = + in[(i * block_size + ii) * N + j * block_size + jj]; + } + } + } else { + for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) { + for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) { + out[ii * block_size + jj] = + in[(i * block_size + ii) * N + j * block_size + jj]; + } + } + } +} + +template +void simd_gemm( + const T* a, + const T* b, + T* c, + bool a_trans, + bool b_trans, + int M, + int N, + int K, + float alpha, + float beta) { + constexpr int block_size = 16; + constexpr int simd_size = simd::max_size; + static_assert( + (block_size % simd_size) == 0, + "Block size must be divisible by SIMD size"); + + int last_k_block_size = K - block_size * (K / block_size); + int last_k_simd_block = (last_k_block_size / simd_size) * simd_size; + for (int i = 0; i < ceildiv(M, block_size); i++) { + for (int j = 0; j < ceildiv(N, block_size); j++) { + AccT c_block[block_size * block_size] = {0.0}; + AccT a_block[block_size * block_size]; + AccT b_block[block_size * block_size]; + + int k = 0; + for (; k < K / block_size; k++) { + // Load a and b blocks + if (a_trans) { + load_block(a, a_block, K, M, k, i, true); + } else { + load_block(a, a_block, M, K, i, k, false); + } + if (b_trans) { + load_block(b, b_block, N, K, j, k, false); + } else { + load_block(b, b_block, K, N, k, j, true); + } + + // Multiply and accumulate + for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) { + for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) { + for (int kk = 0; kk < block_size; kk += simd_size) { + auto av = + simd::load(a_block + ii * block_size + kk); + auto bv = + simd::load(b_block + jj * block_size + kk); + c_block[ii * block_size + jj] += simd::sum(av * bv); + } + } + } + } + if (last_k_block_size) { + // Load a and b blocks + if (a_trans) { + load_block(a, a_block, K, M, k, i, true); + } else { + load_block(a, a_block, M, K, i, k, false); + } + if (b_trans) { + load_block(b, b_block, N, K, j, k, false); + } else { + load_block(b, b_block, K, N, k, j, true); + } + + // Multiply and accumulate + for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) { + for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) { + int kk = 0; + for (; kk < last_k_simd_block; kk += simd_size) { + auto av = + simd::load(a_block + ii * block_size + kk); + auto bv = + simd::load(b_block + jj * block_size + kk); + c_block[ii * block_size + jj] += simd::sum(av * bv); + } + for (; kk < last_k_block_size; ++kk) { + c_block[ii * block_size + jj] += + a_block[ii * block_size + kk] * b_block[jj * block_size + kk]; + } + } + } + } + + // Store + for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) { + for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) { + auto c_idx = (i * block_size + ii) * N + j * block_size + jj; + if (beta != 0) { + c[c_idx] = static_cast( + alpha * c_block[ii * block_size + jj] + beta * c[c_idx]); + } else { + c[c_idx] = static_cast(alpha * c_block[ii * block_size + jj]); + } + } + } + } + } +} + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/backend/cpu/jit_compiler.h b/Source/Cxxmlx/include/mlx/backend/cpu/jit_compiler.h new file mode 100644 index 00000000..3a9e988d --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/cpu/jit_compiler.h @@ -0,0 +1,20 @@ +// Copyright © 2024 Apple Inc. +#pragma once + +#include + +namespace mlx::core { + +class JitCompiler { + public: + // Build a shell command that compiles a source code file to a shared library. + static std::string build_command( + const std::filesystem::path& dir, + const std::string& source_file_name, + const std::string& shared_lib_name); + + // Run a command and get its output. + static std::string exec(const std::string& cmd); +}; + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/backend/cpu/lapack.h b/Source/Cxxmlx/include/mlx/backend/cpu/lapack.h new file mode 100644 index 00000000..1c3ba1a8 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/cpu/lapack.h @@ -0,0 +1,80 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include +#define LAPACK_COMPLEX_CUSTOM +#define lapack_complex_float std::complex +#define lapack_complex_double std::complex +#define lapack_complex_float_real(z) ((z).real()) +#define lapack_complex_float_imag(z) ((z).imag()) +#define lapack_complex_double_real(z) ((z).real()) +#define lapack_complex_double_imag(z) ((z).imag()) + +#ifdef MLX_USE_ACCELERATE +#include +#else +#include +#include +#endif + +#if defined(LAPACK_GLOBAL) || defined(LAPACK_NAME) + +// This is to work around a change in the function signatures of lapack >= 3.9.1 +// where functions taking char* also include a strlen argument, see a similar +// change in OpenCV: +// https://github.com/opencv/opencv/blob/1eb061f89de0fb85c4c75a2deeb0f61a961a63ad/cmake/OpenCVFindLAPACK.cmake#L57 +#define MLX_LAPACK_FUNC(f) LAPACK_##f + +#else + +#define MLX_LAPACK_FUNC(f) f##_ + +#endif + +#define INSTANTIATE_LAPACK_REAL(FUNC) \ + template \ + void FUNC(Args... args) { \ + if constexpr (std::is_same_v) { \ + MLX_LAPACK_FUNC(s##FUNC)(std::forward(args)...); \ + } else if constexpr (std::is_same_v) { \ + MLX_LAPACK_FUNC(d##FUNC)(std::forward(args)...); \ + } \ + } + +INSTANTIATE_LAPACK_REAL(geqrf) +INSTANTIATE_LAPACK_REAL(orgqr) +INSTANTIATE_LAPACK_REAL(syevd) +INSTANTIATE_LAPACK_REAL(potrf) +INSTANTIATE_LAPACK_REAL(getrf) +INSTANTIATE_LAPACK_REAL(getri) +INSTANTIATE_LAPACK_REAL(trtri) + +#define INSTANTIATE_LAPACK_COMPLEX(FUNC) \ + template \ + void FUNC(Args... args) { \ + if constexpr (std::is_same_v>) { \ + MLX_LAPACK_FUNC(c##FUNC)(std::forward(args)...); \ + } else if constexpr (std::is_same_v>) { \ + MLX_LAPACK_FUNC(z##FUNC)(std::forward(args)...); \ + } \ + } + +INSTANTIATE_LAPACK_COMPLEX(heevd) + +#define INSTANTIATE_LAPACK_ALL(FUNC) \ + template \ + void FUNC(Args... args) { \ + if constexpr (std::is_same_v) { \ + MLX_LAPACK_FUNC(s##FUNC)(std::forward(args)...); \ + } else if constexpr (std::is_same_v) { \ + MLX_LAPACK_FUNC(d##FUNC)(std::forward(args)...); \ + } else if constexpr (std::is_same_v>) { \ + MLX_LAPACK_FUNC(c##FUNC)(std::forward(args)...); \ + } else if constexpr (std::is_same_v>) { \ + MLX_LAPACK_FUNC(z##FUNC)(std::forward(args)...); \ + } \ + } + +INSTANTIATE_LAPACK_ALL(geev) +INSTANTIATE_LAPACK_ALL(gesdd) diff --git a/Source/Cxxmlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h b/Source/Cxxmlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h new file mode 100644 index 00000000..95054489 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h @@ -0,0 +1,56 @@ +#pragma once + +#include "mlx/backend/cpu/simd/base_simd.h" + +#if MLX_SIMD_LIBRARY_VERSION < 6 +#include "mlx/backend/cpu/simd/neon_fp16_simd.h" +#endif + +namespace mlx::core::simd { + +#if MLX_SIMD_LIBRARY_VERSION >= 6 +constexpr int N = 8; +template +struct ScalarT { + using v = _Float16; +}; +#endif + +template <> +inline constexpr int max_size = N; + +#define SIMD_FP16_DEFAULT_UNARY(op) \ + template <> \ + inline Simd op(Simd v) { \ + Simd in = v; \ + return op(in); \ + } + +SIMD_FP16_DEFAULT_UNARY(acos) +SIMD_FP16_DEFAULT_UNARY(acosh) +SIMD_FP16_DEFAULT_UNARY(asin) +SIMD_FP16_DEFAULT_UNARY(asinh) +SIMD_FP16_DEFAULT_UNARY(atan) +SIMD_FP16_DEFAULT_UNARY(atanh) +SIMD_FP16_DEFAULT_UNARY(cosh) +SIMD_FP16_DEFAULT_UNARY(expm1) +SIMD_FP16_DEFAULT_UNARY(log) +SIMD_FP16_DEFAULT_UNARY(log2) +SIMD_FP16_DEFAULT_UNARY(log10) +SIMD_FP16_DEFAULT_UNARY(log1p) +SIMD_FP16_DEFAULT_UNARY(sinh) +SIMD_FP16_DEFAULT_UNARY(tan) +SIMD_FP16_DEFAULT_UNARY(tanh) + +#define SIMD_FP16_DEFAULT_BINARY(op) \ + template <> \ + inline Simd op(Simd x, Simd y) { \ + Simd a = x; \ + Simd b = y; \ + return op(a, b); \ + } +SIMD_FP16_DEFAULT_BINARY(atan2) +SIMD_FP16_DEFAULT_BINARY(remainder) +SIMD_FP16_DEFAULT_BINARY(pow) + +} // namespace mlx::core::simd diff --git a/Source/Cxxmlx/include/mlx/backend/cpu/simd/accelerate_simd.h b/Source/Cxxmlx/include/mlx/backend/cpu/simd/accelerate_simd.h new file mode 100644 index 00000000..f62c67d3 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/cpu/simd/accelerate_simd.h @@ -0,0 +1,329 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include + +#include "mlx/backend/cpu/simd/base_simd.h" + +// There seems to be a bug in simd/base_simd.h +// __XROS_2_0 is not defined, the expression evaluates +// to true instead of false setting the SIMD library +// higher than it should be even on macOS < 15 +#if __MAC_OS_X_VERSION_MIN_REQUIRED >= 150000 || \ + __IPHONE_OS_VERSION_MIN_REQUIRED >= 180000 || \ + __WATCH_OS_VERSION_MIN_REQUIRED >= 110000 || \ + __WATCH_OS_VERSION_MIN_REQUIRED >= 110000 || \ + __TV_OS_VERSION_MIN_REQUIRED >= 180000 +#define MLX_SIMD_LIBRARY_VERSION 6 +#else +#define MLX_SIMD_LIBRARY_VERSION 5 +#endif + +namespace mlx::core::simd { + +// Apple simd namespace +namespace asd = ::simd; + +// This indirection is needed to remap certain types to ones that accelerate +// SIMD can handle +template +struct ScalarT { + using v = T; +}; +template +struct ScalarT { + using v = char; +}; +template +struct ScalarT { + using v = char; +}; +template +struct ScalarT { + using v = unsigned long; +}; +template +struct ScalarT { + using v = long; +}; + +template +struct Simd { + static constexpr int size = N; + using scalar_t = typename ScalarT::v; + + Simd() {} + + template + Simd(Simd other) : value(asd::convert(other.value)) {} + + template + Simd(U v) : value(v){}; + + Simd(Simd x, Simd y) { + value = asd::make::packed_t>( + x.value, y.value); + }; + + T operator[](int idx) const { + return reinterpret_cast(&value)[idx]; + } + + T& operator[](int idx) { + return reinterpret_cast(&value)[idx]; + } + + typename asd::Vector::packed_t value; +}; + +// Values chosen based on benchmarks on M3 Max +// TODO: consider choosing these more optimally +template <> +inline constexpr int max_size = 16; +template <> +inline constexpr int max_size = 16; +template <> +inline constexpr int max_size = 8; +template <> +inline constexpr int max_size = 4; +template <> +inline constexpr int max_size = 16; +template <> +inline constexpr int max_size = 16; +template <> +inline constexpr int max_size = 8; +template <> +inline constexpr int max_size = 4; +template <> +inline constexpr int max_size = 8; +template <> +inline constexpr int max_size = 4; + +#define SIMD_DEFAULT_UNARY(name, op) \ + template \ + Simd name(Simd v) { \ + return op(v.value); \ + } + +SIMD_DEFAULT_UNARY(abs, asd::abs) +SIMD_DEFAULT_UNARY(floor, asd::floor) +SIMD_DEFAULT_UNARY(acos, asd::acos) +SIMD_DEFAULT_UNARY(acosh, asd::acosh) +SIMD_DEFAULT_UNARY(asin, asd::asin) +SIMD_DEFAULT_UNARY(asinh, asd::asinh) +SIMD_DEFAULT_UNARY(atan, asd::atan) +SIMD_DEFAULT_UNARY(atanh, asd::atanh) +SIMD_DEFAULT_UNARY(ceil, asd::ceil) +SIMD_DEFAULT_UNARY(cosh, asd::cosh) +SIMD_DEFAULT_UNARY(expm1, asd::expm1) +SIMD_DEFAULT_UNARY(log, asd::log) +SIMD_DEFAULT_UNARY(log2, asd::log2) +SIMD_DEFAULT_UNARY(log10, asd::log10) +SIMD_DEFAULT_UNARY(log1p, asd::log1p) +SIMD_DEFAULT_UNARY(rint, asd::rint) +SIMD_DEFAULT_UNARY(sinh, asd::sinh) +SIMD_DEFAULT_UNARY(sqrt, asd::sqrt) +SIMD_DEFAULT_UNARY(rsqrt, asd::rsqrt) +SIMD_DEFAULT_UNARY(recip, asd::recip) +SIMD_DEFAULT_UNARY(tan, asd::tan) +SIMD_DEFAULT_UNARY(tanh, asd::tanh) + +template +Simd operator-(Simd v) { + return -v.value; +} + +template +Simd operator~(Simd v) { + return ~v.value; +} + +template +Simd isnan(Simd v) { + return asd::convert(v.value != v.value); +} + +// No simd_boolN in accelerate, use int8_t instead +template +Simd operator!(Simd v) { + return asd::convert(!v.value); +} + +#define SIMD_DEFAULT_BINARY(OP) \ + template \ + Simd operator OP(Simd x, U y) { \ + return asd::convert::scalar_t>(x.value OP y); \ + } \ + template \ + Simd operator OP(T1 x, Simd y) { \ + return asd::convert::scalar_t>(x OP y.value); \ + } \ + template \ + Simd operator OP(Simd x, Simd y) { \ + return asd::convert::scalar_t>(x.value OP y.value); \ + } + +SIMD_DEFAULT_BINARY(+) +SIMD_DEFAULT_BINARY(-) +SIMD_DEFAULT_BINARY(/) +SIMD_DEFAULT_BINARY(*) +SIMD_DEFAULT_BINARY(<<) +SIMD_DEFAULT_BINARY(>>) +SIMD_DEFAULT_BINARY(|) +SIMD_DEFAULT_BINARY(^) +SIMD_DEFAULT_BINARY(&) +SIMD_DEFAULT_BINARY(&&) +SIMD_DEFAULT_BINARY(||) + +#define SIMD_DEFAULT_COMPARISONS(OP) \ + template \ + Simd operator OP(Simd a, U b) { \ + return asd::convert(a.value OP b); \ + } \ + template \ + Simd operator OP(T a, Simd b) { \ + return asd::convert(a OP b.value); \ + } \ + template \ + Simd operator OP(Simd a, Simd b) { \ + return asd::convert(a.value OP b.value); \ + } + +SIMD_DEFAULT_COMPARISONS(>) +SIMD_DEFAULT_COMPARISONS(<) +SIMD_DEFAULT_COMPARISONS(>=) +SIMD_DEFAULT_COMPARISONS(<=) +SIMD_DEFAULT_COMPARISONS(==) +SIMD_DEFAULT_COMPARISONS(!=) + +template +Simd clz(Simd x) { + auto a = *(uint32x4_t*)(&x); + auto b = *((uint32x4_t*)(&x) + 1); + a = vclzq_u32(a); + b = vclzq_u32(b); + return asd::make_uint8(a, b); +} + +template +Simd atan2(Simd a, Simd b) { + return asd::atan2(a.value, b.value); +} + +template +Simd maximum(Simd a, Simd b) { + auto out = Simd(asd::max(a.value, b.value)); + if constexpr (!std::is_integral_v) { + out = select(isnan(b), b, select(isnan(a), a, out)); + } + return out; +} + +template +Simd minimum(Simd a, Simd b) { + auto out = Simd(asd::min(a.value, b.value)); + if constexpr (!std::is_integral_v) { + out = select(isnan(b), b, select(isnan(a), a, out)); + } + return out; +} + +template +Simd remainder(Simd a, Simd b) { + Simd r; + if constexpr (!std::is_integral_v) { + r = asd::remainder(a.value, b.value); + } else { + r = a - b * (a / b); + } + if constexpr (std::is_signed_v) { + auto mask = r != 0 && (r < 0 != b < 0); + r = select(mask, r + b, r); + } + return r; +} + +template +Simd select(Simd mask, Simd x, Simd y) { + static_assert(std::is_same_v); + if constexpr (sizeof(T1) == 1) { + return asd::bitselect(y.value, x.value, asd::convert(mask.value)); + } else if constexpr (sizeof(T1) == 2) { + return asd::bitselect(y.value, x.value, asd::convert(mask.value)); + } else if constexpr (sizeof(T1) == 4) { + return asd::bitselect(y.value, x.value, asd::convert(mask.value)); + } else { + return asd::bitselect(y.value, x.value, asd::convert(mask.value)); + } +} + +template +Simd pow(Simd base, Simd exp) { + if constexpr (!std::is_integral_v) { + return asd::pow(base.value, exp.value); + } else { + Simd res = 1; + // Raising an integer to a negative power is undefined + if (any(exp < 0)) { + return 0; + } + while (any(exp > 0)) { + res = select((exp & 1) != 0, res * base, res); + base = select(exp > 0, base * base, base); + exp = exp >> 1; + } + return res; + } +} + +template +Simd clamp(Simd v, Simd min, Simd max) { + return asd::clamp(v.value, min.value, max.value); +} + +template +Simd fma(Simd x, Simd y, U z) { + return asd::muladd(x.value, y.value, Simd(z).value); +} + +// Reductions + +template +bool all(Simd x) { + return asd::all(x.value); +} +template +bool any(Simd x) { + return asd::any(x.value); +} +template +T sum(Simd x) { + return asd::reduce_add(x.value); +} +template +T max(Simd x) { + return asd::reduce_max(x.value); +} +template +T min(Simd x) { + return asd::reduce_min(x.value); +} + +template +T prod(Simd x) { + auto ptr = (T*)&x; + auto lhs = load(ptr); + auto rhs = load(ptr + N / 2); + return prod(lhs * rhs); +} + +} // namespace mlx::core::simd + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#include "mlx/backend/cpu/simd/accelerate_fp16_simd.h" +#endif diff --git a/Source/Cxxmlx/include/mlx/backend/cpu/simd/base_simd.h b/Source/Cxxmlx/include/mlx/backend/cpu/simd/base_simd.h new file mode 100644 index 00000000..775f5dfd --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/cpu/simd/base_simd.h @@ -0,0 +1,319 @@ +#pragma once + +// Required for using M_LN2 in MSVC. +#define _USE_MATH_DEFINES + +#include +#include +#include +#include +#include + +#ifdef _MSC_VER +#include // For _BitScanReverse +#endif + +namespace mlx::core::simd { +template +struct Simd; + +template +static constexpr int max_size = 1; + +template +struct Simd { + static constexpr int size = 1; + T value; + Simd() {} + template + Simd(Simd v) : value(v.value) {} + template + Simd(U v) : value(v) {} + + T operator[](int) const { + return value; + } + + T& operator[](int) { + return value; + } +}; + +template +Simd load(const T* x) { + return *(Simd*)x; +} + +template +void store(T* dst, Simd x) { + // Maintain invariant that bool is either 0 or 1 as + // simd comparison ops set all bits in the result to 1 + if constexpr (std::is_same_v && N > 1) { + x = x & 1; + } + *(Simd*)dst = x; +} + +template +constexpr bool is_complex = false; + +template +constexpr bool is_complex().real())>> = + true; + +template +Simd rint(Simd in) { + if constexpr (is_complex) { + return Simd{ + T{std::rint(in.value.real()), std::rint(in.value.imag())}}; + } else { + return Simd{std::rint(in.value)}; + } +} + +template +Simd rsqrt(Simd in) { + return T(1.0) / sqrt(in); +} + +template +Simd recip(Simd in) { + return T(1.0) / in; +} + +#define DEFAULT_UNARY(name, op) \ + template \ + Simd name(Simd in) { \ + return op(in.value); \ + } + +DEFAULT_UNARY(operator-, std::negate{}) +DEFAULT_UNARY(operator!, std::logical_not{}) +DEFAULT_UNARY(abs, std::abs) +DEFAULT_UNARY(acos, std::acos) +DEFAULT_UNARY(acosh, std::acosh) +DEFAULT_UNARY(asin, std::asin) +DEFAULT_UNARY(asinh, std::asinh) +DEFAULT_UNARY(atan, std::atan) +DEFAULT_UNARY(atanh, std::atanh) +DEFAULT_UNARY(ceil, std::ceil) +DEFAULT_UNARY(conj, std::conj) +DEFAULT_UNARY(cosh, std::cosh) +DEFAULT_UNARY(expm1, std::expm1) +DEFAULT_UNARY(floor, std::floor) +DEFAULT_UNARY(log, std::log) +DEFAULT_UNARY(log10, std::log10) +DEFAULT_UNARY(sinh, std::sinh) +DEFAULT_UNARY(sqrt, std::sqrt) +DEFAULT_UNARY(tan, std::tan) +DEFAULT_UNARY(tanh, std::tanh) + +template +Simd log1p(Simd in) { + if constexpr (is_complex) { + auto x = in.value.real(); + auto y = in.value.imag(); + auto zabs = std::abs(in.value); + auto theta = std::atan2(y, x + 1); + if (zabs < 0.5) { + auto r = x * (2 + x) + y * y; + if (r == 0) { // handle underflow + return Simd{T{x, theta}}; + } + return Simd{T{((decltype(x))(0.5)) * std::log1p(r), theta}}; + } else { + auto z0 = std::hypot(x + 1, y); + return Simd{T{std::log(z0), theta}}; + } + } else { + return Simd{std::log1p(in.value)}; + } +} + +template +Simd log2(Simd in) { + if constexpr (is_complex) { + auto out = std::log(in.value); + auto scale = decltype(out.real())(M_LN2); + return Simd{T{out.real() / scale, out.imag() / scale}}; + } else { + return Simd{std::log2(in.value)}; + } +} + +template +Simd operator~(Simd in) { + return ~in.value; +} + +template +auto real(Simd in) -> Simd { + return std::real(in.value); +} +template +auto imag(Simd in) -> Simd { + return std::imag(in.value); +} +template +Simd isnan(Simd in) { + return std::isnan(in.value); +} + +#define DEFAULT_BINARY(OP) \ + template \ + auto operator OP(Simd a, Simd b) \ + ->Simd { \ + return a.value OP b.value; \ + } \ + template \ + auto operator OP(T1 a, Simd b)->Simd { \ + return a OP b.value; \ + } \ + template \ + auto operator OP(Simd a, T2 b)->Simd { \ + return a.value OP b; \ + } + +DEFAULT_BINARY(+) +DEFAULT_BINARY(-) +DEFAULT_BINARY(*) +DEFAULT_BINARY(/) +DEFAULT_BINARY(<<) +DEFAULT_BINARY(>>) +DEFAULT_BINARY(|) +DEFAULT_BINARY(^) +DEFAULT_BINARY(&) +DEFAULT_BINARY(&&) +DEFAULT_BINARY(||) + +template +Simd clz(Simd x_) { +#ifdef _MSC_VER + // MSVC doesn't have __builtin_clz, use _BitScanReverse instead + unsigned long index; + if (_BitScanReverse(&index, static_cast(x_.value))) { + return static_cast(31 - index); + } + return static_cast(32); // All zeros case +#else + return __builtin_clz(x_.value); +#endif +} + +template +Simd remainder(Simd a_, Simd b_) { + T a = a_.value; + T b = b_.value; + T r; + if constexpr (std::is_integral_v) { + r = a % b; + } else { + r = std::remainder(a, b); + } + if constexpr (std::is_signed_v) { + if (r != 0 && (r < 0 != b < 0)) { + r += b; + } + } + return r; +} + +template +Simd maximum(Simd a_, Simd b_) { + T a = a_.value; + T b = b_.value; + if constexpr (!std::is_integral_v) { + if (std::isnan(a)) { + return a; + } + } + return (a > b) ? a : b; +} + +template +Simd minimum(Simd a_, Simd b_) { + T a = a_.value; + T b = b_.value; + if constexpr (!std::is_integral_v) { + if (std::isnan(a)) { + return a; + } + } + return (a < b) ? a : b; +} + +template +Simd pow(Simd a, Simd b) { + T base = a.value; + T exp = b.value; + if constexpr (!std::is_integral_v) { + return std::pow(base, exp); + } else { + T res = 1; + while (exp) { + if (exp & 1) { + res *= base; + } + exp >>= 1; + base *= base; + } + return res; + } +} + +template +Simd atan2(Simd a, Simd b) { + return std::atan2(a.value, b.value); +} + +#define DEFAULT_COMPARISONS(OP) \ + template \ + Simd operator OP(Simd a, Simd b) { \ + return a.value OP b.value; \ + } \ + template \ + Simd operator OP(T1 a, Simd b) { \ + return a OP b.value; \ + } \ + template \ + Simd operator OP(Simd a, T2 b) { \ + return a.value OP b; \ + } + +DEFAULT_COMPARISONS(>) +DEFAULT_COMPARISONS(<) +DEFAULT_COMPARISONS(>=) +DEFAULT_COMPARISONS(<=) +DEFAULT_COMPARISONS(==) +DEFAULT_COMPARISONS(!=) + +template +Simd select(Simd mask, Simd x, Simd y) { + return mask.value ? x.value : y.value; +} + +template +Simd clamp(Simd v, Simd min, Simd max) { + return std::clamp(v.value, min.value, max.value); +} + +template +Simd fma(Simd x, Simd y, U z) { + return std::fma(x.value, y.value, Simd(z).value); +} + +// Reductions +#define DEFAULT_REDUCTION(name, type) \ + template \ + type name(Simd x) { \ + return x.value; \ + } + +DEFAULT_REDUCTION(max, T) +DEFAULT_REDUCTION(min, T) +DEFAULT_REDUCTION(sum, T) +DEFAULT_REDUCTION(prod, T) +DEFAULT_REDUCTION(any, bool) +DEFAULT_REDUCTION(all, bool) + +} // namespace mlx::core::simd diff --git a/Source/Cxxmlx/include/mlx/backend/cpu/simd/math.h b/Source/Cxxmlx/include/mlx/backend/cpu/simd/math.h new file mode 100644 index 00000000..f9fc8317 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/cpu/simd/math.h @@ -0,0 +1,193 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/backend/cpu/simd/type.h" + +namespace mlx::core::simd { + +constexpr float inf = std::numeric_limits::infinity(); + +/** + * Compute exp(x) in an optimizer friendly way as follows: + * + * First change the problem to computing 2**y where y = x / ln(2). + * + * Now we will compute 2**y as 2**y1 * 2**y2 where y1 is the integer part + * `ipart` and y2 is fractional part. For the integer part we perform bit + * shifting and for the fractional part we use a polynomial approximation. + * + * The algorithm and constants of the polynomial taken from + * https://github.com/akohlmey/fastermath/blob/master/src/exp.c which took them + * from Cephes math library. + * + * Note: The implementation below is a general fast exp. There could be faster + * implementations for numbers strictly < 0. + */ +template +Simd exp(Simd in) { + if constexpr (is_complex) { + return Simd{std::exp(in.value)}; + } else { + Simd x_init = in; + auto x = x_init * 1.442695f; // multiply with log_2(e) + Simd ipart, fpart; + ipart = floor(x + 0.5); + fpart = x - ipart; + + x = 1.535336188319500e-4f; + x = fma(x, fpart, 1.339887440266574e-3f); + x = fma(x, fpart, 9.618437357674640e-3f); + x = fma(x, fpart, 5.550332471162809e-2f); + x = fma(x, fpart, 2.402264791363012e-1f); + x = fma(x, fpart, 6.931472028550421e-1f); + x = fma(x, fpart, 1.000000000000000f); + + // generate 2**ipart in the floating point representation using integer + // bitshifting + Simd epart = (Simd(ipart) + 127) << 23; + + // Deal with NaN and Inf + auto result = select(isnan(x_init), x_init, (*(Simd*)&epart) * x); + result = select(x_init > 88.0f, Simd(inf), result); + result = select(x_init < -88.0f, Simd(0), result); + return Simd(result); + } +} + +/* Implementation from: + * https://github.com/JishinMaster/simd_utils/blob/3c1433a86fb38edcc9b02039f3c9a65b16640976/neon_mathfun.h#L357 + * which originally came from the Cephes math library. + */ +template +Simd sincos(Simd in) { + auto sign_mask_sin = in < 0; + in = abs(in); + Simd x = in; + + // scale by 4/Pi + auto y = x * 1.27323954473516f; + + // store the integer part of y in mm0 + Simd emm2 = y; + + // j=(j+1) & (~1) (see the cephes sources) + emm2 = emm2 + 1; + emm2 = emm2 & ~1; + + y = emm2; + + // Get the polynom selection mask. There is one polynom for 0 <= x <= Pi/4 + // and another one for Pi/4(-0.78515625f), x); + x = fma(y, Simd(-2.4187564849853515625e-4f), x); + x = fma(y, Simd(-3.77489497744594108e-8f), x); + + sign_mask_sin = sign_mask_sin ^ ((emm2 & 4) != 0); + auto sign_mask_cos = ((emm2 - 2) & 4) != 0; + + // Evaluate the first polynom (0 <= x <= Pi/4) in y1, + // and the second polynom (Pi/4 <= x <= 0) in y2 + auto z = x * x; + + auto y1 = + fma(z, Simd(2.443315711809948e-5f), -1.388731625493765e-3f); + auto y2 = fma(z, Simd(-1.9515295891e-4f), 8.3321608736e-3f); + y1 = fma(y1, z, 4.166664568298827e-2f); + y2 = fma(y2, z, -1.6666654611e-1f); + y1 = y1 * z; + y2 = y2 * z; + y1 = y1 * z; + y2 = fma(x, y2, x); + y1 = fma(z, Simd(-0.5f), y1); + y1 = y1 + 1.0f; + + if constexpr (Sine) { + auto ys = select(poly_mask, y1, y2); + return select(sign_mask_sin, -ys, ys); + } else { + auto yc = select(poly_mask, y2, y1); + return select(sign_mask_cos, yc, -yc); + } +} + +template +Simd sin(Simd x) { + if constexpr (is_complex) { + return std::sin(x.value); + } else { + return sincos(x); + } +} + +template +Simd cos(Simd x) { + if constexpr (is_complex) { + return std::cos(x.value); + } else { + return sincos(x); + } +} + +template +Simd erf(Simd x) { + // https://github.com/pytorch/pytorch/blob/abf28982a8cb43342e7669d859de9543fd804cc9/aten/src/ATen/cpu/vec/vec256/vec256_float.h#L175 + Simd v = x; + auto t = recip(fma(Simd(0.3275911f), abs(v), 1.0f)); + auto r = fma(Simd(1.061405429f), t, -1.453152027f); + r = fma(r, t, 1.421413741f); + r = fma(r, t, -0.284496736f); + r = fma(r, t, 0.254829592f); + auto e = -exp(-v * v); + auto result = Simd(fma(e * t, r, 1.0f)); + return select(x > 0, result, -result); +} + +template +Simd erfinv(Simd a_) { + Simd a = a_; + auto t = fma(a, 0.0f - a, 1.0f); + t = log(t); + auto lhs = [](auto t) { + Simd p; + p = 3.03697567e-10f; // 0x1.4deb44p-32 + p = fma(p, t, 2.93243101e-8f); // 0x1.f7c9aep-26 + p = fma(p, t, 1.22150334e-6f); // 0x1.47e512p-20 + p = fma(p, t, 2.84108955e-5f); // 0x1.dca7dep-16 + p = fma(p, t, 3.93552968e-4f); // 0x1.9cab92p-12 + p = fma(p, t, 3.02698812e-3f); // 0x1.8cc0dep-9 + p = fma(p, t, 4.83185798e-3f); // 0x1.3ca920p-8 + p = fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2 + return fma(p, t, 8.40016484e-1f); // 0x1.ae16a4p-1 + }; + auto rhs = [](auto t) { + Simd p; + p = 5.43877832e-9f; // 0x1.75c000p-28 + p = fma(p, t, 1.43285448e-7f); // 0x1.33b402p-23 + p = fma(p, t, 1.22774793e-6f); // 0x1.499232p-20 + p = fma(p, t, 1.12963626e-7f); // 0x1.e52cd2p-24 + p = fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15 + p = fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13 + p = fma(p, t, 2.31468678e-3f); // 0x1.2f6400p-9 + p = fma(p, t, 1.15392581e-2f); // 0x1.7a1e50p-7 + p = fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3 + return fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1 + }; + auto thresh = 6.125f; + // Compute both branches and select if N > 1 + if constexpr (N == 1) { + if ((abs(t) > thresh).value) { // maximum ulp error = 2.35793 + return a * lhs(t); + } else { // maximum ulp error = 2.35002 + return a * rhs(t); + } + } else { + return a * select(abs(t) > thresh, lhs(t), rhs(t)); + } +} + +} // namespace mlx::core::simd diff --git a/Source/Cxxmlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h b/Source/Cxxmlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h new file mode 100644 index 00000000..5d32042c --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h @@ -0,0 +1,212 @@ +#pragma once + +#include + +#include "mlx/backend/cpu/simd/base_simd.h" + +namespace mlx::core::simd { + +constexpr int N = 8; + +template <> +struct Simd { + static constexpr int size = N; + using scalar_t = float16_t; + + Simd() {} + + template + Simd(U v) : value(vdupq_n_f16(v)){}; + + Simd(float16x8_t v) : value(v){}; + + Simd(Simd other) { + auto f32x4_a = *(float32x4_t*)(&other); + auto f32x4_b = *((float32x4_t*)(&other) + 1); + value = vcvt_high_f16_f32(vcvt_f16_f32(f32x4_a), f32x4_b); + }; + + Simd(Simd other) { + value = vcvtq_f16_u16(*(uint16x8_t*)(&other.value)); + }; + + operator Simd() { + auto v = vcvtq_s16_f16(value); + return load((int16_t*)&v); + }; + + operator Simd() { + float32x4x2_t v; + v.val[0] = vcvt_f32_f16(*(float16x4_t*)(&value)); + v.val[1] = vcvt_high_f32_f16(value); + return load((float*)&v); + } + float16_t operator[](int idx) const { + return reinterpret_cast(&value)[idx]; + } + + float16_t& operator[](int idx) { + return reinterpret_cast(&value)[idx]; + } + + float16x8_t value; +}; + +#define DEFINE_NEON_UNARY_OP(name, op) \ + inline Simd name(Simd a) { \ + return Simd{op(a.value)}; \ + } + +DEFINE_NEON_UNARY_OP(abs, vabsq_f16) +DEFINE_NEON_UNARY_OP(ceil, vrndpq_f16) +DEFINE_NEON_UNARY_OP(floor, vrndmq_f16) +DEFINE_NEON_UNARY_OP(sqrt, vsqrtq_f16) +DEFINE_NEON_UNARY_OP(rsqrt, vrsqrteq_f16) +DEFINE_NEON_UNARY_OP(recip, vrecpeq_f16) +DEFINE_NEON_UNARY_OP(rint, vrndnq_f16) + +#define DEFINE_NEON_BINARY_OP(name, op) \ + inline Simd name(Simd a, Simd b) { \ + return op(a.value, b.value); \ + } \ + template \ + Simd name(Simd a, T b) { \ + return op(a.value, Simd(b).value); \ + } \ + template \ + Simd name(T a, Simd b) { \ + return op(Simd(a).value, b.value); \ + } + +inline Simd operator!(Simd v) { + auto out = vceqzq_f16(v.value); + return Simd(*(uint16_t*)&out); +} + +inline Simd operator-(Simd v) { + return vnegq_f16(v.value); +} + +DEFINE_NEON_BINARY_OP(maximum, vmaxq_f16) +DEFINE_NEON_BINARY_OP(minimum, vminq_f16) +DEFINE_NEON_BINARY_OP(operator+, vaddq_f16) +DEFINE_NEON_BINARY_OP(operator-, vsubq_f16) +DEFINE_NEON_BINARY_OP(operator*, vmulq_f16) +DEFINE_NEON_BINARY_OP(operator/, vdivq_f16) + +#define DEFINE_NEON_COMPARISON(Op, op) \ + template \ + Simd operator Op(Simd a, T b) { \ + auto out = op(a.value, Simd(b).value); \ + return Simd(*(uint16_t*)(&out)); \ + } \ + template \ + Simd operator Op(T a, Simd b) { \ + auto out = op(Simd(a).value, b.value); \ + return Simd(*(uint16_t*)(&out)); \ + } \ + inline Simd operator Op( \ + Simd a, Simd b) { \ + auto out = op(a.value, b.value); \ + return Simd(*(uint16_t*)(&out)); \ + } + +DEFINE_NEON_COMPARISON(==, vceqq_f16) +DEFINE_NEON_COMPARISON(>=, vcgeq_f16) +DEFINE_NEON_COMPARISON(<=, vcleq_f16) +DEFINE_NEON_COMPARISON(>, vcgtq_f16) +DEFINE_NEON_COMPARISON(<, vcltq_f16) + +template +Simd operator!=(Simd a, T b) { + return !(a == b); +} +template +Simd operator!=(T a, Simd b) { + return !(a == b); +} +inline Simd operator!=(Simd a, Simd b) { + return !(a == b); +} + +inline Simd operator||( + Simd a, + Simd b) { + return Simd((a != 0) || (b != 0)); +} +template +Simd operator||(Simd a, T b) { + return Simd((a != 0) || (b != 0)); +} +template +Simd operator||(T a, Simd b) { + return Simd((a != 0) || (b != 0)); +} +inline Simd operator&&( + Simd a, + Simd b) { + return Simd((a != 0) && (b != 0)); +} +template +Simd operator&&(Simd a, T b) { + return Simd((a != 0) && (b != 0)); +} +template +Simd operator&&(T a, Simd b) { + return Simd((a != 0) && (b != 0)); +} + +template <> +inline Simd isnan(Simd v) { + return v != v; +} + +template <> +inline Simd +clamp(Simd v, Simd min, Simd max) { + return minimum(maximum(v, min), max); +} + +template +Simd fma(Simd x, Simd y, T z) { + return vfmaq_f16(x.value, y.value, Simd(z).value); +} + +template +Simd +select(Simd mask, Simd x, Simd y) { + return vbslq_f16(Simd(mask).value, x.value, y.value); +} + +// Reductions +inline float16_t max(Simd x) { + float16x4_t y; + y = vpmax_f16(vget_low_f16(x.value), vget_high_f16(x.value)); + y = vpmax_f16(y, y); + y = vpmax_f16(y, y); + return vget_lane_f16(y, 0); +} +inline float16_t min(Simd x) { + float16x4_t y; + y = vpmin_f16(vget_low_f16(x.value), vget_high_f16(x.value)); + y = vpmin_f16(y, y); + y = vpmin_f16(y, y); + return vget_lane_f16(y, 0); +} +inline float16_t sum(Simd x) { + float16x4_t y; + y = vpadd_f16(vget_low_f16(x.value), vget_high_f16(x.value)); + y = vpadd_f16(y, y); + y = vpadd_f16(y, y); + return vget_lane_f16(y, 0); +} +inline float16_t prod(Simd x) { + auto hx = vmul_f16(vget_low_f16(x.value), vget_high_f16(x.value)); + auto out = hx[0]; + hx[0] *= hx[1]; + hx[0] *= hx[2]; + hx[0] *= hx[3]; + return hx[0]; +} + +} // namespace mlx::core::simd diff --git a/Source/Cxxmlx/include/mlx/backend/cpu/simd/simd.h b/Source/Cxxmlx/include/mlx/backend/cpu/simd/simd.h new file mode 100644 index 00000000..8700f24c --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/cpu/simd/simd.h @@ -0,0 +1,4 @@ +#pragma once + +#include "mlx/backend/cpu/simd/math.h" +#include "mlx/backend/cpu/simd/type.h" diff --git a/Source/Cxxmlx/include/mlx/backend/cpu/simd/type.h b/Source/Cxxmlx/include/mlx/backend/cpu/simd/type.h new file mode 100644 index 00000000..59b6ecca --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/cpu/simd/type.h @@ -0,0 +1,11 @@ +#pragma once + +#include "mlx/backend/cpu/simd/base_simd.h" + +#ifdef MLX_USE_ACCELERATE +#if defined(__x86_64__) +// the accelerate_simd implementation require neon -- use base implementation +#else +#include "mlx/backend/cpu/simd/accelerate_simd.h" +#endif +#endif diff --git a/Source/Cxxmlx/include/mlx/backend/cpu/slicing.h b/Source/Cxxmlx/include/mlx/backend/cpu/slicing.h new file mode 100644 index 00000000..eda37320 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/cpu/slicing.h @@ -0,0 +1,21 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/array.h" + +namespace mlx::core { + +std::tuple prepare_slice( + const array& in, + const Shape& start_indices, + const Shape& strides); + +void shared_buffer_slice( + const array& in, + const Strides& out_strides, + size_t data_offset, + size_t data_size, + array& out); + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/backend/cpu/ternary.h b/Source/Cxxmlx/include/mlx/backend/cpu/ternary.h new file mode 100644 index 00000000..a27a7f2a --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/cpu/ternary.h @@ -0,0 +1,154 @@ +// Copyright © 2023 Apple Inc. + +#pragma once +#include "mlx/array.h" +#include "mlx/backend/common/ternary.h" +#include "mlx/backend/common/utils.h" +#include "mlx/backend/cpu/encoder.h" + +namespace mlx::core { + +template +void ternary_op_dims( + const T1* a, + const T2* b, + const T3* c, + U* out, + Op op, + const Shape& shape, + const Strides& a_strides, + const Strides& b_strides, + const Strides& c_strides, + const Strides& out_strides, + int axis) { + auto stride_a = a_strides[axis]; + auto stride_b = b_strides[axis]; + auto stride_c = c_strides[axis]; + auto stride_out = out_strides[axis]; + auto N = shape[axis]; + + for (int i = 0; i < N; i++) { + if constexpr (D > 1) { + ternary_op_dims( + a, + b, + c, + out, + op, + shape, + a_strides, + b_strides, + c_strides, + out_strides, + axis + 1); + } else { + *out = op(*a, *b, *c); + } + a += stride_a; + b += stride_b; + c += stride_c; + out += stride_out; + } +} + +template +void ternary_op_dispatch_dims( + const T1* a_ptr, + const T2* b_ptr, + const T3* c_ptr, + U* out_ptr, + Op op, + size_t size, + Shape& shape, + std::vector& strides) { + const auto& a_strides = strides[0]; + const auto& b_strides = strides[1]; + const auto& c_strides = strides[2]; + const auto& out_strides = strides[3]; + int ndim = shape.size(); + switch (ndim) { + case 1: + ternary_op_dims( + a_ptr, + b_ptr, + c_ptr, + out_ptr, + op, + shape, + a_strides, + b_strides, + c_strides, + out_strides, + 0); + return; + case 2: + ternary_op_dims( + a_ptr, + b_ptr, + c_ptr, + out_ptr, + op, + shape, + a_strides, + b_strides, + c_strides, + out_strides, + 0); + return; + } + + ContiguousIterator a_it(shape, a_strides, ndim - 2); + ContiguousIterator b_it(shape, b_strides, ndim - 2); + ContiguousIterator c_it(shape, c_strides, ndim - 2); + auto stride = out_strides[ndim - 3]; + for (size_t elem = 0; elem < size; elem += stride) { + ternary_op_dims( + a_ptr + a_it.loc, + b_ptr + b_it.loc, + c_ptr + c_it.loc, + out_ptr + elem, + op, + shape, + a_strides, + b_strides, + c_strides, + out_strides, + ndim - 2); + a_it.step(); + b_it.step(); + c_it.step(); + } +} + +template +void ternary_op( + const array& a, + const array& b, + const array& c, + array& out, + Op op, + TernaryOpType topt) { + const T1* a_ptr = a.data(); + const T2* b_ptr = b.data(); + const T3* c_ptr = c.data(); + U* out_ptr = out.data(); + + if (topt == TernaryOpType::ScalarScalarScalar) { + *out_ptr = op(*a_ptr, *b_ptr, *c_ptr); + } else if (topt == TernaryOpType::VectorVectorVector) { + for (size_t i = 0; i < out.size(); ++i) { + *out_ptr = op(*a_ptr, *b_ptr, *c_ptr); + a_ptr++; + b_ptr++; + c_ptr++; + out_ptr++; + } + } else { + auto [shape, strides] = collapse_contiguous_dims( + a.shape(), {a.strides(), b.strides(), c.strides(), out.strides()}); + ternary_op_dispatch_dims( + a_ptr, b_ptr, c_ptr, out_ptr, op, out.size(), shape, strides); + } +} + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/backend/cpu/threefry.h b/Source/Cxxmlx/include/mlx/backend/cpu/threefry.h new file mode 100644 index 00000000..0fc485fc --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/cpu/threefry.h @@ -0,0 +1,21 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include +#include + +namespace mlx::core::random { + +/** Applies the Threefry 2x32 hash function. + * This code is based on the Jax counter-based and splittable PRNG + * https://github.com/google/jax/blob/main/docs/jep/263-prng.md + * + * Original Threefry reference: + * http://www.thesalmons.org/john/random123/papers/random123sc11.pdf + */ +std::pair threefry2x32_hash( + const std::pair& key, + std::pair count); + +} // namespace mlx::core::random diff --git a/Source/Cxxmlx/include/mlx/backend/cpu/unary.h b/Source/Cxxmlx/include/mlx/backend/cpu/unary.h new file mode 100644 index 00000000..4fab6a75 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/cpu/unary.h @@ -0,0 +1,281 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include "mlx/backend/common/unary.h" +#include "mlx/backend/cpu/encoder.h" +#include "mlx/backend/cpu/simd/simd.h" +#include "mlx/utils.h" + +namespace mlx::core { + +template +void unary_op(const T* a, U* out, size_t shape, size_t stride) { + for (size_t i = 0; i < shape; i += 1) { + out[i] = Op{}(*a); + a += stride; + } +} + +template +void unary_op(const array& a, array& out, Op) { + const T* src = a.data(); + U* dst = out.data(); + auto ndim = a.ndim(); + if (a.flags().contiguous) { + auto size = a.data_size(); + constexpr int N = std::min(simd::max_size, simd::max_size); + while (size >= N) { + simd::store(dst, simd::Simd(Op{}(simd::load(src)))); + size -= N; + src += N; + dst += N; + } + while (size > 0) { + *dst = Op{}(*src); + size--; + dst++; + src++; + } + } else { + size_t shape = ndim > 0 ? a.shape().back() : 1; + size_t stride = ndim > 0 ? a.strides().back() : 1; + if (ndim <= 1) { + unary_op(src, dst, shape, stride); + return; + } + auto it = ContiguousIterator(a.shape(), a.strides(), ndim - 1); + for (size_t elem = 0; elem < a.size(); elem += shape) { + unary_op(src + it.loc, dst + elem, shape, stride); + it.step(); + } + } +} + +template +void unary(const array& a, array& out, Op op, Stream stream) { + set_unary_output_data(a, out); + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(a); + encoder.set_output_array(out); + encoder.dispatch([a = array::unsafe_weak_copy(a), + out = array::unsafe_weak_copy(out), + op = op]() mutable { + switch (out.dtype()) { + case bool_: + unary_op(a, out, op); + break; + case uint8: + unary_op(a, out, op); + break; + case uint16: + unary_op(a, out, op); + break; + case uint32: + unary_op(a, out, op); + break; + case uint64: + unary_op(a, out, op); + break; + case int8: + unary_op(a, out, op); + break; + case int16: + unary_op(a, out, op); + break; + case int32: + unary_op(a, out, op); + break; + case int64: + unary_op(a, out, op); + break; + case float16: + unary_op(a, out, op); + break; + case float32: + unary_op(a, out, op); + break; + case float64: + unary_op(a, out, op); + break; + case bfloat16: + unary_op(a, out, op); + break; + case complex64: + unary_op(a, out, op); + break; + } + }); +} + +template +void unary_real_fp(const array& a, array& out, Op op, Stream stream) { + set_unary_output_data(a, out); + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(a); + encoder.set_output_array(out); + encoder.dispatch([a = array::unsafe_weak_copy(a), + out = array::unsafe_weak_copy(out), + op = op]() mutable { + switch (out.dtype()) { + case bfloat16: + unary_op(a, out, op); + break; + case float16: + unary_op(a, out, op); + break; + case float32: + unary_op(a, out, op); + break; + case float64: + unary_op(a, out, op); + break; + default: + std::ostringstream err; + err << "[unary_real] Does not support " << out.dtype(); + throw std::runtime_error(err.str()); + } + }); +} +template +void unary_fp(const array& a, array& out, Op op, Stream stream) { + set_unary_output_data(a, out); + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(a); + encoder.set_output_array(out); + encoder.dispatch([a = array::unsafe_weak_copy(a), + out = array::unsafe_weak_copy(out), + op = op]() mutable { + switch (out.dtype()) { + case bfloat16: + unary_op(a, out, op); + break; + case float16: + unary_op(a, out, op); + break; + case float32: + unary_op(a, out, op); + break; + case float64: + unary_op(a, out, op); + break; + case complex64: + unary_op(a, out, op); + break; + default: + std::ostringstream err; + err << "[unary_fp] Does not support " << out.dtype(); + throw std::runtime_error(err.str()); + } + }); +} + +template +void unary_signed(const array& a, array& out, Op op, Stream stream) { + set_unary_output_data(a, out); + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(a); + encoder.set_output_array(out); + encoder.dispatch([a = array::unsafe_weak_copy(a), + out = array::unsafe_weak_copy(out), + op = op]() mutable { + switch (out.dtype()) { + case int8: + unary_op(a, out, op); + break; + case int16: + unary_op(a, out, op); + break; + case int32: + unary_op(a, out, op); + break; + case int64: + unary_op(a, out, op); + break; + case float16: + unary_op(a, out, op); + break; + case float32: + unary_op(a, out, op); + break; + case float64: + unary_op(a, out, op); + break; + case bfloat16: + unary_op(a, out, op); + break; + case complex64: + unary_op(a, out, op); + break; + default: + throw std::runtime_error("[Abs] Called on unsigned type"); + } + }); +} + +template +void unary_complex(const array& a, array& out, Op op, Stream stream) { + set_unary_output_data(a, out); + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(a); + encoder.set_output_array(out); + encoder.dispatch([a = array::unsafe_weak_copy(a), + out = array::unsafe_weak_copy(out), + op = op]() mutable { unary_op(a, out, op); }); +} + +template +void unary_complex_to_float(const array& a, array& out, Op op, Stream stream) { + set_unary_output_data(a, out); + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(a); + encoder.set_output_array(out); + encoder.dispatch( + [a = array::unsafe_weak_copy(a), + out = array::unsafe_weak_copy(out), + op = op]() mutable { unary_op(a, out, op); }); +} + +template +void unary_int(const array& a, array& out, Op op, Stream stream) { + set_unary_output_data(a, out); + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(a); + encoder.set_output_array(out); + encoder.dispatch([a = array::unsafe_weak_copy(a), + out = array::unsafe_weak_copy(out), + op = op]() mutable { + switch (out.dtype()) { + case uint8: + unary_op(a, out, op); + break; + case uint16: + unary_op(a, out, op); + break; + case uint32: + unary_op(a, out, op); + break; + case uint64: + unary_op(a, out, op); + break; + case int8: + unary_op(a, out, op); + break; + case int16: + unary_op(a, out, op); + break; + case int32: + unary_op(a, out, op); + break; + case int64: + unary_op(a, out, op); + break; + default: + std::ostringstream err; + err << "[unary_int] Does not support " << out.dtype(); + throw std::runtime_error(err.str()); + } + }); +} + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/backend/cpu/unary_ops.h b/Source/Cxxmlx/include/mlx/backend/cpu/unary_ops.h new file mode 100644 index 00000000..f441e88b --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/cpu/unary_ops.h @@ -0,0 +1,175 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include +#include +#include + +#include "mlx/backend/cpu/simd/simd.h" + +namespace mlx::core::detail { + +using namespace mlx::core::simd; + +#define SINGLE() \ + template \ + T operator()(T x) { \ + return (*this)(Simd(x)).value; \ + } + +#define DEFAULT_OP(Op, op) \ + struct Op { \ + template \ + Simd operator()(Simd x) { \ + return simd::op(x); \ + } \ + SINGLE() \ + }; + +DEFAULT_OP(Abs, abs) +DEFAULT_OP(ArcCos, acos) +DEFAULT_OP(ArcCosh, acosh) +DEFAULT_OP(ArcSin, asin) +DEFAULT_OP(ArcSinh, asinh) +DEFAULT_OP(ArcTan, atan) +DEFAULT_OP(ArcTanh, atanh) +DEFAULT_OP(BitwiseInvert, operator~) +DEFAULT_OP(Ceil, ceil) +DEFAULT_OP(Conjugate, conj) +DEFAULT_OP(Cos, cos) +DEFAULT_OP(Cosh, cosh) +DEFAULT_OP(Erf, erf) +DEFAULT_OP(ErfInv, erfinv) +DEFAULT_OP(Exp, exp) +DEFAULT_OP(Expm1, expm1) +DEFAULT_OP(Floor, floor); +DEFAULT_OP(Log, log); +DEFAULT_OP(Log2, log2); +DEFAULT_OP(Log10, log10); +DEFAULT_OP(Log1p, log1p); +DEFAULT_OP(LogicalNot, operator!) +DEFAULT_OP(Negative, operator-) +DEFAULT_OP(Round, rint); +DEFAULT_OP(Sin, sin) +DEFAULT_OP(Sinh, sinh) +DEFAULT_OP(Sqrt, sqrt) +DEFAULT_OP(Rsqrt, rsqrt) +DEFAULT_OP(Tan, tan) +DEFAULT_OP(Tanh, tanh) + +struct Imag { + template + Simd operator()(Simd x) { + return simd::imag(x); + } + SINGLE() +}; + +struct Real { + template + Simd operator()(Simd x) { + return simd::real(x); + } + SINGLE() +}; + +struct Sigmoid { + template + Simd operator()(Simd x) { + auto y = 1.0f / (1.0f + simd::exp(simd::abs(x))); + return simd::select(x < Simd{0}, y, Simd{1} - y); + } + SINGLE() +}; + +struct Sign { + template + Simd operator()(Simd x) { + auto z = Simd{0}; + auto o = Simd{1}; + auto m = Simd{-1}; + if constexpr (std::is_unsigned_v) { + return simd::select(x == z, z, o); + } else if constexpr (std::is_same_v) { + return simd::select(x == z, x, Simd(x / simd::abs(x))); + } else { + return simd::select(x < z, m, simd::select(x > z, o, z)); + } + } + SINGLE() +}; + +struct Square { + template + Simd operator()(Simd x) { + return x * x; + } + SINGLE() +}; + +template +Simd fp32_from_bits(Simd x) { + return *(Simd*)(&x); +} +template +Simd fp32_to_bits(Simd x) { + return *(Simd*)(&x); +} + +struct ToFP8 { + template + Simd operator()(Simd f) { + uint32_t fp8_max = 543 << 21; + auto denorm_mask = Simd(141 << 23); + Simd f_bits; + Simd f32 = f; + f_bits = fp32_to_bits(f32); + Simd result = 0u; + auto sign = f_bits & 0x80000000; + f_bits = f_bits ^ sign; + + auto f_bits_low = + fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask)); + auto result_low = Simd(f_bits_low - denorm_mask); + + auto mant_odd = Simd((f_bits >> 20) & 1); + auto f_bits_high = f_bits + (((uint32_t)(7 - 127) << 23) + 0x7FFFF); + f_bits_high = f_bits_high + Simd(mant_odd); + + auto result_high = Simd(f_bits_high >> 20); + result = select(f_bits < (121 << 23), result_low, result_high); + + auto result_sat = Simd(0x7E); + result = select(f_bits >= fp8_max, result_sat, result); + return result | Simd(sign >> 24); + } + + template + uint8_t operator()(T x) { + return (*this)(Simd(x)).value; + } +}; + +struct FromFP8 { + template + Simd operator()(Simd x) { + auto v = Simd(x & 127) << 7; + Simd out; + if constexpr (simd::max_size >= N) { + auto converted = *(Simd*)(&v); + out = converted * 256.0; + } else { + for (int i = 0; i < N; ++i) { + auto converted = *(float16_t*)(&v[i]); + out[i] = converted * 256.0; + } + } + auto sign = Simd(x & 128); + return select(sign, -out, out); + } + float operator()(uint8_t x) { + return (*this)(Simd(x)).value; + } +}; +} // namespace mlx::core::detail diff --git a/Source/Cxxmlx/include/mlx/backend/cuda/allocator.h b/Source/Cxxmlx/include/mlx/backend/cuda/allocator.h new file mode 100644 index 00000000..af76ad90 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/cuda/allocator.h @@ -0,0 +1,94 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/allocator.h" +#include "mlx/backend/common/buffer_cache.h" +#include "mlx/backend/cuda/cuda_utils.h" + +#include +#include +#include +#include + +namespace mlx::core::cu { + +class CommandEncoder; + +using allocator::Buffer; + +// Stores cuda-managed unified memory. +struct CudaBuffer { + void* data; + size_t size; + int device; // -1 for managed +}; + +class SmallSizePool { + private: + union Block { + Block* next; + CudaBuffer buf; + }; + + Block* buffer_{nullptr}; + void* data_{nullptr}; + Block* next_free_{nullptr}; + + public: + SmallSizePool(); + ~SmallSizePool(); + + SmallSizePool(const SmallSizePool&) = delete; + SmallSizePool& operator=(const SmallSizePool&) = delete; + + CudaBuffer* malloc(); + void free(CudaBuffer* buf); + bool in_pool(CudaBuffer* buf); +}; + +class CudaAllocator : public allocator::Allocator { + public: + Buffer malloc(size_t size) override; + Buffer malloc_async(size_t size, int device, cudaStream_t stream); + void free(Buffer buffer) override; + size_t size(Buffer buffer) const override; + + // Replace the memory of |buf| with unified memory (managed memory or pinned + // host memory), and copy the data over. Pass |stream| to copy asynchronously. + void move_to_unified_memory(CudaBuffer& buf, cudaStream_t stream = nullptr); + + size_t get_active_memory() const; + size_t get_peak_memory() const; + void reset_peak_memory(); + size_t get_memory_limit(); + size_t set_memory_limit(size_t limit); + size_t get_cache_memory() const; + size_t set_cache_limit(size_t limit); + void clear_cache(); + + private: + void free_cuda_buffer(CudaBuffer* buf); + void free_async(CudaBuffer& buf, cudaStream_t stream = nullptr); + + CudaAllocator(); + friend CudaAllocator& allocator(); + + std::mutex mutex_; + size_t memory_limit_; + size_t free_limit_; + size_t total_memory_; + size_t max_pool_size_; + BufferCache buffer_cache_; + size_t active_memory_{0}; + size_t peak_memory_{0}; + std::vector free_streams_; + std::vector mem_pools_; + SmallSizePool scalar_pool_; +}; + +CudaAllocator& allocator(); + +Buffer malloc_async(size_t size, CommandEncoder& encoder); + +} // namespace mlx::core::cu diff --git a/Source/Cxxmlx/include/mlx/backend/cuda/conv/conv.h b/Source/Cxxmlx/include/mlx/backend/cuda/conv/conv.h new file mode 100644 index 00000000..62dc9343 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/cuda/conv/conv.h @@ -0,0 +1,126 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/gpu/copy.h" + +namespace mlx::core { + +template +struct ConvParams { + int N; // Batch size + int C; // In channels + int O; // Out channels + int strides[NDIM]; + int padding[NDIM]; + int kernel_dilation[NDIM]; + int input_dilation[NDIM]; + int groups; + bool flip; + int in_spatial_dims[NDIM]; + int wt_spatial_dims[NDIM]; + int out_spatial_dims[NDIM]; + int64_t in_strides[NDIM + 2]; + + ConvParams( + const array& in, + const array& wt, + const array& out, + const std::vector& strides, + const std::vector& padding, + const std::vector& kernel_dilation, + const std::vector& input_dilation, + int groups, + bool flip) + : N(in.shape(0)), + C(in.shape(-1)), + O(wt.shape(0)), + groups(groups), + flip(flip) { + std::copy_n(strides.begin(), NDIM, this->strides); + std::copy_n(padding.begin(), NDIM, this->padding); + std::copy_n(kernel_dilation.begin(), NDIM, this->kernel_dilation); + std::copy_n(input_dilation.begin(), NDIM, this->input_dilation); + std::copy_n(in.shape().begin() + 1, NDIM, this->in_spatial_dims); + std::copy_n(wt.shape().begin() + 1, NDIM, this->wt_spatial_dims); + std::copy_n(out.shape().begin() + 1, NDIM, this->out_spatial_dims); + std::copy_n(in.strides().begin(), NDIM + 2, this->in_strides); + } +}; + +void gemm_grouped_conv( + cu::CommandEncoder& encoder, + const array& in, + const array& wt, + array& out, + const std::vector& strides, + const std::vector& padding, + const std::vector& kernel_dilation, + const std::vector& input_dilation, + int groups, + bool flip, + Stream s); + +void gemm_conv( + cu::CommandEncoder& encoder, + const array& in, + const array& wt, + array& out, + const std::vector& strides, + const std::vector& padding, + const std::vector& kernel_dilation, + const std::vector& input_dilation, + bool flip, + Stream s); + +inline void gemm_conv( + cu::CommandEncoder& encoder, + array in, + array wt, + array& out, + const std::vector& strides, + const std::vector& padding, + const std::vector& kernel_dilation, + const std::vector& input_dilation, + int groups, + bool flip, + Stream s) { + if (!in.flags().row_contiguous) { + in = contiguous_copy_gpu(in, s); + encoder.add_temporary(in); + } + if (!wt.flags().row_contiguous) { + wt = contiguous_copy_gpu(wt, s); + encoder.add_temporary(wt); + } + + if (groups == 1) { + gemm_conv( + encoder, + in, + wt, + out, + strides, + padding, + kernel_dilation, + input_dilation, + flip, + s); + } else { + gemm_grouped_conv( + encoder, + in, + wt, + out, + strides, + padding, + kernel_dilation, + input_dilation, + groups, + flip, + s); + } +} + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/backend/cuda/cublas_utils.h b/Source/Cxxmlx/include/mlx/backend/cuda/cublas_utils.h new file mode 100644 index 00000000..56702b12 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/cuda/cublas_utils.h @@ -0,0 +1,95 @@ +// Copyright © 2025 Apple Inc. +#pragma once + +#include +#include "mlx/array.h" +#include "mlx/backend/cuda/device.h" +#include "mlx/dtype_utils.h" + +namespace mlx::core { +namespace cublas_utils { + +// Get the shared cublas preference for a device +cublasLtMatmulPreference_t get_preference(cu::Device& device); + +cublasLtMatrixLayout_t create_matrix_layout( + cudaDataType_t type, + uint64_t rows, + uint64_t cols, + bool transposed, + int64_t ld, + int32_t batch_count, + int64_t batch_stride); + +inline cudaDataType_t dtype_to_cublas_type(Dtype dtype, std::string_view tag) { + switch (dtype) { + case float16: + return CUDA_R_16F; + case bfloat16: + return CUDA_R_16BF; + case float32: + return CUDA_R_32F; + case float64: + return CUDA_R_64F; + case complex64: + return CUDA_C_32F; + default: + throw std::runtime_error( + fmt::format( + "Unsupported dtype in {}: {}.", tag, dtype_to_string(dtype))); + } +} + +} // namespace cublas_utils + +class CublasMatmulBase { + public: + virtual ~CublasMatmulBase(); + + void set_bias(cu::CommandEncoder& encoder, const array& bias); + + protected: + CublasMatmulBase() = default; + + // Common member variables shared by all matmul types + uint64_t M_; + uint64_t N_; + cudaDataType_t scale_type_; + cublasLtMatmulPreference_t pref_{nullptr}; + cublasLtHandle_t handle_{nullptr}; + cublasLtMatmulDesc_t matmul_desc_{nullptr}; + cublasLtMatrixLayout_t a_desc_{nullptr}; + cublasLtMatrixLayout_t b_desc_{nullptr}; + cublasLtMatrixLayout_t c_desc_{nullptr}; + cublasLtMatrixLayout_t out_desc_{nullptr}; + cublasLtMatmulHeuristicResult_t heuristic_; + + void init_base( + cu::Device& device, + cudaDataType_t scale_type, + cublasComputeType_t compute_type, + cudaDataType_t data_type, + cudaDataType_t output_type, + bool a_transposed, + uint64_t a_rows, + uint64_t a_cols, + int64_t lda, + bool b_transposed, + uint64_t b_rows, + uint64_t b_cols, + int64_t ldb, + int32_t batch_count, + int64_t a_batch_stride, + int64_t b_batch_stride); + + void execute_matmul( + cu::CommandEncoder& encoder, + void* out, + const void* a, + const void* b, + const void* c, + const void* alpha_ptr, + const void* beta_ptr); +}; + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/backend/cuda/cuda.h b/Source/Cxxmlx/include/mlx/backend/cuda/cuda.h new file mode 100644 index 00000000..410bae4c --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/cuda/cuda.h @@ -0,0 +1,21 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include +#include + +#include "mlx/api.h" + +namespace mlx::core::cu { + +/* Check if the CUDA backend is available. */ +MLX_API bool is_available(); + +/* Get information about a CUDA device. */ +MLX_API const + std::unordered_map>& + device_info(int device_index = 0); + +} // namespace mlx::core::cu diff --git a/Source/Cxxmlx/include/mlx/backend/cuda/cuda_utils.h b/Source/Cxxmlx/include/mlx/backend/cuda/cuda_utils.h new file mode 100644 index 00000000..4c60fec2 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/cuda/cuda_utils.h @@ -0,0 +1,90 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include +#include +#include + +namespace mlx::core { + +// Throw exception if the cuda API does not succeed. +void check_cublas_error(const char* name, cublasStatus_t err); +void check_cuda_error(const char* name, cudaError_t err); +void check_cuda_error(const char* name, CUresult err); +void check_cudnn_error(const char* name, cudnnStatus_t err); + +// The macro version that prints the command that failed. +#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd)) +#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd)) +#define CHECK_CUDNN_ERROR(cmd) check_cudnn_error(#cmd, (cmd)) + +// Base class for RAII managed CUDA resources. +template +class CudaHandle { + public: + CudaHandle(Handle handle = nullptr) : handle_(handle) {} + + CudaHandle(CudaHandle&& other) : handle_(other.handle_) { + assert(this != &other); + other.handle_ = nullptr; + } + + ~CudaHandle() { + // Skip if there was an error to avoid throwing in the destructors + if (cudaPeekAtLastError() != cudaSuccess) { + return; + } + reset(); + } + + CudaHandle(const CudaHandle&) = delete; + CudaHandle& operator=(const CudaHandle&) = delete; + + CudaHandle& operator=(CudaHandle&& other) { + assert(this != &other); + reset(); + std::swap(handle_, other.handle_); + return *this; + } + + void reset() { + if (handle_ != nullptr) { + CHECK_CUDA_ERROR(Destroy(handle_)); + handle_ = nullptr; + } + } + + operator Handle() const { + return handle_; + } + + protected: + Handle handle_; +}; + +namespace cu { +class Device; +}; // namespace cu + +// Wrappers of CUDA resources. +class CudaGraph : public CudaHandle { + public: + using CudaHandle::CudaHandle; + explicit CudaGraph(cu::Device& device); + void end_capture(cudaStream_t stream); +}; + +class CudaGraphExec : public CudaHandle { + public: + void instantiate(cudaGraph_t graph); +}; + +class CudaStream : public CudaHandle { + public: + using CudaHandle::CudaHandle; + explicit CudaStream(cu::Device& device); +}; + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/backend/cuda/cudnn_utils.h b/Source/Cxxmlx/include/mlx/backend/cuda/cudnn_utils.h new file mode 100644 index 00000000..5e8235f1 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/cuda/cudnn_utils.h @@ -0,0 +1,187 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/cuda/device/config.h" +#include "mlx/backend/cuda/utils.h" +#include "mlx/dtype_utils.h" + +#include +#include + +namespace mlx::core { + +namespace cu { +class CommandEncoder; +} + +namespace fe = cudnn_frontend; + +#define CHECK_CUDNN_FE_ERROR(cmd) \ + do { \ + auto error = cmd; \ + if (!error.is_good()) { \ + throw std::runtime_error( \ + fmt::format("{} failed: {}.", #cmd, error.get_message())); \ + } \ + } while (0) + +// Return pointer alignment of |x|'s data. +inline uint8_t get_alignment(const array& x) { + uint8_t alignment = 1; + uintptr_t address = reinterpret_cast(gpu_ptr(x)); + for (; alignment < 32; alignment *= 2) { + if (address % (alignment * 2)) { + return alignment; + } + } + return alignment; +} + +// Convert the type of elements in |vec| to |T|. +template +inline std::vector convert_vector(const Vec& vec) { + return std::vector(vec.begin(), vec.end()); +} + +// Map dtype to cudnn data type. +inline fe::DataType_t dtype_to_cudnn_type(Dtype dtype) { + switch (dtype) { + case int8: + return fe::DataType_t::INT8; + case int32: + return fe::DataType_t::INT32; + case uint8: + return fe::DataType_t::UINT8; + case float16: + return fe::DataType_t::HALF; + case bfloat16: + return fe::DataType_t::BFLOAT16; + case float32: + return fe::DataType_t::FLOAT; + case float64: + return fe::DataType_t::DOUBLE; + default: + throw std::runtime_error( + fmt::format( + "Unsupported dtype in cuDNN: {}.", dtype_to_string(dtype))); + } +} + +// Return an array that can be used as map key for |vec| with size <= MAX_NDIM. +// +// There are 2 differences from the const_param util from kernel_utils.cuh: +// 1. The rest of array is filled with 0. +// 2. This util can be used in .cpp files. +template +inline std::array vector_key(const Vec& vec) { + if (vec.size() > NDIM) { + throw std::runtime_error( + fmt::format("ndim can not be larger than {}.", NDIM)); + } + std::array result = {}; + std::copy_n(vec.begin(), vec.size(), result.begin()); + return result; +} + +// Extends cuDNN graph with helpers. +class DnnGraph : public fe::graph::Graph { + public: + DnnGraph(cudnnHandle_t handle, Dtype io_dtype, Dtype compute_dtype = float32) + : handle_(handle) { + set_io_data_type(dtype_to_cudnn_type(io_dtype)); + set_intermediate_data_type(dtype_to_cudnn_type(compute_dtype)); + set_compute_data_type(dtype_to_cudnn_type(compute_dtype)); + } + + // Create a cuDNN tensor description from MLX array |x|. + auto& tensor( + std::shared_ptr& attrs, + int64_t uid, + const array& x) { + set_tensor_attrs(attrs, uid, x); + return attrs; + } + auto tensor(const char* name, int64_t uid, const array& x) { + auto attrs = Graph::tensor(fe::graph::Tensor_attributes().set_name(name)); + tensor(attrs, uid, x); + return attrs; + } + + // Create a cuDNN tensor description from MLX array |x|, and transpose it from + // NHWC layout to NCHW. + auto& tensor_nchw( + std::shared_ptr& attrs, + int64_t uid, + const array& x) { + set_tensor_attrs_nchw(attrs, uid, x); + return attrs; + } + auto tensor_nchw(const char* name, int64_t uid, const array& x) { + auto attrs = Graph::tensor(fe::graph::Tensor_attributes().set_name(name)); + tensor_nchw(attrs, uid, x); + return attrs; + } + + // Create a 4D cuDNN tensor from 1D array, with |axis| being contiguous dim. + auto tensor_4d(const char* name, int64_t uid, const array& x, int axis) { + assert(x.ndim() == 1); + auto attrs = Graph::tensor(fe::graph::Tensor_attributes().set_name(name)); + std::vector shape(4, 1); + std::vector strides(4, 1); + shape.at(axis) = x.size(); + if (axis > 0) { + strides.at(axis - 1) = x.size(); + } + set_tensor_attrs(attrs, uid, x, shape, strides); + return attrs; + } + + // Create a cuDNN tensor for scalar. + auto scalar(const char* name, int64_t uid, Dtype dtype) { + return Graph::tensor( + fe::graph::Tensor_attributes() + .set_name(name) + .set_uid(uid) + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_is_pass_by_value(true) + .set_data_type(dtype_to_cudnn_type(dtype))); + } + + // Call this before setting notes. + fe::error_t prepare(); + // Call this after setting notes. + fe::error_t build(); + + // Add cuDNN graph to CUDA graph, using native CUDA graph API. + fe::error_t encode_graph( + cu::CommandEncoder& encoder, + std::unordered_map variant_pack); + // Add cuDNN graph to CUDA graph, using stream capture. + fe::error_t encode_capturing( + cu::CommandEncoder& encoder, + std::unordered_map variant_pack); + + private: + void* prepare_workspace(cu::CommandEncoder& encoder); + + void set_tensor_attrs( + std::shared_ptr& tensor, + int64_t uid, + const array& x, + const std::vector& shape, + const std::vector& strides); + void set_tensor_attrs( + std::shared_ptr& tensor, + int64_t uid, + const array& x); + void set_tensor_attrs_nchw( + std::shared_ptr& tensor, + int64_t uid, + const array& x); + + cudnnHandle_t handle_; +}; + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/backend/cuda/device.h b/Source/Cxxmlx/include/mlx/backend/cuda/device.h new file mode 100644 index 00000000..39b886a5 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/cuda/device.h @@ -0,0 +1,210 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/array.h" +#include "mlx/backend/cuda/allocator.h" +#include "mlx/backend/cuda/lru_cache.h" +#include "mlx/backend/cuda/worker.h" +#include "mlx/stream.h" + +#include +#include +#include + +#include + +namespace mlx::core::cu { + +class CommandEncoder { + public: + struct CaptureContext { + CaptureContext(CommandEncoder& enc); + ~CaptureContext(); + CudaGraph graph; + CommandEncoder& enc; + bool discard{false}; + }; + struct ConcurrentContext { + ConcurrentContext(CommandEncoder& enc); + ~ConcurrentContext(); + CommandEncoder& enc; + }; + + explicit CommandEncoder(Device& d); + + CommandEncoder(const CommandEncoder&) = delete; + CommandEncoder& operator=(const CommandEncoder&) = delete; + + CaptureContext capture_context() { + return CaptureContext{*this}; + } + ConcurrentContext concurrent_context() { + return ConcurrentContext{*this}; + } + + void set_input_array(const array& arr); + void set_output_array(const array& arr); + + template + void + add_kernel_node(F* func, dim3 grid_dim, dim3 block_dim, Params&&... params) { + add_kernel_node_ex(func, grid_dim, block_dim, {}, 0, params...); + } + + template + void add_kernel_node_ex( + F* func, + dim3 grid_dim, + dim3 block_dim, + dim3 cluster_dim, + uint32_t smem_bytes, + Params&&... params) { + constexpr size_t num = sizeof...(Params); + void* ptrs[num]; + size_t i = 0; + ([&](auto&& p) { ptrs[i++] = static_cast(&p); }( + std::forward(params)), + ...); + add_kernel_node_raw( + reinterpret_cast(func), + grid_dim, + block_dim, + cluster_dim, + smem_bytes, + ptrs); + } + + void add_kernel_node_raw( + void* func, + dim3 grid_dim, + dim3 block_dim, + dim3 cluster_dim, + uint32_t smem_bytes, + void** params); + + void add_kernel_node_raw( + CUfunction func, + dim3 grid_dim, + dim3 block_dim, + dim3 cluster_dim, + uint32_t smem_bytes, + void** params); + + void add_graph_node(cudaGraph_t child); + + void add_temporary(const array& arr) { + temporaries_.push_back(arr.data_shared_ptr()); + } + + void add_completed_handler(std::function task); + bool needs_commit(); + void commit(); + + Device& device() { + return device_; + } + + CudaStream& stream() { + return stream_; + } + + // Wait until kernels and completion handlers are finished + void synchronize(); + + private: + cudaGraphNode_t add_kernel_node_raw(const cudaKernelNodeParams& params); + CUgraphNode add_kernel_node_raw(const CUDA_KERNEL_NODE_PARAMS& params); + + struct GraphNode { + cudaGraphNode_t node; + // K = kernel + // E = empty + // () = subgraph (with metadata) + // Symbols ':', '-' are reserved as separators + std::string node_type; + std::string id; + }; + + void insert_graph_dependencies(GraphNode node); + void insert_graph_dependencies(std::vector nodes); + + Device& device_; + CudaStream stream_; + CudaGraph graph_; + Worker worker_; + int node_count_{0}; + bool in_concurrent_{false}; + std::vector from_nodes_; + std::vector to_nodes_; + std::string graph_nodes_key_; + std::string graph_deps_key_; + std::vector concurrent_nodes_; + std::vector> temporaries_; + LRUCache graph_cache_; + std::vector active_deps_; + std::vector active_outputs_; + std::unordered_map node_map_; + size_t bytes_in_graph_{0}; + bool is_graph_updatable_{true}; + int max_ops_per_graph_; + int max_mb_per_graph_; +}; + +class Device { + public: + explicit Device(int device); + ~Device(); + + Device(Device&&) = default; + Device(const Device&) = delete; + Device& operator=(const Device&) = delete; + + // Make this device the current cuda device, this method is thread-safe. + void make_current(); + + CommandEncoder& get_command_encoder(Stream s); + cublasLtHandle_t get_cublaslt_handle(); + cudnnHandle_t get_cudnn_handle(); + + int cuda_device() const { + return device_; + } + int compute_capability_major() const { + return compute_capability_major_; + } + int compute_capability_minor() const { + return compute_capability_minor_; + } + bool concurrent_managed_access() const { + return concurrent_managed_access_ == 1; + } + bool host_native_atomic() const { + return host_native_atomic_ == 1; + } + bool managed_memory() const { + return managed_memory_ == 1; + } + bool memory_pools() const { + return memory_pools_ == 1; + } + + private: + int device_; + int compute_capability_major_; + int compute_capability_minor_; + int concurrent_managed_access_; + int host_native_atomic_; + int managed_memory_; + int memory_pools_; + std::string device_name_; + cublasLtHandle_t cublaslt_handle_{nullptr}; + cudnnHandle_t cudnn_handle_{nullptr}; + std::unordered_map encoders_; +}; + +Device& device(int cuda_device); +Device& device(mlx::core::Device d); +CommandEncoder& get_command_encoder(Stream s); + +} // namespace mlx::core::cu diff --git a/Source/Cxxmlx/include/mlx/backend/cuda/device/config.h b/Source/Cxxmlx/include/mlx/backend/cuda/device/config.h new file mode 100644 index 00000000..5a340290 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/cuda/device/config.h @@ -0,0 +1,12 @@ +// Copyright © 2025 Apple Inc. + +// This file is used by both CUDA kernel code and host-only C++ code. + +#pragma once + +// The maximum dimensions of shape/strides passed as kernel parameters. +#define MAX_NDIM 10 + +// All existing NVIDIA hardware has a fixed 32 warp size. Though a built-in +// warpSize variable exists, using it would prevent compile-time optimizations. +#define WARP_SIZE 32 diff --git a/Source/Cxxmlx/include/mlx/backend/cuda/event.h b/Source/Cxxmlx/include/mlx/backend/cuda/event.h new file mode 100644 index 00000000..53afeb01 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/cuda/event.h @@ -0,0 +1,79 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/allocator.h" +#include "mlx/backend/cuda/utils.h" +#include "mlx/stream.h" + +#include + +#include +#include + +namespace mlx::core::cu { + +class Device; + +// RAII-managed move-only wrapper of cudaEvent_t. +struct CudaEventHandle : public CudaHandle { + CudaEventHandle(Device& d, int flags); + Device& device; + int flags; +}; + +// Wrapper of native cuda event. It can synchronize between GPU streams, or wait +// on GPU stream in CPU stream, but can not wait on CPU stream. +class CudaEvent { + public: + CudaEvent(Device& d, int flags); + ~CudaEvent(); + + CudaEvent(CudaEvent&&) = default; + CudaEvent& operator=(CudaEvent&&) = default; + + CudaEvent(const CudaEvent&) = delete; + CudaEvent& operator=(const CudaEvent&) = delete; + + void wait(); + void wait(cudaStream_t stream); + void record(cudaStream_t stream); + + // Return whether the recorded kernels have completed. Note that this method + // returns true if record() has not been called. + bool completed() const; + + // Internal: make sure event pool is initialized. + static void init_pool(); + + private: + CudaEventHandle event_; +}; + +// Event that can synchronize between CPU and GPU. It is much slower than +// CudaEvent so the latter should always be preferred when possible. +class AtomicEvent { + public: + AtomicEvent(Device& d); + + void wait(uint32_t value); + void wait(cudaStream_t stream, uint32_t value); + void wait(Stream s, uint32_t value); + void signal(uint32_t value); + void signal(cudaStream_t stream, uint32_t value); + void signal(Stream s, uint32_t value); + bool is_signaled(uint32_t value) const; + uint32_t value() const; + + private: + const CudaStream& signal_stream(); + + uint32_t* ptr() const { + return static_cast(buf_.get()); + } + + bool coherent_; + std::shared_ptr buf_; +}; + +} // namespace mlx::core::cu diff --git a/Source/Cxxmlx/include/mlx/backend/cuda/gemms/cublas_gemm.h b/Source/Cxxmlx/include/mlx/backend/cuda/gemms/cublas_gemm.h new file mode 100644 index 00000000..1fad45ed --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/cuda/gemms/cublas_gemm.h @@ -0,0 +1,114 @@ +// Copyright © 2025 Apple Inc. +#pragma once + +#include "mlx/array.h" +#include "mlx/backend/cuda/cublas_utils.h" +#include "mlx/backend/cuda/device.h" + +#include + +namespace mlx::core { + +class CublasGemm : public CublasMatmulBase { + public: + CublasGemm( + cu::Device& device, + Dtype dtype, + bool a_transposed, + uint64_t a_rows, + uint64_t a_cols, + int64_t lda, + bool b_transposed, + uint64_t b_rows, + uint64_t b_cols, + int64_t ldb, + int32_t batch_count, + int64_t a_batch_stride, + int64_t b_batch_stride); + + CublasGemm( + cu::Device& device, + Dtype dtype, + bool a_transposed, + uint64_t a_rows, + uint64_t a_cols, + int64_t lda, + bool b_transposed, + uint64_t b_rows, + uint64_t b_cols, + int64_t ldb, + int64_t ldc, + int32_t batch_count, + int64_t a_batch_stride, + int64_t b_batch_stride, + int64_t c_batch_stride); + + // The output's descriptor is inferred from inputs by default, use this method + // for unusual output. + void set_out( + Dtype dtype, + bool transposed, + uint64_t rows, + uint64_t cols, + int64_t ld, + int32_t batch_count, + int64_t batch_stride); + + void run( + cu::CommandEncoder& encoder, + array& out, + const array& a, + const array& b, + const Shape& batch_shape, + const Strides& a_batch_strides, + const Strides& b_batch_strides, + float alpha = 1.0f); + + void run( + cu::CommandEncoder& encoder, + array& out, + const array& a, + const array& b, + const array& c, + const Shape& batch_shape, + const Strides& a_batch_strides, + const Strides& b_batch_strides, + const Strides& c_batch_strides, + float alpha, + float beta); + + private: + void run_batched( + cu::CommandEncoder& encoder, + array& out, + const array& a, + const array& b, + const Shape& batch_shape, + const Strides& a_batch_strides, + const Strides& b_batch_strides, + float alpha); + + void run_batched( + cu::CommandEncoder& encoder, + array& out, + const array& a, + const array& b, + const array& c, + const Shape& batch_shape, + const Strides& a_batch_strides, + const Strides& b_batch_strides, + const Strides& c_batch_strides, + float alpha, + float beta); + + void execute( + cu::CommandEncoder& encoder, + void* out, + const void* a, + const void* b, + const void* c, + float alpha = 1, + float beta = 0); +}; + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/backend/cuda/gemms/gemv.h b/Source/Cxxmlx/include/mlx/backend/cuda/gemms/gemv.h new file mode 100644 index 00000000..b0dd49d4 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/cuda/gemms/gemv.h @@ -0,0 +1,34 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/cuda/device.h" + +namespace mlx::core::cu { + +bool can_use_gemv(int M, int N, int K, bool a_transposed, bool b_transposed); + +void gemv( + const array& a, + const array& b, + array& out, + int M, + int N, + int K, + uint32_t batch_count, + const mlx::core::Shape& batch_shape, + const mlx::core::Strides& a_batch_strides, + const mlx::core::Strides& b_batch_strides, + CommandEncoder& encoder); + +void gather_mv( + const array& mat, + const array& vec, + const array& mat_indices, + const array& vec_indices, + array& out, + int N, + int K, + CommandEncoder& encoder); + +} // namespace mlx::core::cu diff --git a/Source/Cxxmlx/include/mlx/backend/cuda/gemms/grouped_gemm.h b/Source/Cxxmlx/include/mlx/backend/cuda/gemms/grouped_gemm.h new file mode 100644 index 00000000..844b8f44 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/cuda/gemms/grouped_gemm.h @@ -0,0 +1,39 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +namespace mlx::core { + +namespace cu { +class CommandEncoder; +} + +class array; + +void cutlass_grouped_gemm_unaligned( + bool a_transposed, + int lda, + bool b_transposed, + int ldb, + int group_count, + const array& a, + const array& b, + const array& indices, + array& out, + cu::CommandEncoder& encoder); + +void cutlass_segmented_mm( + bool a_transposed, + int lda, + bool b_transposed, + int ldb, + int num_segments, + int M, + int N, + const array& a, + const array& b, + const array& segments, + array& out, + cu::CommandEncoder& encoder); + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/backend/cuda/jit_module.h b/Source/Cxxmlx/include/mlx/backend/cuda/jit_module.h new file mode 100644 index 00000000..3849d8d0 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/cuda/jit_module.h @@ -0,0 +1,120 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/array.h" +#include "mlx/backend/common/utils.h" +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/device/config.h" + +#include +#include +#include +#include + +#include +#include + +namespace mlx::core::cu { + +class Device; + +using KernelBuilderResult = std::tuple< + /* precompiled */ bool, + /* source code */ std::string, + /* kernel names */ std::vector>; +using KernelBuilder = std::function; + +struct KernelArgs { + void** args() { + return args_.data(); + } + + void append(const array& a) { + append(reinterpret_cast(gpu_ptr(a))); + } + + template + void append(T val) { + storage_.emplace_back(val); + append_ptr(&storage_.back()); + } + + template + void append(SmallVector vec) { + storage_.emplace_back(std::move(vec)); + append_ptr(std::get>(storage_.back()).data()); + } + + template + void append(const std::vector& vec) { + append(SmallVector(vec.begin(), vec.end())); + } + + // Make sure the arg is copied to an array with size of NDIM. + template + void append_ndim(SmallVector vec) { + if (vec.size() > NDIM) { + throw std::runtime_error( + fmt::format("ndim can not be larger than {}.", NDIM)); + } + vec.resize(NDIM); + append(std::move(vec)); + } + + void append_ptr(const void* v) { + args_.push_back(const_cast(v)); + } + + private: + std::vector args_; + + // The cuGraphAddKernelNode API requires passing pointers to arguments so + // store temporary values until the node is created. + using Arg = std::variant< + std::monostate, + CUdeviceptr, + bool, + int32_t, + uint32_t, + int64_t, + float, + SmallVector, + SmallVector, + SmallVector>; + std::deque storage_; +}; + +class JitModule { + public: + JitModule( + Device& device, + const std::string& module_name, + const KernelBuilder& builder, + bool cache); + ~JitModule(); + + JitModule(const JitModule&) = delete; + JitModule& operator=(const JitModule&) = delete; + CUfunction get_kernel( + const std::string& kernel_name, + std::function configure_kernel = nullptr); + std::pair get_kernel_and_dims( + const std::string& kernel_name, + std::function configure_kernel = nullptr); + + private: + CUmodule module_{nullptr}; + std::unordered_map> + kernels_; +}; + +std::unordered_map& get_jit_module_cache(); + +JitModule& get_jit_module( + const mlx::core::Device& device, + const std::string& name, + const KernelBuilder& builder, + bool use_disk_cache = true); + +} // namespace mlx::core::cu diff --git a/Source/Cxxmlx/include/mlx/backend/cuda/lru_cache.h b/Source/Cxxmlx/include/mlx/backend/cuda/lru_cache.h new file mode 100644 index 00000000..f23f3ff0 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/cuda/lru_cache.h @@ -0,0 +1,190 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/utils.h" + +#include +#include +#include +#include + +#include + +namespace mlx::core { + +template < + typename K, + typename V, + template typename M = std::unordered_map> +class LRUCache { + public: + using value_type = std::pair; + using list_type = std::list; + using iterator = typename list_type::iterator; + using const_iterator = typename list_type::const_iterator; + using map_type = M; + + explicit LRUCache(size_t capacity) : capacity_(capacity) { + if (capacity == 0) { + throw std::runtime_error("LRUCache requires capacity > 0."); + } + } + + // Initialize with capacity read from |env_name|. + LRUCache(const char* env_name, int default_capacity) + : LRUCache(env::get_var(env_name, default_capacity)) { + if (env::get_var("MLX_ENABLE_CACHE_THRASHING_CHECK", 1)) { + env_name_ = env_name; + } + } + + size_t size() const { + return map_.size(); + } + size_t capacity() const { + return capacity_; + } + bool empty() const { + return vlist_.empty(); + } + + void resize(size_t new_capacity) { + capacity_ = new_capacity; + trim(); + } + + iterator begin() { + return vlist_.begin(); + } + const_iterator begin() const { + return vlist_.begin(); + } + iterator end() { + return vlist_.end(); + } + const_iterator end() const { + return vlist_.end(); + } + + void clear() { + map_.clear(); + vlist_.clear(); + } + + iterator find(const K& key) { + auto it = map_.find(key); + if (it == map_.end()) + return end(); + vlist_.splice(vlist_.begin(), vlist_, it->second); + return it->second; + } + + template + std::pair emplace(const K& key, U&& value) { + auto it = map_.find(key); + if (it != map_.end()) { + vlist_.splice(vlist_.begin(), vlist_, it->second); + return {it->second, false}; + } + + if (env_name_ && ++cache_misses_ > 2 * capacity_) { + throw std::runtime_error( + fmt::format( + "Cache thrashing is happening, please set the environment variable " + "{} to a larger value than {} to fix degraded performance.", + env_name_, + capacity_)); + } + + vlist_.emplace_front(key, std::forward(value)); + map_[key] = vlist_.begin(); + + trim(); + + return {vlist_.begin(), true}; + } + + iterator erase(iterator pos) { + map_.erase(pos->first); + return vlist_.erase(pos); + } + + V& operator[](const K& key) { + auto it = find(key); + if (it == end()) { + it = emplace(key, V{}).first; + } + return it->second; + } + + private: + void trim() { + while (map_.size() > capacity_) { + auto last = std::prev(vlist_.end()); + map_.erase(last->first); + vlist_.pop_back(); + } + } + + const char* env_name_{nullptr}; + size_t cache_misses_{0}; + + list_type vlist_; + map_type map_; + size_t capacity_; +}; + +// Turn a POD struct into a container key by doing bytes compare. +// +// Usage: +// BytesKey key; +// key.pod = { ... }; +template +struct BytesKey { + T pod; + static_assert(std::is_standard_layout_v, "T is not POD"); + + BytesKey() { + // Make sure the paddings between members are filled with 0. + memset(&pod, 0, sizeof(T)); + } + + BytesKey(const BytesKey& other) { + memcpy(&pod, &other.pod, sizeof(T)); + } + + BytesKey(BytesKey&& other) { + memcpy(&pod, &other.pod, sizeof(T)); + } + + bool operator==(const BytesKey& other) const { + auto* ptr1 = reinterpret_cast(&pod); + auto* ptr2 = reinterpret_cast(&other.pod); + return memcmp(ptr1, ptr2, sizeof(T)) == 0; + } +}; + +// Compute hash according to the bytes value of T. +template +struct BytesHash { + static_assert(std::is_standard_layout_v, "T is not POD"); + + size_t operator()(const T& pod) const { + auto* ptr = reinterpret_cast(&pod); + uint32_t value = 0x811C9DC5; + for (int i = 0; i < sizeof(T); ++i) { + value ^= ptr[i]; + value *= 0x01000193; + } + return value; + } +}; + +template +using BytesKeyHashMap = std::unordered_map>; + +template +using LRUBytesKeyCache = LRUCache, V, BytesKeyHashMap>; + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h b/Source/Cxxmlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h new file mode 100644 index 00000000..a9095012 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h @@ -0,0 +1,100 @@ +// Copyright © 2025 Apple Inc. +#pragma once + +#include "mlx/array.h" +#include "mlx/backend/cuda/cublas_utils.h" +#include "mlx/backend/cuda/device.h" + +#include + +namespace mlx::core { + +class CublasQQMM : public CublasMatmulBase { + public: + CublasQQMM( + cu::Device& device, + bool a_transposed, + uint64_t a_rows, + uint64_t a_cols, + int64_t lda, + bool b_transposed, + uint64_t b_rows, + uint64_t b_cols, + int64_t ldb, + int32_t batch_count, + int64_t a_batch_stride, + int64_t b_batch_stride, + Dtype out_dtype, + const std::string& quantization_mode); + + CublasQQMM( + cu::Device& device, + bool a_transposed, + uint64_t a_rows, + uint64_t a_cols, + int64_t lda, + bool b_transposed, + uint64_t b_rows, + uint64_t b_cols, + int64_t ldb, + int64_t ldc, + int32_t batch_count, + int64_t a_batch_stride, + int64_t b_batch_stride, + int64_t c_batch_stride, + Dtype out_dtype, + const std::string& quantization_mode); + + void run( + cu::CommandEncoder& encoder, + array& out, + const array& a, + const array& b, + const array& a_scale, + const array& b_scale, + const array& alpha, + const array& beta); + + void run( + cu::CommandEncoder& encoder, + array& out, + const array& a, + const array& b, + const array& a_scale, + const array& b_scale); + + private: + void set_scales_ptrs( + cu::CommandEncoder& encoder, + const void* a_scale, + const void* b_scale); + + void execute( + cu::CommandEncoder& encoder, + void* out, + const void* a, + const void* b, + const void* a_scale, + const void* b_scale, + const void* c, + const void* alpha, + const void* beta); + + void execute( + cu::CommandEncoder& encoder, + void* out, + const void* a, + const void* b, + const void* a_scale, + const void* b_scale, + const void* c, + const float alpha = 1.0f, + const float beta = 0.0f); + + cublasLtMatmulMatrixScale_t a_scale_mode_; + cublasLtMatmulMatrixScale_t b_scale_mode_; + cublasLtMatmulMatrixScale_t c_scale_mode_; + cublasLtMatmulMatrixScale_t out_scale_mode_; +}; + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/backend/cuda/quantized/qmm/qmm.h b/Source/Cxxmlx/include/mlx/backend/cuda/quantized/qmm/qmm.h new file mode 100644 index 00000000..efcd8eaf --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/cuda/quantized/qmm/qmm.h @@ -0,0 +1,80 @@ +// Copyright © 2026 Apple Inc. + +#pragma once + +#include "mlx/backend/cuda/device.h" +#include "mlx/primitives.h" + +#include + +namespace mlx::core { + +bool supports_qmm_sm90( + const array& x, + const array& w, + const array& scales, + const std::optional& biases, + const array& out, + bool transpose, + int bits, + int group_size, + QuantizationMode mode, + cu::Device& device); + +void qmm_sm90( + const array& x, + const array& w, + const array& scales, + const array& biases, + array& out, + int bits, + int group_size, + cu::CommandEncoder& encoder, + Stream s); + +bool supports_fp_qmv( + const array& x, + const array& w, + const array& scales, + const std::optional& biases, + const array& out, + bool transpose, + int bits, + int group_size, + QuantizationMode mode, + cu::Device& device); + +void fp_qmv( + const array& x, + const array& w, + const array& scales, + array& out, + int bits, + int group_size, + cu::CommandEncoder& encoder, + Stream s); + +bool supports_qmv( + const array& x, + const array& w, + const array& scales, + const std::optional& biases, + const array& out, + bool transpose, + int bits, + int group_size, + QuantizationMode mode, + cu::Device& device); + +void qmv( + const array& x, + const array& w, + const array& scales, + const std::optional& biases, + array& out, + int bits, + int group_size, + QuantizationMode mode, + cu::CommandEncoder& encoder); + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/backend/cuda/quantized/qqmm_impl.h b/Source/Cxxmlx/include/mlx/backend/cuda/quantized/qqmm_impl.h new file mode 100644 index 00000000..ab2b74c1 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/cuda/quantized/qqmm_impl.h @@ -0,0 +1,37 @@ +// Copyright © 2025 Apple Inc. +#pragma once + +#include "mlx/backend/cuda/device.h" +#include "mlx/primitives.h" + +#include + +namespace mlx::core { + +struct GemmScalars { + std::optional alpha_device; + std::optional beta_device; + + bool has_values() const { + return alpha_device.has_value(); + } +}; + +void qqmm_impl( + cu::CommandEncoder& encoder, + int M, + int N, + int K, + bool a_transposed, + int64_t lda, + bool b_transposed, + int64_t ldb, + array& out, + const array& a, + const array& b, + const array& a_scale, + const array& b_scale, + QuantizationMode mode, + const GemmScalars& scalars = {}); + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/backend/cuda/quantized/qqmm_utils.h b/Source/Cxxmlx/include/mlx/backend/cuda/quantized/qqmm_utils.h new file mode 100644 index 00000000..fba9ac9d --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/cuda/quantized/qqmm_utils.h @@ -0,0 +1,62 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/array.h" +#include "mlx/backend/cuda/device.h" + +namespace mlx::core { + +// Compute padded dimensions for tiled layout +// Tiles are 128 rows × 4 columns, must allocate full tiles +inline std::pair get_padded_scale_dims(int num_rows, int num_cols) { + constexpr int rows_per_tile = 128; + constexpr int cols_per_tile = 4; + + int padded_rows = + ((num_rows + rows_per_tile - 1) / rows_per_tile) * rows_per_tile; + int padded_cols = + ((num_cols + cols_per_tile - 1) / cols_per_tile) * cols_per_tile; + + return {padded_rows, padded_cols}; +} + +void swizzle_scales( + const array& scales, + array& scales_tiled, + cu::CommandEncoder& enc, + const Stream& s); + +inline array pad_and_swizzle_scales( + const array& scale, + cu::CommandEncoder& encoder, + const Stream& s) { + // Compute padded dimensions for full tiles (128 rows × 4 cols) + auto [pad_outer, pad_inner] = + get_padded_scale_dims(scale.shape(-2), scale.shape(-1)); + // cuBLAS requirements for scale factor layout: + // 1. Dimensions must be padded to full tiles (128 rows × 4 cols) + // 2. Out-of-bounds values must be filled with zeros + // 3. Starting addresses must be 16-byte aligned + // https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout + // Note: cu::malloc_async already provides 256-byte alignment + array scale_tiled( + cu::malloc_async(pad_outer * pad_inner, encoder), + Shape{pad_outer, pad_inner}, + scale.dtype()); + swizzle_scales(scale, scale_tiled, encoder, s); + + encoder.add_temporary(scale_tiled); + return scale_tiled; +} + +// Compute alpha = tensor_amax_x * tensor_amax_w / (448 * 6)^2 +// Allocate beta zero on device as well +void compute_qqmm_pointers( + array& alpha_out, + array& beta_out, + const array& tensor_amax_x, + const array& tensor_amax_w, + cu::CommandEncoder& enc); + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/backend/cuda/quantized/quantized.h b/Source/Cxxmlx/include/mlx/backend/cuda/quantized/quantized.h new file mode 100644 index 00000000..f15c0f76 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/cuda/quantized/quantized.h @@ -0,0 +1,57 @@ +// Copyright © 2025 Apple Inc. + +#include +#include "mlx/backend/cuda/device.h" + +namespace mlx::core { + +void affine_quantize( + const array& w, + array& wq, + array& scales, + array& biases, + int group_size_, + int bits_, + cu::CommandEncoder& enc, + const Stream& s); + +void affine_dequantize( + const array& wq, + const array& scales, + const array& biases, + array& w, + int group_size_, + int bits_, + cu::CommandEncoder& enc, + const Stream& s); + +void fp_quantize( + const array& w, + array& wq, + array& scales, + int group_size, + int bits, + const std::optional& global_scale, + cu::CommandEncoder& enc, + const Stream& s); + +void fp_dequantize( + const array& wq, + const array& scales, + array& w, + int group_size, + int bits, + const std::optional& global_scale, + cu::CommandEncoder& enc, + const Stream& s); + +void fp_quantize_dequantize( + const array& w, + array& what, + int group_size, + int bits, + const std::optional& global_scale, + cu::CommandEncoder& enc, + const Stream& s); + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/backend/cuda/quantized/quantized_utils.h b/Source/Cxxmlx/include/mlx/backend/cuda/quantized/quantized_utils.h new file mode 100644 index 00000000..66be7686 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/cuda/quantized/quantized_utils.h @@ -0,0 +1,50 @@ +// Copyright © 2026 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/gpu/copy.h" + +namespace mlx::core { +inline array ensure_row_contiguous( + const array& x, + cu::CommandEncoder& enc, + const Stream& s) { + if (!x.flags().row_contiguous) { + array x_copy = contiguous_copy_gpu(x, s); + enc.add_temporary(x_copy); + return x_copy; + } else { + return x; + } +} + +inline array ensure_row_contiguous_matrix( + const array& x, + cu::CommandEncoder& enc, + const Stream& s) { + if (x.ndim() < 2) { + if (x.strides()[0] == 1) { + return x; + } + } else { + auto stride_0 = x.strides()[x.ndim() - 2]; + auto stride_1 = x.strides()[x.ndim() - 1]; + if (stride_0 == x.shape(-1) && stride_1 == 1) { + return x; + } + } + array x_copy = contiguous_copy_gpu(x, s); + enc.add_temporary(x_copy); + return x_copy; +} + +inline array +ensure_contiguous(const array& x, cu::CommandEncoder& enc, const Stream& s) { + if (x.flags().row_contiguous || x.flags().col_contiguous) { + return x; + } + array x_copy = contiguous_copy_gpu(x, s); + enc.add_temporary(x_copy); + return x_copy; +} + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/backend/cuda/utils.h b/Source/Cxxmlx/include/mlx/backend/cuda/utils.h new file mode 100644 index 00000000..387e79ad --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/cuda/utils.h @@ -0,0 +1,49 @@ +// Copyright © 2025 Apple Inc. + +// This file include utilities that are used by C++ code (i.e. .cpp files). + +#pragma once + +#include "mlx/array.h" +#include "mlx/backend/cuda/allocator.h" +#include "mlx/backend/cuda/cuda_utils.h" + +namespace mlx::core { + +template +inline uint32_t max_occupancy_block_dim(T kernel) { + int _, block_dim; + if constexpr (std::is_same_v) { + CHECK_CUDA_ERROR( + cuOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel, 0, 0, 0)); + } else { + CHECK_CUDA_ERROR( + cudaOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel)); + } + return block_dim; +} + +template +inline T* gpu_ptr(array& arr) { + return reinterpret_cast( + static_cast( + static_cast(arr.buffer().ptr())->data) + + arr.offset()); +} + +// For const array, keep constness in pointer unless it is untyped. +template +inline std::conditional_t, void*, const T*> gpu_ptr( + const array& arr) { + return gpu_ptr(const_cast(arr)); +} + +struct Dtype; + +// Convert Dtype to CUDA C++ types. +const char* dtype_to_cuda_type(const Dtype& dtype); + +// Allocate an empty array and add it as temporary. +void* allocate_workspace(cu::CommandEncoder& encoder, size_t workspace_size); + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/backend/cuda/worker.h b/Source/Cxxmlx/include/mlx/backend/cuda/worker.h new file mode 100644 index 00000000..8f05e7b9 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/cuda/worker.h @@ -0,0 +1,55 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/cuda/event.h" + +#include +#include +#include +#include +#include + +namespace mlx::core::cu { + +// Run tasks in worker thread, synchronized with cuda stream. +class Worker { + public: + explicit Worker(Device& d); + ~Worker(); + + Worker(const Worker&) = delete; + Worker& operator=(const Worker&) = delete; + + // Add a pending |task| that will run when consumed or commited. + void add_task(std::function task); + + // Inform worker thread to run current batches after kernels in |stream| + // finish running. + void commit(cudaStream_t stream); + + private: + static void signal(void*); + + void thread_fn(); + std::mutex mtx_; + std::condition_variable cond_; + + uint64_t committed_batch_{0}; + uint64_t signaled_batch_{0}; + + // Cuda stream and event for signaling kernel completion. + CudaStream signal_stream_; + CudaEvent signal_event_; + + bool stop_{false}; + + // Tasks are put in |pending_tasks_| first, and then moved to + // |worker_tasks_| when end_batch() is called. + using Tasks = std::vector>; + Tasks pending_tasks_; + std::map worker_tasks_; + std::thread worker_; +}; + +} // namespace mlx::core::cu diff --git a/Source/Cxxmlx/include/mlx/backend/gpu/copy.h b/Source/Cxxmlx/include/mlx/backend/gpu/copy.h new file mode 100644 index 00000000..6e6bc797 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/gpu/copy.h @@ -0,0 +1,57 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include "mlx/backend/common/copy.h" +#include "mlx/stream.h" + +#include + +namespace mlx::core { + +// Generic copy inplace +void copy_gpu_inplace( + const array& in, + array& out, + const Shape& data_shape, + const Strides& i_strides, + const Strides& o_strides, + int64_t i_offset, + int64_t o_offset, + CopyType ctype, + const Stream& s, + std::optional dynamic_i_offset = std::nullopt, + std::optional dynamic_o_offset = std::nullopt); + +void copy_gpu(const array& src, array& out, CopyType ctype, const Stream& s); +void copy_gpu(const array& src, array& out, CopyType ctype); + +void copy_gpu_inplace( + const array& in, + array& out, + CopyType ctype, + const Stream& s); + +void copy_gpu_inplace( + const array& in, + array& out, + const Strides& i_strides, + int64_t i_offset, + CopyType ctype, + const Stream& s); + +// Fill the output with the scalar val +void fill_gpu(const array& val, array& out, const Stream& s); + +// Return a contiguous array with same shape that copies the data of |arr|. +array contiguous_copy_gpu(const array& arr, const Stream& s); + +// Copy data from |in| and transpose to |out|'s shape. +void reshape_gpu(const array& in, array& out, Stream s); + +// Like the normal ops but safe to call in eval_gpu. +array flatten_in_eval(const array& x, int start_axis, int end_axis, Stream s); +array reshape_in_eval(const array& x, Shape shape, Stream s); +array swapaxes_in_eval(const array& x, int axis1, int axis2); + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/backend/gpu/device_info.h b/Source/Cxxmlx/include/mlx/backend/gpu/device_info.h new file mode 100644 index 00000000..7adb7f0b --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/gpu/device_info.h @@ -0,0 +1,36 @@ +// Copyright © 2026 Apple Inc. + +#pragma once + +#include +#include +#include + +#include "mlx/api.h" + +namespace mlx::core::gpu { + +MLX_API bool is_available(); + +/** + * Get the number of available GPU devices. + */ +MLX_API int device_count(); + +/** + * Get information about a GPU device. + * + * Returns a map of device properties. Keys vary by backend: + * - device_name (string): Device name + * - architecture (string): Architecture identifier + * - total_memory/memory_size (size_t): Total device memory + * - free_memory (size_t): Available memory (CUDA only) + * - uuid (string): Device UUID (CUDA only) + * - pci_bus_id (string): PCI bus ID (CUDA only) + * - compute_capability_major/minor (size_t): Compute capability (CUDA only) + */ +MLX_API const + std::unordered_map>& + device_info(int device_index = 0); + +} // namespace mlx::core::gpu diff --git a/Source/Cxxmlx/include/mlx/backend/gpu/eval.h b/Source/Cxxmlx/include/mlx/backend/gpu/eval.h new file mode 100644 index 00000000..f646c2ec --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/gpu/eval.h @@ -0,0 +1,18 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include +#include + +#include "mlx/array.h" +#include "mlx/stream.h" + +namespace mlx::core::gpu { + +void new_stream(Stream stream); +void eval(array& arr); +void finalize(Stream s); +void synchronize(Stream s); + +} // namespace mlx::core::gpu diff --git a/Source/Cxxmlx/include/mlx/backend/gpu/slicing.h b/Source/Cxxmlx/include/mlx/backend/gpu/slicing.h new file mode 100644 index 00000000..596f7afa --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/gpu/slicing.h @@ -0,0 +1,36 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/array.h" + +namespace mlx::core { + +void slice_gpu( + const array& in, + array& out, + const Shape& start_indices, + const Shape& strides, + const Stream& s); + +void concatenate_gpu( + const std::vector& inputs, + array& out, + int axis, + const Stream& s); + +void pad_gpu( + const array& in, + const array& val, + array& out, + const std::vector& axes, + const Shape& low_pad_size, + const Stream& s); + +array compute_dynamic_offset( + const array& indices, + const Strides& strides, + const std::vector& axes, + const Stream& s); + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/backend/metal/allocator.h b/Source/Cxxmlx/include/mlx/backend/metal/allocator.h new file mode 100644 index 00000000..5e177b3d --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/metal/allocator.h @@ -0,0 +1,79 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include +#include +#include + +#include "mlx/allocator.h" +#include "mlx/backend/common/buffer_cache.h" +#include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/resident.h" + +namespace mlx::core::metal { + +using allocator::Buffer; + +class MetalAllocator : public allocator::Allocator { + /** Allocator for Metal GPUs. */ + public: + virtual Buffer malloc(size_t size) override; + virtual void free(Buffer buffer) override; + virtual size_t size(Buffer buffer) const override; + virtual Buffer make_buffer(void* ptr, size_t size) override; + virtual void release(Buffer buffer) override; + + size_t get_active_memory() { + return active_memory_; + }; + size_t get_peak_memory() { + return peak_memory_; + }; + void reset_peak_memory() { + std::unique_lock lk(mutex_); + peak_memory_ = 0; + }; + size_t get_cache_memory() { + return buffer_cache_.cache_size(); + }; + size_t set_cache_limit(size_t limit); + size_t set_memory_limit(size_t limit); + size_t get_memory_limit(); + size_t set_wired_limit(size_t limit); + void clear_cache(); + + private: + MTL::Device* device_; + + // The size of allocations which go on the heap until it is full. This size + // is chosen because it is the actual minimum size of a buffer allocated from + // the heap, a heap can have at most heap.size() / 256 buffers. + static constexpr int small_size_ = 256; + static constexpr int heap_size_ = 1 << 20; + MTL::Heap* heap_; + MetalAllocator(); + ~MetalAllocator(); + friend MetalAllocator& allocator(); + + // Caching allocator + BufferCache buffer_cache_; + + ResidencySet residency_set_; + + // Allocation stats + size_t block_limit_; + size_t gc_limit_; + size_t active_memory_{0}; + size_t peak_memory_{0}; + size_t max_pool_size_; + size_t wired_limit_{0}; + size_t num_resources_{0}; + size_t resource_limit_{0}; + + std::mutex mutex_; +}; + +MetalAllocator& allocator(); + +} // namespace mlx::core::metal diff --git a/Source/Cxxmlx/include/mlx/backend/metal/binary.h b/Source/Cxxmlx/include/mlx/backend/metal/binary.h new file mode 100644 index 00000000..0341a2f8 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/metal/binary.h @@ -0,0 +1,33 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/array.h" + +namespace mlx::core { + +void binary_op_gpu( + const std::vector& inputs, + std::vector& outputs, + const char* op, + const Stream& s); + +void binary_op_gpu( + const std::vector& inputs, + array& out, + const char* op, + const Stream& s); + +void binary_op_gpu_inplace( + const std::vector& inputs, + std::vector& outputs, + const char* op, + const Stream& s); + +void binary_op_gpu_inplace( + const std::vector& inputs, + array& out, + const char* op, + const Stream& s); + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/backend/metal/device.h b/Source/Cxxmlx/include/mlx/backend/metal/device.h new file mode 100644 index 00000000..e6162d7d --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/metal/device.h @@ -0,0 +1,289 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "mlx/array.h" +#include "mlx/device.h" + +namespace mlx::core::metal { + +using MTLFCList = + std::vector>; + +struct DeviceStream; + +struct MLX_API CommandEncoder { + explicit CommandEncoder(DeviceStream& stream); + CommandEncoder(const CommandEncoder&) = delete; + CommandEncoder& operator=(const CommandEncoder&) = delete; + + struct ConcurrentContext { + ConcurrentContext(CommandEncoder& enc) : enc(enc) { + enc.concurrent_ = true; + } + ~ConcurrentContext() { + enc.concurrent_ = false; + enc.prev_outputs_.insert( + enc.concurrent_outputs_.begin(), enc.concurrent_outputs_.end()); + enc.concurrent_outputs_.clear(); + } + + private: + CommandEncoder& enc; + }; + + void set_input_array(const array& a, int idx, int64_t offset = 0); + void set_output_array(array& a, int idx, int64_t offset = 0); + void register_output_array(const array& a); + void dispatch_threadgroups(MTL::Size grid_dims, MTL::Size group_dims); + void dispatch_threads(MTL::Size grid_dims, MTL::Size group_dims); + void maybeInsertBarrier(); + void set_buffer(const MTL::Buffer* buf, int idx, int64_t offset = 0); + + void set_compute_pipeline_state(MTL::ComputePipelineState* kernel) { + enc_->setComputePipelineState(kernel); + } + + void wait_for_fence(MTL::Fence* fence) { + enc_->waitForFence(fence); + } + + void update_fence(MTL::Fence* fence) { + enc_->updateFence(fence); + } + + template >> + void set_vector_bytes(const Vec& vec, size_t nelems, int idx) { + enc_->setBytes(vec.data(), nelems * sizeof(typename Vec::value_type), idx); + } + template >> + void set_vector_bytes(const Vec& vec, int idx) { + return set_vector_bytes(vec, vec.size(), idx); + } + + template + void set_bytes(const T* v, int n, int idx) { + return enc_->setBytes(v, n * sizeof(T), idx); + } + + template + void set_bytes(const T& v, int idx) { + return enc_->setBytes(&v, sizeof(T), idx); + } + + void set_threadgroup_memory_length(size_t length, int idx) { + enc_->setThreadgroupMemoryLength(length, idx); + } + + ConcurrentContext start_concurrent() { + return ConcurrentContext(*this); + } + ~CommandEncoder(); + + // Inputs to all kernels in the encoder including temporaries + std::unordered_set& inputs() { + return all_inputs_; + }; + + // Outputs of all kernels in the encoder including temporaries + std::unordered_set& outputs() { + return all_outputs_; + }; + + void barrier(); + + private: + DeviceStream& stream_; + MTL::ComputeCommandEncoder* enc_; + bool needs_barrier_{false}; + bool concurrent_{false}; + std::unordered_set prev_outputs_; + std::unordered_set next_outputs_; + std::unordered_set concurrent_outputs_; + std::unordered_set all_inputs_; + std::unordered_set all_outputs_; +}; + +struct Fence { + Fence(MTL::Fence* fence) : fence(fence) {} + ~Fence() { + fence->release(); + } + MTL::Fence* fence; +}; + +struct DeviceStream { + DeviceStream(MTL::CommandQueue* queue) : queue(queue) {}; + ~DeviceStream() { + queue->release(); + if (buffer != nullptr) { + buffer->release(); + } + }; + MTL::CommandQueue* queue; + // A map of prior command encoder outputs to their corresponding fence + std::unordered_map> outputs; + // Used to allow thread-safe access to the outputs map + std::mutex fence_mtx; + + // Data updated between command buffers + MTL::CommandBuffer* buffer{nullptr}; + int buffer_ops{0}; + size_t buffer_sizes{0}; + + // The command encoder, fence, and temporaries are updated between command + // encoders + std::unique_ptr encoder{nullptr}; + std::shared_ptr fence; + std::vector temporaries; +}; + +class MLX_API Device { + public: + Device(); + Device(const Device&) = delete; + Device& operator=(const Device&) = delete; + ~Device(); + + MTL::Device* mtl_device() { + return device_; + }; + + const std::string& get_architecture() { + return arch_; + } + + int get_architecture_gen() const { + return arch_gen_; + } + + void new_queue(int index); + + MTL::CommandQueue* get_queue(Stream stream); + + MTL::CommandBuffer* get_command_buffer(int index); + bool command_buffer_needs_commit(int index); + void commit_command_buffer(int index); + CommandEncoder& get_command_encoder(int index); + void end_encoding(int index); + + MTL::Library* get_library( + const std::string& name, + const std::string& path = ""); + + MTL::Library* get_library( + const std::string& name, + const std::function& builder); + + void clear_library(const std::string& name); + + MTL::ComputePipelineState* get_kernel( + const std::string& base_name, + MTL::Library* mtl_lib, + const std::string& hash_name = "", + const MTLFCList& func_consts = {}, + const std::vector& linked_functions = {}); + + MTL::ComputePipelineState* get_kernel( + const std::string& base_name, + const std::string& hash_name = "", + const MTLFCList& func_consts = {}, + const std::vector& linked_functions = {}); + + MTL::ArgumentEncoder* argument_encoder( + const std::vector& arg_descs) const; + + // Record temporary arrays for the given stream index + void add_temporary(array arr, int index); + void add_temporaries(std::vector arrays, int index); + + void set_residency_set(const MTL::ResidencySet* residency_set); + + private: + DeviceStream& get_stream_(int index) { + return stream_map_.find(index)->second; + } + MTL::Library* get_library_cache_(const std::string& name); + + MTL::Library* get_library_(const std::string& name); + MTL::Library* build_library_(const std::string& source_string); + + MTL::Function* get_function_(const std::string& name, MTL::Library* mtl_lib); + + MTL::Function* get_function_( + const std::string& name, + const std::string& specialized_name, + const MTLFCList& func_consts, + MTL::Library* mtl_lib); + + MTL::LinkedFunctions* get_linked_functions_( + const std::vector& funcs); + + MTL::ComputePipelineState* get_kernel_( + const std::string& name, + const MTL::Function* mtl_function); + + MTL::ComputePipelineState* get_kernel_( + const std::string& name, + const MTL::Function* mtl_function, + const MTL::LinkedFunctions* linked_functions); + + MTL::ComputePipelineState* get_kernel_( + const std::string& base_name, + MTL::Library* mtl_lib, + const std::string& hash_name, + const MTLFCList& func_consts = {}, + const std::vector& linked_functions = {}); + + MTL::Device* device_; + std::unordered_map stream_map_; + + std::shared_mutex kernel_mtx_; + std::shared_mutex library_mtx_; + std::unordered_map library_map_; + MTL::Library* default_library_; + std::unordered_map< + MTL::Library*, + std::unordered_map> + library_kernels_; + const MTL::ResidencySet* residency_set_{nullptr}; + std::string arch_; + int arch_gen_; + int max_ops_per_buffer_; + int max_mb_per_buffer_; +}; + +MLX_API Device& device(mlx::core::Device); + +std::unique_ptr> new_scoped_memory_pool(); + +inline bool is_nax_available() { +#ifdef MLX_METAL_NO_NAX + return false; +#else + auto _check_nax = []() { + bool can_use_nax = false; + if (__builtin_available( + macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) { + can_use_nax = true; + } + auto& d = metal::device(mlx::core::Device::gpu); + auto arch = d.get_architecture().back(); + auto gen = d.get_architecture_gen(); + can_use_nax &= gen >= (arch == 'p' ? 18 : 17); + return can_use_nax; + }; + static bool is_nax_available_ = _check_nax(); + return is_nax_available_; +#endif +} + +} // namespace mlx::core::metal diff --git a/Source/Cxxmlx/include/mlx/backend/metal/jit/includes.h b/Source/Cxxmlx/include/mlx/backend/metal/jit/includes.h new file mode 100644 index 00000000..dcaf09a1 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/metal/jit/includes.h @@ -0,0 +1,59 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +namespace mlx::core::metal { + +const char* utils(); +const char* binary_ops(); +const char* unary_ops(); +const char* ternary_ops(); +const char* reduce_utils(); +const char* gather(); +const char* scatter(); +const char* masked_scatter(); + +const char* arange(); +const char* unary(); +const char* binary(); +const char* binary_two(); +const char* copy(); +const char* fft(); +const char* gather_axis(); +const char* gather_front(); +const char* hadamard(); +const char* logsumexp(); +const char* quantized_utils(); +const char* quantized(); +const char* fp_quantized(); +const char* ternary(); +const char* scan(); +const char* scatter_axis(); +const char* softmax(); +const char* sort(); +const char* reduce(); + +const char* gemm(); +const char* steel_gemm_fused(); +const char* steel_gemm_masked(); +const char* steel_gemm_splitk(); +const char* steel_gemm_gather(); +const char* steel_gemm_segmented(); +const char* conv(); +const char* steel_conv(); +const char* steel_conv_3d(); +const char* steel_conv_general(); +const char* gemv_masked(); +const char* steel_attention(); + +const char* gemm_nax(); +const char* steel_gemm_fused_nax(); +const char* steel_gemm_gather_nax(); +const char* steel_gemm_splitk_nax(); + +const char* quantized_nax(); +const char* fp_quantized_nax(); + +const char* steel_attention_nax(); + +} // namespace mlx::core::metal diff --git a/Source/Cxxmlx/include/mlx/backend/metal/jit/indexing.h b/Source/Cxxmlx/include/mlx/backend/metal/jit/indexing.h new file mode 100644 index 00000000..fa141fcc --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/metal/jit/indexing.h @@ -0,0 +1,76 @@ +// Copyright © 2023-2024 Apple Inc. + +constexpr std::string_view gather_kernels = R"( +[[kernel]] void gather{0}_{3}_{6}_{7}( + const device {1}* src [[buffer(0)]], + device {1}* out [[buffer(1)]], + const constant int* src_shape [[buffer(2)]], + const constant int64_t* src_strides [[buffer(3)]], + const constant size_t& src_ndim [[buffer(4)]], + const constant int* slice_sizes [[buffer(5)]], + const constant int* axes [[buffer(6)]], + const constant int* idx_shapes [[buffer(7)]], + const constant int64_t* idx_strides [[buffer(8)]], + const constant bool* idx_contigs [[buffer(9)]], + const constant int& idx_ndim [[buffer(10)]], + {4} + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) {{ + Indices<{2}, {3}> idxs{{ + {{ {5} }}, idx_shapes, idx_strides, idx_contigs, idx_ndim}}; + + return gather_impl<{1}, {2}, {3}, {6}, {7}>( + src, + out, + src_shape, + src_strides, + src_ndim, + slice_sizes, + axes, + idxs, + index, + grid_dim); +}} +)"; + +constexpr std::string_view scatter_kernels = R"( +[[kernel]] void scatter{0}_{4}_updc_{7}_nwork{8}_{9}( + const device {1}* updates [[buffer(1)]], + device mlx_atomic<{1}>* out [[buffer(2)]], + const constant int* upd_shape [[buffer(3)]], + const constant int64_t* upd_strides [[buffer(4)]], + const constant size_t& upd_ndim [[buffer(5)]], + const constant size_t& upd_size [[buffer(6)]], + const constant int* out_shape [[buffer(7)]], + const constant int64_t* out_strides [[buffer(8)]], + const constant size_t& out_ndim [[buffer(9)]], + const constant int* axes [[buffer(10)]], + const constant int* idx_shapes [[buffer(11)]], + const constant int64_t* idx_strides [[buffer(12)]], + const constant bool* idx_contigs [[buffer(13)]], + const constant int& idx_ndim [[buffer(14)]], + const constant size_t& idx_size [[buffer(15)]], + {5} + uint2 gid [[thread_position_in_grid]]) {{ + Indices<{2}, {4}> idxs{{ {{ {6} }}, idx_shapes, idx_strides, idx_contigs, idx_ndim}}; + + return scatter_impl<{1}, {2}, {3}, {4}, {7}, {8}, {9}>( + updates, + out, + upd_shape, + upd_strides, + upd_ndim, + upd_size, + out_shape, + out_strides, + out_ndim, + axes, + idx_size, + idxs, + gid); +}} +)"; + +constexpr std::string_view masked_assign_kernel = R"( +template [[host_name("{0}")]] [[kernel]] decltype(masked_assign_impl<{1}, {2}>) masked_assign_impl<{1}, {2}>; +)"; diff --git a/Source/Cxxmlx/include/mlx/backend/metal/kernels.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels.h new file mode 100644 index 00000000..63fccc59 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/metal/kernels.h @@ -0,0 +1,386 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include + +#include "mlx/array.h" +#include "mlx/backend/metal/device.h" + +namespace mlx::core { + +MTL::ComputePipelineState* get_arange_kernel( + metal::Device& d, + const std::string& kernel_name, + const array& out); + +MTL::ComputePipelineState* get_unary_kernel( + metal::Device& d, + const std::string& kernel_name, + Dtype in_type, + Dtype out_type, + const char* op); + +MTL::ComputePipelineState* get_binary_kernel( + metal::Device& d, + const std::string& kernel_name, + Dtype in_type, + Dtype out_type, + const char* op); + +MTL::ComputePipelineState* get_binary_two_kernel( + metal::Device& d, + const std::string& kernel_name, + Dtype in_type, + Dtype out_type, + const char* op); + +MTL::ComputePipelineState* get_ternary_kernel( + metal::Device& d, + const std::string& kernel_name, + Dtype type, + const char* op); + +MTL::ComputePipelineState* get_copy_kernel( + metal::Device& d, + const std::string& kernel_name, + const array& in, + const array& out); + +MTL::ComputePipelineState* get_dynamic_copy_kernel( + metal::Device& d, + const std::string& kernel_name, + const array& in, + const array& out); + +MTL::ComputePipelineState* get_softmax_kernel( + metal::Device& d, + const std::string& kernel_name, + bool precise, + const array& out); + +MTL::ComputePipelineState* get_logsumexp_kernel( + metal::Device& d, + const std::string& kernel_name, + const array& out); + +MTL::ComputePipelineState* get_scan_kernel( + metal::Device& d, + const std::string& kernel_name, + bool reverse, + bool inclusive, + const std::string& reduce_type, + const array& in, + const array& out); + +MTL::ComputePipelineState* get_sort_kernel( + metal::Device& d, + const std::string& kernel_name, + const array& in, + const array& out, + int bn, + int tn); + +MTL::ComputePipelineState* get_mb_sort_kernel( + metal::Device& d, + const std::string& kernel_name, + const array& in, + const array& idx, + int bn, + int tn); + +MTL::ComputePipelineState* get_reduce_init_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& func_name, + const std::string& op_name, + const Dtype& out_type); + +MTL::ComputePipelineState* get_reduce_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& func_name, + const std::string& op_name, + const Dtype& in_type, + const Dtype& out_type, + const std::string& idx_t, + int ndim = -1, + int bm = -1, + int bn = -1); + +MTL::ComputePipelineState* get_steel_gemm_fused_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, + const array& out, + bool transpose_a, + bool transpose_b, + int bm, + int bn, + int bk, + int wm, + int wn); + +MTL::ComputePipelineState* get_steel_gemm_splitk_kernel( + metal::Device& d, + const std::string& kernel_name, + const array& in, + const array& out, + bool transpose_a, + bool transpose_b, + int bm, + int bn, + int bk, + int wm, + int wn, + bool mn_aligned, + bool k_aligned); + +MTL::ComputePipelineState* get_steel_gemm_splitk_accum_kernel( + metal::Device& d, + const std::string& kernel_name, + const array& in, + const array& out, + bool axbpy); + +MTL::ComputePipelineState* get_steel_gemm_masked_kernel( + metal::Device& d, + const std::string& kernel_name, + const array& out, + const std::optional& mask_out, + const std::optional& mask_op, + bool transpose_a, + bool transpose_b, + int bm, + int bn, + int bk, + int wm, + int wn, + bool mn_aligned, + bool k_aligned); + +MTL::ComputePipelineState* get_steel_gemm_gather_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, + const array& out, + bool transpose_a, + bool transpose_b, + int bm, + int bn, + int bk, + int wm, + int wn, + bool rhs); + +MTL::ComputePipelineState* get_steel_gemm_segmented_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, + const array& out, + bool transpose_a, + bool transpose_b, + int bm, + int bn, + int bk, + int wm, + int wn); + +MTL::ComputePipelineState* get_steel_conv_kernel( + metal::Device& d, + const std::string& kernel_name, + const array& out, + int bm, + int bn, + int bk, + int wm, + int wn, + int n_channel_specialization, + bool small_filter); + +MTL::ComputePipelineState* get_steel_conv_3d_kernel( + metal::Device& d, + const std::string& kernel_name, + const array& out, + int bm, + int bn, + int bk, + int wm, + int wn, + bool small_filter); + +MTL::ComputePipelineState* get_gemv_masked_kernel( + metal::Device& d, + const std::string& kernel_name, + const array& out, + const std::optional& mask_out, + const std::optional& mask_op, + bool transpose_mat, + int bm, + int bn, + int sm, + int sn, + int tm, + int tn, + bool contiguous); + +MTL::ComputePipelineState* get_steel_conv_general_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, + const array& out, + int bm, + int bn, + int bk, + int wm, + int wn); + +MTL::ComputePipelineState* get_fft_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, + const std::string& template_def); + +MTL::ComputePipelineState* get_quantized_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& template_def, + const std::string& mode); + +MTL::ComputePipelineState* get_gather_qmm_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, + const array& x, + int group_size, + int bits, + const std::string& mode, + int bm, + int bn, + int bk, + int wm, + int wn, + bool transpose); + +MTL::ComputePipelineState* get_steel_gemm_fused_nax_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, + const array& out, + bool transpose_a, + bool transpose_b, + int bm, + int bn, + int bk, + int wm, + int wn); + +MTL::ComputePipelineState* get_steel_gemm_gather_nax_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, + const array& out, + bool transpose_a, + bool transpose_b, + int bm, + int bn, + int bk, + int wm, + int wn, + bool rhs); + +MTL::ComputePipelineState* get_steel_gemm_splitk_nax_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, + const array& out, + bool transpose_a, + bool transpose_b, + int bm, + int bn, + int bk, + int wm, + int wn); + +MTL::ComputePipelineState* get_qmm_nax_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& template_def, + const std::string& mode); + +MTL::ComputePipelineState* get_gather_qmm_nax_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, + const array& x, + int group_size, + int bits, + const std::string& mode, + int bm, + int bn, + int bk, + int wm, + int wn, + bool transpose); + +MTL::ComputePipelineState* get_steel_attention_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, + const array& q, + int bq, + int bk, + int bd, + int wm, + int wn, + const array& m); + +MTL::ComputePipelineState* get_steel_attention_nax_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, + const array& q, + int bq, + int bk, + int bd, + int wm, + int wn, + const array& m); + +// Create a GPU kernel template definition for JIT compilation +template +std::string get_template_definition( + std::string_view name, + std::string_view func, + Args... args) { + std::ostringstream s; + s << func << "<"; + bool first = true; + auto add_arg = [&s, &first](const auto& arg) { + if (!first) { + s << ", "; + } + first = false; + s << arg; + }; + (add_arg(args), ...); + s << ">"; + return fmt::format( + "\ntemplate [[host_name(\"{0}\")]] [[kernel]] decltype({1}) {1};\n", + name, + s.str()); +} + +} // namespace mlx::core diff --git a/Source/Cmlx/mlx-generated/metal/arange.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/arange.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/arange.h rename to Source/Cxxmlx/include/mlx/backend/metal/kernels/arange.h diff --git a/Source/Cmlx/mlx-generated/metal/atomic.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/atomic.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/atomic.h rename to Source/Cxxmlx/include/mlx/backend/metal/kernels/atomic.h diff --git a/Source/Cmlx/mlx-generated/metal/bf16.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/bf16.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/bf16.h rename to Source/Cxxmlx/include/mlx/backend/metal/kernels/bf16.h diff --git a/Source/Cmlx/mlx-generated/metal/bf16_math.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/bf16_math.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/bf16_math.h rename to Source/Cxxmlx/include/mlx/backend/metal/kernels/bf16_math.h diff --git a/Source/Cmlx/mlx-generated/metal/binary.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/binary.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/binary.h rename to Source/Cxxmlx/include/mlx/backend/metal/kernels/binary.h diff --git a/Source/Cmlx/mlx-generated/metal/binary_ops.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/binary_ops.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/binary_ops.h rename to Source/Cxxmlx/include/mlx/backend/metal/kernels/binary_ops.h diff --git a/Source/Cmlx/mlx-generated/metal/binary_two.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/binary_two.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/binary_two.h rename to Source/Cxxmlx/include/mlx/backend/metal/kernels/binary_two.h diff --git a/Source/Cmlx/mlx-generated/metal/cexpf.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/cexpf.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/cexpf.h rename to Source/Cxxmlx/include/mlx/backend/metal/kernels/cexpf.h diff --git a/Source/Cmlx/mlx-generated/metal/complex.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/complex.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/complex.h rename to Source/Cxxmlx/include/mlx/backend/metal/kernels/complex.h diff --git a/Source/Cmlx/mlx-generated/metal/copy.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/copy.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/copy.h rename to Source/Cxxmlx/include/mlx/backend/metal/kernels/copy.h diff --git a/Source/Cmlx/mlx-generated/metal/defines.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/defines.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/defines.h rename to Source/Cxxmlx/include/mlx/backend/metal/kernels/defines.h diff --git a/Source/Cxxmlx/include/mlx/backend/metal/kernels/erf.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/erf.h new file mode 100644 index 00000000..d367eef9 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/metal/kernels/erf.h @@ -0,0 +1,69 @@ +// Copyright © 2023 Apple Inc. + +#pragma once +#include +#include "mlx/backend/metal/kernels/expm1f.h" + +/* + * Approximation to the error function. + * Based on code from: + * https://stackoverflow.com/questions/35148198/efficient-faithfully-rounded-implementation-of-error-function-erff#answer-35148199 + */ +float erf(float a) { + float r, s, t, u; + t = metal::abs(a); + s = a * a; + if (t > 0.927734375f) { + // maximum error 0.99527 ulp + r = metal::fma( + -1.72853470e-5f, t, 3.83197126e-4f); // -0x1.220000p-16,0x1.91cfb2p-12 + u = metal::fma( + -3.88396438e-3f, t, 2.42546219e-2f); // -0x1.fd1438p-9, 0x1.8d6342p-6 + r = metal::fma(r, s, u); + r = metal::fma(r, t, -1.06777877e-1f); // -0x1.b55cb8p-4 + r = metal::fma(r, t, -6.34846687e-1f); // -0x1.450aa0p-1 + r = metal::fma(r, t, -1.28717512e-1f); // -0x1.079d0cp-3 + r = metal::fma(r, t, -t); + r = -expm1f(r); + r = metal::copysign(r, a); + } else { + // maximum error 0.98929 ulp + r = -5.96761703e-4f; // -0x1.38e000p-11 + r = metal::fma(r, s, 4.99119423e-3f); // 0x1.471a58p-8 + r = metal::fma(r, s, -2.67681349e-2f); // -0x1.b691b2p-6 + r = metal::fma(r, s, 1.12819925e-1f); // 0x1.ce1c44p-4 + r = metal::fma(r, s, -3.76125336e-1f); // -0x1.812700p-2 + r = metal::fma(r, s, 1.28379166e-1f); // 0x1.06eba8p-3 + r = metal::fma(r, a, a); + } + return r; +} + +float erfinv(float a) { + auto t = metal::fma(a, 0.0f - a, 1.0f); + t = metal::log(t); + float p; + if (metal::abs(t) > 6.125f) { // maximum ulp error = 2.35793 + p = 3.03697567e-10f; // 0x1.4deb44p-32 + p = metal::fma(p, t, 2.93243101e-8f); // 0x1.f7c9aep-26 + p = metal::fma(p, t, 1.22150334e-6f); // 0x1.47e512p-20 + p = metal::fma(p, t, 2.84108955e-5f); // 0x1.dca7dep-16 + p = metal::fma(p, t, 3.93552968e-4f); // 0x1.9cab92p-12 + p = metal::fma(p, t, 3.02698812e-3f); // 0x1.8cc0dep-9 + p = metal::fma(p, t, 4.83185798e-3f); // 0x1.3ca920p-8 + p = metal::fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2 + p = metal::fma(p, t, 8.40016484e-1f); // 0x1.ae16a4p-1 + } else { // maximum ulp error = 2.35002 + p = 5.43877832e-9f; // 0x1.75c000p-28 + p = metal::fma(p, t, 1.43285448e-7f); // 0x1.33b402p-23 + p = metal::fma(p, t, 1.22774793e-6f); // 0x1.499232p-20 + p = metal::fma(p, t, 1.12963626e-7f); // 0x1.e52cd2p-24 + p = metal::fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15 + p = metal::fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13 + p = metal::fma(p, t, 2.31468678e-3f); // 0x1.2f6400p-9 + p = metal::fma(p, t, 1.15392581e-2f); // 0x1.7a1e50p-7 + p = metal::fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3 + p = metal::fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1 + } + return a * p; +} diff --git a/Source/Cmlx/mlx-generated/metal/expm1f.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/expm1f.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/expm1f.h rename to Source/Cxxmlx/include/mlx/backend/metal/kernels/expm1f.h diff --git a/Source/Cxxmlx/include/mlx/backend/metal/kernels/fft.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/fft.h new file mode 100644 index 00000000..e478a85b --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/metal/kernels/fft.h @@ -0,0 +1,486 @@ +// Copyright © 2024 Apple Inc. + +// Metal FFT using Stockham's algorithm +// +// References: +// - VkFFT (https://github.com/DTolm/VkFFT) +// - Eric Bainville's excellent page (http://www.bealto.com/gpu-fft.html) + +#include + +#include "mlx/backend/metal/kernels/fft/radix.h" +#include "mlx/backend/metal/kernels/fft/readwrite.h" +#include "mlx/backend/metal/kernels/steel/defines.h" + +using namespace metal; + +#define MAX_RADIX 13 +// Reached when elems_per_thread_ = 6, max_radix = 13 +// and some threads have to do 3 radix 6s requiring 18 float2s. +#define MAX_OUTPUT_SIZE 18 + +// Specialize for a particular value of N at runtime +STEEL_CONST bool inv_ [[function_constant(0)]]; +STEEL_CONST bool is_power_of_2_ [[function_constant(1)]]; +STEEL_CONST int elems_per_thread_ [[function_constant(2)]]; +// rader_m = n / rader_n +STEEL_CONST int rader_m_ [[function_constant(3)]]; +// Stockham steps +STEEL_CONST int radix_13_steps_ [[function_constant(4)]]; +STEEL_CONST int radix_11_steps_ [[function_constant(5)]]; +STEEL_CONST int radix_8_steps_ [[function_constant(6)]]; +STEEL_CONST int radix_7_steps_ [[function_constant(7)]]; +STEEL_CONST int radix_6_steps_ [[function_constant(8)]]; +STEEL_CONST int radix_5_steps_ [[function_constant(9)]]; +STEEL_CONST int radix_4_steps_ [[function_constant(10)]]; +STEEL_CONST int radix_3_steps_ [[function_constant(11)]]; +STEEL_CONST int radix_2_steps_ [[function_constant(12)]]; +// Rader steps +STEEL_CONST int rader_13_steps_ [[function_constant(13)]]; +STEEL_CONST int rader_11_steps_ [[function_constant(14)]]; +STEEL_CONST int rader_8_steps_ [[function_constant(15)]]; +STEEL_CONST int rader_7_steps_ [[function_constant(16)]]; +STEEL_CONST int rader_6_steps_ [[function_constant(17)]]; +STEEL_CONST int rader_5_steps_ [[function_constant(18)]]; +STEEL_CONST int rader_4_steps_ [[function_constant(19)]]; +STEEL_CONST int rader_3_steps_ [[function_constant(20)]]; +STEEL_CONST int rader_2_steps_ [[function_constant(21)]]; + +// See "radix.h" for radix codelets +typedef void (*RadixFunc)(thread float2*, thread float2*); + +// Perform a single radix n butterfly with appropriate twiddles +template +METAL_FUNC void radix_butterfly( + int i, + int p, + thread float2* x, + thread short* indices, + thread float2* y) { + // i: the index in the overall DFT that we're processing. + // p: the size of the DFTs we're merging at this step. + // m: how many threads are working on this DFT. + int k, j; + + // Use faster bitwise operations when working with powers of two + constexpr bool radix_p_2 = (radix & (radix - 1)) == 0; + if (radix_p_2 && is_power_of_2_) { + constexpr short power = __builtin_ctz(radix); + k = i & (p - 1); + j = ((i - k) << power) + k; + } else { + k = i % p; + j = (i / p) * radix * p + k; + } + + // Apply twiddles + if (p > 1) { + float2 twiddle_1 = get_twiddle(k, radix * p); + float2 twiddle = twiddle_1; + x[1] = complex_mul(x[1], twiddle); + + STEEL_PRAGMA_UNROLL + for (int t = 2; t < radix; t++) { + twiddle = complex_mul(twiddle, twiddle_1); + x[t] = complex_mul(x[t], twiddle); + } + } + + radix_func(x, y); + + STEEL_PRAGMA_UNROLL + for (int t = 0; t < radix; t++) { + indices[t] = j + t * p; + } +} + +// Perform all the radix steps required for a +// particular radix size n. +template +METAL_FUNC void radix_n_steps( + int i, + thread int* p, + int m, + int n, + int num_steps, + thread float2* inputs, + thread short* indices, + thread float2* values, + threadgroup float2* buf) { + int m_r = n / radix; + // When combining different sized radices, we have to do + // multiple butterflies in a single thread. + // E.g. n = 28 = 4 * 7 + // 4 threads, 7 elems_per_thread + // All threads do 1 radix7 butterfly. + // 3 threads do 2 radix4 butterflies. + // 1 thread does 1 radix4 butterfly. + int max_radices_per_thread = (elems_per_thread_ + radix - 1) / radix; + + int index = 0; + int r_index = 0; + for (int s = 0; s < num_steps; s++) { + for (int t = 0; t < max_radices_per_thread; t++) { + index = i + t * m; + if (index < m_r) { + for (int r = 0; r < radix; r++) { + inputs[r] = buf[index + r * m_r]; + } + radix_butterfly( + index, *p, inputs, indices + t * radix, values + t * radix); + } + } + + // Wait until all threads have read their inputs into thread local mem + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (int t = 0; t < max_radices_per_thread; t++) { + index = i + t * m; + if (index < m_r) { + for (int r = 0; r < radix; r++) { + r_index = t * radix + r; + buf[indices[r_index]] = values[r_index]; + } + } + } + + // Wait until all threads have written back to threadgroup mem + threadgroup_barrier(mem_flags::mem_threadgroup); + *p *= radix; + } +} + +#define RADIX_STEP(radix, radix_func, num_steps) \ + radix_n_steps( \ + fft_idx, p, m, n, num_steps, inputs, indices, values, buf); + +template +METAL_FUNC void +perform_fft(int fft_idx, thread int* p, int m, int n, threadgroup float2* buf) { + float2 inputs[MAX_RADIX]; + short indices[MAX_OUTPUT_SIZE]; + float2 values[MAX_OUTPUT_SIZE]; + + RADIX_STEP(2, radix2, rader ? rader_2_steps_ : radix_2_steps_); + RADIX_STEP(3, radix3, rader ? rader_3_steps_ : radix_3_steps_); + RADIX_STEP(4, radix4, rader ? rader_4_steps_ : radix_4_steps_); + RADIX_STEP(5, radix5, rader ? rader_5_steps_ : radix_5_steps_); + RADIX_STEP(6, radix6, rader ? rader_6_steps_ : radix_6_steps_); + RADIX_STEP(7, radix7, rader ? rader_7_steps_ : radix_7_steps_); + RADIX_STEP(8, radix8, rader ? rader_8_steps_ : radix_8_steps_); + RADIX_STEP(11, radix11, rader ? rader_11_steps_ : radix_11_steps_); + RADIX_STEP(13, radix13, rader ? rader_13_steps_ : radix_13_steps_); +} + +// Each FFT is computed entirely in shared GPU memory. +// +// N is decomposed into radix-n DFTs: +// e.g. 128 = 2 * 4 * 4 * 4 +template +[[kernel]] void fft( + const device in_T* in [[buffer(0)]], + device out_T* out [[buffer(1)]], + constant const int& n, + constant const int& batch_size, + uint3 elem [[thread_position_in_grid]], + uint3 grid [[threads_per_grid]]) { + threadgroup float2 shared_in[tg_mem_size]; + + thread ReadWriter read_writer = ReadWriter( + in, + &shared_in[0], + out, + n, + batch_size, + elems_per_thread_, + elem, + grid, + inv_); + + if (read_writer.out_of_bounds()) { + return; + }; + read_writer.load(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + int p = 1; + int fft_idx = elem.z; // Thread index in DFT + int m = grid.z; // Threads per DFT + int tg_idx = elem.y * n; // Index of this DFT in threadgroup + threadgroup float2* buf = &shared_in[tg_idx]; + + perform_fft(fft_idx, &p, m, n, buf); + + read_writer.write(); +} + +template +[[kernel]] void rader_fft( + const device in_T* in [[buffer(0)]], + device out_T* out [[buffer(1)]], + const device float2* raders_b_q [[buffer(2)]], + const device short* raders_g_q [[buffer(3)]], + const device short* raders_g_minus_q [[buffer(4)]], + constant const int& n, + constant const int& batch_size, + constant const int& rader_n, + uint3 elem [[thread_position_in_grid]], + uint3 grid [[threads_per_grid]]) { + // Use Rader's algorithm to compute fast FFTs + // when a prime factor `p` of `n` is greater than 13 but + // has `p - 1` Stockham decomposable into to prime factors <= 13. + // + // E.g. n = 102 + // = 2 * 3 * 17 + // . = 2 * 3 * RADER(16) + // . = 2 * 3 * RADER(4 * 4) + // + // In numpy: + // x_perm = x[g_q] + // y = np.fft.fft(x_perm) * b_q + // z = np.fft.ifft(y) + x[0] + // out = z[g_minus_q] + // out[0] = x[1:].sum() + // + // Where the g_q and g_minus_q are permutations formed + // by the group under multiplicative modulo N using the + // primitive root of N and b_q is a constant. + // See https://en.wikipedia.org/wiki/Rader%27s_FFT_algorithm + // + // Rader's uses fewer operations than Bluestein's and so + // is more accurate. It's also faster in most cases. + threadgroup float2 shared_in[tg_mem_size]; + + thread ReadWriter read_writer = ReadWriter( + in, + &shared_in[0], + out, + n, + batch_size, + elems_per_thread_, + elem, + grid, + inv_); + + if (read_writer.out_of_bounds()) { + return; + }; + read_writer.load(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // The number of the threads we're using for each DFT + int m = grid.z; + + int fft_idx = elem.z; + int tg_idx = elem.y * n; + threadgroup float2* buf = &shared_in[tg_idx]; + + // rader_m = n / rader_n; + int rader_m = rader_m_; + + // We have to load two x_0s for each thread since sometimes + // elems_per_thread_ crosses a boundary. + // E.g. with n = 34, rader_n = 17, elems_per_thread_ = 4 + // 0 0 0 0 1 1 1 1 2 2 2 2 3 3 3 3 4 4 4 4 5 5 5 5 6 6 6 6 7 7 7 7 8 8 + // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 + short x_0_index = + metal::min(fft_idx * elems_per_thread_ / (rader_n - 1), rader_m - 1); + float2 x_0[2] = {buf[x_0_index], buf[x_0_index + 1]}; + + // Do the Rader permutation in shared memory + float2 temp[MAX_RADIX]; + int max_index = n - rader_m - 1; + for (int e = 0; e < elems_per_thread_; e++) { + short index = metal::min(fft_idx * elems_per_thread_ + e, max_index); + short g_q = raders_g_q[index / rader_m]; + temp[e] = buf[rader_m + (g_q - 1) * rader_m + index % rader_m]; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (int e = 0; e < elems_per_thread_; e++) { + short index = metal::min(fft_idx * elems_per_thread_ + e, max_index); + buf[index + rader_m] = temp[e]; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Rader FFT on x[rader_m:] + int p = 1; + perform_fft(fft_idx, &p, m, n - rader_m, buf + rader_m); + + // x_1 + ... + x_n is computed for us in the first FFT step so + // we save it in the first rader_m indices of the array for later. + int x_sum_index = metal::min(fft_idx, rader_m - 1); + buf[x_sum_index] = buf[rader_m + x_sum_index * (rader_n - 1)]; + + float2 inv = {1.0f, -1.0f}; + for (int e = 0; e < elems_per_thread_; e++) { + short index = metal::min(fft_idx * elems_per_thread_ + e, max_index); + short interleaved_index = + index / rader_m + (index % rader_m) * (rader_n - 1); + temp[e] = complex_mul( + buf[rader_m + interleaved_index], + raders_b_q[interleaved_index % (rader_n - 1)]); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (int e = 0; e < elems_per_thread_; e++) { + short index = metal::min(fft_idx * elems_per_thread_ + e, max_index); + buf[rader_m + index] = temp[e] * inv; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Rader IFFT on x[rader_m:] + p = 1; + perform_fft(fft_idx, &p, m, n - rader_m, buf + rader_m); + + float2 rader_inv_factor = {1.0f / (rader_n - 1), -1.0f / (rader_n - 1)}; + + for (int e = 0; e < elems_per_thread_; e++) { + short index = metal::min(fft_idx * elems_per_thread_ + e, n - rader_m - 1); + short diff_index = index / (rader_n - 1) - x_0_index; + temp[e] = buf[rader_m + index] * rader_inv_factor + x_0[diff_index]; + } + + // Use the sum of elements that was computed in the first FFT + float2 x_sum = buf[x_0_index] + x_0[0]; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (int e = 0; e < elems_per_thread_; e++) { + short index = metal::min(fft_idx * elems_per_thread_ + e, max_index); + short g_q_index = index % (rader_n - 1); + short g_q = raders_g_minus_q[g_q_index]; + short out_index = index - g_q_index + g_q + (index / (rader_n - 1)); + buf[out_index] = temp[e]; + } + + buf[x_0_index * rader_n] = x_sum; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + p = rader_n; + perform_fft(fft_idx, &p, m, n, buf); + + read_writer.write(); +} + +template +[[kernel]] void bluestein_fft( + const device in_T* in [[buffer(0)]], + device out_T* out [[buffer(1)]], + const device float2* w_q [[buffer(2)]], + const device float2* w_k [[buffer(3)]], + constant const int& length, + constant const int& n, + constant const int& batch_size, + uint3 elem [[thread_position_in_grid]], + uint3 grid [[threads_per_grid]]) { + // Computes arbitrary length FFTs with Bluestein's algorithm + // + // In numpy: + // bluestein_n = next_power_of_2(2*n - 1) + // out = w_k * np.fft.ifft(np.fft.fft(w_k * in, bluestein_n) * w_q) + // + // Where w_k and w_q are precomputed on CPU in high precision as: + // w_k = np.exp(-1j * np.pi / n * (np.arange(-n + 1, n) ** 2)) + // w_q = np.fft.fft(1/w_k[-n:]) + threadgroup float2 shared_in[tg_mem_size]; + + thread ReadWriter read_writer = ReadWriter( + in, + &shared_in[0], + out, + n, + batch_size, + elems_per_thread_, + elem, + grid, + inv_); + + if (read_writer.out_of_bounds()) { + return; + }; + read_writer.load_padded(length, w_k); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + int p = 1; + int fft_idx = elem.z; // Thread index in DFT + int m = grid.z; // Threads per DFT + int tg_idx = elem.y * n; // Index of this DFT in threadgroup + threadgroup float2* buf = &shared_in[tg_idx]; + + // fft + perform_fft(fft_idx, &p, m, n, buf); + + float2 inv = float2(1.0f, -1.0f); + for (int t = 0; t < elems_per_thread_; t++) { + int index = fft_idx + t * m; + buf[index] = complex_mul(buf[index], w_q[index]) * inv; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // ifft + p = 1; + perform_fft(fft_idx, &p, m, n, buf); + + read_writer.write_padded(length, w_k); +} + +template < + int tg_mem_size, + typename in_T, + typename out_T, + int step, + bool real = false> +[[kernel]] void four_step_fft( + const device in_T* in [[buffer(0)]], + device out_T* out [[buffer(1)]], + constant const int& n1, + constant const int& n2, + constant const int& batch_size, + uint3 elem [[thread_position_in_grid]], + uint3 grid [[threads_per_grid]]) { + // Fast four step FFT implementation for powers of 2. + int overall_n = n1 * n2; + int n = step == 0 ? n1 : n2; + int stride = step == 0 ? n2 : n1; + + // The number of the threads we're using for each DFT + int m = grid.z; + int fft_idx = elem.z; + + threadgroup float2 shared_in[tg_mem_size]; + threadgroup float2* buf = &shared_in[elem.y * n]; + + using read_writer_t = ReadWriter; + read_writer_t read_writer = read_writer_t( + in, + &shared_in[0], + out, + n, + batch_size, + elems_per_thread_, + elem, + grid, + inv_); + + if (read_writer.out_of_bounds()) { + return; + }; + read_writer.load_strided(stride, overall_n); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + int p = 1; + perform_fft(fft_idx, &p, m, n, buf); + + read_writer.write_strided(stride, overall_n); +} diff --git a/Source/Cmlx/mlx-generated/metal/fft/radix.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/fft/radix.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/fft/radix.h rename to Source/Cxxmlx/include/mlx/backend/metal/kernels/fft/radix.h diff --git a/Source/Cxxmlx/include/mlx/backend/metal/kernels/fft/readwrite.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/fft/readwrite.h new file mode 100644 index 00000000..0dc62992 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/metal/kernels/fft/readwrite.h @@ -0,0 +1,624 @@ +// Copyright © 2024 Apple Inc. + +#include + +#include "mlx/backend/metal/kernels/fft/radix.h" + +/* FFT helpers for reading and writing from/to device memory. + +For many sizes, GPU FFTs are memory bandwidth bound so +read/write performance is important. + +Where possible, we read 128 bits sequentially in each thread, +coalesced with accesses from adjacent threads for optimal performance. + +We implement specialized reading/writing for: + - FFT + - RFFT + - IRFFT + +Each with support for: + - Contiguous reads + - Padded reads + - Strided reads +*/ + +#define MAX_RADIX 13 + +using namespace metal; + +template < + typename in_T, + typename out_T, + int step = 0, + bool four_step_real = false> +struct ReadWriter { + const device in_T* in; + threadgroup float2* buf; + device out_T* out; + int n; + int batch_size; + int elems_per_thread; + uint3 elem; + uint3 grid; + int threads_per_tg; + bool inv; + + // Used for strided access + int strided_device_idx = 0; + int strided_shared_idx = 0; + + METAL_FUNC ReadWriter( + const device in_T* in_, + threadgroup float2* buf_, + device out_T* out_, + const short n_, + const int batch_size_, + const short elems_per_thread_, + const uint3 elem_, + const uint3 grid_, + const bool inv_) + : in(in_), + buf(buf_), + out(out_), + n(n_), + batch_size(batch_size_), + elems_per_thread(elems_per_thread_), + elem(elem_), + grid(grid_), + inv(inv_) { + // Account for padding on last threadgroup + threads_per_tg = elem.x == grid.x - 1 + ? (batch_size - (grid.x - 1) * grid.y) * grid.z + : grid.y * grid.z; + } + + // ifft(x) = 1/n * conj(fft(conj(x))) + METAL_FUNC float2 post_in(float2 elem) const { + return inv ? float2(elem.x, -elem.y) : elem; + } + + // Handle float case for generic RFFT alg + METAL_FUNC float2 post_in(float elem) const { + return float2(elem, 0); + } + + METAL_FUNC float2 pre_out(float2 elem) const { + return inv ? float2(elem.x / n, -elem.y / n) : elem; + } + + METAL_FUNC float2 pre_out(float2 elem, int length) const { + return inv ? float2(elem.x / length, -elem.y / length) : elem; + } + + METAL_FUNC bool out_of_bounds() const { + // Account for possible extra threadgroups + int grid_index = elem.x * grid.y + elem.y; + return grid_index >= batch_size; + } + + METAL_FUNC void load() const { + size_t batch_idx = size_t(elem.x * grid.y) * n; + short tg_idx = elem.y * grid.z + elem.z; + short max_index = grid.y * n - 2; + + // 2 complex64s = 128 bits + constexpr int read_width = 2; + for (short e = 0; e < (elems_per_thread / read_width); e++) { + short index = read_width * tg_idx + read_width * threads_per_tg * e; + index = metal::min(index, max_index); + // vectorized reads + buf[index] = post_in(in[batch_idx + index]); + buf[index + 1] = post_in(in[batch_idx + index + 1]); + } + max_index += 1; + if (elems_per_thread % 2 != 0) { + short index = tg_idx + + read_width * threads_per_tg * (elems_per_thread / read_width); + index = metal::min(index, max_index); + buf[index] = post_in(in[batch_idx + index]); + } + } + + METAL_FUNC void write() const { + size_t batch_idx = size_t(elem.x * grid.y) * n; + short tg_idx = elem.y * grid.z + elem.z; + short max_index = grid.y * n - 2; + + constexpr int read_width = 2; + for (short e = 0; e < (elems_per_thread / read_width); e++) { + short index = read_width * tg_idx + read_width * threads_per_tg * e; + index = metal::min(index, max_index); + // vectorized reads + out[batch_idx + index] = pre_out(buf[index]); + out[batch_idx + index + 1] = pre_out(buf[index + 1]); + } + max_index += 1; + if (elems_per_thread % 2 != 0) { + short index = tg_idx + + read_width * threads_per_tg * (elems_per_thread / read_width); + index = metal::min(index, max_index); + out[batch_idx + index] = pre_out(buf[index]); + } + } + + // Padded IO for Bluestein's algorithm + METAL_FUNC void load_padded(int length, const device float2* w_k) const { + size_t batch_idx = size_t(elem.x * grid.y) * length + elem.y * length; + int fft_idx = elem.z; + int m = grid.z; + + threadgroup float2* seq_buf = buf + elem.y * n; + for (int e = 0; e < elems_per_thread; e++) { + int index = metal::min(fft_idx + e * m, n - 1); + if (index < length) { + float2 elem = post_in(in[batch_idx + index]); + seq_buf[index] = complex_mul(elem, w_k[index]); + } else { + seq_buf[index] = 0.0; + } + } + } + + METAL_FUNC void write_padded(int length, const device float2* w_k) const { + size_t batch_idx = size_t(elem.x * grid.y) * length + elem.y * length; + int fft_idx = elem.z; + int m = grid.z; + float2 inv_factor = {1.0f / n, -1.0f / n}; + + threadgroup float2* seq_buf = buf + elem.y * n; + for (int e = 0; e < elems_per_thread; e++) { + int index = metal::min(fft_idx + e * m, n - 1); + if (index < length) { + float2 elem = seq_buf[index + length - 1] * inv_factor; + out[batch_idx + index] = pre_out(complex_mul(elem, w_k[index]), length); + } + } + } + + // Strided IO for four step FFT + METAL_FUNC void compute_strided_indices(int stride, int overall_n) { + // Use the batch threadgroup dimension to coalesce memory accesses: + // e.g. stride = 12 + // device | shared mem + // 0 1 2 3 | 0 12 - - + // - - - - | 1 13 - - + // - - - - | 2 14 - - + // 12 13 14 15 | 3 15 - - + int coalesce_width = grid.y; + int tg_idx = elem.y * grid.z + elem.z; + int outer_batch_size = stride / coalesce_width; + + int strided_batch_idx = (elem.x % outer_batch_size) * coalesce_width + + overall_n * (elem.x / outer_batch_size); + strided_device_idx = strided_batch_idx + + tg_idx / coalesce_width * elems_per_thread * stride + + tg_idx % coalesce_width; + strided_shared_idx = (tg_idx % coalesce_width) * n + + tg_idx / coalesce_width * elems_per_thread; + } + + // Four Step FFT First Step + METAL_FUNC void load_strided(int stride, int overall_n) { + compute_strided_indices(stride, overall_n); + for (int e = 0; e < elems_per_thread; e++) { + buf[strided_shared_idx + e] = + post_in(in[strided_device_idx + e * stride]); + } + } + + METAL_FUNC void write_strided(int stride, int overall_n) { + for (int e = 0; e < elems_per_thread; e++) { + float2 output = buf[strided_shared_idx + e]; + int combined_idx = (strided_device_idx + e * stride) % overall_n; + int ij = (combined_idx / stride) * (combined_idx % stride); + // Apply four step twiddles at end of first step + float2 twiddle = get_twiddle(ij, overall_n); + out[strided_device_idx + e * stride] = complex_mul(output, twiddle); + } + } +}; + +// Four Step FFT Second Step +template <> +METAL_FUNC void ReadWriter::load_strided( + int stride, + int overall_n) { + // Silence compiler warnings + (void)stride; + (void)overall_n; + // Don't invert between steps + bool default_inv = inv; + inv = false; + load(); + inv = default_inv; +} + +template <> +METAL_FUNC void ReadWriter::write_strided( + int stride, + int overall_n) { + compute_strided_indices(stride, overall_n); + for (int e = 0; e < elems_per_thread; e++) { + float2 output = buf[strided_shared_idx + e]; + out[strided_device_idx + e * stride] = pre_out(output, overall_n); + } +} + +// For RFFT, we interleave batches of two real sequences into one complex one: +// +// z_k = x_k + j.y_k +// X_k = (Z_k + Z_(N-k)*) / 2 +// Y_k = -j * ((Z_k - Z_(N-k)*) / 2) +// +// This roughly doubles the throughput over the regular FFT. +template <> +METAL_FUNC bool ReadWriter::out_of_bounds() const { + int grid_index = elem.x * grid.y + elem.y; + // We pack two sequences into one for RFFTs + return grid_index * 2 >= batch_size; +} + +template <> +METAL_FUNC void ReadWriter::load() const { + size_t batch_idx = size_t(elem.x * grid.y) * n * 2 + elem.y * n * 2; + threadgroup float2* seq_buf = buf + elem.y * n; + + // No out of bounds accesses on odd batch sizes + int grid_index = elem.x * grid.y + elem.y; + short next_in = + batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : n; + + short m = grid.z; + short fft_idx = elem.z; + + for (int e = 0; e < elems_per_thread; e++) { + int index = metal::min(fft_idx + e * m, n - 1); + seq_buf[index].x = in[batch_idx + index]; + seq_buf[index].y = in[batch_idx + index + next_in]; + } +} + +template <> +METAL_FUNC void ReadWriter::write() const { + short n_over_2 = (n / 2) + 1; + + size_t batch_idx = + size_t(elem.x * grid.y) * n_over_2 * 2 + elem.y * n_over_2 * 2; + threadgroup float2* seq_buf = buf + elem.y * n; + + int grid_index = elem.x * grid.y + elem.y; + short next_out = + batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : n_over_2; + + float2 conj = {1, -1}; + float2 minus_j = {0, -1}; + + short m = grid.z; + short fft_idx = elem.z; + + for (int e = 0; e < elems_per_thread / 2 + 1; e++) { + int index = metal::min(fft_idx + e * m, n_over_2 - 1); + // x_0 = z_0.real + // y_0 = z_0.imag + if (index == 0) { + out[batch_idx + index] = {seq_buf[index].x, 0}; + out[batch_idx + index + next_out] = {seq_buf[index].y, 0}; + } else { + float2 x_k = seq_buf[index]; + float2 x_n_minus_k = seq_buf[n - index] * conj; + out[batch_idx + index] = (x_k + x_n_minus_k) / 2; + out[batch_idx + index + next_out] = + complex_mul(((x_k - x_n_minus_k) / 2), minus_j); + } + } +} + +template <> +METAL_FUNC void ReadWriter::load_padded( + int length, + const device float2* w_k) const { + size_t batch_idx = size_t(elem.x * grid.y) * length * 2 + elem.y * length * 2; + threadgroup float2* seq_buf = buf + elem.y * n; + + // No out of bounds accesses on odd batch sizes + int grid_index = elem.x * grid.y + elem.y; + short next_in = + batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : length; + + short m = grid.z; + short fft_idx = elem.z; + + for (int e = 0; e < elems_per_thread; e++) { + int index = metal::min(fft_idx + e * m, n - 1); + if (index < length) { + float2 elem = + float2(in[batch_idx + index], in[batch_idx + index + next_in]); + seq_buf[index] = complex_mul(elem, w_k[index]); + } else { + seq_buf[index] = 0; + } + } +} + +template <> +METAL_FUNC void ReadWriter::write_padded( + int length, + const device float2* w_k) const { + int length_over_2 = (length / 2) + 1; + size_t batch_idx = + size_t(elem.x * grid.y) * length_over_2 * 2 + elem.y * length_over_2 * 2; + threadgroup float2* seq_buf = buf + elem.y * n + length - 1; + + int grid_index = elem.x * grid.y + elem.y; + short next_out = batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 + ? 0 + : length_over_2; + + float2 conj = {1, -1}; + float2 inv_factor = {1.0f / n, -1.0f / n}; + float2 minus_j = {0, -1}; + + short m = grid.z; + short fft_idx = elem.z; + + for (int e = 0; e < elems_per_thread / 2 + 1; e++) { + int index = metal::min(fft_idx + e * m, length_over_2 - 1); + // x_0 = z_0.real + // y_0 = z_0.imag + if (index == 0) { + float2 elem = complex_mul(w_k[index], seq_buf[index] * inv_factor); + out[batch_idx + index] = float2(elem.x, 0); + out[batch_idx + index + next_out] = float2(elem.y, 0); + } else { + float2 x_k = complex_mul(w_k[index], seq_buf[index] * inv_factor); + float2 x_n_minus_k = complex_mul( + w_k[length - index], seq_buf[length - index] * inv_factor); + x_n_minus_k *= conj; + // w_k should happen before this extraction + out[batch_idx + index] = (x_k + x_n_minus_k) / 2; + out[batch_idx + index + next_out] = + complex_mul(((x_k - x_n_minus_k) / 2), minus_j); + } + } +} + +// For IRFFT, we do the opposite +// +// Z_k = X_k + j.Y_k +// x_k = Re(Z_k) +// Y_k = Imag(Z_k) +template <> +METAL_FUNC bool ReadWriter::out_of_bounds() const { + int grid_index = elem.x * grid.y + elem.y; + // We pack two sequences into one for IRFFTs + return grid_index * 2 >= batch_size; +} + +template <> +METAL_FUNC void ReadWriter::load() const { + short n_over_2 = (n / 2) + 1; + size_t batch_idx = + size_t(elem.x * grid.y) * n_over_2 * 2 + elem.y * n_over_2 * 2; + threadgroup float2* seq_buf = buf + elem.y * n; + + // No out of bounds accesses on odd batch sizes + int grid_index = elem.x * grid.y + elem.y; + short next_in = + batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : n_over_2; + + short m = grid.z; + short fft_idx = elem.z; + + float2 conj = {1, -1}; + float2 plus_j = {0, 1}; + + for (int t = 0; t < elems_per_thread / 2 + 1; t++) { + int index = metal::min(fft_idx + t * m, n_over_2 - 1); + float2 x = in[batch_idx + index]; + float2 y = in[batch_idx + index + next_in]; + // NumPy forces first input to be real + bool first_val = index == 0; + // NumPy forces last input on even irffts to be real + bool last_val = n % 2 == 0 && index == n_over_2 - 1; + if (first_val || last_val) { + x = float2(x.x, 0); + y = float2(y.x, 0); + } + seq_buf[index] = x + complex_mul(y, plus_j); + seq_buf[index].y = -seq_buf[index].y; + if (index > 0 && !last_val) { + seq_buf[n - index] = (x * conj) + complex_mul(y * conj, plus_j); + seq_buf[n - index].y = -seq_buf[n - index].y; + } + } +} + +template <> +METAL_FUNC void ReadWriter::write() const { + int batch_idx = elem.x * grid.y * n * 2 + elem.y * n * 2; + threadgroup float2* seq_buf = buf + elem.y * n; + + int grid_index = elem.x * grid.y + elem.y; + short next_out = + batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : n; + + short m = grid.z; + short fft_idx = elem.z; + + for (int e = 0; e < elems_per_thread; e++) { + int index = metal::min(fft_idx + e * m, n - 1); + out[batch_idx + index] = seq_buf[index].x / n; + out[batch_idx + index + next_out] = seq_buf[index].y / -n; + } +} + +template <> +METAL_FUNC void ReadWriter::load_padded( + int length, + const device float2* w_k) const { + int n_over_2 = (n / 2) + 1; + int length_over_2 = (length / 2) + 1; + + size_t batch_idx = + size_t(elem.x * grid.y) * length_over_2 * 2 + elem.y * length_over_2 * 2; + threadgroup float2* seq_buf = buf + elem.y * n; + + // No out of bounds accesses on odd batch sizes + int grid_index = elem.x * grid.y + elem.y; + short next_in = batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 + ? 0 + : length_over_2; + + short m = grid.z; + short fft_idx = elem.z; + + float2 conj = {1, -1}; + float2 plus_j = {0, 1}; + + for (int t = 0; t < elems_per_thread / 2 + 1; t++) { + int index = metal::min(fft_idx + t * m, n_over_2 - 1); + float2 x = in[batch_idx + index]; + float2 y = in[batch_idx + index + next_in]; + if (index < length_over_2) { + bool last_val = length % 2 == 0 && index == length_over_2 - 1; + if (last_val) { + x = float2(x.x, 0); + y = float2(y.x, 0); + } + float2 elem1 = x + complex_mul(y, plus_j); + seq_buf[index] = complex_mul(elem1 * conj, w_k[index]); + if (index > 0 && !last_val) { + float2 elem2 = (x * conj) + complex_mul(y * conj, plus_j); + seq_buf[length - index] = + complex_mul(elem2 * conj, w_k[length - index]); + } + } else { + short pad_index = metal::min(length + (index - length_over_2) * 2, n - 2); + seq_buf[pad_index] = 0; + seq_buf[pad_index + 1] = 0; + } + } +} + +template <> +METAL_FUNC void ReadWriter::write_padded( + int length, + const device float2* w_k) const { + size_t batch_idx = size_t(elem.x * grid.y) * length * 2 + elem.y * length * 2; + threadgroup float2* seq_buf = buf + elem.y * n + length - 1; + + int grid_index = elem.x * grid.y + elem.y; + short next_out = + batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : length; + + short m = grid.z; + short fft_idx = elem.z; + + float2 inv_factor = {1.0f / n, -1.0f / n}; + for (int e = 0; e < elems_per_thread; e++) { + int index = fft_idx + e * m; + if (index < length) { + float2 output = complex_mul(seq_buf[index] * inv_factor, w_k[index]); + out[batch_idx + index] = output.x / length; + out[batch_idx + index + next_out] = output.y / -length; + } + } +} + +// Four Step RFFT +template <> +METAL_FUNC void +ReadWriter::load_strided( + int stride, + int overall_n) { + // Silence compiler warnings + (void)stride; + (void)overall_n; + // Don't invert between steps + bool default_inv = inv; + inv = false; + load(); + inv = default_inv; +} + +template <> +METAL_FUNC void +ReadWriter::write_strided( + int stride, + int overall_n) { + int overall_n_over_2 = overall_n / 2 + 1; + int coalesce_width = grid.y; + int tg_idx = elem.y * grid.z + elem.z; + int outer_batch_size = stride / coalesce_width; + + int strided_batch_idx = (elem.x % outer_batch_size) * coalesce_width + + overall_n_over_2 * (elem.x / outer_batch_size); + strided_device_idx = strided_batch_idx + + tg_idx / coalesce_width * elems_per_thread / 2 * stride + + tg_idx % coalesce_width; + strided_shared_idx = (tg_idx % coalesce_width) * n + + tg_idx / coalesce_width * elems_per_thread / 2; + for (int e = 0; e < elems_per_thread / 2; e++) { + float2 output = buf[strided_shared_idx + e]; + out[strided_device_idx + e * stride] = output; + } + + // Add on n/2 + 1 element + if (tg_idx == 0 && elem.x % outer_batch_size == 0) { + out[strided_batch_idx + overall_n / 2] = buf[n / 2]; + } +} + +// Four Step IRFFT +template <> +METAL_FUNC void +ReadWriter::load_strided( + int stride, + int overall_n) { + int overall_n_over_2 = overall_n / 2 + 1; + auto conj = float2(1, -1); + + compute_strided_indices(stride, overall_n); + // Translate indices in terms of N - k + for (int e = 0; e < elems_per_thread; e++) { + int device_idx = strided_device_idx + e * stride; + int overall_batch = device_idx / overall_n; + int overall_index = device_idx % overall_n; + if (overall_index < overall_n_over_2) { + device_idx -= overall_batch * (overall_n - overall_n_over_2); + buf[strided_shared_idx + e] = in[device_idx] * conj; + } else { + int conj_idx = overall_n - overall_index; + device_idx = overall_batch * overall_n_over_2 + conj_idx; + buf[strided_shared_idx + e] = in[device_idx]; + } + } +} + +template <> +METAL_FUNC void +ReadWriter::load_strided( + int stride, + int overall_n) { + // Silence compiler warnings + (void)stride; + (void)overall_n; + bool default_inv = inv; + inv = false; + load(); + inv = default_inv; +} + +template <> +METAL_FUNC void +ReadWriter::write_strided( + int stride, + int overall_n) { + compute_strided_indices(stride, overall_n); + + for (int e = 0; e < elems_per_thread; e++) { + out[strided_device_idx + e * stride] = + pre_out(buf[strided_shared_idx + e], overall_n).x; + } +} diff --git a/Source/Cmlx/mlx-generated/metal/fp4.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/fp4.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/fp4.h rename to Source/Cxxmlx/include/mlx/backend/metal/kernels/fp4.h diff --git a/Source/Cmlx/mlx-generated/metal/fp8.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/fp8.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/fp8.h rename to Source/Cxxmlx/include/mlx/backend/metal/kernels/fp8.h diff --git a/Source/Cxxmlx/include/mlx/backend/metal/kernels/fp_quantized.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/fp_quantized.h new file mode 100644 index 00000000..5c5e4b2e --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/metal/kernels/fp_quantized.h @@ -0,0 +1,1850 @@ +// Copyright © 2025 Apple Inc. + +#include +#include + +#include "mlx/backend/metal/kernels/fp4.h" +#include "mlx/backend/metal/kernels/fp8.h" + +constant bool align_M [[function_constant(200)]]; +constant bool align_N [[function_constant(201)]]; +constant bool align_K [[function_constant(202)]]; + +using namespace metal; + +#define MLX_MTL_CONST static constant constexpr const + +MLX_MTL_CONST int SIMD_SIZE = 32; +MLX_MTL_CONST int QUAD_SIZE = 4; + +template +inline constexpr short get_pack_factor() { + return wsize / bits; +} + +template +inline constexpr short get_bytes_per_pack() { + return wsize / 8; +} + +template +static inline T dequantize_scale(uint8_t s) { + if constexpr (group_size == 16) { + // Use nv scale + return T(*(thread fp8_e4m3*)(&s)); + } else { + return T(*(thread fp8_e8m0*)(&s)); + } +} + +template +struct Quantize { + uint8_t operator()(float x) { + if (bits == 8) { + return fp8_e4m3(x).bits; + } else { + return fp4_e2m1(x).bits; + } + } +}; + +template +struct Dequantize { + U operator()(uint8_t x) { + if constexpr (bits == 8) { + return U(*(thread fp8_e4m3*)(&x)); + } else { + return U(*(thread fp4_e2m1*)(&x)); + } + } +}; + +template +inline void load_vector(const device T* x, thread U* x_thread) { +#pragma unroll + for (int i = 0; i < values_per_thread; i++) { + x_thread[i] = x[i]; + } +} + +template +inline void load_vector_safe(const device T* x, thread U* x_thread, int N) { + for (int i = 0; i < N; i++) { + x_thread[i] = x[i]; + } + + for (int i = N; i < values_per_thread; i++) { + x_thread[i] = 0; + } +} + +template +inline U qdot(const device uint8_t* w, const thread U* x_thread, U scale) { + U accum = 0; + if constexpr (bits == 4) { + const device uint16_t* ws = (const device uint16_t*)w; + for (int i = 0; i < (values_per_thread / 4); i++) { + accum += + (x_thread[4 * i] * Dequantize<4>{}(ws[i]) + + x_thread[4 * i + 1] * Dequantize<4>{}(ws[i] >> 4) + + x_thread[4 * i + 2] * Dequantize<4>{}(ws[i] >> 8) + + x_thread[4 * i + 3] * Dequantize<4>{}(ws[i] >> 12)); + } + } else { + for (int i = 0; i < values_per_thread; i++) { + accum += x_thread[i] * Dequantize<8>{}(w[i]); + } + } + + return scale * accum; +} + +template +inline U +qdot_safe(const device uint8_t* w, const thread U* x_thread, U scale, int N) { + U accum = 0; + + if constexpr (bits == 4) { + const device uint16_t* ws = (const device uint16_t*)w; + for (int i = 0; i < (N / 4); i++) { + accum += + (x_thread[4 * i] * Dequantize<4>{}(ws[i]) + + x_thread[4 * i + 1] * Dequantize<4>{}(ws[i] >> 4) + + x_thread[4 * i + 2] * Dequantize<4>{}(ws[i] >> 8) + + x_thread[4 * i + 3] * Dequantize<4>{}(ws[i] >> 12)); + } + } else { + for (int i = 0; i < N; i++) { + accum += x_thread[i] * Dequantize<8>{}(w[i]); + } + } + return scale * accum; +} + +template +inline void qouter(const thread uint8_t* w, U x, U scale, thread U* result) { + if constexpr (bits == 4) { + for (int i = 0; i < (values_per_thread / 2); i++) { + result[2 * i] += x * scale * Dequantize<4>{}(w[i]); + result[2 * i + 1] += x * scale * Dequantize<4>{}(w[i] >> 4); + } + } else { + for (int i = 0; i < values_per_thread; i++) { + result[i] += x * scale * Dequantize<8>{}(w[i]); + } + } +} + +template +inline void dequantize(uint8_t w, U scale, threadgroup U* w_local) { + if constexpr (bits == 4) { + w_local[0] = scale * Dequantize<4, U>{}(w); + w_local[1] = scale * Dequantize<4, U>{}(w >> 4); + } else { + w_local[0] = scale * Dequantize<8, U>{}(w); + } +} + +template < + typename T, + short BROWS, + short BCOLS, + short dst_ld, + short reduction_dim, + short tgp_size, + short group_size, + short bits> +struct QuantizedBlockLoader { + MLX_MTL_CONST short pack_factor = get_pack_factor<8, bits>(); + MLX_MTL_CONST short bytes_per_pack = get_bytes_per_pack(); + MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor; + MLX_MTL_CONST short n_reads = + (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size; + MLX_MTL_CONST short group_steps = group_size < BCOLS ? 1 : group_size / BCOLS; + MLX_MTL_CONST short scale_step = group_size < BCOLS ? BCOLS / group_size : 1; + + static_assert( + (n_reads * pack_factor) <= group_size, + "The number of reads per thread must be less than the group size."); + + const int src_ld; + const int tile_stride; + short group_step_cnt; + const int group_stride; + + const short thread_idx; + const short bi; + const short bj; + + threadgroup T* dst; + const device uint8_t* src; + const device uint8_t* scales; + + QuantizedBlockLoader( + const device uint8_t* src_, + const device uint8_t* scales_, + const int src_ld_, + threadgroup T* dst_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(src_ld_), + tile_stride( + reduction_dim ? BCOLS_PACKED * bytes_per_pack + : BROWS * src_ld * bytes_per_pack / pack_factor), + group_step_cnt(0), + group_stride(BROWS * src_ld / group_size), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(n_reads * thread_idx / BCOLS_PACKED), + bj((n_reads * thread_idx) % BCOLS_PACKED), + dst(dst_ + bi * dst_ld + bj * pack_factor), + src(src_ + bi * src_ld * bytes_per_pack / pack_factor + + bj * bytes_per_pack), + scales( + scales_ + bi * src_ld / group_size + + (bj * pack_factor) / group_size) {} + + void load_unsafe() const { + if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { + return; + } + + T scale = dequantize_scale(*scales); + for (int i = 0; i < n_reads; i++) { + dequantize( + src[i * bytes_per_pack], scale, dst + i * pack_factor); + } + } + + void load_safe(short2 src_tile_dim) const { + if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { + return; + } + + if (reduction_dim == 1 && bi >= src_tile_dim.x) { + for (int i = 0; i < n_reads * pack_factor; i++) { + dst[i] = T(0); + } + return; + } + + if (reduction_dim == 0 && bi >= src_tile_dim.y) { + for (int i = 0; i < n_reads * pack_factor; i++) { + dst[i] = T(0); + } + return; + } + + T scale = dequantize_scale(*scales); + for (int i = 0; i < n_reads; i++) { + dequantize( + src[i * bytes_per_pack], scale, dst + i * pack_factor); + } + } + + void next() { + src += tile_stride; + if (reduction_dim == 1) { + if (group_steps > 1) { + group_step_cnt++; + if (group_step_cnt == group_steps) { + group_step_cnt = 0; + scales++; + } + } else { + scales += scale_step; + } + } else { + scales += group_stride; + } + } +}; + +template +METAL_FUNC void fp_qmv_quad_impl( + const device uint32_t* w, + const device uint8_t* scales, + const device T* x, + device T* y, + constant int& in_vec_size, + const constant int& out_vec_size, + uint3 tid [[threadgroup_position_in_grid]], + uint quad_gid [[quadgroup_index_in_threadgroup]], + uint quad_lid [[thread_index_in_quadgroup]]) { + constexpr int quads_per_simd = SIMD_SIZE / QUAD_SIZE; + constexpr int pack_factor = get_pack_factor<32, bits>(); + constexpr int values_per_thread = D / QUAD_SIZE; + constexpr int steps_per_thread = + values_per_thread < group_size ? 1 : values_per_thread / group_size; + constexpr int values_per_step = values_per_thread / steps_per_thread; + constexpr int packs_per_thread = values_per_thread / pack_factor; + constexpr int packs_per_step = values_per_step / pack_factor; + constexpr int results_per_quadgroup = 8; + + typedef float U; + + thread U x_thread[values_per_thread]; + thread U result[results_per_quadgroup] = {0}; + + // Adjust positions + const int in_vec_size_w = in_vec_size / pack_factor; + const int in_vec_size_g = in_vec_size / group_size; + const int out_row = tid.y * quads_per_simd * results_per_quadgroup + quad_gid; + + w += out_row * in_vec_size_w + quad_lid * packs_per_thread; + scales += + out_row * in_vec_size_g + (quad_lid * values_per_thread) / group_size; + x += tid.x * in_vec_size + quad_lid * values_per_thread; + y += tid.x * out_vec_size + out_row; + + load_vector(x, x_thread); + + for (int row = 0; row < results_per_quadgroup; row++) { + auto wl = (const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd); + const device uint8_t* sl = scales + row * in_vec_size_g * quads_per_simd; +#pragma unroll + for (int k = 0; k < steps_per_thread; ++k) { + U s = dequantize_scale(sl[0]); + if (row * quads_per_simd + out_row < out_vec_size) { + result[row] += qdot( + wl, x_thread + k * values_per_step, s); + } + sl++; + wl += (sizeof(uint32_t) / sizeof(uint8_t)) * packs_per_step; + } + } + + for (int row = 0; row < results_per_quadgroup; row++) { + result[row] = quad_sum(result[row]); + if (quad_lid == 0 && row * quads_per_simd + out_row < out_vec_size) { + y[row * quads_per_simd] = static_cast(result[row]); + } + } +} + +template +METAL_FUNC void fp_qmv_fast_impl( + const device uint32_t* w, + const device uint8_t* scales, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + constexpr int packs_per_thread = 2; + constexpr int num_simdgroups = 2; + constexpr int results_per_simdgroup = 4; + constexpr int pack_factor = get_pack_factor<32, bits>(); + constexpr int bytes_per_pack = get_bytes_per_pack<32>(); + constexpr int values_per_thread = pack_factor * packs_per_thread; + constexpr int block_size = values_per_thread * SIMD_SIZE; + constexpr int scale_step_per_thread = group_size / values_per_thread; + + const device uint8_t* ws = (const device uint8_t*)w; + + typedef float U; + thread U x_thread[values_per_thread]; + thread U result[results_per_simdgroup] = {0}; + + // Adjust positions + const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; + const int in_vec_size_g = in_vec_size / group_size; + const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + + simd_gid * results_per_simdgroup; + + ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack; + scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; + x += tid.x * in_vec_size + simd_lid * values_per_thread; + y += tid.x * out_vec_size + out_row; + + for (int k = 0; k < in_vec_size; k += block_size) { + load_vector(x, x_thread); + + for (int row = 0; row < results_per_simdgroup; row++) { + auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); + const device auto* sl = scales + row * in_vec_size_g; + + U s = dequantize_scale(sl[0]); + result[row] += qdot(wl, x_thread, s); + } + + ws += block_size * bytes_per_pack / pack_factor; + scales += block_size / group_size; + x += block_size; + } + + for (int row = 0; row < results_per_simdgroup; row++) { + result[row] = simd_sum(result[row]); + if (simd_lid == 0) { + y[row] = static_cast(result[row]); + } + } +} + +template +METAL_FUNC void fp_qmv_impl( + const device uint32_t* w, + const device uint8_t* scales, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + constexpr int num_simdgroups = 2; + constexpr int results_per_simdgroup = 4; + constexpr int packs_per_thread = 1; + constexpr int pack_factor = get_pack_factor<32, bits>(); + constexpr int bytes_per_pack = get_bytes_per_pack<32>(); + + constexpr int values_per_thread = pack_factor * packs_per_thread; + constexpr int block_size = values_per_thread * SIMD_SIZE; + constexpr int scale_step_per_thread = group_size / values_per_thread; + + const device uint8_t* ws = (const device uint8_t*)w; + + typedef float U; + + thread U x_thread[values_per_thread]; + thread U result[results_per_simdgroup] = {0}; + + // Adjust positions + const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; + const int in_vec_size_g = in_vec_size / group_size; + const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + + simd_gid * results_per_simdgroup; + const int used_out_row = min(out_vec_size - results_per_simdgroup, out_row); + + if (out_row >= out_vec_size) { + return; + } + + // In this case we need to properly guard all our reads because there isn't + // even 1 tile in the matrix + if (out_vec_size < (num_simdgroups * results_per_simdgroup)) { + ws += + out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack; + scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; + x += tid.x * in_vec_size + simd_lid * values_per_thread; + y += tid.x * out_vec_size + out_row; + + int k = 0; + for (; k < in_vec_size - block_size; k += block_size) { + load_vector(x, x_thread); + + for (int row = 0; + row < results_per_simdgroup && out_row + row < out_vec_size; + row++) { + auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); + const device auto* sl = scales + row * in_vec_size_g; + + uint8_t s = sl[0]; + result[row] += qdot(wl, x_thread, s); + } + + ws += block_size * bytes_per_pack / pack_factor; + scales += block_size / group_size; + x += block_size; + } + const int remaining = clamp( + static_cast(in_vec_size - k - simd_lid * values_per_thread), + 0, + values_per_thread); + if (remaining > 0) { + load_vector_safe(x, x_thread, remaining); + + for (int row = 0; + row < results_per_simdgroup && out_row + row < out_vec_size; + row++) { + auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); + const device auto* sl = scales + row * in_vec_size_g; + + U s = dequantize_scale(sl[0]); + result[row] += qdot(wl, x_thread, s); + } + } + + for (int row = 0; + row < results_per_simdgroup && out_row + row < out_vec_size; + row++) { + result[row] = simd_sum(result[row]); + if (simd_lid == 0) { + y[row] = static_cast(result[row]); + } + } + } + + // In this case the last tile is moved back to redo some output values + else { + ws += used_out_row * in_vec_size_w + + simd_lid * packs_per_thread * bytes_per_pack; + scales += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread; + x += tid.x * in_vec_size + simd_lid * values_per_thread; + y += tid.x * out_vec_size + used_out_row; + + int k = 0; + for (; k < in_vec_size - block_size; k += block_size) { + load_vector(x, x_thread); + + for (int row = 0; row < results_per_simdgroup; row++) { + auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); + const device auto* sl = scales + row * in_vec_size_g; + + U s = dequantize_scale(sl[0]); + result[row] += qdot(wl, x_thread, s); + } + + ws += block_size * bytes_per_pack / pack_factor; + scales += block_size / group_size; + x += block_size; + } + const int remaining = clamp( + static_cast(in_vec_size - k - simd_lid * values_per_thread), + 0, + values_per_thread); + if (remaining > 0) { + load_vector_safe(x, x_thread, remaining); + + for (int row = 0; row < results_per_simdgroup; row++) { + auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); + const device auto* sl = scales + row * in_vec_size_g; + + U s = dequantize_scale(sl[0]); + result[row] += + qdot_safe(wl, x_thread, s, remaining); + } + } + for (int row = 0; row < results_per_simdgroup; row++) { + result[row] = simd_sum(result[row]); + if (simd_lid == 0) { + y[row] = static_cast(result[row]); + } + } + } +} + +template +METAL_FUNC void fp_qvm_impl( + const device uint32_t* w, + const device uint8_t* scales, + const device T* x, + device T* y, + const int in_vec_size, + const int out_vec_size, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + constexpr int num_simdgroups = 2; + constexpr int pack_factor = get_pack_factor<32, bits>(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + + constexpr int tn = group_size / pack_factor; + constexpr int block_size = SIMD_SIZE; + + using W_T = uint32_t; + const device W_T* ws = (const device W_T*)w; + + typedef float U; + typedef struct { + W_T wi[tn * bytes_per_pack]; + } vec_w; + + thread vec_w w_local; + thread U result[tn * pack_factor] = {0}; + thread U scale = 0; + thread U x_local = 0; + + // Adjust positions + const int out_vec_size_w = out_vec_size * bytes_per_pack / pack_factor; + const int out_vec_size_g = out_vec_size / group_size; + // 32 * (tid.y * 2 + simd_gid) + int out_col = pack_factor * tn * (tid.y * num_simdgroups + simd_gid); + ws += out_col * bytes_per_pack / pack_factor + simd_lid * out_vec_size_w; + scales += out_col / group_size + simd_lid * out_vec_size_g; + x += tid.x * in_vec_size + simd_lid; + y += tid.x * out_vec_size + out_col; + + if (out_col >= out_vec_size) { + return; + } + + // Loop over in_vec in blocks of block_size + int remaining = in_vec_size % block_size; + if (remaining == 0) { + for (int i = 0; i < in_vec_size; i += block_size) { + x_local = *x; + scale = dequantize_scale(*scales); + w_local = *((device vec_w*)ws); + qouter( + (thread uint8_t*)&w_local, x_local, scale, result); + + x += block_size; + scales += block_size * out_vec_size_g; + ws += block_size * out_vec_size_w; + } + } else { + for (int i = block_size; i < in_vec_size; i += block_size) { + x_local = *x; + scale = dequantize_scale(*scales); + w_local = *((device vec_w*)ws); + + qouter( + (thread uint8_t*)&w_local, x_local, scale, result); + + x += block_size; + scales += block_size * out_vec_size_g; + ws += block_size * out_vec_size_w; + } + if (static_cast(simd_lid) < remaining) { + x_local = *x; + scale = dequantize_scale(*scales); + w_local = *((device vec_w*)ws); + } else { + x_local = 0; + scale = 0; + } + qouter( + (thread uint8_t*)&w_local, x_local, scale, result); + } + +// Accumulate in the simdgroup +#pragma clang loop unroll(full) + for (int k = 0; k < tn * pack_factor; k++) { + result[k] = simd_sum(result[k]); + } + + // Store the result + if (simd_lid == 0) { +#pragma clang loop unroll(full) + for (int k = 0; k < tn * pack_factor; k++) { + y[k] = static_cast(result[k]); + } + } +} + +template < + typename T, + const int group_size, + const int bits, + const bool aligned_N, + const int BM = 32, + const int BK = 32, + const int BN = 32> +METAL_FUNC void fp_qmm_t_impl( + const device uint32_t* w, + const device uint8_t* scales, + const device T* x, + device T* y, + threadgroup T* Xs, + threadgroup T* Ws, + const constant int& K, + const constant int& N, + const constant int& M, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); + static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); + + (void)lid; + + constexpr int WM = 2; + constexpr int WN = 2; + constexpr int pack_factor = get_pack_factor<8, bits>(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + + // Instantiate the appropriate BlockMMA and Loader + using mma_t = mlx::steel:: + BlockMMA; + using loader_x_t = + mlx::steel::BlockLoader; + using loader_w_t = QuantizedBlockLoader< + T, + BN, + BK, + BK_padded, + 1, + WM * WN * SIMD_SIZE, + group_size, + bits>; + + // Set the block + const int K_w = K * bytes_per_pack / pack_factor; + const int K_g = K / group_size; + const int y_row = tid.y * BM; + const int y_col = tid.x * BN; + + auto wl = (const device uint8_t*)w; + + x += y_row * static_cast(K); + wl += y_col * K_w; + scales += y_col * K_g; + y += y_row * static_cast(N) + y_col; + + // Make the x loader and mma operation + const short num_els = min(BM, M - y_row); + const short num_outs = min(BN, N - y_col); + loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid); + loader_w_t loader_w(wl, scales, K, Ws, simd_gid, simd_lid); + mma_t mma_op(simd_gid, simd_lid); + + if (num_els < BM) { + if (!aligned_N && num_outs < BN) { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(BK, num_els)); + loader_w.load_safe(short2(BK, num_outs)); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } else { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(BK, num_els)); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } + } else { + if (!aligned_N && num_outs < BN) { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_unsafe(); + loader_w.load_safe(short2(BK, num_outs)); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } else { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_unsafe(); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } + } + + // Store results to device memory + threadgroup_barrier(mem_flags::mem_threadgroup); + if (num_els < BM || num_outs < BN) { + mma_op.store_result_safe(y, N, short2(num_outs, num_els)); + } else { + mma_op.store_result(y, N); + } +} + +template < + typename T, + const int group_size, + const int bits, + const int BM = 32, + const int BK = 32, + const int BN = 32> +METAL_FUNC void fp_qmm_n_impl( + const device uint32_t* w, + const device uint8_t* scales, + const device T* x, + device T* y, + threadgroup T* Xs, + threadgroup T* Ws, + const constant int& K, + const constant int& N, + const constant int& M, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); + static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); + + (void)lid; + + constexpr int WM = 2; + constexpr int WN = 2; + constexpr int pack_factor = get_pack_factor<8, bits>(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + + // Instantiate the appropriate BlockMMA and Loader + using mma_t = mlx::steel:: + BlockMMA; + using loader_x_t = mlx::steel:: + BlockLoader; + using loader_w_t = QuantizedBlockLoader< + T, + BK, + BN, + BN_padded, + 0, + WM * WN * SIMD_SIZE, + group_size, + bits>; + + auto wl = (const device uint8_t*)w; + + // Set the block + const int y_row = tid.y * BM; + const int y_col = tid.x * BN; + x += y_row * static_cast(K); + wl += y_col * bytes_per_pack / pack_factor; + scales += y_col / group_size; + y += y_row * static_cast(N) + y_col; + + // Make the x loader and mma operation + const short num_els = min(BM, M - y_row); + loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid); + loader_w_t loader_w(wl, scales, N, Ws, simd_gid, simd_lid); + mma_t mma_op(simd_gid, simd_lid); + + if (num_els < BM) { + if ((K % BK) != 0) { + const int k_blocks = K / BK; + for (int k = 0; k < k_blocks; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(BK, num_els)); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + const short num_k = K - k_blocks * BK; + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(num_k, num_els)); + loader_w.load_safe(short2(BN, num_k)); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + } else { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(BK, num_els)); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } + } else { + if ((K % BK) != 0) { + const int k_blocks = K / BK; + for (int k = 0; k < k_blocks; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_unsafe(); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + const short num_k = K - k_blocks * BK; + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(num_k, BM)); + loader_w.load_safe(short2(BN, num_k)); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + } else { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_unsafe(); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } + } + + // Store results to device memory + threadgroup_barrier(mem_flags::mem_threadgroup); + if (num_els < BM) { + mma_op.store_result_safe(y, N, short2(BN, num_els)); + } else { + mma_op.store_result(y, N); + } +} + +template +METAL_FUNC void adjust_matrix_offsets( + const device T*& x, + const device uint32_t*& w, + const device uint8_t*& scales, + device T*& y, + int output_stride, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + uint3 tid [[threadgroup_position_in_grid]]) { + // Set the input/output matrices + uint32_t x_idx = tid.z; + uint32_t w_idx = tid.z; + if (x_batch_ndims == 1) { + x += x_idx * x_strides[0]; + } else { + x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims); + } + if (w_batch_ndims == 1) { + w += w_idx * w_strides[0]; + scales += w_idx * s_strides[0]; + } else { + ulong2 idx = elem_to_loc_broadcast( + w_idx, w_shape, w_strides, s_strides, w_batch_ndims); + w += idx.x; + scales += idx.y; + } + y += tid.z * output_stride; +} + +template +METAL_FUNC void adjust_matrix_offsets( + const device T*& x, + const device uint32_t*& w, + const device uint8_t*& scales, + const device uint32_t* lhs_indices, + const device uint32_t* rhs_indices, + device T*& y, + int output_stride, + const constant int& batch_ndims, + const constant int* batch_shape, + const constant int64_t* lhs_strides, + const constant int64_t* rhs_strides, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + uint3 tid [[threadgroup_position_in_grid]]) { + // Set the input/output matrices + uint32_t x_idx; + uint32_t w_idx; + if (batch_ndims == 1) { + x_idx = lhs_indices[tid.z * lhs_strides[0]]; + w_idx = rhs_indices[tid.z * rhs_strides[0]]; + } else { + ulong2 idx = elem_to_loc_broadcast( + tid.z, batch_shape, lhs_strides, rhs_strides, batch_ndims); + x_idx = lhs_indices[idx.x]; + w_idx = rhs_indices[idx.y]; + } + if (x_batch_ndims == 1) { + x += x_idx * x_strides[0]; + } else { + x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims); + } + if (w_batch_ndims == 1) { + w += w_idx * w_strides[0]; + scales += w_idx * s_strides[0]; + } else { + ulong2 idx = elem_to_loc_broadcast( + w_idx, w_shape, w_strides, s_strides, w_batch_ndims); + w += idx.x; + scales += idx.y; + } + y += tid.z * output_stride; +} + +template +[[kernel]] void fp_qmv_quad( + const device uint32_t* w, + const device uint8_t* scales, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint quad_gid [[quadgroup_index_in_threadgroup]], + uint quad_lid [[thread_index_in_quadgroup]]) { + if (batched) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + y, + out_vec_size * M, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + } + fp_qmv_quad_impl( + w, scales, x, y, in_vec_size, out_vec_size, tid, quad_gid, quad_lid); +} + +template +[[kernel]] void fp_qmv_fast( + const device uint32_t* w, + const device uint8_t* scales, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if (batched) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + y, + out_vec_size * M, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + } + fp_qmv_fast_impl( + w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); +} + +template +[[kernel]] void fp_qmv( + const device uint32_t* w, + const device uint8_t* scales, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if (batched) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + y, + out_vec_size * M, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + } + fp_qmv_impl( + w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); +} + +template +[[kernel]] void fp_qvm( + const device uint32_t* w, + const device uint8_t* scales, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if (batched) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + y, + out_vec_size * M, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + } + fp_qvm_impl( + w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); +} + +template +[[kernel]] void fp_qvm_split_k( + const device uint32_t* w, + const device uint8_t* scales, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + const constant int& final_block_size, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + y, + out_vec_size * M, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + + // When (in_vec_size % split_k != 0) the final block needs to be smaller + int in_vec_size_adj = + tid.z % split_k == split_k - 1 ? final_block_size : in_vec_size; + + fp_qvm_impl( + w, scales, x, y, in_vec_size_adj, out_vec_size, tid, simd_gid, simd_lid); +} + +template < + typename T, + const int group_size, + const int bits, + const bool aligned_N, + const bool batched, + const int BM = 32, + const int BK = 32, + const int BN = 32> +[[kernel]] void fp_qmm_t( + const device uint32_t* w, + const device uint8_t* scales, + const device T* x, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BN * BK_padded]; + + if (batched) { + adjust_matrix_offsets( + x, + w, + scales, + y, + M * N, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + } + fp_qmm_t_impl( + w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); +} + +template < + typename T, + const int group_size, + const int bits, + const bool batched, + const int BM = 32, + const int BK = 32, + const int BN = 32> +[[kernel]] void fp_qmm_n( + const device uint32_t* w, + const device uint8_t* scales, + const device T* x, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BK * BN_padded]; + + if (batched) { + adjust_matrix_offsets( + x, + w, + scales, + y, + M * N, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + } + + fp_qmm_n_impl( + w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); +} + +template +[[kernel]] void fp_gather_qmv_fast( + const device uint32_t* w, + const device uint8_t* scales, + const device T* x, + const device uint32_t* lhs_indices, + const device uint32_t* rhs_indices, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + const constant int& batch_ndims, + const constant int* batch_shape, + const constant int64_t* lhs_strides, + const constant int64_t* rhs_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + lhs_indices, + rhs_indices, + y, + out_vec_size * M, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + fp_qmv_fast_impl( + w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); +} + +template +[[kernel]] void fp_gather_qmv( + const device uint32_t* w, + const device uint8_t* scales, + const device T* x, + const device uint32_t* lhs_indices, + const device uint32_t* rhs_indices, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + const constant int& batch_ndims, + const constant int* batch_shape, + const constant int64_t* lhs_strides, + const constant int64_t* rhs_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + lhs_indices, + rhs_indices, + y, + out_vec_size * M, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + fp_qmv_impl( + w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); +} + +template +[[kernel]] void fp_gather_qvm( + const device uint32_t* w, + const device uint8_t* scales, + const device T* x, + const device uint32_t* lhs_indices, + const device uint32_t* rhs_indices, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + const constant int& batch_ndims, + const constant int* batch_shape, + const constant int64_t* lhs_strides, + const constant int64_t* rhs_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + lhs_indices, + rhs_indices, + y, + out_vec_size * M, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + fp_qvm_impl( + w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); +} + +template < + typename T, + const int group_size, + const int bits, + const bool aligned_N, + const int BM = 32, + const int BK = 32, + const int BN = 32> +[[kernel]] void fp_gather_qmm_t( + const device uint32_t* w, + const device uint8_t* scales, + const device T* x, + const device uint32_t* lhs_indices, + const device uint32_t* rhs_indices, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + const constant int& batch_ndims, + const constant int* batch_shape, + const constant int64_t* lhs_strides, + const constant int64_t* rhs_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BN * BK_padded]; + + adjust_matrix_offsets( + x, + w, + scales, + lhs_indices, + rhs_indices, + y, + M * N, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + fp_qmm_t_impl( + w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); +} + +template < + typename T, + const int group_size, + const int bits, + const int BM = 32, + const int BK = 32, + const int BN = 32> +[[kernel]] void fp_gather_qmm_n( + const device uint32_t* w, + const device uint8_t* scales, + const device T* x, + const device uint32_t* lhs_indices, + const device uint32_t* rhs_indices, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + const constant int& batch_ndims, + const constant int* batch_shape, + const constant int64_t* lhs_strides, + const constant int64_t* rhs_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BK * BN_padded]; + + adjust_matrix_offsets( + x, + w, + scales, + lhs_indices, + rhs_indices, + y, + M * N, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + fp_qmm_n_impl( + w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); +} + +template < + typename T, + int group_size, + int bits, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose> +[[kernel]] void fp_gather_qmm_rhs( + const device T* x, + const device uint32_t* w, + const device uint8_t* scales, + const device uint32_t* indices, + device T* y, + const constant int& M, + const constant int& N, + const constant int& K, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) { + constexpr int pack_factor = get_pack_factor<8, bits>(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + + using mma_t = mlx::steel::BlockMMA< + T, + T, + BM, + BN, + BK, + WM, + WN, + false, + transpose, + BK_padded, + transpose ? BK_padded : BN_padded>; + using loader_x_t = + mlx::steel::BlockLoader; + using loader_w_t = QuantizedBlockLoader< + T, + transpose ? BN : BK, + transpose ? BK : BN, + transpose ? BK_padded : BN_padded, + transpose, + WM * WN * SIMD_SIZE, + group_size, + bits>; + + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[transpose ? BN * BK_padded : BK * BN_padded]; + + // Compute the block + const int K_w = K * bytes_per_pack / pack_factor; + const int K_g = K / group_size; + const int N_w = N * bytes_per_pack / pack_factor; + const int N_g = N / group_size; + const int K_it = K / BK; + const size_t stride_w = transpose ? N * K_w : K * N_w; + const size_t stride_s = transpose ? N * K_g : K * N_g; + const int y_row = tid.y * BM; + const int y_col = tid.x * BN; + const size_t y_row_long = size_t(y_row); + const size_t y_col_long = size_t(y_col); + + // Prepare threadgroup bounds + const short tgp_bm = align_M ? BM : short(min(BM, M - y_row)); + const short tgp_bn = align_N ? BN : short(min(BN, N - y_col)); + + // Calculate the final tiles in the case that K is not aligned + const int k_remain = K - K_it * BK; + const short2 tile_x = short2(k_remain, tgp_bm); + const short2 tile_w = + transpose ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + + // Move x and output to the correct block + auto wl = (const device uint8_t*)w; + x += y_row_long * K; + y += y_row_long * N + y_col_long; + wl += transpose ? y_col_long * K_w : y_col * bytes_per_pack / pack_factor; + scales += transpose ? y_col_long * K_g : y_col / group_size; + + // Do as many matmuls as necessary + uint32_t index; + short offset; + uint32_t index_next = indices[y_row]; + short offset_next = 0; + int n = 0; + while (n < tgp_bm) { + n++; + offset = offset_next; + index = index_next; + offset_next = tgp_bm; + for (; n < tgp_bm; n++) { + if (indices[y_row + n] != index) { + offset_next = n; + index_next = indices[y_row + n]; + break; + } + } + threadgroup_barrier(mem_flags::mem_none); + + // Prepare threadgroup mma operation + thread mma_t mma_op(simd_group_id, simd_lane_id); + + // Prepare threadgroup loading operations + thread loader_x_t loader_x(x, K, Xs, simd_group_id, simd_lane_id); + thread loader_w_t loader_w( + wl + index * stride_w, + scales + index * stride_s, + transpose ? K : N, + Ws, + simd_group_id, + simd_lane_id); + + // Matrices are all aligned check nothing + if (align_M && align_N) { + gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it); + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + gemm_loop_finalize(Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); + } + + // Store results to device memory + if (offset_next - offset == BM) { + mma_op.store_result(y, N); + } else { + mma_op.store_result_slice( + y, N, short2(0, offset), short2(BN, offset_next)); + } + } else { + // Tile aligned so check outside of the hot loop + if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) { + gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it); + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + gemm_loop_finalize( + Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); + } + + // Store results to device memory + if (offset_next - offset == BM) { + mma_op.store_result(y, N); + } else { + mma_op.store_result_slice( + y, N, short2(0, offset), short2(BN, offset_next)); + } + } + + // Tile partially aligned check rows + else if (align_N || tgp_bn == BN) { + gemm_loop_unaligned( + Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK); + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + gemm_loop_finalize( + Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); + } + mma_op.store_result_slice( + y, N, short2(0, offset), short2(BN, offset_next)); + } + + // Tile partially aligned check cols + else if (align_M || tgp_bm == BM) { + gemm_loop_unaligned( + Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK); + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + gemm_loop_finalize( + Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); + } + mma_op.store_result_slice( + y, N, short2(0, offset), short2(tgp_bn, offset_next)); + } + + // Nothing aligned so check both rows and cols + else { + gemm_loop_unaligned( + Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK); + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + gemm_loop_finalize( + Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); + } + mma_op.store_result_slice( + y, N, short2(0, offset), short2(tgp_bn, offset_next)); + } + } + } +} + +template +[[kernel]] void fp_quantize( + const device T* w [[buffer(0)]], + device uint8_t* out [[buffer(1)]], + device uint8_t* scales [[buffer(2)]], + uint2 tidx [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + constexpr bool use_mx_scale = group_size == 32; + size_t index = tidx.x + grid_dim.x * size_t(tidx.y); + + float scale; + float w_thread = w[index]; + if (use_mx_scale) { + scale = simd_max(abs(w_thread)); + } else { + float w_max_l = simd_max(tidx.x < 16 ? abs(w_thread) : 0.0); + float w_max_r = simd_max(tidx.x >= 16 ? abs(w_thread) : 0.0); + scale = tidx.x < 16 ? w_max_l : w_max_r; + } + scale /= bits == 4 ? 6.0f : 448.0f; + + using ScaleType = metal::conditional_t; + auto s = ScaleType(scale); + uint8_t q_scale = s.bits; + scale = float(s); + + size_t gindex = index / group_size; + if (index % group_size == 0) { + scales[gindex] = q_scale; + } + + uint8_t output = Quantize{}(scale == 0 ? 0.0f : w_thread / scale); + if (bits == 4) { + uint8_t sval = simd_shuffle_down(output, 1); + output |= sval << bits; + } + constexpr int pack_factor = bits == 8 ? 1 : 2; + if (index % pack_factor == 0) { + out[index / pack_factor] = output; + } +} + +template +[[kernel]] void fp_dequantize( + const device uint8_t* w [[buffer(0)]], + const device uint8_t* scales [[buffer(1)]], + device T* out [[buffer(3)]], + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + constexpr bool use_mx_scale = group_size == 32; + constexpr int pack_factor = bits == 8 ? 1 : 2; + size_t offset = index.x + grid_dim.x * size_t(index.y); + size_t oindex = offset * pack_factor; + size_t gindex = oindex / group_size; + + out += oindex; + + using ScaleType = metal::conditional_t; + auto q_scale = ((device ScaleType*)(scales))[gindex]; + auto scale = float(q_scale); + + uint val = w[offset]; +#pragma clang loop unroll(full) + for (int i = 0; i < pack_factor; i++) { + uint8_t d; + if (bits == 4) { + d = (val >> (bits * i)) & 0x0f; + } else if (bits == 8) { + d = val; + } + out[i] = static_cast(scale * Dequantize{}(d)); + } +} + +template +[[kernel]] void fp_quantize_dequantize( + const device T* w [[buffer(0)]], + device T* out [[buffer(1)]], + uint2 tidx [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + constexpr bool use_mx_scale = group_size == 32; + size_t index = tidx.x + grid_dim.x * size_t(tidx.y); + + float scale; + float w_thread = w[index]; + if (use_mx_scale) { + scale = simd_max(abs(w_thread)); + } else { + float w_max_l = simd_max(tidx.x < 16 ? abs(w_thread) : 0.0); + float w_max_r = simd_max(tidx.x >= 16 ? abs(w_thread) : 0.0); + scale = tidx.x < 16 ? w_max_l : w_max_r; + } + scale /= bits == 4 ? 6.0f : 448.0f; + + using ScaleType = metal::conditional_t; + auto s = ScaleType(scale); + scale = float(s); + + uint8_t output = Quantize{}(scale == 0 ? 0.0f : w_thread / scale); + + out[index] = static_cast(scale * Dequantize{}(output)); +} diff --git a/Source/Cxxmlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h new file mode 100644 index 00000000..80e1c4c2 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h @@ -0,0 +1,1044 @@ +// Copyright © 2025 Apple Inc. + +#include +#include + +#include "mlx/backend/metal/kernels/fp4.h" +#include "mlx/backend/metal/kernels/fp8.h" + +constant bool align_M [[function_constant(200)]]; +constant bool align_N [[function_constant(201)]]; +constant bool align_K [[function_constant(202)]]; + +using namespace metal; + +#define MLX_MTL_CONST static constant constexpr const + +MLX_MTL_CONST int SIMD_SIZE = 32; +MLX_MTL_CONST int QUAD_SIZE = 4; + +template +inline constexpr short get_pack_factor() { + return wsize / bits; +} + +template +inline constexpr short get_bytes_per_pack() { + return wsize / 8; +} + +template +static inline T dequantize_scale(uint8_t s) { + if constexpr (group_size == 16) { + // Use nv scale + return T(*(thread fp8_e4m3*)(&s)); + } else { + return T(*(thread fp8_e8m0*)(&s)); + } +} + +template +struct Quantize { + uint8_t operator()(float x) { + if (bits == 8) { + return fp8_e4m3(x).bits; + } else { + return fp4_e2m1(x).bits; + } + } +}; + +template +struct Dequantize { + U operator()(uint8_t x) { + if constexpr (bits == 8) { + return U(*(thread fp8_e4m3*)(&x)); + } else { + return U(*(thread fp4_e2m1*)(&x)); + } + } +}; + +template +inline void dequantize(uint8_t w, U scale, threadgroup U* w_local) { + if constexpr (bits == 4) { + w_local[0] = scale * Dequantize<4, U>{}(w); + w_local[1] = scale * Dequantize<4, U>{}(w >> 4); + } else { + w_local[0] = scale * Dequantize<8, U>{}(w); + } +} + +template < + typename T, + short BROWS, + short BCOLS, + short dst_ld, + short reduction_dim, + short tgp_size, + short group_size, + short bits> +struct QuantizedBlockLoader { + MLX_MTL_CONST short pack_factor = get_pack_factor<8, bits>(); + MLX_MTL_CONST short bytes_per_pack = get_bytes_per_pack(); + MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor; + MLX_MTL_CONST short n_reads = + (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size; + + MLX_MTL_CONST short n_reads_per_scale = (n_reads * pack_factor) <= group_size + ? n_reads + : (group_size / pack_factor); + MLX_MTL_CONST short n_steps_per_read = n_reads / n_reads_per_scale; + + MLX_MTL_CONST short n_groups = BCOLS / group_size; + + const int src_ld; + const int tile_stride; + const int group_stride; + + const short thread_idx; + const short bi; + const short bj; + + const short group_id; + + threadgroup T* dst; + const device uint8_t* src; + const device uint8_t* scales; + + QuantizedBlockLoader( + const device uint8_t* src_, + const device uint8_t* scales_, + const int src_ld_, + threadgroup T* dst_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(src_ld_), + tile_stride( + reduction_dim ? BCOLS_PACKED * bytes_per_pack + : BROWS * src_ld * bytes_per_pack / pack_factor), + group_stride(BROWS * src_ld / group_size), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(n_reads * thread_idx / BCOLS_PACKED), + bj((n_reads * thread_idx) % BCOLS_PACKED), + group_id((bj * pack_factor) / group_size), + dst(dst_ + bi * dst_ld + bj * pack_factor), + src(src_ + bi * src_ld * bytes_per_pack / pack_factor + + bj * bytes_per_pack), + scales(scales_ + bi * src_ld / group_size + group_id) {} + + void load_unsafe() const { + if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { + return; + } + + int k = 0; + for (int i = 0; i < n_steps_per_read; i++) { + T scale = dequantize_scale(scales[i]); + for (int j = 0; j < n_reads_per_scale; j++) { + dequantize( + src[k * bytes_per_pack], scale, dst + k * pack_factor); + k++; + } + } + } + + void load_safe(short2 src_tile_dim) const { + if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { + return; + } + + if (reduction_dim == 1 && bi >= src_tile_dim.x) { + for (int i = 0; i < n_reads * pack_factor; i++) { + dst[i] = T(0); + } + return; + } + + if (reduction_dim == 0 && bi >= src_tile_dim.y) { + for (int i = 0; i < n_reads * pack_factor; i++) { + dst[i] = T(0); + } + return; + } + + int k = 0; + for (int i = 0; i < n_steps_per_read; i++) { + T scale = dequantize_scale(scales[i]); + for (int j = 0; j < n_reads_per_scale; j++) { + dequantize( + src[k * bytes_per_pack], scale, dst + k * pack_factor); + k++; + } + } + } + + void next() { + src += tile_stride; + if (reduction_dim == 1) { + scales += n_groups; + } else { + scales += n_groups * group_stride; + } + } +}; + +using namespace mlx::steel; + +template < + typename T, + const int group_size, + const int bits, + const bool aligned_N, + const int BM = 64, + const int BK = 64, + const int BN = 64, + const int WM = 2, + const int WN = 2, + typename Wtype = bfloat> +METAL_FUNC void fp_qmm_t_impl( + const device uint32_t* w, + const device uint8_t* scales, + const device T* x, + device T* y, + threadgroup Wtype* Ws, + const constant int& K, + const constant int& N, + const constant int& M, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); + static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); + + (void)lid; + + constexpr int pack_factor = get_pack_factor<8, bits>(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + + constexpr int BK_padded = (BK + 16 / sizeof(Wtype)); + + // Instantiate Loader + using loader_w_t = QuantizedBlockLoader< + Wtype, + BN, + BK, + BK_padded, + 1, + WM * WN * SIMD_SIZE, + group_size, + bits>; + + // Set the block + const int K_w = K * bytes_per_pack / pack_factor; + const int K_g = K / group_size; + const int y_row = tid.y * BM; + const int y_col = tid.x * BN; + + auto wl = (const device uint8_t*)w; + + x += y_row * static_cast(K); + wl += y_col * K_w; + scales += y_col * K_g; + y += y_row * static_cast(N) + y_col; + + // Make the weight loader + loader_w_t loader_w(wl, scales, K, Ws, simd_gid, simd_lid); + + constexpr short UM = 16; + constexpr short UN = 32; + constexpr short UK = 16; + constexpr short SM = BM / WM; + constexpr short SN = BN / WN; + constexpr short SK = 32; + + constexpr short TM = SM / UM; + constexpr short TN = SN / UN; + constexpr short TK = SK / UK; + + const short tm = SM * (simd_gid / WN); + const short tn = SN * (simd_gid % WN); + + constexpr bool transpose_a = false; + constexpr bool transpose_b = true; + + const short sgp_sm = min(SM, short(M - (y_row + tm))); + const bool is_unaligned_sm = (sgp_sm != SM); + + const short sgp_sn = aligned_N ? SN : min(SN, short(N - (y_col + tn))); + + const short tgp_bn = aligned_N ? BN : min(BN, int(N - (y_col))); + const bool is_unaligned_bn = aligned_N ? false : (tgp_bn != BN); + + using AccumType = float; + + using ASubTile = NAXSubTile; + using BSubTile = NAXSubTile; + using DSubTile = NAXSubTile; + + NAXTile Dtile; + + Dtile.clear(); + + x += tm * K; + + dispatch_bool(!is_unaligned_sm, [&](auto kAlignedM) { + dispatch_bool(aligned_N || !is_unaligned_bn, [&](auto kAlignedN) { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + if constexpr (kAlignedN.value) { + loader_w.load_unsafe(); + } else { + loader_w.load_safe(short2(BK, tgp_bn)); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + STEEL_PRAGMA_NO_UNROLL + for (int kk1 = 0; kk1 < BK; kk1 += SK) { + NAXTile Atile; + NAXTile Btile; + + volatile int compiler_barrier; + + if constexpr (kAlignedM.value) { + Atile.load(x + kk1, K); + } else { + Atile.load_safe(x + kk1, K, short2(SK, sgp_sm)); + } + + Btile.template load(Ws + tn * BK_padded + kk1); + + tile_matmad_nax( + Dtile, + Atile, + metal::bool_constant{}, + Btile, + metal::bool_constant{}); + + (void)compiler_barrier; + } + + x += BK; + loader_w.next(); + } + + // Store results to device memory + threadgroup_barrier(mem_flags::mem_threadgroup); + + if constexpr (kAlignedM.value && kAlignedN.value) { + Dtile.store(y + tm * N + tn, N); + } else if (kAlignedM.value && sgp_sn == SN) { + Dtile.store(y + tm * N + tn, N); + } else { + Dtile.store_safe(y + tm * N + tn, N, short2(sgp_sn, sgp_sm)); + } + }); + }); +} + +template < + typename T, + const int group_size, + const int bits, + const int BM = 64, + const int BK = 64, + const int BN = 64, + const int WM = 2, + const int WN = 2, + typename Wtype = bfloat> +METAL_FUNC void fp_qmm_n_impl( + const device uint32_t* w, + const device uint8_t* scales, + const device T* x, + device T* y, + threadgroup T* Ws, + const constant int& K, + const constant int& N, + const constant int& M, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); + static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); + + (void)lid; + (void)M; + + constexpr int pack_factor = get_pack_factor<8, bits>(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + + constexpr int BN_padded = (BN + 16 / sizeof(T)); + + using loader_w_t = QuantizedBlockLoader< + T, + BK, + BN, + BN_padded, + 0, + WM * WN * SIMD_SIZE, + group_size, + bits>; + + // Set the block + const int K_w = K * bytes_per_pack / pack_factor; + const int K_g = K / group_size; + const int y_row = tid.y * BM; + const int y_col = tid.x * BN; + + auto wl = (const device uint8_t*)w; + + x += y_row * static_cast(K); + wl += y_col * K_w; + scales += y_col * K_g; + y += y_row * static_cast(N) + y_col; + + // Make the x loader and mma operation + // const short num_els = min(BM, M - y_row); + // const short num_outs = min(BN, N - y_col); + loader_w_t loader_w(wl, scales, K, Ws, simd_gid, simd_lid); + + constexpr short UM = 16; + constexpr short UN = 32; + constexpr short UK = 16; + constexpr short SM = BM / WM; + constexpr short SN = BN / WN; + constexpr short SK = 32; + + constexpr short TM = SM / UM; + constexpr short TN = SN / UN; + constexpr short TK = SK / UK; + + const short tm = SM * (simd_gid / WN); + const short tn = SN * (simd_gid % WN); + + const short ldb_tgp = BN_padded; + + constexpr bool transpose_a = false; + constexpr bool transpose_b = false; + + using AccumType = float; + + using ASubTile = NAXSubTile; + using BSubTile = NAXSubTile; + using DSubTile = NAXSubTile; + + NAXTile Dtile; + + Dtile.clear(); + + x += tm * K; + + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + + STEEL_PRAGMA_NO_UNROLL + for (int kk1 = 0; kk1 < BK; kk1 += SK) { + NAXTile Atile; + NAXTile Btile; + + volatile int compiler_barrier; + + Atile.load(x + kk1, K); + Btile.template load(Ws + tn + kk1 * ldb_tgp); + + tile_matmad_nax( + Dtile, + Atile, + metal::bool_constant{}, + Btile, + metal::bool_constant{}); + + (void)compiler_barrier; + } + + x += BK; + loader_w.next(); + } + + // Store results to device memory + threadgroup_barrier(mem_flags::mem_threadgroup); + + Dtile.store(y + tm * N + tn, N); +} + +template +METAL_FUNC void adjust_matrix_offsets( + const device T*& x, + const device uint32_t*& w, + const device S*& scales, + device T*& y, + int output_stride, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + uint3 tid [[threadgroup_position_in_grid]]) { + // Set the input/output matrices + uint32_t x_idx = tid.z; + uint32_t w_idx = tid.z; + if (x_batch_ndims == 1) { + x += x_idx * x_strides[0]; + } else { + x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims); + } + if (w_batch_ndims == 1) { + w += w_idx * w_strides[0]; + scales += w_idx * s_strides[0]; + } else { + ulong2 idx = elem_to_loc_broadcast( + w_idx, w_shape, w_strides, s_strides, w_batch_ndims); + w += idx.x; + scales += idx.y; + } + y += tid.z * output_stride; +} + +template +METAL_FUNC void adjust_matrix_offsets( + const device T*& x, + const device uint32_t*& w, + const device S*& scales, + const device uint32_t* lhs_indices, + const device uint32_t* rhs_indices, + device T*& y, + int output_stride, + const constant int& batch_ndims, + const constant int* batch_shape, + const constant int64_t* lhs_strides, + const constant int64_t* rhs_strides, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + uint3 tid [[threadgroup_position_in_grid]]) { + // Set the input/output matrices + uint32_t x_idx; + uint32_t w_idx; + if (batch_ndims == 1) { + x_idx = lhs_indices[tid.z * lhs_strides[0]]; + w_idx = rhs_indices[tid.z * rhs_strides[0]]; + } else { + ulong2 idx = elem_to_loc_broadcast( + tid.z, batch_shape, lhs_strides, rhs_strides, batch_ndims); + x_idx = lhs_indices[idx.x]; + w_idx = rhs_indices[idx.y]; + } + if (x_batch_ndims == 1) { + x += x_idx * x_strides[0]; + } else { + x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims); + } + if (w_batch_ndims == 1) { + w += w_idx * w_strides[0]; + scales += w_idx * s_strides[0]; + } else { + ulong2 idx = elem_to_loc_broadcast( + w_idx, w_shape, w_strides, s_strides, w_batch_ndims); + w += idx.x; + scales += idx.y; + } + y += tid.z * output_stride; +} + +template < + typename T, + const int group_size, + const int bits, + const bool aligned_N, + const bool batched, + const int BM = 64, + const int BK = 64, + const int BN = 64, + const int WM = 2, + const int WN = 2, + typename Wtype = bfloat> +[[kernel]] void fp_qmm_t_nax( + const device uint32_t* w, + const device uint8_t* scales, + const device T* x, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BK_padded = (BK + 16 / sizeof(Wtype)); + + threadgroup Wtype Ws[BN * BK_padded]; + + if (batched) { + adjust_matrix_offsets( + x, + w, + scales, + y, + M * N, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + } + fp_qmm_t_impl( + w, scales, x, y, Ws, K, N, M, tid, lid, simd_gid, simd_lid); +} + +template < + typename T, + const int group_size, + const int bits, + const bool batched, + const int BM = 64, + const int BK = 64, + const int BN = 64, + const int WM = 2, + const int WN = 2, + typename Wtype = bfloat> +[[kernel]] void fp_qmm_n_nax( + const device uint32_t* w, + const device uint8_t* scales, + const device T* x, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BK * BN_padded]; + + if (batched) { + adjust_matrix_offsets( + x, + w, + scales, + y, + M * N, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + } + + fp_qmm_n_impl( + w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); +} + +template < + typename T, + const int group_size, + const int bits, + const bool aligned_N, + const int BM = 64, + const int BK = 64, + const int BN = 64, + const int WM = 2, + const int WN = 2, + typename Wtype = bfloat> +[[kernel]] void fp_gather_qmm_t_nax( + const device uint32_t* w, + const device uint8_t* scales, + const device T* x, + const device uint32_t* lhs_indices, + const device uint32_t* rhs_indices, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + const constant int& batch_ndims, + const constant int* batch_shape, + const constant int64_t* lhs_strides, + const constant int64_t* rhs_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BK_padded = (BK + 16 / sizeof(Wtype)); + + threadgroup Wtype Ws[BN * BK_padded]; + + adjust_matrix_offsets( + x, + w, + scales, + lhs_indices, + rhs_indices, + y, + M * N, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + fp_qmm_t_impl( + w, scales, x, y, Ws, K, N, M, tid, lid, simd_gid, simd_lid); +} + +template < + typename T, + const int group_size, + const int bits, + const int BM = 64, + const int BK = 64, + const int BN = 64, + const int WM = 2, + const int WN = 2, + typename Wtype = bfloat> +[[kernel]] void fp_gather_qmm_n_nax( + const device uint32_t* w, + const device uint8_t* scales, + const device T* x, + const device uint32_t* lhs_indices, + const device uint32_t* rhs_indices, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + const constant int& batch_ndims, + const constant int* batch_shape, + const constant int64_t* lhs_strides, + const constant int64_t* rhs_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BK * BN_padded]; + + adjust_matrix_offsets( + x, + w, + scales, + lhs_indices, + rhs_indices, + y, + M * N, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + fp_qmm_n_impl( + w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); +} + +template < + typename T, + int group_size, + const int bits, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose, + typename Wtype = bfloat> +[[kernel]] void fp_gather_qmm_rhs_nax( + const device T* x, + const device uint32_t* w, + const device uint8_t* scales, + const device uint32_t* indices, + device T* y, + const constant int& M, + const constant int& N, + const constant int& K, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) { + constexpr int pack_factor = get_pack_factor<8, bits>(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + constexpr int BK_padded = (BK + 16 / sizeof(Wtype)); + constexpr int BN_padded = (BN + 16 / sizeof(Wtype)); + + using loader_w_t = QuantizedBlockLoader< + Wtype, + transpose ? BN : BK, + transpose ? BK : BN, + transpose ? BK_padded : BN_padded, + transpose, + WM * WN * SIMD_SIZE, + group_size, + bits>; + + threadgroup Wtype Ws[transpose ? BN * BK_padded : BK * BN_padded]; + + // Compute the block + const int K_w = K * bytes_per_pack / pack_factor; + const int K_g = K / group_size; + const int N_w = N * bytes_per_pack / pack_factor; + const int N_g = N / group_size; + const int K_it = K / BK; + const size_t stride_w = transpose ? N * K_w : K * N_w; + const size_t stride_s = transpose ? N * K_g : K * N_g; + const int y_row = tid.y * BM; + const int y_col = tid.x * BN; + const size_t y_row_long = size_t(y_row); + const size_t y_col_long = size_t(y_col); + + // Prepare threadgroup bounds + const short tgp_bm = align_M ? BM : short(min(BM, M - y_row)); + const short tgp_bn = align_N ? BN : short(min(BN, N - y_col)); + + // Calculate the final tiles in the case that K is not aligned + const int k_remain = K - K_it * BK; + const short2 tile_w = + transpose ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + + // Move x and output to the correct block + auto wl = (const device uint8_t*)w; + x += y_row_long * K; + y += y_row_long * N + y_col_long; + wl += transpose ? y_col_long * K_w : y_col * bytes_per_pack / pack_factor; + scales += transpose ? y_col_long * K_g : y_col / group_size; + + constexpr short UM = 16; + constexpr short UN = 32; + constexpr short UK = 16; + constexpr short SM = BM / WM; + constexpr short SN = BN / WN; + constexpr short SK = 32; + + constexpr short TM = SM / UM; + constexpr short TN = SN / UN; + constexpr short TK = SK / UK; + + const short tm = SM * (simd_group_id / WN); + const short tn = SN * (simd_group_id % WN); + + const short sgp_sm = + align_M ? SM : min(SM, short(max(0, (M - (y_row + tm))))); + const short sgp_sn = + align_N ? SN : min(SN, short(max(0, (N - (y_col + tn))))); + + const bool is_unaligned_sm = align_M ? false : (sgp_sm != SM); + const bool is_unaligned_bn = align_N ? false : (tgp_bn != BN); + + constexpr short BR = transpose ? TN : TK; + constexpr short BC = transpose ? TK : TN; + + using AccumType = float; + + using ASubTile = NAXSubTile; + using BSubTile = NAXSubTile; + using DSubTile = NAXSubTile; + + // Do as many matmuls as necessary + uint32_t index; + short offset; + uint32_t index_next = indices[y_row]; + short offset_next = 0; + int n = 0; + while (n < tgp_bm) { + n++; + offset = offset_next; + index = index_next; + offset_next = tgp_bm; + for (; n < tgp_bm; n++) { + if (indices[y_row + n] != index) { + offset_next = n; + index_next = indices[y_row + n]; + break; + } + } + threadgroup_barrier(mem_flags::mem_none); + + // Prepare threadgroup mma operation + NAXTile Dtile; + + Dtile.clear(); + + const device T* xn = x + tm * K; + + // Prepare threadgroup loading operations + thread loader_w_t loader_w( + wl + index * stride_w, + scales + index * stride_s, + transpose ? K : N, + Ws, + simd_group_id, + simd_lane_id); + + dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) { + dispatch_bool(align_N || !is_unaligned_bn, [&](auto kAlignedN) { + for (int k = 0; k < K_it; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + if constexpr (kAlignedN.value) { + loader_w.load_unsafe(); + } else { + loader_w.load_safe( + transpose ? short2(BK, tgp_bn) : short2(tgp_bn, BK)); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + STEEL_PRAGMA_NO_UNROLL + for (int kk1 = 0; kk1 < BK; kk1 += SK) { + NAXTile Atile; + NAXTile Btile; + + volatile int compiler_barrier; + + if constexpr (kAlignedM.value) { + Atile.load(xn + kk1, K); + } else { + Atile.load_safe(xn + kk1, K, short2(SK, sgp_sm)); + } + + if constexpr (transpose) { + Btile.template load( + Ws + tn * BK_padded + kk1); + } else { + Btile.template load( + Ws + tn + kk1 * BN_padded); + } + + tile_matmad_nax( + Dtile, + Atile, + metal::bool_constant{}, + Btile, + metal::bool_constant{}); + + (void)compiler_barrier; + } + + xn += BK; + loader_w.next(); + } + + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_w.load_safe(tile_w); + threadgroup_barrier(mem_flags::mem_threadgroup); + + STEEL_PRAGMA_NO_UNROLL + for (int kk1 = 0; kk1 < BK; kk1 += SK) { + NAXTile Atile; + NAXTile Btile; + + volatile int compiler_barrier; + + const short psk = min(int(SK), max(0, (BK - kk1))); + Atile.load_safe(xn + kk1, K, short2(psk, sgp_sm)); + + if constexpr (transpose) { + Btile.template load( + Ws + tn * BK_padded + kk1); + } else { + Btile.template load( + Ws + tn + kk1 * BN_padded); + } + + tile_matmad_nax( + Dtile, + Atile, + metal::bool_constant{}, + Btile, + metal::bool_constant{}); + + (void)compiler_barrier; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + const short m_lo_lim = min(int(sgp_sm), max(0, offset - tm)); + const short m_hi_lim = min(int(sgp_sm), max(0, offset_next - tm)); + + // Store results to device memory + if constexpr (kAlignedN.value) { + if (m_lo_lim == 0 && m_hi_lim == SM) { + Dtile.store(y + tm * N + tn, N); + } else { + Dtile.store_slice( + y + tm * N + tn, N, short2(0, m_lo_lim), short2(SN, m_hi_lim)); + } + } else { + Dtile.store_slice( + y + tm * N + tn, + N, + short2(0, m_lo_lim), + short2(sgp_sn, m_hi_lim)); + } + }); + }); + } +} diff --git a/Source/Cxxmlx/include/mlx/backend/metal/kernels/gemv_masked.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/gemv_masked.h new file mode 100644 index 00000000..407d14bb --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/metal/kernels/gemv_masked.h @@ -0,0 +1,827 @@ +// Copyright © 2023-2024 Apple Inc. + +#include "mlx/backend/metal/kernels/steel/utils.h" + +using namespace metal; + +#define MLX_MTL_CONST static constant constexpr const +#define MLX_MTL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)") + +struct _NoMask { + char x; + + constexpr METAL_FUNC operator bool() { + return true; + } + constexpr METAL_FUNC operator bool() const threadgroup { + return true; + } + constexpr METAL_FUNC operator bool() const device { + return true; + } + constexpr METAL_FUNC operator bool() const constant { + return true; + } +}; + +typedef struct _NoMask nomask_t; + +template +struct ScaleOp { + OutT scale; + + METAL_FUNC OutT apply(InT x) const { + return static_cast(x) * scale; + } +}; + +template < + typename T, + typename out_mask_t, + typename op_mask_t, + const int BM, /* Threadgroup rows (in simdgroups) */ + const int BN, /* Threadgroup cols (in simdgroups) */ + const int SM, /* Simdgroup rows (in threads) */ + const int SN, /* Simdgroup cols (in threads) */ + const int TM, /* Thread rows (in elements) */ + const int TN, /* Thread cols (in elements) */ + typename AccT = float> +struct GEMVKernel { + MLX_MTL_CONST int threadsM = BM * SM; + MLX_MTL_CONST int threadsN = BN * SN; + + MLX_MTL_CONST int blockM = threadsM * TM; + MLX_MTL_CONST int blockN = threadsN * TN; + + static_assert(SM * SN == 32, "simdgroup can only have 32 threads"); + + static_assert( + SN == 8 || SN == 16 || SN == 32, + "gemv block must have a width of 8, 16, or 32"); + + static_assert(blockN >= blockM, "Masked gemv must have blockN >= blockM"); + + MLX_MTL_CONST bool has_operand_mask = !metal::is_same_v; + MLX_MTL_CONST bool has_output_mask = !metal::is_same_v; + + MLX_MTL_CONST bool has_mul_operand_mask = + has_operand_mask && !metal::is_same_v; + MLX_MTL_CONST bool has_mul_output_mask = + has_output_mask && !metal::is_same_v; + + // - The matrix of size (M = out_vec_size, K = in_vec_size) is divided up + // into blocks of (blockM, blockN) divided among threadgroups + // - Every thread works on a block of (TM, TN) + // - We assume each threadgroup has (threadsN, threadsM, 1) threads + // + // 1. A thread loads TN elements each from mat along TM rows + // and the corresponding scalar from the vector + // 2. The thread then multiplies and adds to accumulate its local result for + // the block + // 3. At the end, each thread has accumulated results over all blocks across + // the rows. These are then summed up across the threadgroup + // 4. Each threadgroup writes its accumulated blockM outputs + // + // Edge case handling: + // - The threadgroup with the largest tid has blocks that exceed the matrix + // * The blocks that start outside the matrix are never read (thread results + // remain zero) + // * The last thread that partially overlaps with the matrix is shifted + // inwards such that the thread block fits exactly in the matrix + + MLX_MTL_CONST short tgp_mem_size = BN > 1 ? BN*(blockM + TM) : 0; + MLX_MTL_CONST bool needs_tgp_reduction = BN > 1; + + template + static METAL_FUNC void + load_unsafe(const device T* src, thread U dst[TN], const int src_offset = 0) { + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + dst[tn] = static_cast(src[src_offset + tn]); + } + } + + template + static METAL_FUNC void load_safe( + const device T* src, + thread U dst[TN], + const int src_offset = 0, + const int src_size = TN) { + if (src_offset + TN <= src_size) { + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + dst[tn] = static_cast(src[src_offset + tn]); + } + } else { // Edgecase + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + dst[tn] = src_offset + tn < src_size + ? static_cast(src[src_offset + tn]) + : U(0); + } + } + } + + static METAL_FUNC void run( + const device T* mat [[buffer(0)]], + const device T* in_vec [[buffer(1)]], + device T* out_vec [[buffer(3)]], + const constant int& in_vec_size [[buffer(4)]], + const constant int& out_vec_size [[buffer(5)]], + const constant int& matrix_ld [[buffer(6)]], + const device out_mask_t* out_mask [[buffer(20)]], + const device op_mask_t* mat_mask [[buffer(21)]], + const device op_mask_t* vec_mask [[buffer(22)]], + const constant int* mask_strides [[buffer(23)]], + threadgroup AccT* tgp_memory [[threadgroup(0)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + // Appease compiler + (void)lid; + + // Thread local accumulation results + thread AccT result[TM] = {0}; + thread T inter[TN]; + thread AccT v_coeff[TN]; + + const int thrM = SN != 32 ? simd_lid / SN : 0; + const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid); + + const int sgN = BN != 1 ? (simd_gid % BN) : 0; + + const int simdM = BN != 1 ? SM * (simd_gid / BN) : int(SM * simd_gid); + const int simdN = BN != 1 ? SN * (simd_gid % BN) : 0; + + int bm = (simdM + thrM) * TM; + int bn = (simdN + thrN) * TN; + + // Block position + int out_row = tid.x * blockM + bm; + + // Exit simdgroup if rows out of bound + if (out_row >= out_vec_size) + return; + + // Adjust tail simdgroup to ensure in bound reads + out_row = out_row + TM <= out_vec_size ? out_row : out_vec_size - TM; + + // Prepare mask offsets + const constant int* out_mask_strides = mask_strides; + const constant int* mat_mask_strides = + mask_strides + (has_output_mask ? 2 : 0); + const constant int* vec_mask_strides = + mat_mask_strides + (has_operand_mask ? 2 : 0); + + const int m_block_idx = blockN > blockM ? out_row / blockN : int(tid.x); + + const int out_mask_offset = + !has_output_mask ? 0 : m_block_idx * out_mask_strides[1]; + + int mat_mask_offset = + !has_operand_mask ? 0 : m_block_idx * mat_mask_strides[1]; + int vec_mask_offset = 0; + const int mat_mask_step = !has_operand_mask ? 0 : mat_mask_strides[0]; + const int vec_mask_step = !has_operand_mask ? 0 : vec_mask_strides[1]; + + T out_scale{1}; + + // Check output mask + if (has_output_mask) { + auto mask_out = out_mask[out_mask_offset]; + + // Write zeros and return if mask is 0 + if (!mask_out) { + if (simdN == 0 && thrN == 0) { + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + out_vec[out_row + tm] = T(0.); + } + } + + return; + } + + // Store scalar if multiplicative mask + if (has_mul_output_mask) { + out_scale = T(mask_out); + } + } + + // Advance matrix + mat += out_row * matrix_ld; + + // Prepare for loop + constexpr const uniform loop_stride = make_uniform(blockN); + const uniform in_size = make_uniform(in_vec_size); + const uniform n_iter = in_size / loop_stride; + const uniform last_iter = loop_stride * n_iter; + const uniform leftover = in_size - last_iter; + + // Loop over in_vec in blocks of blockN + for (int i = 0; i < n_iter; ++i) { + if (!has_operand_mask || + (bool(mat_mask[mat_mask_offset]) && + bool(vec_mask[vec_mask_offset]))) { + T block_scale{1}; + if (has_mul_operand_mask) { + block_scale = + T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]); + } + + load_unsafe(in_vec, v_coeff, bn); + + // Apply scale + if (has_mul_operand_mask) { + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + v_coeff[tn] *= block_scale; + } + } + + // Per thread work loop + int mat_offset = 0; + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + // Load for the row + load_unsafe(mat, inter, mat_offset + bn); + + // Accumulate results + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + result[tm] += inter[tn] * v_coeff[tn]; + } + + mat_offset += matrix_ld; + } + } + + bn += blockN; + mat_mask_offset += mat_mask_step; + vec_mask_offset += vec_mask_step; + } + + if (leftover > 0) { + if (!has_operand_mask || + (bool(mat_mask[mat_mask_offset]) && + bool(vec_mask[vec_mask_offset]))) { + T block_scale{1}; + if (has_mul_operand_mask) { + block_scale = + T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]); + } + + load_safe(in_vec, v_coeff, bn, in_size); + + // Apply scale + if (has_mul_operand_mask) { + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + v_coeff[tn] *= block_scale; + } + } + + // Per thread work loop + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + // Load for the row + load_safe(&mat[tm * matrix_ld], inter, bn, in_size); + + // Accumulate results + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + result[tm] += inter[tn] * v_coeff[tn]; + } + } + } + } + + // Apply out scale + if (has_mul_output_mask) { + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + result[tm] *= out_scale; + } + } + + // Simdgroup accumulations + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + MLX_MTL_PRAGMA_UNROLL + for (ushort sn = (SN / 2); sn >= 1; sn >>= 1) { + result[tm] += simd_shuffle_down(result[tm], sn); + } + } + + // Threadgroup accumulation results + if (needs_tgp_reduction) { + threadgroup AccT* tgp_results = tgp_memory + sgN * (blockM + TM) + bm; + if (thrN == 0) { + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + tgp_results[tm] = result[tm]; + } + + threadgroup_barrier(mem_flags::mem_none); + + if (sgN == 0) { + MLX_MTL_PRAGMA_UNROLL + for (int sgn = 1; sgn < BN; sgn++) { + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + result[tm] += tgp_results[sgn * (blockM + TM) + tm]; + } + } + } + } + } + + // Write outputs + if (simdN == 0 && thrN == 0) { + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + out_vec[out_row + tm] = static_cast(result[tm]); + } + } + } +}; + +/////////////////////////////////////////////////////////////////////////////// +/// Vector matrix multiplication +/////////////////////////////////////////////////////////////////////////////// + +template < + typename T, + typename out_mask_t, + typename op_mask_t, + const int BM, /* Threadgroup rows (in simdgroups) */ + const int BN, /* Threadgroup cols (in simdgroups) */ + const int SM, /* Simdgroup rows (in threads) */ + const int SN, /* Simdgroup cols (in threads) */ + const int TM, /* Thread rows (in elements) */ + const int TN, /* Thread cols (in elements) */ + typename AccT = float> +struct GEMVTKernel { + MLX_MTL_CONST int threadsM = BM * SM; + MLX_MTL_CONST int threadsN = BN * SN; + + MLX_MTL_CONST int blockM = threadsM * TM; + MLX_MTL_CONST int blockN = threadsN * TN; + + static_assert(SM * SN == 32, "simdgroup can only have 32 threads"); + + MLX_MTL_CONST bool has_operand_mask = !metal::is_same_v; + MLX_MTL_CONST bool has_output_mask = !metal::is_same_v; + + MLX_MTL_CONST bool has_mul_operand_mask = + has_operand_mask && !metal::is_same_v; + MLX_MTL_CONST bool has_mul_output_mask = + has_output_mask && !metal::is_same_v; + + // - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up + // into blocks of (blockM, blockN) divided among threadgroups + // - Every thread works on a block of (TM, TN) + // - We assume each threadgroup has (threadsN, threadsM, 1) threads + // + // 1. A thread loads TN elements each from mat along TM contiguous rows + // and the corresponding scalar from the vector + // 2. The thread then accumulates its local result for the block + // 3. At the end, each thread has accumulated results over all blocks across + // the rows. These are then summed up across the threadgroup + // 4. Each threadgroup writes its accumulated BN * TN outputs + // + // Edge case handling: + // - The threadgroup with the largest tid has blocks that exceed the matrix + // * The blocks that start outside the matrix are never read (thread results + // remain zero) + // * The last thread that partially overlaps with the matrix is shifted + // inwards such that the thread block fits exactly in the matrix + + MLX_MTL_CONST short tgp_mem_size = BM > 1 ? BM*(blockN + TN) : 0; + MLX_MTL_CONST bool needs_tgp_reduction = BM > 1; + + static METAL_FUNC void run( + const device T* mat [[buffer(0)]], + const device T* in_vec [[buffer(1)]], + device T* out_vec [[buffer(3)]], + const constant int& in_vec_size [[buffer(4)]], + const constant int& out_vec_size [[buffer(5)]], + const constant int& marix_ld [[buffer(6)]], + const device out_mask_t* out_mask [[buffer(20)]], + const device op_mask_t* mat_mask [[buffer(21)]], + const device op_mask_t* vec_mask [[buffer(22)]], + const constant int* mask_strides [[buffer(23)]], + threadgroup AccT* tgp_memory [[threadgroup(0)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + // Appease compiler + (void)lid; + + // Thread local accumulation results + AccT result[TN] = {0}; + T inter[TN]; + AccT v_coeff[TM]; + + const int thrM = SN != 32 ? simd_lid / SN : 0; + const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid); + + const int sgM = BN != 1 ? (simd_gid / BN) : int(simd_gid); + const int sgN = BN != 1 ? (simd_gid % BN) : 0; + + const int simdM = SM * sgM; + const int simdN = SN * sgN; + + int cm = (simdM + thrM); + int cn = (simdN + thrN); + + int bm = cm * TM; + int bn = cn * TN; + + int out_col = tid.x * blockN + bn; + + // Prepare mask offsets + const constant int* out_mask_strides = mask_strides; + const constant int* mat_mask_strides = + out_mask_strides + (has_output_mask ? 2 : 0); + const constant int* vec_mask_strides = + mat_mask_strides + (has_operand_mask ? 2 : 0); + + const int n_block_idx = blockM > blockN ? out_col / blockM : int(tid.x); + + const int out_mask_offset = + !has_output_mask ? 0 : n_block_idx; // * out_mask_strides[0]; + + int mat_mask_offset = + !has_operand_mask ? 0 : n_block_idx * mat_mask_strides[0]; + int vec_mask_offset = 0; + const int mat_mask_step = !has_operand_mask ? 0 : mat_mask_strides[1]; + const int vec_mask_step = !has_operand_mask ? 0 : vec_mask_strides[0]; + + T out_scale{1}; + + // Check output mask + if (has_output_mask) { + auto mask_out = out_mask[out_mask_offset]; + + // Write zeros and return if mask is 0 + if (!mask_out) { + if (cm == 0 && out_col < out_vec_size) { + if (out_col + TN <= out_vec_size) { + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + out_vec[out_col + tn] = T(0.); + } + } else { + for (int tn = 0; tn < TN && (out_col + tn) < out_vec_size; tn++) { + out_vec[out_col + tn] = T(0.); + } + } + } + + return; + } + + // Store scalar if multiplicative mask + if (has_mul_output_mask) { + out_scale = T(mask_out); + } + } + + // Prepare for loop + constexpr const uniform loop_stride = make_uniform(blockM); + const uniform in_size = make_uniform(in_vec_size); + const uniform n_iter = in_size / loop_stride; + const uniform last_iter = loop_stride * n_iter; + const uniform leftover = in_size - last_iter; + + // Edgecase handling + if (out_col < out_vec_size) { + out_col = (out_col + TN) <= out_vec_size ? out_col : out_vec_size - TN; + + // Per thread accumulation main loop + for (int i = 0; i < n_iter; ++i) { + // Adding a threadgroup_barrier improves performance slightly + // This is possibly it may help exploit cache better + threadgroup_barrier(mem_flags::mem_none); + + if (!has_operand_mask || + (bool(mat_mask[mat_mask_offset]) && + bool(vec_mask[vec_mask_offset]))) { + T block_scale{1}; + if (has_mul_operand_mask) { + block_scale = + T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]); + } + + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + v_coeff[tm] = static_cast(in_vec[bm + tm]); + } + + // Apply scale + if (has_mul_operand_mask) { + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + v_coeff[tm] *= block_scale; + } + } + + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + for (int tn = 0; tn < TN; tn++) { + inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn]; + } + for (int tn = 0; tn < TN; tn++) { + result[tn] += v_coeff[tm] * inter[tn]; + } + } + } + + bm += blockM; + mat_mask_offset += mat_mask_step; + vec_mask_offset += vec_mask_step; + } + + if (leftover > 0) { + if (!has_operand_mask || + (bool(mat_mask[mat_mask_offset]) && + bool(vec_mask[vec_mask_offset]))) { + T block_scale{1}; + if (has_mul_operand_mask) { + block_scale = + T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]); + } + + for (int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) { + v_coeff[tm] = static_cast(in_vec[bm + tm]); + + if (has_mul_operand_mask) { + v_coeff[tm] *= block_scale; + } + + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn]; + } + + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + result[tn] += v_coeff[tm] * inter[tn]; + } + } + } + } + } + + // Apply out scale + if (has_mul_output_mask) { + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + result[tn] *= out_scale; + } + } + + // Simdgroup accumulations + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + MLX_MTL_PRAGMA_UNROLL + for (ushort sm = (SM / 2); sm >= 1; sm >>= 1) { + result[tn] += simd_shuffle_down(result[tn], SN * sm); + } + } + + // Threadgroup accumulation results + if (needs_tgp_reduction) { + threadgroup AccT* tgp_results = tgp_memory + sgM * (blockN + TN) + bn; + if (thrM == 0) { + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + tgp_results[tn] = result[tn]; + } + + threadgroup_barrier(mem_flags::mem_none); + + if (sgM == 0) { + MLX_MTL_PRAGMA_UNROLL + for (int sgm = 1; sgm < BM; sgm++) { + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + result[tn] += tgp_results[sgm * (blockN + TN) + tn]; + } + } + } + } + } + + // Threadgroup accumulation and writing out results + if (cm == 0 && out_col < out_vec_size) { + MLX_MTL_PRAGMA_UNROLL + for (int j = 0; j < TN; j++) { + out_vec[out_col + j] = static_cast(result[j]); + } + } + } +}; + +/////////////////////////////////////////////////////////////////////////////// +/// Matrix vector multiplication +/////////////////////////////////////////////////////////////////////////////// + +template < + typename T, + typename out_mask_t, + typename op_mask_t, + const int BM, /* Threadgroup rows (in simdgroups) */ + const int BN, /* Threadgroup cols (in simdgroups) */ + const int SM, /* Simdgroup rows (in threads) */ + const int SN, /* Simdgroup cols (in threads) */ + const int TM, /* Thread rows (in elements) */ + const int TN, /* Thread cols (in elements) */ + const bool kDoNCBatch> /* Batch ndim > 1 */ +[[kernel, max_total_threads_per_threadgroup(BM * BN * 32)]] void gemv_masked( + const device T* mat [[buffer(0)]], + const device T* in_vec [[buffer(1)]], + device T* out_vec [[buffer(3)]], + const constant int& in_vec_size [[buffer(4)]], + const constant int& out_vec_size [[buffer(5)]], + const constant int& marix_ld [[buffer(6)]], + const constant int& batch_ndim [[buffer(9)]], + const constant int* batch_shape [[buffer(10)]], + const constant int64_t* vector_batch_stride [[buffer(11)]], + const constant int64_t* matrix_batch_stride [[buffer(12)]], + const device out_mask_t* out_mask [[buffer(20)]], + const device op_mask_t* mat_mask [[buffer(21)]], + const device op_mask_t* vec_mask [[buffer(22)]], + const constant int* mask_strides [[buffer(23)]], + const constant int64_t* mask_batch_strides [[buffer(24)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + using gemv_kernel = + GEMVKernel; + threadgroup float tgp_memory + [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size]; + + constexpr bool has_operand_mask = !metal::is_same_v; + constexpr bool has_output_mask = !metal::is_same_v; + + // Update batch offsets + if (kDoNCBatch) { + in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim); + mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim); + + if (has_output_mask) { + out_mask += + elem_to_loc(tid.z, batch_shape, mask_batch_strides, batch_ndim); + mask_batch_strides += batch_ndim; + } + + if (has_operand_mask) { + const constant auto* mask_strides_mat = mask_batch_strides; + const constant auto* mask_strides_vec = mask_strides_mat + batch_ndim; + + ulong2 batch_offsets = elem_to_loc_broadcast( + tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim); + + mat_mask += batch_offsets.x; + vec_mask += batch_offsets.y; + } + + } else { + in_vec += tid.z * vector_batch_stride[0]; + mat += tid.z * matrix_batch_stride[0]; + + if (has_output_mask) { + out_mask += tid.z * mask_batch_strides[0]; + mask_batch_strides += batch_ndim; + } + + if (has_operand_mask) { + mat_mask += tid.z * mask_batch_strides[0]; + vec_mask += tid.z * mask_batch_strides[batch_ndim]; + } + } + + out_vec += tid.z * out_vec_size; + + gemv_kernel::run( + mat, + in_vec, + out_vec, + in_vec_size, + out_vec_size, + marix_ld, + out_mask, + mat_mask, + vec_mask, + mask_strides, + gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory, + tid, + lid, + simd_gid, + simd_lid); +} + +/////////////////////////////////////////////////////////////////////////////// +/// Vector matrix multiplication +/////////////////////////////////////////////////////////////////////////////// + +template < + typename T, + typename out_mask_t, + typename op_mask_t, + const int BM, /* Threadgroup rows (in simdgroups) */ + const int BN, /* Threadgroup cols (in simdgroups) */ + const int SM, /* Simdgroup rows (in threads) */ + const int SN, /* Simdgroup cols (in threads) */ + const int TM, /* Thread rows (in elements) */ + const int TN, /* Thread cols (in elements) */ + const bool kDoNCBatch> /* Batch ndim > 1 */ +[[kernel, max_total_threads_per_threadgroup(BM * BN * 32)]] void gemv_t_masked( + const device T* mat [[buffer(0)]], + const device T* in_vec [[buffer(1)]], + device T* out_vec [[buffer(3)]], + const constant int& in_vec_size [[buffer(4)]], + const constant int& out_vec_size [[buffer(5)]], + const constant int& marix_ld [[buffer(6)]], + const constant int& batch_ndim [[buffer(9)]], + const constant int* batch_shape [[buffer(10)]], + const constant int64_t* vector_batch_stride [[buffer(11)]], + const constant int64_t* matrix_batch_stride [[buffer(12)]], + const device out_mask_t* out_mask [[buffer(20)]], + const device op_mask_t* mat_mask [[buffer(21)]], + const device op_mask_t* vec_mask [[buffer(22)]], + const constant int* mask_strides [[buffer(23)]], + const constant int64_t* mask_batch_strides [[buffer(24)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + using gemv_kernel = + GEMVTKernel; + threadgroup float tgp_memory + [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size]; + + constexpr bool has_operand_mask = !metal::is_same_v; + constexpr bool has_output_mask = !metal::is_same_v; + + // Update batch offsets + if (kDoNCBatch) { + in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim); + mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim); + + if (has_output_mask) { + out_mask += + elem_to_loc(tid.z, batch_shape, mask_batch_strides, batch_ndim); + mask_batch_strides += batch_ndim; + } + + if (has_operand_mask) { + const constant auto* mask_strides_mat = mask_batch_strides; + const constant auto* mask_strides_vec = mask_strides_mat + batch_ndim; + + ulong2 batch_offsets = elem_to_loc_broadcast( + tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim); + + mat_mask += batch_offsets.x; + vec_mask += batch_offsets.y; + } + + } else { + in_vec += tid.z * vector_batch_stride[0]; + mat += tid.z * matrix_batch_stride[0]; + + if (has_output_mask) { + out_mask += tid.z * mask_batch_strides[0]; + mask_batch_strides += batch_ndim; + } + + if (has_operand_mask) { + mat_mask += tid.z * mask_batch_strides[0]; + vec_mask += tid.z * mask_batch_strides[batch_ndim]; + } + } + + out_vec += tid.z * out_vec_size; + + gemv_kernel::run( + mat, + in_vec, + out_vec, + in_vec_size, + out_vec_size, + marix_ld, + out_mask, + mat_mask, + vec_mask, + mask_strides, + gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory, + tid, + lid, + simd_gid, + simd_lid); +} diff --git a/Source/Cxxmlx/include/mlx/backend/metal/kernels/hadamard.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/hadamard.h new file mode 100644 index 00000000..9f2311c1 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/metal/kernels/hadamard.h @@ -0,0 +1,182 @@ +// Copyright © 2024 Apple Inc. +#include +#include + +#include "mlx/backend/metal/kernels/steel/defines.h" + +using namespace metal; + +// Thread local Hadamard transform for 2^R +template +METAL_FUNC void radix_func(thread float* x) { + constexpr short logR = __builtin_ctz(R); + short h = 1; + STEEL_PRAGMA_UNROLL + for (short s = 0; s < logR; s++) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < R / 2; i++) { + short k = i & (h - 1); + short j = ((i - k) << 1) + k; + float a = x[j]; + float b = x[j + h]; + x[j] = a + b; + x[j + h] = a - b; + } + h <<= 1; + } +} + +template +[[kernel]] void hadamard_n( + const device T* in [[buffer(0)]], + device T* out [[buffer(1)]], + constant const float& scale, + uint3 elem [[thread_position_in_grid]], + uint3 grid [[threads_per_grid]]) { + // Compute a Hadamard transform of size N = 2^k + // + // Equivalent to: + // from scipy.linalg import hadamard + // y = hadamard(len(x)) @ x + + constexpr short num_threads = N / max_radix; + constexpr short logN = __builtin_ctz(N); + constexpr short logR = __builtin_ctz(max_radix); + constexpr short num_steps = logN / logR; + constexpr short logFinal = logN % logR; + constexpr short final_radix = 1 << (logFinal); + + int batch_idx = elem.y * N * stride + elem.z; + short i = elem.x; + + threadgroup T buf[N]; + + // Read values from device + if (stride == 1) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < max_radix / read_width; j++) { + short index = j * read_width * num_threads + i * read_width; + STEEL_PRAGMA_UNROLL + for (short r = 0; r < read_width; r++) { + buf[index + r] = in[batch_idx + index + r]; + } + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < max_radix; j++) { + buf[j * num_threads + i] = in[batch_idx + (j * num_threads + i) * stride]; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + float x[max_radix]; + short h = 1; + + STEEL_PRAGMA_UNROLL + for (short s = 0; s < num_steps; s++) { + short k = i & (h - 1); + short j = ((i - k) << logR) + k; + + STEEL_PRAGMA_UNROLL + for (short r = 0; r < max_radix; r++) { + x[r] = buf[j + h * r]; + } + + radix_func(x); + + STEEL_PRAGMA_UNROLL + for (short r = 0; r < max_radix; r++) { + buf[j + h * r] = T(x[r]); + } + + h <<= logR; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + // Do the final radix + // e.g. max_radix = 16 + // N = 1024 = 16 * 16 * 4 + if (final_radix > 1) { + // Each thread does multiple butterflies + STEEL_PRAGMA_UNROLL + for (int t = 0; t < max_radix / final_radix; t++) { + short index = i + t * num_threads; + short k = index & (h - 1); + short j = ((index - k) << logFinal) + k; + STEEL_PRAGMA_UNROLL + for (short r = 0; r < final_radix; r++) { + x[r] = buf[j + h * r]; + } + + radix_func(x); + + STEEL_PRAGMA_UNROLL + for (short r = 0; r < final_radix; r++) { + buf[j + h * r] = T(x[r]); + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + // Write values to device + if (stride == 1) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < max_radix / read_width; j++) { + short index = j * read_width * num_threads + i * read_width; + STEEL_PRAGMA_UNROLL + for (short r = 0; r < read_width; r++) { + out[batch_idx + index + r] = T(buf[index + r] * scale); + } + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < max_radix; j++) { + out[batch_idx + (j * num_threads + i) * stride] = + buf[j * num_threads + i]; + } + } +} + +template +[[kernel]] void hadamard_m( + const device T* in [[buffer(0)]], + device T* out [[buffer(1)]], + constant const float& scale, + uint3 elem [[thread_position_in_grid]], + uint3 grid [[threads_per_grid]]) { + // Compute a Hadamard transform of size M + // using a naive O(M^2) codelet. + // + // This kernel is the second stage in the computation + // of a Hadamard transform of size M*N where N = 2^k. + + int index = elem.x * grid.y + elem.y; + short i = index % (N / read_width); + int batch_idx = index / (N / read_width) * M * N; + + float x[read_width][M]; + STEEL_PRAGMA_UNROLL + for (short c = 0; c < M; c++) { + STEEL_PRAGMA_UNROLL + for (short r = 0; r < read_width; r++) { + x[r][c] = in[batch_idx + c * N + i * read_width + r]; + } + } + + STEEL_PRAGMA_UNROLL + for (short r = 0; r < read_width; r++) { + // This function is JIT compiled for M + // using the Hadamard matrix strings in `metal/hadamard.cpp` + hadamard_radix_m(x[r]); + } + + // Write back to device + STEEL_PRAGMA_UNROLL + for (short c = 0; c < M; c++) { + STEEL_PRAGMA_UNROLL + for (short r = 0; r < read_width; r++) { + out[batch_idx + c * N + i * read_width + r] = T(x[r][c] * scale); + } + } +} diff --git a/Source/Cxxmlx/include/mlx/backend/metal/kernels/indexing/gather.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/indexing/gather.h new file mode 100644 index 00000000..8b93c016 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/metal/kernels/indexing/gather.h @@ -0,0 +1,51 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/backend/metal/kernels/indexing/indexing.h" + +template +METAL_FUNC void gather_impl( + const device T* src [[buffer(0)]], + device T* out [[buffer(1)]], + const constant int* src_shape [[buffer(2)]], + const constant int64_t* src_strides [[buffer(3)]], + const constant size_t& src_ndim [[buffer(4)]], + const constant int* slice_sizes [[buffer(5)]], + const constant int* axes [[buffer(6)]], + const thread Indices& indices, + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + LocT src_idx = 0; + for (int i = 0; i < NIDX; ++i) { + LocT idx_loc; + if (IDX_NDIM == 0) { + idx_loc = 0; + } else if (IDX_NDIM == 1) { + idx_loc = index.x * static_cast(indices.strides[indices.ndim * i]); + } else { + idx_loc = index.x * static_cast(indices.strides[indices.ndim * i]); + idx_loc += indices.row_contiguous[i] + ? index.y + : elem_to_loc( + index.y, + &indices.shapes[indices.ndim * i + 1], + &indices.strides[indices.ndim * i + 1], + indices.ndim - 1); + } + auto ax = axes[i]; + auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]); + src_idx += static_cast(idx_val) * static_cast(src_strides[ax]); + } + + auto src_offset = + elem_to_loc(index.z, slice_sizes, src_strides, src_ndim); + + LocT out_idx = index.z; + if (IDX_NDIM == 1) { + out_idx += static_cast(grid_dim.z) * index.x; + } else if (IDX_NDIM >= 2) { + out_idx += grid_dim.z * (index.x * static_cast(grid_dim.y) + index.y); + } + out[out_idx] = src[src_offset + src_idx]; +} diff --git a/Source/Cmlx/mlx-generated/metal/indexing/gather_axis.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/indexing/gather_axis.h rename to Source/Cxxmlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h diff --git a/Source/Cxxmlx/include/mlx/backend/metal/kernels/indexing/gather_front.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/indexing/gather_front.h new file mode 100644 index 00000000..1389e4c6 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/metal/kernels/indexing/gather_front.h @@ -0,0 +1,24 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/metal/kernels/indexing/indexing.h" + +template +[[kernel]] void gather_front( + const device T* src, + const device IdxT* indices, + device T* out, + const constant int64_t& stride, + const constant int& size, + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + auto idx = offset_neg_idx(indices[index.y], size); + LocT src_idx = static_cast(stride) * idx; + LocT out_idx = static_cast(stride) * index.y; + + int s_idx = N * index.x; + for (int i = 0; i < N && s_idx < stride; ++i, ++s_idx) { + out[out_idx + s_idx] = src[src_idx + s_idx]; + } +} diff --git a/Source/Cmlx/mlx-generated/metal/indexing/indexing.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/indexing/indexing.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/indexing/indexing.h rename to Source/Cxxmlx/include/mlx/backend/metal/kernels/indexing/indexing.h diff --git a/Source/Cmlx/mlx-generated/metal/indexing/masked_scatter.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/indexing/masked_scatter.h rename to Source/Cxxmlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h diff --git a/Source/Cxxmlx/include/mlx/backend/metal/kernels/indexing/scatter.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/indexing/scatter.h new file mode 100644 index 00000000..f0217b33 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/metal/kernels/indexing/scatter.h @@ -0,0 +1,59 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/backend/metal/kernels/indexing/indexing.h" + +template < + typename T, + typename IdxT, + typename Op, + int NIDX, + bool UPD_ROW_CONTIG, + int NWORK, + typename LocT> +METAL_FUNC void scatter_impl( + const device T* updates, + device mlx_atomic* out, + const constant int* upd_shape, + const constant int64_t* upd_strides, + const constant size_t& upd_ndim, + const constant size_t& upd_size, + const constant int* out_shape, + const constant int64_t* out_strides, + const constant size_t& out_ndim, + const constant int* axes, + const constant size_t& idx_size, + const thread Indices& indices, + uint2 gid [[thread_position_in_grid]]) { + Op op; + + auto ind_idx = gid.y * NWORK; + LocT out_offset = 0; + if (upd_size > 1) { + out_offset = elem_to_loc( + gid.x, upd_shape + indices.ndim, out_strides, out_ndim); + } + + for (int j = 0; j < NWORK && ind_idx < idx_size; ++j, ind_idx++) { + LocT out_idx = out_offset; + for (int i = 0; i < NIDX; ++i) { + auto idx_loc = indices.row_contiguous[i] + ? ind_idx + : elem_to_loc( + ind_idx, + &indices.shapes[indices.ndim * i], + &indices.strides[indices.ndim * i], + indices.ndim); + auto ax = axes[i]; + auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], out_shape[ax]); + out_idx += + static_cast(idx_val) * static_cast(out_strides[ax]); + } + auto upd_idx = ind_idx * static_cast(upd_size) + gid.x; + if constexpr (!UPD_ROW_CONTIG) { + upd_idx = elem_to_loc(upd_idx, upd_shape, upd_strides, upd_ndim); + } + op.atomic_update(out, updates[upd_idx], out_idx); + } +} diff --git a/Source/Cmlx/mlx-generated/metal/indexing/scatter_axis.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/indexing/scatter_axis.h rename to Source/Cxxmlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h diff --git a/Source/Cmlx/mlx-generated/metal/logging.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/logging.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/logging.h rename to Source/Cxxmlx/include/mlx/backend/metal/kernels/logging.h diff --git a/Source/Cmlx/mlx-generated/metal/logsumexp.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/logsumexp.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/logsumexp.h rename to Source/Cxxmlx/include/mlx/backend/metal/kernels/logsumexp.h diff --git a/Source/Cmlx/mlx-generated/metal/quantized.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/quantized.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/quantized.h rename to Source/Cxxmlx/include/mlx/backend/metal/kernels/quantized.h diff --git a/Source/Cmlx/mlx-generated/metal/quantized_nax.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/quantized_nax.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/quantized_nax.h rename to Source/Cxxmlx/include/mlx/backend/metal/kernels/quantized_nax.h diff --git a/Source/Cmlx/mlx-generated/metal/quantized_utils.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/quantized_utils.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/quantized_utils.h rename to Source/Cxxmlx/include/mlx/backend/metal/kernels/quantized_utils.h diff --git a/Source/Cxxmlx/include/mlx/backend/metal/kernels/reduce.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/reduce.h new file mode 100644 index 00000000..ee5c3d5c --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/metal/kernels/reduce.h @@ -0,0 +1,5 @@ +#pragma once +#include "mlx/backend/metal/kernels/reduction/reduce_all.h" +#include "mlx/backend/metal/kernels/reduction/reduce_col.h" +#include "mlx/backend/metal/kernels/reduction/reduce_init.h" +#include "mlx/backend/metal/kernels/reduction/reduce_row.h" diff --git a/Source/Cxxmlx/include/mlx/backend/metal/kernels/reduce_utils.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/reduce_utils.h new file mode 100644 index 00000000..279a7afe --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/metal/kernels/reduce_utils.h @@ -0,0 +1,6 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/backend/metal/kernels/atomic.h" +#include "mlx/backend/metal/kernels/reduction/ops.h" diff --git a/Source/Cmlx/mlx-generated/metal/reduction/ops.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/reduction/ops.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/reduction/ops.h rename to Source/Cxxmlx/include/mlx/backend/metal/kernels/reduction/ops.h diff --git a/Source/Cmlx/mlx-generated/metal/reduction/reduce_all.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/reduction/reduce_all.h rename to Source/Cxxmlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h diff --git a/Source/Cmlx/mlx-generated/metal/reduction/reduce_col.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/reduction/reduce_col.h rename to Source/Cxxmlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h diff --git a/Source/Cmlx/mlx-generated/metal/reduction/reduce_init.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/reduction/reduce_init.h rename to Source/Cxxmlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h diff --git a/Source/Cmlx/mlx-generated/metal/reduction/reduce_row.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/reduction/reduce_row.h rename to Source/Cxxmlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h diff --git a/Source/Cxxmlx/include/mlx/backend/metal/kernels/scan.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/scan.h new file mode 100644 index 00000000..16682613 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/metal/kernels/scan.h @@ -0,0 +1,514 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include "mlx/backend/metal/kernels/binary_ops.h" + +#define DEFINE_SIMD_SCAN() \ + template = true> \ + T simd_scan(T val) { \ + return simd_scan_impl(val); \ + } \ + \ + template = true> \ + T simd_scan(T val) { \ + for (int i = 1; i <= 16; i *= 2) { \ + val = operator()(val, simd_shuffle_and_fill_up(val, init, i)); \ + } \ + return val; \ + } + +#define DEFINE_SIMD_EXCLUSIVE_SCAN() \ + template = true> \ + T simd_exclusive_scan(T val) { \ + return simd_exclusive_scan_impl(val); \ + } \ + \ + template = true> \ + T simd_exclusive_scan(T val) { \ + val = simd_scan(val); \ + return simd_shuffle_and_fill_up(val, init, 1); \ + } + +template +struct CumSum { + DEFINE_SIMD_SCAN() + DEFINE_SIMD_EXCLUSIVE_SCAN() + + static constexpr constant U init = static_cast(0); + + template + U operator()(U a, T b) { + return a + b; + } + + U simd_scan_impl(U x) { + return simd_prefix_inclusive_sum(x); + } + + U simd_exclusive_scan_impl(U x) { + return simd_prefix_exclusive_sum(x); + } +}; + +template +struct CumProd { + DEFINE_SIMD_SCAN() + DEFINE_SIMD_EXCLUSIVE_SCAN() + + static constexpr constant U init = static_cast(1.0f); + + template + U operator()(U a, T b) { + return a * b; + } + + U simd_scan_impl(U x) { + return simd_prefix_inclusive_product(x); + } + + U simd_exclusive_scan_impl(U x) { + return simd_prefix_exclusive_product(x); + } +}; + +template <> +struct CumProd { + static constexpr constant bool init = true; + + template + bool operator()(bool a, T b) { + return a & static_cast(b); + } + + bool simd_scan(bool x) { + for (int i = 1; i <= 16; i *= 2) { + bool other = simd_shuffle_and_fill_up(x, init, i); + x &= other; + } + return x; + } + + bool simd_exclusive_scan(bool x) { + x = simd_scan(x); + return simd_shuffle_and_fill_up(x, init, 1); + } +}; + +template +struct CumMax { + static constexpr constant U init = Limits::min; + + template + U operator()(U a, T b) { + return (a >= b) ? a : b; + } + + U simd_scan(U x) { + for (int i = 1; i <= 16; i *= 2) { + U other = simd_shuffle_and_fill_up(x, init, i); + x = (x >= other) ? x : other; + } + return x; + } + + U simd_exclusive_scan(U x) { + x = simd_scan(x); + return simd_shuffle_and_fill_up(x, init, 1); + } +}; + +template +struct CumMin { + static constexpr constant U init = Limits::max; + + template + U operator()(U a, T b) { + return (a <= b) ? a : b; + } + + U simd_scan(U x) { + for (int i = 1; i <= 16; i *= 2) { + U other = simd_shuffle_and_fill_up(x, init, i); + x = (x <= other) ? x : other; + } + return x; + } + + U simd_exclusive_scan(U x) { + x = simd_scan(x); + return simd_shuffle_and_fill_up(x, init, 1); + } +}; + +template +struct CumLogaddexp { + static constexpr constant U init = Limits::min; + + template + U operator()(U a, T b) { + return LogAddExp{}(a, static_cast(b)); + } + + U simd_scan(U x) { + for (int i = 1; i <= 16; i *= 2) { + U other = simd_shuffle_and_fill_up(x, init, i); + x = LogAddExp{}(x, other); + } + return x; + } + + U simd_exclusive_scan(U x) { + x = simd_scan(x); + return simd_shuffle_and_fill_up(x, init, 1); + } +}; + +template +inline void load_unsafe(U values[N_READS], const device T* input) { + if (reverse) { + for (int i = 0; i < N_READS; i++) { + values[N_READS - i - 1] = input[i]; + } + } else { + for (int i = 0; i < N_READS; i++) { + values[i] = input[i]; + } + } +} + +template +inline void load_safe( + U values[N_READS], + const device T* input, + int start, + int total, + U init) { + if (reverse) { + for (int i = 0; i < N_READS; i++) { + values[N_READS - i - 1] = + (start + N_READS - i - 1 < total) ? input[i] : init; + } + } else { + for (int i = 0; i < N_READS; i++) { + values[i] = (start + i < total) ? input[i] : init; + } + } +} + +template +inline void write_unsafe(U values[N_READS], device U* out) { + if (reverse) { + for (int i = 0; i < N_READS; i++) { + out[i] = values[N_READS - i - 1]; + } + } else { + for (int i = 0; i < N_READS; i++) { + out[i] = values[i]; + } + } +} + +template +inline void write_safe(U values[N_READS], device U* out, int start, int total) { + if (reverse) { + for (int i = 0; i < N_READS; i++) { + if (start + N_READS - i - 1 < total) { + out[i] = values[N_READS - i - 1]; + } + } + } else { + for (int i = 0; i < N_READS; i++) { + if (start + i < total) { + out[i] = values[i]; + } + } + } +} + +template < + typename T, + typename U, + typename Op, + int N_READS, + bool inclusive, + bool reverse> +[[kernel]] void contiguous_scan( + const device T* in [[buffer(0)]], + device U* out [[buffer(1)]], + const constant size_t& axis_size [[buffer(2)]], + uint3 gid [[threadgroup_position_in_grid]], + uint3 gsize [[threadgroups_per_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 lsize [[threads_per_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + constexpr int simd_size = 32; + Op op; + + // Position the pointers + size_t offset = (gid.y + gsize.y * size_t(gid.z)) * axis_size; + in += offset; + out += offset; + + // Compute the number of simd_groups + uint simd_groups = lsize.x / simd_size; + + // Allocate memory + U prefix = Op::init; + U values[N_READS]; + threadgroup U simdgroup_sums[32]; + + // Loop over the reduced axis in blocks of size ceildiv(axis_size, + // N_READS*lsize) + // Read block + // Compute inclusive scan of the block + // Compute inclusive scan per thread + // Compute exclusive scan of thread sums in simdgroup + // Write simdgroup sums in SM + // Compute exclusive scan of simdgroup sums + // Compute the output by scanning prefix, prev_simdgroup, prev_thread, + // value + // Write block + + for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize.x); r++) { + // Compute the block offset + uint offset = r * lsize.x * N_READS + lid.x * N_READS; + + // Read the values + if (reverse) { + if ((offset + N_READS) < axis_size) { + load_unsafe( + values, in + axis_size - offset - N_READS); + } else { + load_safe( + values, + in + axis_size - offset - N_READS, + offset, + axis_size, + Op::init); + } + } else { + if ((offset + N_READS) < axis_size) { + load_unsafe(values, in + offset); + } else { + load_safe( + values, in + offset, offset, axis_size, Op::init); + } + } + + // Compute an inclusive scan per thread + for (int i = 1; i < N_READS; i++) { + values[i] = op(values[i], values[i - 1]); + } + + // Compute exclusive scan of thread sums + U prev_thread = op.simd_exclusive_scan(values[N_READS - 1]); + + // Write simdgroup_sums to SM + threadgroup_barrier(mem_flags::mem_threadgroup); + if (simd_lane_id == simd_size - 1) { + simdgroup_sums[simd_group_id] = op(prev_thread, values[N_READS - 1]); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Compute exclusive scan of simdgroup_sums + if (simd_group_id == 0) { + U prev_simdgroup = op.simd_exclusive_scan(simdgroup_sums[simd_lane_id]); + simdgroup_sums[simd_lane_id] = prev_simdgroup; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Compute the output + for (int i = 0; i < N_READS; i++) { + values[i] = op(values[i], prefix); + values[i] = op(values[i], simdgroup_sums[simd_group_id]); + values[i] = op(values[i], prev_thread); + } + + // Write the values + if (reverse) { + if (inclusive) { + if ((offset + N_READS) < axis_size) { + write_unsafe( + values, out + axis_size - offset - N_READS); + } else { + write_safe( + values, out + axis_size - offset - N_READS, offset, axis_size); + } + } else { + if (lid.x == 0 && offset == 0) { + out[axis_size - 1] = Op::init; + } + if ((offset + N_READS + 1) < axis_size) { + write_unsafe( + values, out + axis_size - offset - 1 - N_READS); + } else { + write_safe( + values, + out + axis_size - offset - 1 - N_READS, + offset + 1, + axis_size); + } + } + } else { + if (inclusive) { + if ((offset + N_READS) < axis_size) { + write_unsafe(values, out + offset); + } else { + write_safe( + values, out + offset, offset, axis_size); + } + } else { + if (lid.x == 0 && offset == 0) { + out[0] = Op::init; + } + if ((offset + N_READS + 1) < axis_size) { + write_unsafe(values, out + offset + 1); + } else { + write_safe( + values, out + offset + 1, offset + 1, axis_size); + } + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Share the prefix + if (simd_group_id == simd_groups - 1 && simd_lane_id == simd_size - 1) { + simdgroup_sums[0] = values[N_READS - 1]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + prefix = simdgroup_sums[0]; + } +} + +template < + typename T, + typename U, + typename Op, + int N_READS, + bool inclusive, + bool reverse> +[[kernel]] void strided_scan( + const device T* in [[buffer(0)]], + device U* out [[buffer(1)]], + const constant size_t& axis_size [[buffer(2)]], + const constant size_t& stride [[buffer(3)]], + const constant size_t& stride_blocks [[buffer(4)]], + uint3 gid [[threadgroup_position_in_grid]], + uint3 gsize [[threadgroups_per_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + constexpr int simd_size = 32; + constexpr int BM = 32; + constexpr int BN = 32; + constexpr int BN_pad = 32 + 16 / sizeof(U); + constexpr int n_simds = BN / N_READS; + constexpr int n_scans = BN / n_simds; + Op op; + + threadgroup U read_buffer[BM * BN_pad]; + U values[n_scans]; + U prefix[n_scans]; + for (int i = 0; i < n_scans; i++) { + prefix[i] = Op::init; + } + + // Compute offsets + size_t full_gid = gid.y + gsize.y * size_t(gid.z); + size_t offset = full_gid / stride_blocks * axis_size * stride; + size_t global_index_x = full_gid % stride_blocks * BN; + uint read_offset_y = (lid.x * N_READS) / BN; + uint read_offset_x = (lid.x * N_READS) % BN; + uint scan_offset_y = simd_lane_id; + uint scan_offset_x = simd_group_id * n_scans; + + uint stride_limit = stride - global_index_x; + in += offset + global_index_x + read_offset_x; + out += offset + global_index_x + read_offset_x; + threadgroup U* read_into = + read_buffer + read_offset_y * BN_pad + read_offset_x; + threadgroup U* read_from = + read_buffer + scan_offset_y * BN_pad + scan_offset_x; + + for (uint j = 0; j < axis_size; j += BM) { + // Calculate the indices for the current thread + uint index_y = j + read_offset_y; + uint check_index_y = index_y; + if (reverse) { + index_y = axis_size - 1 - index_y; + } + + // Read in SM + threadgroup_barrier(mem_flags::mem_threadgroup); + if (check_index_y < axis_size && (read_offset_x + N_READS) < stride_limit) { + for (int i = 0; i < N_READS; i++) { + read_into[i] = in[index_y * stride + i]; + } + } else { + for (int i = 0; i < N_READS; i++) { + if (check_index_y < axis_size && (read_offset_x + i) < stride_limit) { + read_into[i] = in[index_y * stride + i]; + } else { + read_into[i] = Op::init; + } + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Read strided into registers + for (int i = 0; i < n_scans; i++) { + values[i] = read_from[i]; + } + simdgroup_barrier(mem_flags::mem_threadgroup); + + // Perform the scan + for (int i = 0; i < n_scans; i++) { + values[i] = op.simd_scan(values[i]); + values[i] = op(values[i], prefix[i]); + prefix[i] = simd_shuffle(values[i], simd_size - 1); + } + + // Write to SM + for (int i = 0; i < n_scans; i++) { + read_from[i] = values[i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Write to device memory + if (!inclusive) { + if (check_index_y == 0) { + if ((read_offset_x + N_READS) < stride_limit) { + for (int i = 0; i < N_READS; i++) { + out[index_y * stride + i] = Op::init; + } + } else { + for (int i = 0; i < N_READS; i++) { + if ((read_offset_x + i) < stride_limit) { + out[index_y * stride + i] = Op::init; + } + } + } + } + if (reverse) { + index_y -= 1; + check_index_y += 1; + } else { + index_y += 1; + check_index_y += 1; + } + } + if (check_index_y < axis_size && (read_offset_x + N_READS) < stride_limit) { + for (int i = 0; i < N_READS; i++) { + out[index_y * stride + i] = read_into[i]; + } + } else { + for (int i = 0; i < N_READS; i++) { + if (check_index_y < axis_size && (read_offset_x + i) < stride_limit) { + out[index_y * stride + i] = read_into[i]; + } + } + } + } +} diff --git a/Source/Cmlx/mlx-generated/metal/sdpa_vector.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/sdpa_vector.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/sdpa_vector.h rename to Source/Cxxmlx/include/mlx/backend/metal/kernels/sdpa_vector.h diff --git a/Source/Cmlx/mlx-generated/metal/softmax.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/softmax.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/softmax.h rename to Source/Cxxmlx/include/mlx/backend/metal/kernels/softmax.h diff --git a/Source/Cmlx/mlx-generated/metal/sort.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/sort.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/sort.h rename to Source/Cxxmlx/include/mlx/backend/metal/kernels/sort.h diff --git a/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/attn/attn.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/attn/attn.h new file mode 100644 index 00000000..991d4d69 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/attn/attn.h @@ -0,0 +1,296 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/backend/metal/kernels/steel/attn/loader.h" +#include "mlx/backend/metal/kernels/steel/attn/mma.h" +#include "mlx/backend/metal/kernels/steel/attn/params.h" +#include "mlx/backend/metal/kernels/steel/attn/transforms.h" +#include "mlx/backend/metal/kernels/steel/gemm/params.h" +#include "mlx/backend/metal/kernels/steel/utils.h" + +using namespace metal; + +/////////////////////////////////////////////////////////////////////////////// +// GEMM kernel class +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +template +struct LoopAlignment {}; + +template < + typename T, + typename U, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + bool MN_aligned, + bool K_aligned, + typename AccumType = typename AccumHelper::accum_type, + typename Epilogue = TransformNone> +struct GEMMKernel { + STEEL_CONST short tgp_padding_a = 16 / sizeof(T); + STEEL_CONST short tgp_padding_b = 16 / sizeof(T); + STEEL_CONST short tgp_mem_size_a = + transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a); + STEEL_CONST short tgp_mem_size_b = + transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b); + STEEL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b; + + STEEL_CONST short tgp_size = WM * WN * 32; + + using loader_a_t = BlockLoader< + T, + transpose_a ? BK : BM, + transpose_a ? BM : BK, + transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a, + !transpose_a, + tgp_size>; + using loader_b_t = BlockLoader< + T, + transpose_b ? BN : BK, + transpose_b ? BK : BN, + transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b, + transpose_b, + tgp_size>; + using mma_t = BlockMMA< + T, + U, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a, + transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b, + AccumType, + Epilogue>; + + /* Main kernel function */ + template + static METAL_FUNC void gemm_loop( + threadgroup T* As [[threadgroup(0)]], + threadgroup T* Bs [[threadgroup(1)]], + const int gemm_k_iterations, + thread loader_a_t& loader_a, + thread loader_b_t& loader_b, + thread mma_t& mma_op, + thread const short& tgp_bm, + thread const short& tgp_bn, + thread const short& lbk, + LoopAlignment l = {}) { + // Appease the compiler + (void)l; + + short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm); + + short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK); + + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + if (M_aligned) { + loader_a.load_unsafe(); + } else { + loader_a.load_safe(tile_dims_A); + } + + if (N_aligned) { + loader_b.load_unsafe(); + } else { + loader_b.load_safe(tile_dims_B); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + if (!K_aligned_) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + short2 tile_dims_A_last = + transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm); + short2 tile_dims_B_last = + transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk); + + loader_a.load_safe(tile_dims_A_last); + loader_b.load_safe(tile_dims_B_last); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + mma_op.mma(As, Bs); + } + } + + /* Main kernel function */ + static METAL_FUNC void run( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + device U* D [[buffer(2)]], + const constant GEMMParams* params [[buffer(3)]], + threadgroup T* As [[threadgroup(0)]], + threadgroup T* Bs [[threadgroup(1)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + // Pacifying compiler + (void)lid; + + const int tid_y = ((tid.y) << params->swizzle_log) + + ((tid.x) & ((1 << params->swizzle_log) - 1)); + const int tid_x = (tid.x) >> params->swizzle_log; + + if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { + return; + } + + threadgroup_barrier(mem_flags::mem_none); + + // Find block in A, B, C + const int c_row = tid_y * BM; + const int c_col = tid_x * BN; + const size_t c_row_long = size_t(c_row); + const size_t c_col_long = size_t(c_col); + + A += transpose_a ? c_row_long : c_row_long * params->lda; + B += transpose_b ? c_col_long * params->ldb : c_col_long; + D += c_row_long * params->ldd + c_col_long; + + // Prepare threadgroup loading operations + thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); + thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); + + // Prepare threadgroup mma operation + thread mma_t mma_op(simd_group_id, simd_lane_id); + + int gemm_k_iterations = params->gemm_k_iterations_aligned; + + /////////////////////////////////////////////////////////////////////////////// + // MNK aligned loop + if (MN_aligned) { + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + threadgroup_barrier(mem_flags::mem_none); + + // Loop tail + if (!K_aligned) { + int lbk = params->K - params->gemm_k_iterations_aligned * BK; + short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM); + short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk); + + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + mma_op.mma(As, Bs); + } + + // Store results to device memory + mma_op.store_result(D, params->ldd); + return; + + } + /////////////////////////////////////////////////////////////////////////////// + // MN unaligned loop + else { // Loop over K - unaligned case + short tgp_bm = min(BM, params->M - c_row); + short tgp_bn = min(BN, params->N - c_col); + short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK; + + if (tgp_bm == BM && tgp_bn == BN) { + gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk); + + mma_op.store_result(D, params->ldd); + return; + + } else if (tgp_bn == BN) { + gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk); + + mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); + return; + + } else if (tgp_bm == BM) { + gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk); + + mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); + return; + + } else { + gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk); + + mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); + return; + } + } + } +}; + +} // namespace steel +} // namespace mlx \ No newline at end of file diff --git a/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h new file mode 100644 index 00000000..49183094 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h @@ -0,0 +1,471 @@ +// Copyright © 2024-25 Apple Inc. + +#include "mlx/backend/metal/kernels/steel/attn/attn.h" + +using namespace mlx::steel; + +/////////////////////////////////////////////////////////////////////////////// +// GEMM kernels +/////////////////////////////////////////////////////////////////////////////// + +constant bool align_Q [[function_constant(200)]]; +constant bool align_K [[function_constant(201)]]; + +constant bool has_mask [[function_constant(300)]]; +constant bool do_causal [[function_constant(301)]]; +constant bool has_sinks [[function_constant(302)]]; + +struct MaxOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return metal::max(x, y); + } +}; + +struct SumOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return x + y; + } +}; + +struct MulOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return x * y; + } +}; + +struct SubOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return x - y; + } +}; + +struct ExpSubOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return fast::exp2(x - y); + } +}; + +struct DivOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return x / y; + } +}; + +// clang-format off +template < + typename T, + int BQ, + int BK, + int BD, + int WM, + int WN, + typename MaskType = float, + typename AccumType = float> +[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void attention( + const device T* Q [[buffer(0)]], + const device T* K [[buffer(1)]], + const device T* V [[buffer(2)]], + device T* O [[buffer(3)]], + const constant AttnParams* params [[buffer(4)]], + const constant AttnMaskParams* mask_params [[buffer(5), function_constant(has_mask)]], + const device MaskType* mask [[buffer(6), function_constant(has_mask)]], + const device T* sinks [[buffer(7), function_constant(has_sinks)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on + + // Pacifying compiler + (void)lid; + + // Move to correct block + ulong3 tidl{tid.x, tid.y, tid.z}; + + Q += tidl.z * params->Q_strides[0] + // Batch + tidl.y * params->Q_strides[1] + // Head + tidl.x * BQ * params->Q_strides[2]; // Sequence + + ulong kv_head_idx = int(tid.y) / params->gqa_factor; + K += tidl.z * params->K_strides[0] + // Batch + kv_head_idx * params->K_strides[1]; // Head + + V += tidl.z * params->V_strides[0] + // Batch + kv_head_idx * params->V_strides[1]; // Head + + O += tidl.z * params->O_strides[0] + // Batch + tidl.y * params->O_strides[1] + // Head + tidl.x * BQ * params->O_strides[2]; // Sequence + + if (has_mask) { + mask += tidl.z * mask_params->M_strides[0] + // Batch + tidl.y * mask_params->M_strides[1]; // Head + } + + // Prepare threadgroup memory + constexpr short padQ = 16 / sizeof(T); + constexpr short padK = 16 / sizeof(T); + constexpr short padV = 16 / sizeof(T); + + constexpr short LDQ_tgp = BD + padQ; + constexpr short LDK_tgp = BK + padK; + constexpr short LDV_tgp = BD + padV; + + constexpr short tgp_mem_0 = (BK + padK) * (BD); + constexpr short tgp_mem_1 = BK * (BD + padV); + constexpr short tgp_mem_s = tgp_mem_0 > tgp_mem_1 ? tgp_mem_0 : tgp_mem_1; + + threadgroup T Q_smem[BQ * (BD + padQ)]; + threadgroup T KV_smem[tgp_mem_s]; + + threadgroup T* Qs = Q_smem; + threadgroup T* Ks = KV_smem; + threadgroup T* Vs = KV_smem; + + // Prepare block loaders + using QBlockLoader = BlockLoaderT< + /* typename T = */ T, + /* short BROWS = */ BQ, + /* short BCOLS = */ BD, + /* short kDstStrRow = */ LDQ_tgp, + /* short kDstStrCol = */ 1, + /* short reduction_dim = */ 1, + /* short tgp_size = */ WM * WN * 32>; + + // K is loaded in transposed + using KBlockLoader = BlockLoaderT< + /* typename T = */ T, + /* short BROWS = */ BK, + /* short BCOLS = */ BD, + /* short kDstStrRow = */ 1, + /* short kDstStrCol = */ LDK_tgp, + /* short reduction_dim = */ 0, + /* short tgp_size = */ WM * WN * 32>; + + using VBlockLoader = BlockLoaderT< + /* typename T = */ T, + /* short BROWS = */ BK, + /* short BCOLS = */ BD, + /* short kDstStrRow = */ LDV_tgp, + /* short kDstStrCol = */ 1, + /* short reduction_dim = */ 0, + /* short tgp_size = */ WM * WN * 32>; + + QBlockLoader loader_q( + Q, params->Q_strides[2], Qs, simd_group_id, simd_lane_id); + KBlockLoader loader_k( + K, params->K_strides[2], Ks, simd_group_id, simd_lane_id); + VBlockLoader loader_v( + V, params->V_strides[2], Vs, simd_group_id, simd_lane_id); + + const AccumType scale = params->scale * M_LOG2E_F; + + // Prepare MMA tiles + constexpr short kFragSize = 8; // MMAFrag size + using MMAFrag_acc_t = BaseMMAFrag; + + constexpr int kNWarps = WM * WN; + static_assert( + BQ >= (kNWarps * kFragSize) && BQ % (kNWarps * kFragSize) == 0, + "Each simdgroup must host atleast 1 simdgroup matrix along Q sequence."); + + // Q seq frags per warp + constexpr int TQ = BQ / (kNWarps * kFragSize); + // KV sequence frags (all warps load the same frags) + constexpr int TK = BK / kFragSize; + // HeadDim frags (all warps load the same frags) + constexpr int TD = BD / kFragSize; + + static_assert(TQ == 1, "Check TQ"); + + MMATile Qtile; + MMATile Ktile; + MMATile Stile; + MMATile Vtile; + MMATile Otile; + + Otile.clear(); + + // Prepare mma tile offsets + const short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id); + const short sm = simd_coord.y; + const short sn = simd_coord.x; + const short tm = kFragSize * TQ * simd_group_id; + + const short Qs_offset = (tm + sm) * LDQ_tgp + sn; + const short Ks_offset = sm * LDK_tgp + sn; + const short Vs_offset = sm * LDV_tgp + sn; + + constexpr short Qs_tile_stride = kFragSize; + constexpr short Ks_tile_stride = kFragSize * LDK_tgp; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load Q blocks + if (!align_Q && int(tid.x) == (params->NQ_aligned)) { + loader_q.load_safe(short2(BD, params->qL_rem)); + } else { + loader_q.load_unsafe(); + } + + // Init row reduction variables + constexpr short kRowsPT = decltype(Stile)::kRowsPerThread; + + AccumType max_score[kRowsPT]; + AccumType sum_score[kRowsPT] = {0}; + + // Init to -Inf + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + max_score[i] = Limits::finite_min; + } + + if (has_sinks) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + max_score[i] = M_LOG2E_F * static_cast(sinks[tidl.y]); + sum_score[i] = 1; + } + } + + int kb_lim = params->NK; + + if (do_causal) { + int q_max = (tid.x + 1) * BQ + params->qL_off; + kb_lim = (q_max + BK - 1) / BK; + kb_lim = min(params->NK, kb_lim); + } + + // Loop over KV seq length + for (int kb = 0; kb < kb_lim; kb++) { + // Load K block and apply scale + threadgroup_barrier(mem_flags::mem_threadgroup); + if (!align_K && kb == (params->NK_aligned)) { + loader_k.load_safe(short2(BD, params->kL_rem)); + } else { + loader_k.load_unsafe(); + } + + // Do S = Q @ K.T + Stile.clear(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + STEEL_PRAGMA_UNROLL + for (short dd = 0; dd < TD; dd++) { + simdgroup_barrier(mem_flags::mem_none); + + Qtile.template load( + &Qs[Qs_offset + dd * Qs_tile_stride]); + Ktile.template load( + &Ks[Ks_offset + dd * Ks_tile_stride]); + + simdgroup_barrier(mem_flags::mem_none); + + tile_matmad(Stile, Qtile, Ktile, Stile); + } + + // Apply scale in float32 + STEEL_PRAGMA_UNROLL + for (short ii = 0; ii < decltype(Stile)::kElemsPerTile; ii++) { + Stile.elems()[ii] *= scale; + } + + // Mask out length sequence + if (!align_K && kb == (params->NK_aligned)) { + using stile_t = decltype(Stile); + using selem_t = typename stile_t::elem_type; + constexpr auto neg_inf = Limits::finite_min; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < stile_t::kTileRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < stile_t::kTileCols; j++) { + short col_pos = sn + (j * stile_t::kFragCols); + STEEL_PRAGMA_UNROLL + for (short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) { + if ((col_pos + jj) >= params->kL_rem) { + Stile.frag_at(i, j)[jj] = neg_inf; + } + } + } + } + } + + // Mask out if causal + if (do_causal && kb >= (kb_lim - ((BQ + BK - 1) / BK) - int(!align_K))) { + using stile_t = decltype(Stile); + using selem_t = typename stile_t::elem_type; + constexpr auto neg_inf = Limits::finite_min; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < stile_t::kTileRows; i++) { + const int row_pos = + tid.x * BQ + params->qL_off + tm + sm + (i * stile_t::kFragRows); + STEEL_PRAGMA_UNROLL + for (short j = 0; j < stile_t::kTileCols; j++) { + const int col_pos = kb * BK + sn + (j * stile_t::kFragCols); + STEEL_PRAGMA_UNROLL + for (short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) { + if (row_pos < (col_pos + jj)) { + Stile.frag_at(i, j)[jj] = neg_inf; + } + } + } + } + } + + // Other masking as needed + if (has_mask) { + using stile_t = decltype(Stile); + using selem_t = typename stile_t::elem_type; + constexpr auto neg_inf = Limits::finite_min; + + constexpr bool is_bool = is_same_v; + using melem_t = typename metal::conditional_t; + + using MMAFrag_mask_t = BaseMMAFrag; + using frag_t = typename MMAFrag_mask_t::frag_type; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < stile_t::kTileRows; i++) { + const int row_pos = tid.x * BQ + tm + sm + (i * stile_t::kFragRows); + STEEL_PRAGMA_UNROLL + for (short j = 0; j < stile_t::kTileCols; j++) { + const int col_pos = kb * BK + sn + (j * stile_t::kFragCols); + + frag_t mfrag; + + MMAFrag_mask_t::load_safe( + mfrag, + mask, + int64_t(mask_params->M_strides[2]), + Int<1>{}, + params->qL, + params->kL, + row_pos, + col_pos); + + STEEL_PRAGMA_UNROLL + for (short jj = 0; jj < stile_t::MMAFrag_t::kElemsPerFrag; jj++) { + if constexpr (is_bool) { + Stile.frag_at(i, j)[jj] = + mfrag[jj] ? Stile.frag_at(i, j)[jj] : neg_inf; + } else { + Stile.frag_at(i, j)[jj] += M_LOG2E_F * selem_t(mfrag[jj]); + } + } + } + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load V blocks + if (!align_K && kb == (params->NK_aligned)) { + loader_v.load_safe(short2(BD, params->kL_rem)); + } else { + loader_v.load_unsafe(); + } + + // Do softmax + + // Temp variables + AccumType new_max[kRowsPT]; + AccumType factor[kRowsPT]; + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + new_max[i] = max_score[i]; + } + + // Row max + Stile.template row_reduce(new_max); + + // exp(Si - rowmax(Si)) + Stile.template row_bin_op(new_max); + + // Factor exp(rowmax(Si) - rowmax(Si-1)) + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + factor[i] = fast::exp2(max_score[i] - new_max[i]); + } + + // Save max for next iteration + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + max_score[i] = new_max[i]; + } + + // Row Sum + AccumType sum_score_tmp[kRowsPT] = {0}; + Stile.template row_reduce(sum_score_tmp); + + // Update norm + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + sum_score[i] = sum_score[i] * factor[i] + sum_score_tmp[i]; + } + + // Update O + Otile.template row_bin_op(factor); + + // Load V into registers + threadgroup_barrier(mem_flags::mem_threadgroup); + + STEEL_PRAGMA_UNROLL + for (short iq = 0; iq < TQ; iq++) { + STEEL_PRAGMA_UNROLL + for (short id = 0; id < TD; id++) { + STEEL_PRAGMA_UNROLL + for (short ik = 0; ik < TK; ik++) { + if constexpr (BD == 128) { + simdgroup_barrier(mem_flags::mem_none); + } + + const short kk = ik * kFragSize; + const short dd = id * kFragSize; + + Vtile.template load( + &Vs[Vs_offset + kk * LDV_tgp + dd]); + + if constexpr (BD == 128) { + simdgroup_barrier(mem_flags::mem_none); + } + + MMAFrag_acc_t::mma( + Otile.frag_at(iq, id), + Stile.frag_at(iq, ik), + Vtile.frag_at(0, 0), + Otile.frag_at(iq, id)); + } + } + } + + // Prepare for next iteration + loader_k.next(); + loader_v.next(); + } + + // Normalize output + Otile.template row_bin_op(sum_score); + threadgroup_barrier(mem_flags::mem_none); + + // Store results + O += (tm + sm) * params->O_strides[2] + sn; + + if (!align_Q && int(tid.x) == (params->NQ_aligned)) { + auto dst_tile_dims = short2(BD - sn, params->qL_rem - (tm + sm)); + + if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) + return; + + Otile.template store_safe(O, params->O_strides[2], dst_tile_dims); + } else { + Otile.template store(O, params->O_strides[2]); + } +} diff --git a/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h new file mode 100644 index 00000000..6abee21f --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h @@ -0,0 +1,481 @@ +// Copyright © 2024-25 Apple Inc. + +#include "mlx/backend/metal/kernels/steel/attn/nax.h" +#include "mlx/backend/metal/kernels/steel/attn/params.h" +#include "mlx/backend/metal/kernels/steel/attn/transforms.h" +#include "mlx/backend/metal/kernels/steel/utils.h" + +using namespace mlx::steel; + +/////////////////////////////////////////////////////////////////////////////// +// GEMM kernels +/////////////////////////////////////////////////////////////////////////////// + +constant bool align_Q [[function_constant(200)]]; +constant bool align_K [[function_constant(201)]]; + +constant bool has_mask [[function_constant(300)]]; +constant bool do_causal [[function_constant(301)]]; +constant bool has_sinks [[function_constant(302)]]; + +template +struct TransformScale { + T scale; + METAL_FUNC TransformScale(T scale_) : scale(scale_) {} + + METAL_FUNC T apply(T x) const { + return scale * x; + } +}; + +struct MaxOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return metal::max(x, y); + } +}; + +struct SumOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return x + y; + } +}; + +struct MulOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return x * y; + } +}; + +struct SubOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return x - y; + } +}; + +struct ExpSubOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return fast::exp2(x - y); + } +}; + +struct DivOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return x / y; + } +}; + +// clang-format off +template < + typename T, + int BQ, + int BK, + int BD, + int WM, + int WN, + typename MaskType = float, + typename AccumType = float> +[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void attention_nax( + const device T* Q [[buffer(0)]], + const device T* K [[buffer(1)]], + const device T* V [[buffer(2)]], + device T* O [[buffer(3)]], + const constant AttnParams* params [[buffer(4)]], + const constant AttnMaskParams* mask_params [[buffer(5), function_constant(has_mask)]], + const device MaskType* mask [[buffer(6), function_constant(has_mask)]], + const device T* sinks [[buffer(7), function_constant(has_sinks)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on + + // Pacifying compiler + (void)lid; + (void)simd_lane_id; + + // Move to correct block + ulong3 tidl{tid.x, tid.y, tid.z}; + + Q += tidl.z * params->Q_strides[0] + // Batch + tidl.y * params->Q_strides[1] + // Head + tidl.x * BQ * params->Q_strides[2]; // Sequence + + ulong kv_head_idx = int(tid.y) / params->gqa_factor; + K += tidl.z * params->K_strides[0] + // Batch + kv_head_idx * params->K_strides[1]; // Head + + V += tidl.z * params->V_strides[0] + // Batch + kv_head_idx * params->V_strides[1]; // Head + + O += tidl.z * params->O_strides[0] + // Batch + tidl.y * params->O_strides[1] + // Head + tidl.x * BQ * params->O_strides[2]; // Sequence + + if (has_mask) { + mask += tidl.z * mask_params->M_strides[0] + // Batch + tidl.y * mask_params->M_strides[1]; // Head + } + + const metal::uniform scale2 = + make_uniform(params->scale) * make_uniform(1.44269504089f); + + // Prepare MMA tiles + constexpr short UQ = 16; + constexpr short UD = 32; + + constexpr int kNWarps = WM * WN; + static_assert( + BQ >= (kNWarps * UQ) && BQ % (kNWarps * UQ) == 0, + "Each simdgroup must host atleast 1 simdgroup matrix along Q sequence."); + + // Q seq frags per warp + constexpr int TQ = BQ / (kNWarps * UQ); + // HeadDim frags (all warps load the same frags) + constexpr int TD = BD / UD; + + static_assert(TQ == 1, "Check TQ"); + + using OSubTile = NAXSubTile; + NAXTile Otile; + + Otile.clear(); + + // Prepare mma tile offsets + const short2 simd_coord = OSubTile::NAXFrag_t::get_coord(); + const short sm = simd_coord.y; + const short sn = simd_coord.x; + const short tm = UQ * TQ * simd_group_id; + + Q += (tm + sm) * int(params->Q_strides[2]) + sn; + K += sm * int(params->K_strides[2]) + sn; + V += sm * int(params->V_strides[2]) + sn; + + // Init row reduction variables + constexpr short kRowsPT = decltype(Otile)::kRowsPerThread; + + metal::vec max_score; + metal::vec sum_score{0}; + + // Init to -Inf + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + max_score[i] = Limits::finite_min; + } + + if (has_sinks) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + max_score[i] = M_LOG2E_F * static_cast(sinks[tidl.y]); + sum_score[i] = 1; + } + } + + int kb_lim = params->NK; + + if (do_causal) { + int q_max = (tid.x + 1) * BQ + params->qL_off; + kb_lim = (q_max + BK - 1) / BK; + kb_lim = min(params->NK, kb_lim); + } + + const bool is_last_bq = int(tid.x) == (params->NQ_aligned); + // const bool is_last_tq = int(simd_group_id) >= (params->qL_rem / UQ); + const bool is_last_q = is_last_bq; + + const short lim_rows_q = params->qL_rem - (tm + sm); + const short lim_rows_k = params->kL_rem - sm; + + // Loop over KV seq length + for (int kb = 0; kb < kb_lim; kb++) { + const int is_last_k = (kb == (params->NK_aligned)); + + // Do S = Q @ K.T + constexpr short UDs = 16; + constexpr short UKs = 32; + + constexpr short TDs = BD / UDs; + constexpr short TKs = BK / UKs; + + using SSubTile = NAXSubTile; + using QSubTile = NAXSubTile; + using KSubTile = NAXSubTile; + + NAXTile Stile; + + Stile.clear(); + + STEEL_PRAGMA_UNROLL + for (short iq = 0; iq < TQ; iq++) { + STEEL_PRAGMA_UNROLL + for (short ik = 0; ik < TKs; ik++) { + STEEL_PRAGMA_UNROLL + for (short id = 0; id < TDs; id++) { + NAXTile Qtile; + NAXTile Ktile; + + const int Q_load_off = iq * UQ * int(params->Q_strides[2]) + id * UDs; + const int K_load_off = + ik * UKs * int(params->K_strides[2]) + id * UDs; + + if (!align_Q && is_last_q) { + // Qtile.load_rows( + // Q + Q_load_off, + // int(params->Q_strides[2]), + // lim_rows_q - iq * UQ); + Qtile.load_safe( + Q + Q_load_off, + int(params->Q_strides[2]), + short2(BD, lim_rows_q - iq * UQ)); + } else { + Qtile.load(Q + Q_load_off, int(params->Q_strides[2])); + } + + if (!align_K && is_last_k) { + // Ktile.load_rows( + // K + K_load_off, + // int(params->K_strides[2]), + // lim_rows_k - ik * UKs); + Ktile.load_safe( + K + K_load_off, + int(params->K_strides[2]), + short2(BD, lim_rows_k - ik * UKs)); + } else { + Ktile.load(K + K_load_off, int(params->K_strides[2])); + } + + subtile_matmad_nax( + Stile.subtile_at(iq, ik), + Qtile.subtile_at(0, 0), + metal::false_type{}, + Ktile.subtile_at(0, 0), + metal::true_type{}); + } + } + } + + // Scale S + STEEL_PRAGMA_UNROLL + for (short ii = 0; ii < decltype(Stile)::kElemsPerTile; ii++) { + Stile.elems()[ii] *= float(scale2); + } + + // Scale and Retile S + constexpr short UK = 16; + constexpr short TK = BK / UK; + using PSubTile = NAXSubTile; + + NAXTile Ptile; + + STEEL_PRAGMA_UNROLL + for (short ii = 0; ii < decltype(Stile)::kElemsPerTile; ii++) { + Ptile.elems()[ii] = Stile.elems()[ii]; + } + + // Mask out length sequence + if (!align_K && is_last_k) { + constexpr auto neg_inf = Limits::finite_min; + + STEEL_PRAGMA_UNROLL + for (short iq = 0; iq < TQ; iq++) { + STEEL_PRAGMA_UNROLL + for (short ik = 0; ik < TK; ik++) { + const short col_pos = sn + ik * UK; + + thread auto& fg = Ptile.subtile_at(iq, ik).frag_at(0, 0); + + STEEL_PRAGMA_UNROLL + for (short ii = 0; ii < PSubTile::kFragThrRows; ii++) { + STEEL_PRAGMA_UNROLL + for (short jj = 0; jj < PSubTile::kFragThrCols; jj++) { + const auto loc = ii * PSubTile::kFragThrCols + jj; + fg[loc] = ((col_pos + jj) >= params->kL_rem) ? neg_inf : fg[loc]; + } + } + } + } + } + + // Mask out if causal + if (do_causal && kb >= (kb_lim - ((BQ + BK - 1) / BK) - int(!align_K))) { + constexpr auto neg_inf = Limits::finite_min; + + const int base_row = tid.x * BQ + params->qL_off + tm; + const int base_col = kb * BK; + + STEEL_PRAGMA_UNROLL + for (short iq = 0; iq < TQ; iq++) { + STEEL_PRAGMA_UNROLL + for (short ik = 0; ik < TK; ik++) { + const short row_pos = base_row + iq * UQ; + const short col_pos = base_col + ik * UK; + + thread auto& fg = Ptile.subtile_at(iq, ik).frag_at(0, 0); + + STEEL_PRAGMA_UNROLL + for (short ii = 0; ii < PSubTile::kFragThrRows; ii++) { + STEEL_PRAGMA_UNROLL + for (short jj = 0; jj < PSubTile::kFragThrCols; jj++) { + const auto r = row_pos + ii * PSubTile::kFragRowsJump + sm; + const auto c = col_pos + jj + sn; + const auto loc = ii * PSubTile::kFragThrCols + jj; + fg[loc] = (r < c) ? neg_inf : fg[loc]; + } + } + } + } + } + + // Other masking as needed + if (has_mask) { + constexpr auto neg_inf = Limits::finite_min; + + const int base_row = tid.x * BQ + tm; + const int base_col = kb * BK; + + constexpr bool is_bool = is_same_v; + using melem_t = typename metal::conditional_t; + using MSubTile = NAXSubTile; + + STEEL_PRAGMA_UNROLL + for (short iq = 0; iq < TQ; iq++) { + STEEL_PRAGMA_UNROLL + for (short ik = 0; ik < TK; ik++) { + const short row_pos = base_row + iq * UQ + sm; + const short col_pos = base_col + ik * UK + sn; + + MSubTile mfrag; + mfrag.load_safe( + mask, + int64_t(mask_params->M_strides[2]), + Int<1>{}, + params->qL, + params->kL, + row_pos, + col_pos); + + thread auto& fg = Ptile.subtile_at(iq, ik).frag_at(0, 0); + + STEEL_PRAGMA_UNROLL + for (short jj = 0; jj < MSubTile::kElemsPerFrag; jj++) { + if constexpr (is_bool) { + fg[jj] = mfrag.elems()[jj] ? fg[jj] : neg_inf; + } else { + fg[jj] += M_LOG2E_F * AccumType(mfrag.elems()[jj]); + } + } + } + } + } + + // Do softmax + + // Temp variables + metal::vec new_max; + metal::vec factor; + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + new_max[i] = max_score[i]; + } + + // Row max + Ptile.template row_reduce(new_max); + + // exp(Si - rowmax(Si)) + Ptile.template row_bin_op(new_max); + + // Factor exp(rowmax(Si) - rowmax(Si-1)) + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + factor[i] = fast::exp2(max_score[i] - new_max[i]); + max_score[i] = new_max[i]; + } + + // Row Sum + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + sum_score[i] = sum_score[i] * factor[i]; + } + + Ptile.template row_reduce(sum_score); + + // Update O + Otile.template row_bin_op(factor); + + simdgroup_barrier(mem_flags::mem_none); + + // Do O = P @ V + STEEL_PRAGMA_UNROLL + for (short iq = 0; iq < TQ; iq++) { + STEEL_PRAGMA_UNROLL + for (short id = 0; id < TD; id++) { + if constexpr (BD == 128) { + if (id == 2) { + threadgroup_barrier(mem_flags::mem_none); + } + } + + STEEL_PRAGMA_UNROLL + for (short ik = 0; ik < TK; ik++) { + using VSubTile = NAXSubTile; + NAXTile Vtile; + + const int V_load_off = ik * UK * int(params->V_strides[2]) + id * UD; + + if (!align_K && is_last_k) { + // Vtile.load_rows( + // V + V_load_off, + // int(params->V_strides[2]), + // lim_rows_k - ik * UK); + Vtile.load_safe( + V + V_load_off, + int(params->V_strides[2]), + short2(BD, lim_rows_k - ik * UK)); + } else { + Vtile.load(V + V_load_off, int(params->V_strides[2])); + } + + subtile_matmad_nax( + Otile.subtile_at(iq, id), + Ptile.subtile_at(iq, ik), + metal::bool_constant{}, + Vtile.subtile_at(0, 0), + metal::bool_constant{}); + } + } + } + + // Prepare for next iteration + K += BK * int(params->K_strides[2]); + V += BK * int(params->V_strides[2]); + } + + // Normalize output + + threadgroup_barrier(mem_flags::mem_none); + + metal::vec rcp; + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + rcp[i] = 1.f / sum_score[i]; + } + + Otile.template row_bin_op(rcp); + + // Store results + O += (tm + sm) * int(params->O_strides[2]) + sn; + + if (!align_Q && is_last_q) { + if (lim_rows_q <= 0) + return; + + // Otile.store_rows(O, params->O_strides[2], lim_rows_q); + Otile.store_safe(O, params->O_strides[2], short2(BD, lim_rows_q)); + } else { + Otile.store(O, int(params->O_strides[2])); + } +} diff --git a/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/attn/loader.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/attn/loader.h new file mode 100644 index 00000000..7ec79814 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/attn/loader.h @@ -0,0 +1,264 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/backend/metal/kernels/steel/defines.h" + +/////////////////////////////////////////////////////////////////////////////// +// Loading helper +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +template < + typename T, + short BROWS, + short BCOLS, + short dst_ld, + short reduction_dim, + short tgp_size, + short alignment = 1, + short n_reads = (BCOLS * BROWS) / (tgp_size), + short TCOLS = BCOLS / n_reads, + short TROWS = tgp_size / TCOLS> +struct BlockLoader { + STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS; + STEEL_CONST short vec_size = n_reads; + + // Leading dimension for src + const int src_ld; + const int tile_stride; + + // Thread location indices + const short thread_idx; + const short bi; + const short bj; + + // threadgroup and device memory + threadgroup T* dst; + const device T* src; + + struct alignas(alignment * sizeof(T)) ReadVector { + uint8_t v[sizeof(T) * vec_size]; + }; + + /* Constructor */ + METAL_FUNC BlockLoader( + const device T* src_, + const int src_ld_, + threadgroup T* dst_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(src_ld_), + tile_stride(reduction_dim ? BCOLS : BROWS * src_ld), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(thread_idx / TCOLS), + bj(vec_size * (thread_idx % TCOLS)), + dst(dst_ + bi * dst_ld + bj), + src(src_ + bi * src_ld + bj) {} + + /* Apply operation to threadgroup without bound checking */ + template + METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = op.apply(dst[i * dst_ld + j]); + } + } + } + + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void load_unsafe() const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + *((threadgroup ReadVector*)(&dst[i * dst_ld])) = + *((const device ReadVector*)(&src[i * src_ld])); + } + } + + /* Load from device memory into threadgroup memory - with bound checking */ + METAL_FUNC void load_safe(short2 src_tile_dim) const { + src_tile_dim = src_tile_dim - short2(bj, bi); + + // Skip loading if thread has no valid reads + if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = T(0); + } + } + return; + } + + // Use fast thread memory for bound checks + bool tmp_idx[vec_size]; + T tmp_val[vec_size]; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + // Make sure tmp_idx only contains valid indices + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x); + } + + // Read valid indices into tmp_val + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)]; + } + + // Zero out unneeded values + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0); + } + + // Copy values to threadgroup memory + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = tmp_val[j]; + } + } + } + + /* Iteration helper */ + METAL_FUNC void next() { + src += tile_stride; + } +}; + +template +struct CShape { + STEEL_CONST int kRows = R; + STEEL_CONST int kCols = C; +}; + +template < + typename T, + short BROWS, + short BCOLS, + short kDstStrRow, + short kDstStrCol, + short reduction_dim, + short tgp_size, + short n_reads = (BCOLS * BROWS) / (tgp_size), + short TCOLS = BCOLS / n_reads, + short TROWS = tgp_size / TCOLS> +struct BlockLoaderT { + STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS; + STEEL_CONST short vec_size = n_reads; + + // Leading dimension for src + const int src_ld; + const int tile_stride; + + // Thread location indices + const short thread_idx; + const short bi; + const short bj; + + // threadgroup and device memory + threadgroup T* dst; + const device T* src; + + /* Constructor */ + METAL_FUNC BlockLoaderT( + const device T* src_, + const int src_ld_, + threadgroup T* dst_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(src_ld_), + tile_stride(reduction_dim ? BCOLS : BROWS * src_ld), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(thread_idx / TCOLS), + bj(vec_size * (thread_idx % TCOLS)), + dst(dst_ + bi * kDstStrRow + bj * kDstStrCol), + src(src_ + bi * src_ld + bj) {} + + /* Apply operation to threadgroup without bound checking */ + template + METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * kDstStrRow + j * kDstStrCol] = + op.apply(dst[i * kDstStrRow + j * kDstStrCol]); + } + } + } + + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void load_unsafe() const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * kDstStrRow + j * kDstStrCol] = src[i * src_ld + j]; + } + } + } + + /* Load from device memory into threadgroup memory - with bound checking */ + METAL_FUNC void load_safe(short2 src_tile_dim) const { + src_tile_dim = src_tile_dim - short2(bj, bi); + + // Skip loading if thread has no valid reads + if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * kDstStrRow + j * kDstStrCol] = T(0); + } + } + return; + } + + // Use fast thread memory for bound checks + bool tmp_idx[vec_size]; + T tmp_val[vec_size]; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + // Make sure tmp_idx only contains valid indices + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x); + } + + // Read valid indices into tmp_val + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)]; + } + + // Zero out unneeded values + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0); + } + + // Copy values to threadgroup memory + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * kDstStrRow + j * kDstStrCol] = tmp_val[j]; + } + } + } + + /* Iteration helper */ + METAL_FUNC void next() { + src += tile_stride; + } +}; + +} // namespace steel +} // namespace mlx diff --git a/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/attn/mma.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/attn/mma.h new file mode 100644 index 00000000..737e930d --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/attn/mma.h @@ -0,0 +1,750 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include +#include +#include + +#include "mlx/backend/metal/kernels/steel/attn/transforms.h" +#include "mlx/backend/metal/kernels/steel/defines.h" +#include "mlx/backend/metal/kernels/steel/utils/integral_constant.h" + +using namespace metal; + +/////////////////////////////////////////////////////////////////////////////// +// MMA helper +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +template +struct Shape2D { + RInt r; + CInt c; + + Shape2D(RInt r_, CInt c_) : r(r_), c(c_) {} +}; + +template +struct Layout2D { + Shape shape; + Layout layout; +}; + +template +struct BaseMMAFrag { + static_assert( + kFragRows_ == 8, + "Only 8 x 8 fragment matrices are currently supported"); + static_assert( + kFragCols_ == 8, + "Only 8 x 8 fragment matrices are currently supported"); +}; + +template +struct BaseMMAFrag { + STEEL_CONST int kFragRows = 8; + STEEL_CONST int kFragCols = 8; + + STEEL_CONST int kElemsPerFrag = (kFragRows * kFragCols) / 32; + + STEEL_CONST int kElemRows = 1; + STEEL_CONST int kElemCols = 2; + + static_assert( + kElemRows * kElemCols == kElemsPerFrag, + "MMAFrag shape is not consistent with MMAFrag size"); + + typedef metal::simdgroup_matrix mat_type; + typedef metal::vec frag_type; + typedef metal::vec row_frag_type; + typedef metal::vec col_frag_type; + + template + using dtype_mat_t = typename metal::simdgroup_matrix; + + template + using dtype_frag_t = typename metal::vec; + + METAL_FUNC static constexpr short2 get_coord( + ushort simd_lane_id [[thread_index_in_simdgroup]]) { + const short qid = simd_lane_id / 4; + const short fm = (qid & 4) + ((simd_lane_id / 2) % 4); + const short fn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; + return short2{fn, fm}; + } + + template + METAL_FUNC static constexpr void + load(thread frag_type& dst, SrcPtrType src, StrX str_x, StrY str_y) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[i * kElemCols + j] = static_cast(src[i * str_x + j * str_y]); + } + } + } + + template < + typename SrcPtrType, + typename StrX, + typename StrY, + typename LimX, + typename LimY, + typename OffX, + typename OffY> + METAL_FUNC static constexpr void load_safe( + thread frag_type& dst, + SrcPtrType src, + StrX str_x, + StrY str_y, + LimX lim_x, + LimY lim_y, + OffX off_x = Int<0>{}, + OffY off_y = Int<0>{}) { + src += off_x * str_x + off_y * str_y; + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + if ((off_x + i) < lim_x && (off_y + j) < lim_y) { + dst[i * kElemCols + j] = static_cast(src[0]); + } else { + dst[i * kElemCols + j] = T(0); + } + src += str_y; + } + src -= kElemCols * str_y; + src += str_x; + } + } + + template + METAL_FUNC static constexpr void + store(const thread frag_type& src, DstPtrType dst, StrX str_x, StrY str_y) { + using U = pointer_element_t; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[i * str_x + j * str_y] = static_cast(src[i * kElemCols + j]); + } + } + } + + template < + typename DstPtrType, + typename StrX, + typename StrY, + typename LimX, + typename LimY, + typename OffX, + typename OffY> + METAL_FUNC static constexpr void store_safe( + const thread frag_type& src, + DstPtrType dst, + StrX str_x, + StrY str_y, + LimX lim_x, + LimY lim_y, + OffX off_x = Int<0>{}, + OffY off_y = Int<0>{}) { + using U = pointer_element_t; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + if ((off_x + i) < lim_x && (off_y + j) < lim_y) { + dst[(off_x + i) * str_x + (off_y + j) * str_y] = + static_cast(src[i * kElemCols + j]); + } + } + } + } + + template + METAL_FUNC static constexpr void mma( + thread frag_type& D, + thread dtype_frag_t& A, + thread dtype_frag_t& B, + thread dtype_frag_t& C) { + mat_type D_mat; + dtype_mat_t A_mat; + dtype_mat_t B_mat; + dtype_mat_t C_mat; + + reinterpret_cast&>(A_mat.thread_elements()) = A; + reinterpret_cast&>(B_mat.thread_elements()) = B; + reinterpret_cast&>(C_mat.thread_elements()) = C; + + mma(D_mat, A_mat, B_mat, C_mat); + + D = reinterpret_cast(D_mat.thread_elements()); + } + + template + METAL_FUNC static constexpr void mma( + thread mat_type& D, + thread dtype_mat_t& A, + thread dtype_mat_t& B, + thread dtype_mat_t& C) { + simdgroup_multiply_accumulate(D, A, B, C); + } + + template + METAL_FUNC static constexpr void row_reduce( + thread const frag_type& inp_vals, + thread T* reduced_vals) { + T thr_reduce = Op::apply(inp_vals.x, inp_vals.y); + + T qgr_reduce = simd_shuffle_xor(thr_reduce, ushort(1)); + qgr_reduce = Op::apply(thr_reduce, qgr_reduce); + + T sgr_reduce = simd_shuffle_xor(qgr_reduce, ushort(8)); + sgr_reduce = Op::apply(qgr_reduce, sgr_reduce); + + reduced_vals[0] = Op::apply(reduced_vals[0], sgr_reduce); + } + + template + METAL_FUNC static constexpr void row_bin_op( + thread frag_type& inp_vals, + thread T* row_vals) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + inp_vals[i * kElemCols + j] = + Op::apply(inp_vals[i * kElemCols + j], row_vals[i]); + } + } + } +}; + +template < + typename T, + int kTileRows_, + int kTileCols_, + class MMAFrag_ = BaseMMAFrag> +struct MMATile { + using MMAFrag_t = MMAFrag_; + using elem_type = T; + STEEL_CONST int kFragRows = MMAFrag_t::kFragRows; + STEEL_CONST int kFragCols = MMAFrag_t::kFragCols; + STEEL_CONST int kElemsPerFrag = MMAFrag_t::kElemsPerFrag; + + STEEL_CONST int kTileRows = kTileRows_; + STEEL_CONST int kTileCols = kTileCols_; + + STEEL_CONST int kRows = kTileRows * kFragRows; + STEEL_CONST int kCols = kTileCols * kFragCols; + + STEEL_CONST int kNumFrags = kTileRows * kTileCols; + STEEL_CONST int kElemsPerTile = kNumFrags * kElemsPerFrag; + + STEEL_CONST int kRowsPerThread = kTileRows * MMAFrag_t::kElemRows; + STEEL_CONST int kColsPerThread = kTileCols * MMAFrag_t::kElemCols; + + typedef typename MMAFrag_t::mat_type mat_type; + typedef typename MMAFrag_t::frag_type frag_type; + + frag_type val_frags[kNumFrags]; // = {frag_type(0)}; + + METAL_FUNC MMATile() thread {} + + METAL_FUNC constexpr void clear() { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kNumFrags; ++i) { + val_frags[i] = frag_type(0); + } + } + + METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) { + return val_frags[i * kTileCols + j]; + } + + METAL_FUNC constexpr const thread frag_type& frag_at( + const short i, + const short j) const { + return val_frags[i * kTileCols + j]; + } + + METAL_FUNC mat_type mat_at(const short i, const short j) { + mat_type val_mat; + STEEL_PRAGMA_UNROLL + for (short ii = 0; ii < kElemsPerFrag; ++ii) { + val_mat.thread_elements()[ii] = frag_at(i, j)[ii]; + } + return val_mat; + } + + METAL_FUNC thread elem_type* elems() { + return reinterpret_cast(val_frags); + } + + METAL_FUNC const thread elem_type* elems() const { + return reinterpret_cast(val_frags); + } + + template + METAL_FUNC void row_reduce(thread T vals[kRowsPerThread]) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::template row_reduce( + frag_at(i, j), &vals[i * MMAFrag_t::kElemRows]); + } + } + } + + template + METAL_FUNC void row_bin_op(thread T vals[kRowsPerThread]) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::template row_bin_op( + frag_at(i, j), &vals[i * MMAFrag_t::kElemRows]); + } + } + } + + template + METAL_FUNC void load(const threadgroup U* src) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::load( + frag_at(i, j), + &( + src[(i * kFragRows) * w_x * str_x + + (j * kFragCols) * w_y * str_y]), + Int{}, + Int{}); + } + } + } + + template + METAL_FUNC void store(threadgroup U* dst) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::store( + frag_at(i, j), + &( + dst[(i * kFragRows) * w_x * str_x + + (j * kFragCols) * w_y * str_y]), + Int{}, + Int{}); + } + } + } + + template + METAL_FUNC void load(const device U* src, const int ld) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::load( + frag_at(i, j), + &(src[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]), + ld, + Int<1>{}); + } + } + } + + template + METAL_FUNC void store(device U* dst, const int ld) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::store( + frag_at(i, j), + &(dst[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]), + ld, + Int<1>{}); + } + } + } + + template + METAL_FUNC void + load_safe(const device U* src, const int ld, const short2 src_tile_dims) { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kTileCols; ++j) { + MMAFrag_t::load_safe( + frag_at(i, j), + src, + ld, + Int<1>{}, + src_tile_dims.y, + src_tile_dims.x, + (i * kFragRows) * w_x, + (j * kFragCols) * w_y); + } + } + } + + template + METAL_FUNC void + store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kTileCols; ++j) { + MMAFrag_t::store_safe( + frag_at(i, j), + dst, + ld, + Int<1>{}, + dst_tile_dims.y, + dst_tile_dims.x, + (i * kFragRows) * w_x, + (j * kFragCols) * w_y); + } + } + } +}; + +template < + typename Dtype, + typename Atype, + typename Btype, + typename Ctype, + int M, + int N, + int K, + class MMAFragD, + class MMAFragA, + class MMAFragB, + class MMAFragC> +METAL_FUNC void tile_matmad( + thread MMATile& D, + thread MMATile& A, + thread MMATile& B, + thread MMATile& C) { + STEEL_PRAGMA_UNROLL + for (short m = 0; m < M; ++m) { + STEEL_PRAGMA_UNROLL + for (short n = 0; n < N; ++n) { + short m_serp = m; //(n % 2) ? (M - 1 - m) : m; + short n_serp = (m % 2) ? (N - 1 - n) : n; + + STEEL_PRAGMA_UNROLL + for (short k = 0; k < K; ++k) { + MMAFragD::mma( + D.frag_at(m_serp, n_serp), + A.frag_at(m_serp, k), + B.frag_at(k, n_serp), + C.frag_at(m_serp, n_serp)); + } + } + } +} + +template < + typename T, + typename U, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + short lda_tgp, + short ldb_tgp, + typename AccumType = float, + typename Epilogue = TransformNone> +struct BlockMMA { + // MMAFrag size + STEEL_CONST short kFragSize = 8; + using MMAFrag_acc_t = BaseMMAFrag; + + // Warp tile simdgroup matrix strides along M + STEEL_CONST short TM_stride = kFragSize * WM; + // Warp tile simdgroup matrix strides along M + STEEL_CONST short TN_stride = kFragSize * WN; + + // Warp tile size along M + STEEL_CONST short TM = BM / TM_stride; + // Warp tile size along N + STEEL_CONST short TN = BN / TN_stride; + + // Threadgroup A strides + STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M + STEEL_CONST short A_str_k = transpose_a ? lda_tgp : 1; // K + + // Threadgroup B strides + STEEL_CONST short B_str_k = transpose_b ? 1 : ldb_tgp; // K + STEEL_CONST short B_str_n = transpose_b ? ldb_tgp : 1; // N + + // Threadgroup strides along K + STEEL_CONST short tile_stride_a = kFragSize * A_str_k; + STEEL_CONST short tile_stride_b = kFragSize * B_str_k; + + // Simdgroup matrices + MMATile Atile; + MMATile Btile; + MMATile Ctile; + + // Offsets within threadgroup + short sm; + short sn; + + short As_offset; + short Bs_offset; + + /* Constructor */ + METAL_FUNC BlockMMA( + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) { + // Determine thread position in simdgroup matrix + short tm = kFragSize * (simd_group_id / WN); + short tn = kFragSize * (simd_group_id % WN); + + short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id); + sm = simd_coord.y; + sn = simd_coord.x; + + // Determine thread and simdgroup offset + As_offset = (tm + sm) * A_str_m + (sn)*A_str_k; // M, K + Bs_offset = (sm)*B_str_k + (tn + sn) * B_str_n; // K, N + + sm += tm; + sn += tn; + } + + /* (BM, BK) X (BK, BN) multiply accumulate function */ + METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) { + // Adjust for simdgroup and thread location + As += As_offset; + Bs += Bs_offset; + + // Iterate over BK in blocks of kFragSize + STEEL_PRAGMA_UNROLL + for (short kk = 0; kk < BK; kk += kFragSize) { + simdgroup_barrier(mem_flags::mem_none); + + Atile.template load(As); + + simdgroup_barrier(mem_flags::mem_none); + + Btile.template load(Bs); + + simdgroup_barrier(mem_flags::mem_none); + + tile_matmad(Ctile, Atile, Btile, Ctile); + + // Progress to next simdgroup tile + As += tile_stride_a; + Bs += tile_stride_b; + } + } + + /* Store results from simdgroup_matrix results into device memory */ + METAL_FUNC void store_result(device U* D, const int ldd) { + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { + Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]); + } + + // Adjust for simdgroup and thread location + D += sm * ldd + sn; + + Ctile.template store(D, ldd); + } + + METAL_FUNC void + store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) { + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { + Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]); + } + + // Adjust for simdgroup and thread location + D += sm * ldd + sn; + dst_tile_dims -= short2(sn, sm); + + if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) + return; + + Ctile.template store_safe(D, ldd, dst_tile_dims); + } + + /* Apply epilogue */ + template + METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) { + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { + Ctile.elems()[i] = epilogue_op.apply(Ctile.elems()[i]); + } + } + + /* Apply epilogue */ + template + METAL_FUNC void apply_epilogue( + const device U* C, + const int ldc, + const int fdc, + thread const BinaryEpilogue& epilogue_op) { + // Adjust for simdgroup and thread location + C += (sm)*ldc + (sn)*fdc; + + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread auto& accum = Ctile.frag_at(i, j); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short k = 0; k < decltype(Ctile)::kElemsPerFrag; k++) { + accum[k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]); + } + } + } + } + + /* Apply epilogue */ + template + METAL_FUNC void apply_epilogue_safe( + const device U* C, + const int ldc, + const int fdc, + short2 dst_tile_dims, + thread const BinaryEpilogue& epilogue_op) { + // Adjust for simdgroup and thread location + C += (sm)*ldc + (sn)*fdc; + dst_tile_dims -= short2(sn, sm); + + if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) + return; + + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread auto& accum = Ctile.frag_at(i, j); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + + constexpr short kelems = decltype(Ctile)::kElemsPerFrag; + + // Read C + U c_elems[kelems] = {0}; + + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + if ((j * TN_stride + k) < dst_tile_dims.x) { + c_elems[k] = C[offset_c + k * fdc]; + } + } + + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + accum[k] = epilogue_op.apply(accum[k], c_elems[k]); + } + } + } + } + + /* Store results from simdgroup_matrix results into device memory */ + METAL_FUNC void store_result( + device U* D, + const int ldd, + const device U* C, + const int ldc, + const int fdc, + thread const Epilogue& epilogue_op) const { + // Adjust for simdgroup and thread location + C += (sm)*ldc + (sn)*fdc; + D += (sm)*ldd + sn; + + constexpr short kelems = decltype(Ctile)::kElemsPerFrag; + + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread const auto& accum = Ctile.frag_at(i, j); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + int offset_d = (i * TM_stride) * ldd + (j * TN_stride); + + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + D[offset_d + k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]); + } + } + } + } + + METAL_FUNC void store_result_safe( + device U* D, + const int ldd, + const device U* C, + const int ldc, + const int fdc, + short2 dst_tile_dims, + thread const Epilogue& epilogue_op) const { + // Adjust for simdgroup and thread location + C += (sm)*ldc + (sn)*fdc; + D += (sm)*ldd + sn; + dst_tile_dims -= short2(sn, sm); + + if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) + return; + + constexpr short kelems = decltype(Ctile)::kElemsPerFrag; + + STEEL_PRAGMA_UNROLL + for (int i = 0; i < TM; i++) { + if (i * TM_stride < dst_tile_dims.y) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread const auto& accum = Ctile.frag_at(i, j); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + int offset_d = (i * TM_stride) * ldd + (j * TN_stride); + + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + if ((j * TN_stride + k) < dst_tile_dims.x) { + D[offset_d + k] = + epilogue_op.apply(accum[k], C[offset_c + k * fdc]); + } + } + } + } + } + } +}; + +} // namespace steel +} // namespace mlx diff --git a/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/attn/nax.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/attn/nax.h new file mode 100644 index 00000000..c8f3ea5e --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/attn/nax.h @@ -0,0 +1,1076 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include +#include + +#include "mlx/backend/metal/kernels/steel/defines.h" +#include "mlx/backend/metal/kernels/steel/utils/integral_constant.h" + +#include + +using namespace metal; + +/////////////////////////////////////////////////////////////////////////////// +// MMA helper +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +/////////////////////////////////////////////////////////////////////////////// +// NAX Steel with new tiles +/////////////////////////////////////////////////////////////////////////////// + +struct BaseNAXFrag { + STEEL_CONST short kFragRows = 16; + STEEL_CONST short kFragCols = 16; + + STEEL_CONST short kElemsPerFrag = (kFragRows * kFragCols) / 32; + + STEEL_CONST short kElemRows = 2; + STEEL_CONST short kElemCols = 4; + + STEEL_CONST short kElemRowsJump = 8; + + static_assert( + kElemRows * kElemCols == kElemsPerFrag, + "MMAFrag shape is not consistent with MMAFrag size"); + + template + using dtype_frag_t = typename metal::vec; + + METAL_FUNC static short2 get_coord() { + const ushort simd_lane_id = __metal_get_thread_index_in_simdgroup(ushort()); + const short qid = simd_lane_id >> 2; + const short fm = ((qid & 4) | ((simd_lane_id >> 1) & 3)); + const short fn = ((qid & 2) | (simd_lane_id & 1)) * 4; + return short2{fn, fm}; + } + + METAL_FUNC static short2 get_coord(short idx) { + const ushort simd_lane_id = __metal_get_thread_index_in_simdgroup(ushort()); + const short qid = simd_lane_id >> 2; + const short fm = ((qid & 4) | ((simd_lane_id >> 1) & 3)) + (idx >> 2) * 8; + const short fn = ((qid & 2) | (simd_lane_id & 1)) * 4 + idx % 4; + return short2{fn, fm}; + } + + template < + typename T, + typename SrcPtrType, + typename StrX, + typename StrY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC static constexpr void load( + thread dtype_frag_t& dst, + SrcPtrType src, + StrX str_x, + StrY str_y, + OffX off_x = {}, + OffY off_y = {}) { + const short2 sc = short2{0, 0}; // get_coord(); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + const auto r = off_x + i * kElemRowsJump + sc.y; + const auto c = off_y + sc.x; + + if constexpr (metal::is_same_v>) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[i * kElemCols + j] = static_cast(src[r * str_x + c + j]); + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[i * kElemCols + j] = + static_cast(src[r * str_x + (c + j) * str_y]); + } + } + } + } + + template < + typename T, + typename SrcPtrType, + typename StrX, + typename StrY, + typename LimX, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC static constexpr void load_rows( + thread dtype_frag_t& dst, + SrcPtrType src, + StrX str_x, + StrY str_y, + LimX lim_x, + OffX off_x = {}, + OffY off_y = {}) { + const short2 sc = short2{0, 0}; // get_coord(); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + const auto r = off_x + i * kElemRowsJump + sc.y; + const auto c = off_y + sc.x; + + if (r < lim_x) { + if constexpr (metal::is_same_v>) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[i * kElemCols + j] = static_cast(src[r * str_x + (c + j)]); + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[i * kElemCols + j] = + static_cast(src[r * str_x + (c + j) * str_y]); + } + } + + } else { + dst = dtype_frag_t(0); + } + } + } + + template < + typename T, + typename SrcPtrType, + typename StrX, + typename StrY, + typename LimX, + typename LimY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC static constexpr void load_safe( + thread dtype_frag_t& dst, + SrcPtrType src, + StrX str_x, + StrY str_y, + LimX lim_x, + LimY lim_y, + OffX off_x = {}, + OffY off_y = {}) { + const short2 sc = short2{0, 0}; // get_coord(); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + const auto r = off_x + i * kElemRowsJump + sc.y; + const auto c = off_y + sc.x; + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + if (r < lim_x && (c + j) < lim_y) { + dst[i * kElemCols + j] = + static_cast(src[r * str_x + (c + j) * str_y]); + } else { + dst[i * kElemCols + j] = T(0); + } + } + } + } + + template < + typename T, + typename DstPtrType, + typename StrX, + typename StrY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC static constexpr void store( + const thread dtype_frag_t& src, + DstPtrType dst, + StrX str_x, + StrY str_y, + OffX off_x = {}, + OffY off_y = {}) { + using U = pointer_element_t; + + const short2 sc = short2{0, 0}; // get_coord(); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + const auto r = off_x + i * kElemRowsJump + sc.y; + const auto c = off_y + sc.x; + + if constexpr (metal::is_same_v>) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[r * str_x + c + j] = static_cast(src[i * kElemCols + j]); + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[r * str_x + (c + j) * str_y] = + static_cast(src[i * kElemCols + j]); + } + } + } + } + + template < + typename T, + typename DstPtrType, + typename StrX, + typename StrY, + typename LimX, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC static constexpr void store_rows( + const thread dtype_frag_t& src, + DstPtrType dst, + StrX str_x, + StrY str_y, + LimX lim_x, + OffX off_x = {}, + OffY off_y = {}) { + using U = pointer_element_t; + + const short2 sc = short2{0, 0}; // get_coord(); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + const auto r = off_x + i * kElemRowsJump + sc.y; + const auto c = off_y + sc.x; + + if (r < lim_x) { + if constexpr (metal::is_same_v>) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[r * str_x + c + j] = static_cast(src[i * kElemCols + j]); + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[r * str_x + (c + j) * str_y] = + static_cast(src[i * kElemCols + j]); + } + } + } + } + } + + template < + typename T, + typename DstPtrType, + typename StrX, + typename StrY, + typename LimX, + typename LimY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC static constexpr void store_safe( + const thread dtype_frag_t& src, + DstPtrType dst, + StrX str_x, + StrY str_y, + LimX lim_x, + LimY lim_y, + OffX off_x = {}, + OffY off_y = {}) { + using U = pointer_element_t; + + const short2 sc = short2{0, 0}; // get_coord(); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + const auto r = off_x + i * kElemRowsJump + sc.y; + const auto c = off_y + sc.x; + + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + if (r < lim_x && (c + j) < lim_y) { + dst[r * str_x + (c + j) * str_y] = + static_cast(src[i * kElemCols + j]); + } + } + } + } + + template < + typename T, + typename DstPtrType, + typename StrX, + typename StrY, + typename StartX, + typename StopX, + typename StartY, + typename StopY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC static constexpr void store_slice( + const thread dtype_frag_t& src, + DstPtrType dst, + StrX str_x, + StrY str_y, + StartX start_x, + StopX stop_x, + StartY start_y, + StopY stop_y, + OffX off_x = Int<0>{}, + OffY off_y = Int<0>{}) { + using U = pointer_element_t; + + const short2 sc = short2{0, 0}; // get_coord(); + + const_for_loop<0, kElemRows, 1>([&](auto idx_row) { + const auto r = off_x + idx_row * Int{}; + if (r >= stop_x - sc.y || r < start_x - sc.y) { + return; + } + + const_for_loop<0, kElemCols, 1>([&](auto idx_col) { + const auto c = off_y + idx_col; + if (c >= stop_y - sc.x || c < start_y - sc.x) { + return; + } + + const auto src_idx = idx_row * Int{} + idx_col; + dst[(r + sc.y) * str_x + (c + sc.x) * str_y] = + static_cast(src[src_idx]); + }); + }); + } + + template + METAL_FUNC static constexpr void row_reduce( + thread const dtype_frag_t& inp_vals, + thread T* reduced_vals) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + T thr_reduce = Op::apply( + Op::apply(inp_vals[i * kElemCols + 0], inp_vals[i * kElemCols + 1]), + Op::apply(inp_vals[i * kElemCols + 2], inp_vals[i * kElemCols + 3])); + + T qgr_reduce = simd_shuffle_xor(thr_reduce, ushort(1)); + qgr_reduce = Op::apply(thr_reduce, qgr_reduce); + + T sgr_reduce = simd_shuffle_xor(qgr_reduce, ushort(8)); + sgr_reduce = Op::apply(qgr_reduce, sgr_reduce); + + reduced_vals[i] = Op::apply(reduced_vals[i], sgr_reduce); + } + } + + template + METAL_FUNC static constexpr void row_bin_op( + thread dtype_frag_t& inp_vals, + thread T* row_vals) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + inp_vals[i * kElemCols + j] = + Op::apply(inp_vals[i * kElemCols + j], row_vals[i]); + } + } + } +}; + +template < + typename T, + short kRows_, + short kCols_, + typename NAXFrag_ = BaseNAXFrag> +struct NAXSubTile { + using NAXFrag_t = NAXFrag_; + STEEL_CONST short kRows = kRows_; + STEEL_CONST short kCols = kCols_; + + STEEL_CONST short kFragRows = NAXFrag_t::kFragRows; + STEEL_CONST short kFragCols = NAXFrag_t::kFragCols; + STEEL_CONST short kElemsPerFrag = NAXFrag_t::kElemsPerFrag; + + STEEL_CONST short kSubTileRows = kRows / kFragRows; + STEEL_CONST short kSubTileCols = kCols / kFragCols; + + STEEL_CONST short kNumFrags = kSubTileRows * kSubTileCols; + STEEL_CONST short kElemsPerSubTile = kNumFrags * kElemsPerFrag; + + STEEL_CONST int kRowsPerThread = kSubTileRows * NAXFrag_t::kElemRows; + STEEL_CONST int kColsPerThread = kSubTileCols * NAXFrag_t::kElemCols; + + STEEL_CONST short kFragThrRows = NAXFrag_t::kElemRows; + STEEL_CONST short kFragThrCols = NAXFrag_t::kElemCols; + STEEL_CONST short kFragRowsJump = NAXFrag_t::kElemRowsJump; + + using frag_type = typename NAXFrag_t::template dtype_frag_t; + + frag_type val_frags[kNumFrags]; + + METAL_FUNC constexpr void clear() { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kNumFrags; ++i) { + val_frags[i] = frag_type(0); + } + } + + METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) { + return val_frags[i * kSubTileCols + j]; + } + + METAL_FUNC constexpr const thread frag_type& frag_at( + const short i, + const short j) const { + return val_frags[i * kSubTileCols + j]; + } + + template + METAL_FUNC constexpr thread frag_type& frag_at() { + return val_frags[i * kSubTileCols + j]; + } + + template + METAL_FUNC constexpr const thread frag_type& frag_at() const { + return val_frags[i * kSubTileCols + j]; + } + + METAL_FUNC thread T* elems() { + return reinterpret_cast(val_frags); + } + + METAL_FUNC const thread T* elems() const { + return reinterpret_cast(val_frags); + } + + template + METAL_FUNC void row_reduce(thread metal::vec& vals) const { + thread T* vptr = (thread T*)(&vals); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::template row_reduce( + frag_at(i, j), &vptr[i * kFragThrRows]); + } + } + } + + template + METAL_FUNC void row_bin_op(thread metal::vec& vals) { + thread T* vptr = (thread T*)(&vals); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::template row_bin_op( + frag_at(i, j), &vptr[i * kFragThrRows]); + } + } + } + + template < + typename SrcPtrType, + typename StrX, + typename StrY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC constexpr void load( + SrcPtrType src, + StrX str_x, + StrY str_y, + OffX off_x = {}, + OffY off_y = {}) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::load( + frag_at(i, j), + src, + str_x, + str_y, + off_x + i * kFragRows, + off_y + j * kFragCols); + } + } + } + + template < + typename DstPtrType, + typename StrX, + typename StrY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC constexpr void store( + DstPtrType dst, + StrX str_x, + StrY str_y, + OffX off_x = {}, + OffY off_y = {}) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::store( + frag_at(i, j), + dst, + str_x, + str_y, + off_x + i * kFragRows, + off_y + j * kFragCols); + } + } + } + + template < + typename SrcPtrType, + typename StrX, + typename StrY, + typename LimX, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC constexpr void load_rows( + SrcPtrType src, + StrX str_x, + StrY str_y, + LimX lim_x, + OffX off_x = {}, + OffY off_y = {}) { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::load_rows( + frag_at(i, j), + src, + str_x, + str_y, + lim_x, + off_x + (i * kFragRows), + off_y + (j * kFragCols)); + } + } + } + + template < + typename SrcPtrType, + typename StrX, + typename StrY, + typename LimX, + typename LimY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC constexpr void load_safe( + SrcPtrType src, + StrX str_x, + StrY str_y, + LimX lim_x, + LimY lim_y, + OffX off_x = {}, + OffY off_y = {}) { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::load_safe( + frag_at(i, j), + src, + str_x, + str_y, + lim_x, + lim_y, + off_x + (i * kFragRows), + off_y + (j * kFragCols)); + } + } + } + + template < + typename DstPtrType, + typename StrX, + typename StrY, + typename LimX, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC constexpr void store_rows( + DstPtrType dst, + StrX str_x, + StrY str_y, + LimX lim_x, + OffX off_x = {}, + OffY off_y = {}) const { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::store_safe( + frag_at(i, j), + dst, + str_x, + str_y, + lim_x, + off_x + (i * kFragRows), + off_y + (j * kFragCols)); + } + } + } + + template < + typename DstPtrType, + typename StrX, + typename StrY, + typename LimX, + typename LimY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC constexpr void store_safe( + DstPtrType dst, + StrX str_x, + StrY str_y, + LimX lim_x, + LimY lim_y, + OffX off_x = {}, + OffY off_y = {}) const { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::store_safe( + frag_at(i, j), + dst, + str_x, + str_y, + lim_x, + lim_y, + off_x + (i * kFragRows), + off_y + (j * kFragCols)); + } + } + } + + template < + typename DstPtrType, + typename StrX, + typename StrY, + typename StartX, + typename StopX, + typename StartY, + typename StopY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC constexpr void store_slice( + DstPtrType dst, + StrX str_x, + StrY str_y, + StartX start_x, + StopX stop_x, + StartY start_y, + StopY stop_y, + OffX off_x = Int<0>{}, + OffY off_y = Int<0>{}) const { + const_for_loop<0, kSubTileRows, 1>([&](auto idx_row) { + const_for_loop<0, kSubTileCols, 1>([&](auto idx_col) { + NAXFrag_t::store_slice( + frag_at(), + dst, + str_x, + str_y, + start_x, + stop_x, + start_y, + stop_y, + off_x + idx_row * Int{}, + off_y + idx_col * Int{}); + }); + }); + } +}; + +template < + short RC, + short CC, + short RA, + short CA, + short RB, + short CB, + typename CType, + typename AType, + typename BType, + bool transpose_a, + bool transpose_b, + typename NAXFrag_t = BaseNAXFrag> +METAL_FUNC void subtile_matmad_nax( + thread NAXSubTile& C, + thread NAXSubTile& A, + metal::bool_constant, + thread NAXSubTile& B, + metal::bool_constant) { + // Static checks + constexpr short FMa = transpose_a ? CA : RA; + constexpr short FMc = RC; + static_assert(FMa == FMc, "NAX matmul: M dimensions do not match"); + + constexpr short FNb = transpose_b ? RB : CB; + constexpr short FNc = CC; + static_assert(FNb == FNc, "NAX matmul: N dimensions do not match"); + + constexpr short FKa = transpose_a ? RA : CA; + constexpr short FKb = transpose_b ? CB : RB; + static_assert(FKa == FKb, "NAX matmul: N dimensions do not match"); + + constexpr short FM = FMc; + constexpr short FN = FNc; + constexpr short FK = FKa; + + constexpr int TM = FM / 16; + constexpr int TN = FN / 16; + constexpr int TK = FK / 16; + + constexpr auto desc = mpp::tensor_ops::matmul2d_descriptor( + FM, + FN, + FK, + transpose_a, + transpose_b, + true, + mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate); + + mpp::tensor_ops::matmul2d gemm_op; + + auto ct_a = + gemm_op.template get_left_input_cooperative_tensor(); + auto ct_b = + gemm_op + .template get_right_input_cooperative_tensor(); + auto ct_c = gemm_op.template get_destination_cooperative_tensor< + decltype(ct_a), + decltype(ct_b), + CType>(); + + STEEL_PRAGMA_UNROLL + for (short mm = 0; mm < TM; mm++) { + STEEL_PRAGMA_UNROLL + for (short kk = 0; kk < TK; kk++) { + const short fi = transpose_a ? kk : mm; + const short fj = transpose_a ? mm : kk; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < 8; i++) { + ct_a[(TK * mm + kk) * 8 + i] = A.frag_at(fi, fj)[i]; + } + } + } + + STEEL_PRAGMA_UNROLL + for (short nn = 0; nn < TN; nn++) { + STEEL_PRAGMA_UNROLL + for (short kk = 0; kk < TK; kk++) { + const short fi = transpose_b ? nn : kk; + const short fj = transpose_b ? kk : nn; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < 8; i++) { + ct_b[(TN * kk + nn) * 8 + i] = B.frag_at(fi, fj)[i]; + } + } + } + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < ct_c.get_capacity(); i++) { + ct_c[i] = C.elems()[i]; + } + + gemm_op.run(ct_a, ct_b, ct_c); + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < ct_c.get_capacity(); i++) { + C.elems()[i] = ct_c[i]; + } +} + +template +struct NAXTile { + using NAXSubTile_t = NAXSubTile_; + using elem_type = T; + STEEL_CONST short kSubTileRows = NAXSubTile_t::kRows; + STEEL_CONST short kSubTileCols = NAXSubTile_t::kCols; + STEEL_CONST short kElemsPerSubTile = NAXSubTile_t::kElemsPerSubTile; + + STEEL_CONST short kTileRows = kTileRows_; + STEEL_CONST short kTileCols = kTileCols_; + + STEEL_CONST short kRows = kTileRows * kSubTileRows; + STEEL_CONST short kCols = kTileCols * kSubTileCols; + + STEEL_CONST short kSubTiles = kTileRows * kTileCols; + STEEL_CONST short kElemsPerTile = kSubTiles * kElemsPerSubTile; + + STEEL_CONST short kRowsPerThread = kTileRows * NAXSubTile_t::kRowsPerThread; + STEEL_CONST short kColsPerThread = kTileCols * NAXSubTile_t::kColsPerThread; + + STEEL_CONST short kSubTileThrRows = NAXSubTile_t::kRowsPerThread; + STEEL_CONST short kSubTileThrCols = NAXSubTile_t::kColsPerThread; + + NAXSubTile_t val_subtiles[kSubTiles]; + + METAL_FUNC NAXTile() thread {} + + METAL_FUNC constexpr void clear() { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kSubTiles; ++i) { + val_subtiles[i].clear(); + } + } + + METAL_FUNC constexpr thread NAXSubTile_t& subtile_at( + const short i, + const short j) { + return val_subtiles[i * kTileCols + j]; + } + + METAL_FUNC constexpr const thread NAXSubTile_t& subtile_at( + const short i, + const short j) const { + return val_subtiles[i * kTileCols + j]; + } + + template + METAL_FUNC constexpr const thread NAXSubTile_t& subtile_at() const { + return val_subtiles[i * kTileCols + j]; + } + + METAL_FUNC thread elem_type* elems() { + return reinterpret_cast(val_subtiles[0].elems()); + } + + METAL_FUNC const thread elem_type* elems() const { + return reinterpret_cast(val_subtiles[0].elems()); + } + + template + METAL_FUNC void row_reduce(thread metal::vec& vals) const { + auto sub_rows = (thread metal::vec*)(&vals); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + subtile_at(i, j).template row_reduce(sub_rows[i]); + } + } + } + + template + METAL_FUNC void row_bin_op(thread metal::vec& vals) { + auto sub_rows = (thread metal::vec*)(&vals); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + subtile_at(i, j).template row_bin_op(sub_rows[i]); + } + } + } + + template + METAL_FUNC void load(const threadgroup U* src) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + subtile_at(i, j).load( + src, + Int{}, + Int{}, + i * kSubTileRows, + j * kSubTileCols); + } + } + } + + template + METAL_FUNC void store(threadgroup U* dst) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + subtile_at(i, j).store( + dst, + Int{}, + Int{}, + i * kSubTileRows, + j * kSubTileCols); + } + } + } + + template + METAL_FUNC void load(const device U* src, const int ld) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + subtile_at(i, j).load( + &src[(i * kSubTileRows) * ld + (j * kSubTileCols)], ld, Int<1>{}); + } + } + } + + template + METAL_FUNC void store(device U* dst, const int ld) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + subtile_at(i, j).store( + &dst[(i * kSubTileRows) * ld + (j * kSubTileCols)], ld, Int<1>{}); + } + } + } + + template + METAL_FUNC void + load_safe(const device U* src, const int ld, const short2 src_tile_dims) { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kTileCols; ++j) { + subtile_at(i, j).load_safe( + src, + ld, + Int<1>{}, + src_tile_dims.y, + src_tile_dims.x, + i * kSubTileRows, + j * kSubTileCols); + } + } + } + + template + METAL_FUNC void + load_rows(const device U* src, const int ld, const short n_rows) { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kTileCols; ++j) { + subtile_at(i, j).load_rows( + &src[(i * kSubTileRows) * ld + (j * kSubTileCols)], + ld, + Int<1>{}, + n_rows - i * kSubTileRows); + } + } + } + + template + METAL_FUNC void + store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kTileCols; ++j) { + subtile_at(i, j).store_safe( + dst, + ld, + Int<1>{}, + dst_tile_dims.y, + dst_tile_dims.x, + i * kSubTileRows, + j * kSubTileCols); + } + } + } + + template + METAL_FUNC void store_rows(device U* dst, const int ld, const short n_rows) + const { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kTileCols; ++j) { + subtile_at(i, j).store_rows( + &dst[(i * kSubTileRows) * ld + (j * kSubTileCols)], + ld, + Int<1>{}, + n_rows - i * kSubTileRows); + } + } + } + + template + METAL_FUNC void store_slice( + device U* dst, + const int ld, + const short2 start, + const short2 stop) const { + const_for_loop<0, kTileRows, 1>([&](auto idx_row) { + const_for_loop<0, kTileCols, 1>([&](auto idx_col) { + subtile_at().store_slice( + dst, + ld, + Int<1>{}, + start.y, + stop.y, + start.x, + stop.x, + idx_row * Int{}, + idx_col * Int{}); + }); + }); + } +}; + +template < + class CTile, + class ATile, + class BTile, + bool transpose_a, + bool transpose_b> +METAL_FUNC void tile_matmad_nax( + thread CTile& C, + thread ATile& A, + metal::bool_constant, + thread BTile& B, + metal::bool_constant) { + // Static checks + constexpr short TMa = transpose_a ? ATile::kTileCols : ATile::kTileRows; + constexpr short TMc = CTile::kTileRows; + static_assert(TMa == TMc, "NAX tile matmul: M dimensions do not match"); + + constexpr short FMa = transpose_a ? ATile::kSubTileCols : ATile::kSubTileRows; + constexpr short FMc = CTile::kSubTileRows; + static_assert(FMa == FMc, "NAX subtile matmul: M dimensions do not match"); + + constexpr short TNb = transpose_b ? BTile::kTileRows : BTile::kTileCols; + constexpr short TNc = CTile::kTileCols; + static_assert(TNb == TNc, "NAX tile matmul: N dimensions do not match"); + + constexpr short FNb = transpose_b ? BTile::kSubTileRows : BTile::kSubTileCols; + constexpr short FNc = CTile::kSubTileCols; + static_assert(FNb == FNc, "NAX subtile matmul: N dimensions do not match"); + + constexpr short TKa = transpose_a ? ATile::kTileRows : ATile::kTileCols; + constexpr short TKb = transpose_b ? BTile::kTileCols : BTile::kTileRows; + static_assert(TKa == TKb, "NAX tile matmul: K dimensions do not match"); + + constexpr short FKa = transpose_a ? ATile::kSubTileRows : ATile::kSubTileCols; + constexpr short FKb = transpose_b ? BTile::kSubTileCols : BTile::kSubTileRows; + static_assert(FKa == FKb, "NAX subtile matmul: K dimensions do not match"); + + constexpr short TM = TMc; + constexpr short TN = TNc; + constexpr short TK = TKa; + + // Do matmul here + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; ++j) { + STEEL_PRAGMA_UNROLL + for (short k = 0; k < TK; ++k) { + const short ra = transpose_a ? k : i; + const short ca = transpose_a ? i : k; + const short rb = transpose_b ? j : k; + const short cb = transpose_b ? k : j; + + subtile_matmad_nax( + C.subtile_at(i, j), + A.subtile_at(ra, ca), + metal::bool_constant{}, + B.subtile_at(rb, cb), + metal::bool_constant{}); + } + } + } +} + +} // namespace steel +} // namespace mlx diff --git a/Source/Cmlx/mlx-generated/metal/steel/attn/params.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/attn/params.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/steel/attn/params.h rename to Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/attn/params.h diff --git a/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h new file mode 100644 index 00000000..c0624d21 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h @@ -0,0 +1,71 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/backend/metal/kernels/steel/utils.h" + +/////////////////////////////////////////////////////////////////////////////// +// Transforms and Epilogues +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +template +struct TransformNone { + static METAL_FUNC OutT apply(InT x) { + return static_cast(x); + } + + static METAL_FUNC OutT apply(InT x, OutT) { + return static_cast(x); + } +}; + +template +struct TransformAdd { + TransformAdd(const float, const float) {} + + static METAL_FUNC OutT apply(InT x) { + return static_cast(x); + } + + static METAL_FUNC OutT apply(InT x, OutT c) { + return static_cast(x) + c; + } +}; + +template +struct TransformAxpby { + const float alpha; + const float beta; + + TransformAxpby(const float alpha_, const float beta_) + : alpha(alpha_), beta(beta_) {} + + static METAL_FUNC OutT apply(InT x) { + return static_cast(x); + } + + METAL_FUNC OutT apply(InT x, OutT c) const { + return static_cast(x * alpha + (beta * c)); + } +}; + +template +struct AccumHelper { + typedef float accum_type; +}; + +struct BlockSwizzle { + static METAL_FUNC int2 + swizzle(uint3 tid [[threadgroup_position_in_grid]], const int swizzle_log) { + const int tid_x = (tid.x) >> swizzle_log; + const int tid_y = + ((tid.y) << swizzle_log) + ((tid.x) & ((1 << swizzle_log) - 1)); + return int2(tid_x, tid_y); + } +}; + +} // namespace steel +} // namespace mlx \ No newline at end of file diff --git a/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/conv/conv.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/conv/conv.h new file mode 100644 index 00000000..d2e718f2 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/conv/conv.h @@ -0,0 +1,13 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/backend/metal/kernels/steel/defines.h" +#include "mlx/backend/metal/kernels/steel/utils.h" + +#include "mlx/backend/metal/kernels/steel/conv/loader.h" +#include "mlx/backend/metal/kernels/steel/conv/params.h" +#include "mlx/backend/metal/kernels/steel/gemm/mma.h" + +using namespace metal; +using namespace mlx::steel; diff --git a/Source/Cmlx/mlx-generated/metal/steel/conv/kernels/steel_conv.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/steel/conv/kernels/steel_conv.h rename to Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h diff --git a/Source/Cmlx/mlx-generated/metal/steel/conv/kernels/steel_conv_3d.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_3d.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/steel/conv/kernels/steel_conv_3d.h rename to Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_3d.h diff --git a/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h new file mode 100644 index 00000000..1241f773 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h @@ -0,0 +1,225 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h" + +constant bool align_C [[function_constant(200)]]; + +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + typename AccumType = float, + typename Epilogue = TransformNone> +[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void +implicit_gemm_conv_2d_general( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + device T* C [[buffer(2)]], + const constant MLXConvParams<2>* params [[buffer(3)]], + const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]], + const constant Conv2DGeneralJumpParams* jump_params [[buffer(5)]], + const constant Conv2DGeneralBaseInfo* base_h [[buffer(6)]], + const constant Conv2DGeneralBaseInfo* base_w [[buffer(7)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr bool transpose_a = false; + constexpr bool transpose_b = true; + constexpr short tgp_padding_a = 16 / sizeof(T); + constexpr short tgp_padding_b = 16 / sizeof(T); + + constexpr short shape_a_cols = (transpose_a ? BM : BK) + tgp_padding_a; + constexpr short shape_b_cols = (transpose_b ? BK : BN) + tgp_padding_b; + constexpr short shape_a_rows = (transpose_a ? BK : BM); + constexpr short shape_b_rows = (transpose_b ? BN : BK); + constexpr short tgp_mem_size_a = shape_a_cols * shape_a_rows; + constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows; + + constexpr short tgp_size = WM * WN * 32; + + // Input loader + using loader_a_t = + Conv2DInputBlockLoaderGeneral; + + // Weight loader + using loader_b_t = + Conv2DWeightBlockLoaderGeneral; + + using mma_t = BlockMMA< + T, + T, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + shape_a_cols, + shape_b_cols>; + + threadgroup T As[tgp_mem_size_a]; + threadgroup T Bs[tgp_mem_size_b]; + + const int tid_y = ((tid.y) << gemm_params->swizzle_log) + + ((tid.x) & ((1 << gemm_params->swizzle_log) - 1)); + const int tid_x = (tid.x) >> gemm_params->swizzle_log; + + if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) { + return; + } + + const int tid_z = tid.z; + + const int base_oh = tid_z / jump_params->f_out_jump_w; + const int base_ow = tid_z % jump_params->f_out_jump_w; + + const int base_wh = base_h[base_oh].weight_base; + const int base_ww = base_w[base_ow].weight_base; + + const int base_wh_size = base_h[base_oh].weight_size; + const int base_ww_size = base_w[base_ow].weight_size; + + const int c_row = tid_y * BM; + const int c_col = tid_x * BN; + const int K = gemm_params->K; + + B += c_col * K; + + const int4 offsets_a(0, c_row, base_oh, base_ow); + const int2 offsets_b(0, c_col); + + // Prepare threadgroup loading operations + loader_a_t loader_a( + A, + As, + offsets_a, + params, + jump_params, + base_wh, + base_ww, + simd_gid, + simd_lid); + loader_b_t loader_b( + B, + Bs, + offsets_b, + params, + jump_params, + base_wh, + base_ww, + simd_gid, + simd_lid); + + // Prepare threadgroup mma operation + mma_t mma_op(simd_gid, simd_lid); + + if (align_C) { + int gemm_k_iterations = + base_wh_size * base_ww_size * gemm_params->gemm_k_iterations; + + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + } + + else { + for (int k = 1; k < gemm_params->gemm_k_iterations; k++) { + for (int j = 0; j < base_wh_size * base_ww_size; j++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + } + const short remaining_k = params->C % BK; + for (int j = 0; j < base_wh_size * base_ww_size; j++) { + // Load elements into threadgroup + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_a.load_safe(remaining_k); + loader_b.load_safe(remaining_k); + threadgroup_barrier(mem_flags::mem_threadgroup); + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + } + + threadgroup_barrier(mem_flags::mem_none); + + // Store results to device memory + { + // Adjust for simdgroup and thread location + int offset_m = c_row + mma_op.sm; + int offset_n = c_col + mma_op.sn; + C += offset_n; + + if (offset_n >= gemm_params->N) + return; + + short diff = gemm_params->N - offset_n; + + STEEL_PRAGMA_UNROLL + for (int i = 0; i < mma_t::TM; i++) { + int cm = offset_m + i * mma_t::TM_stride; + + int n = cm / jump_params->adj_out_hw; + int hw = cm % jump_params->adj_out_hw; + int oh = + (hw / jump_params->adj_out_w) * jump_params->f_out_jump_h + base_oh; + int ow = + (hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + base_ow; + + if (n < params->N && oh < params->oS[0] && ow < params->oS[1]) { + int offset_cm = n * params->out_strides[0] + + oh * params->out_strides[1] + ow * params->out_strides[2]; + + STEEL_PRAGMA_UNROLL + for (int j = 0; j < mma_t::TN; j++) { + // Get accumulated result and associated offset in C + thread const auto& accum = mma_op.Ctile.frag_at(i, j); + int offset = offset_cm + (j * mma_t::TN_stride); + + constexpr short kelems = decltype(mma_op.Ctile)::kElemsPerFrag; + + // Apply epilogue and output C + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + if ((j * mma_t::TN_stride + k) < diff) { + C[offset + k] = Epilogue::apply(accum[k]); + } + } + } + } + } + } +} diff --git a/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/conv/loader.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/conv/loader.h new file mode 100644 index 00000000..f84a640f --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/conv/loader.h @@ -0,0 +1,6 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h" +#include "mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h" \ No newline at end of file diff --git a/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h new file mode 100644 index 00000000..9124e304 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h @@ -0,0 +1,955 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/backend/metal/kernels/steel/utils.h" + +#include "mlx/backend/metal/kernels/steel/conv/params.h" + +/////////////////////////////////////////////////////////////////////////////// +// Loading helper +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +template < + typename T, + short BM, + short BN, + short BK, + short tgp_size, + short tgp_padding = 0> +struct Conv2DInputBlockLoaderLargeFilter { + // Destination dimensions + STEEL_CONST short BROWS = BM; + STEEL_CONST short BCOLS = BK; + + // Read dimensions + STEEL_CONST short dst_ld = BCOLS + tgp_padding; + STEEL_CONST short vec_size = tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4; + + // Thread read shape + STEEL_CONST short TCOLS = BCOLS / vec_size; + STEEL_CONST short TROWS = tgp_size / TCOLS; + + // Rows / strided reads within the block + STEEL_CONST short n_rows = BROWS / TROWS; + + // Thread location indices + const short thread_idx; + const short bi; + const short bj; + + // threadgroup and device memory + threadgroup T* dst; + + const constant MLXConvParams<2>* params; + const constant ImplicitGemmConv2DParams* gemm_params; + + short weight_h; + short weight_w; + + const device T* src[n_rows]; + + int read_n[n_rows]; + int read_ih[n_rows]; + int read_iw[n_rows]; + + /* Constructor */ + METAL_FUNC Conv2DInputBlockLoaderLargeFilter( + const device T* src_, + threadgroup T* dst_, + const int2 offsets, + const constant MLXConvParams<2>* params_, + const constant ImplicitGemmConv2DParams* gemm_params_, + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) + : thread_idx(simd_group_id * 32 + simd_lane_id), + bi(thread_idx / TCOLS), + bj(vec_size * (thread_idx % TCOLS)), + dst(dst_ + bi * dst_ld + bj), + params(params_), + gemm_params(gemm_params_), + weight_h(0), + weight_w(0) { + int out_n_pixels = params->oS[0] * params->oS[1]; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; ++i) { + int offset_nhw = offsets.y + bi + i * TROWS; + int n = offset_nhw / out_n_pixels; + int hw = offset_nhw % out_n_pixels; + int oh = hw / params->oS[1]; + int ow = hw % params->oS[1]; + + int ih = oh * params->str[0] - params->pad[0]; + int iw = ow * params->str[1] - params->pad[1]; + + read_n[i] = n; + read_ih[i] = ih; + read_iw[i] = iw; + + // Adjust for flip + if (params->flip) { + ih += (params->wS[0] - 1) * params->kdil[0]; + iw += (params->wS[1] - 1) * params->kdil[1]; + } + + // Read from input if in bounds + src[i] = src_ + n * params->in_strides[0] + ih * params->in_strides[1] + + iw * params->in_strides[2] + bj; + } + } + + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void load_unsafe() const { + STEEL_PRAGMA_UNROLL + for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) { + // Find bounds + int n = read_n[i]; + int ih = read_ih[i] + weight_h * params->kdil[0]; + int iw = read_iw[i] + weight_w * params->kdil[1]; + + // Read from input if in bounds + if ((n < params->N) && (ih >= 0 && ih < params->iS[0]) && + (iw >= 0 && iw < params->iS[1])) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; ++j) { + dst[is * dst_ld + j] = src[i][j]; + } + } + + // Zero pad otherwise + else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; ++j) { + dst[is * dst_ld + j] = T(0); + } + } + } + } + + /* Iteration helper */ + METAL_FUNC void next() { + if (++weight_w < params->wS[1]) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; i++) { + src[i] += gemm_params->inp_jump_w; + } + + return; + } + + weight_w = 0; + + if (++weight_h < params->wS[0]) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; i++) { + src[i] += gemm_params->inp_jump_h; + } + + return; + } + + weight_h = 0; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; i++) { + src[i] += gemm_params->inp_jump_c; + } + } +}; + +template < + typename T, + short BM, + short BN, + short BK, + short tgp_size, + short tgp_padding = 0> +struct Conv2DInputBlockLoaderSmallFilter { + // Destination dimensions + STEEL_CONST short BROWS = BM; + STEEL_CONST short BCOLS = BK; + + // Read dimensions + STEEL_CONST short dst_ld = BCOLS + tgp_padding; + STEEL_CONST short vec_size = tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4; + + // Thread read shape + STEEL_CONST short TCOLS = BCOLS / vec_size; + STEEL_CONST short TROWS = tgp_size / TCOLS; + + // Rows / strided reads within the block + STEEL_CONST short n_rows = BROWS / TROWS; + + using mask_t = short; + + // Thread location indices + const short thread_idx; + const short bi; + const short bj; + + // threadgroup and device memory + threadgroup T* dst; + + const constant MLXConvParams<2>* params; + const constant ImplicitGemmConv2DParams* gemm_params; + + short weight_h; + short weight_w; + + const device T* src[n_rows]; + + mask_t mask_h[n_rows]; + mask_t mask_w[n_rows]; + + /* Constructor */ + METAL_FUNC Conv2DInputBlockLoaderSmallFilter( + const device T* src_, + threadgroup T* dst_, + const int2 offsets, + const constant MLXConvParams<2>* params_, + const constant ImplicitGemmConv2DParams* gemm_params_, + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) + : thread_idx(simd_group_id * 32 + simd_lane_id), + bi(thread_idx / TCOLS), + bj(vec_size * (thread_idx % TCOLS)), + dst(dst_ + bi * dst_ld + bj), + params(params_), + gemm_params(gemm_params_), + weight_h(0), + weight_w(0) { + int out_n_pixels = params->oS[0] * params->oS[1]; + + int read_n[n_rows]; + int read_ih[n_rows]; + int read_iw[n_rows]; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; ++i) { + int offset_nhw = offsets.y + bi + i * TROWS; + int n = offset_nhw / out_n_pixels; + int hw = offset_nhw % out_n_pixels; + int oh = hw / params->oS[1]; + int ow = hw % params->oS[1]; + + int ih = oh * params->str[0] - params->pad[0]; + int iw = ow * params->str[1] - params->pad[1]; + + read_n[i] = n; + read_ih[i] = ih; + read_iw[i] = iw; + + // Adjust for flip + if (params->flip) { + ih += (params->wS[0] - 1) * params->kdil[0]; + iw += (params->wS[1] - 1) * params->kdil[1]; + } + + // Read from input if in bounds + src[i] = src_ + n * params->in_strides[0] + ih * params->in_strides[1] + + iw * params->in_strides[2] + bj; + } + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; ++i) { + mask_h[i] = 0; + mask_w[i] = 0; + } + + for (short kh = 0; kh < params->wS[0]; kh++) { + short flip_h = params->flip ? params->wS[0] - kh - 1 : kh; + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; ++i) { + int n = read_n[i]; + int ih = read_ih[i] + flip_h * params->kdil[0]; + + bool in_bounds = n < params->N && ih >= 0 && ih < params->iS[0]; + + mask_h[i] |= (in_bounds << kh); + } + } + + for (short kw = 0; kw < params->wS[1]; kw++) { + short flip_w = params->flip ? params->wS[1] - kw - 1 : kw; + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; ++i) { + int iw = read_iw[i] + flip_w * params->kdil[1]; + + bool in_bounds = iw >= 0 && iw < params->iS[1]; + + mask_w[i] |= (in_bounds << kw); + } + } + } + + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void load_unsafe() const { + mask_t h_mask = mask_t(1) << weight_h; + mask_t w_mask = mask_t(1) << weight_w; + + STEEL_PRAGMA_UNROLL + for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) { + // Read from input if in bounds + if ((mask_h[i] & h_mask) && (mask_w[i] & w_mask)) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; ++j) { + dst[is * dst_ld + j] = src[i][j]; + } + } + + // Zero pad otherwise + else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; ++j) { + dst[is * dst_ld + j] = T(0); + } + } + } + } + + /* Iteration helper */ + METAL_FUNC void next() { + if (++weight_w < params->wS[1]) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; i++) { + src[i] += gemm_params->inp_jump_w; + } + + return; + } + + weight_w = 0; + + if (++weight_h < params->wS[0]) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; i++) { + src[i] += gemm_params->inp_jump_h; + } + + return; + } + + weight_h = 0; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; i++) { + src[i] += gemm_params->inp_jump_c; + } + } +}; + +template < + typename T, + short BM, + short BN, + short BK, + short tgp_size, + short tgp_padding = 0> +struct Conv2DWeightBlockLoader { + // Destination dimensions + STEEL_CONST short BROWS = BN; + STEEL_CONST short BCOLS = BK; + + // Read dimensions + STEEL_CONST short dst_ld = BCOLS + tgp_padding; + STEEL_CONST short vec_size = + (BN == 8) ? 1 : (tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4); + + // Thread read shape + STEEL_CONST short TCOLS = BCOLS / vec_size; + STEEL_CONST short TROWS = tgp_size / TCOLS; + + // Rows / strided reads within the block + STEEL_CONST short n_rows = BROWS / TROWS; + + // Leading dimension for src + const int src_ld; + + // Thread location indices + const short thread_idx; + const short bi; + const short bj; + + // threadgroup and device memory + threadgroup T* dst; + const device T* src; + + const constant MLXConvParams<2>* params; + + int weight_hw; + int weight_step; + + const int read_n; + const bool do_read; + + /* Constructor */ + METAL_FUNC Conv2DWeightBlockLoader( + const device T* src_, + threadgroup T* dst_, + const int2 offsets, + const constant MLXConvParams<2>* params_, + const constant ImplicitGemmConv2DParams* gemm_params_, + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(params_->wt_strides[0]), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(thread_idx / TCOLS), + bj(vec_size * (thread_idx % TCOLS)), + dst(dst_ + bi * dst_ld + bj), + src(src_ + bi * src_ld + bj), + params(params_), + weight_hw(0), + weight_step(params->C / params->groups), + read_n(offsets.y + bi), + do_read(read_n + n_rows * TROWS <= gemm_params_->N) {} + + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void load_unsafe() const { + if (BN != 8 || do_read) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BN; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = src[i * src_ld + j]; + } + } + } else { + for (short i = 0; i < BN; i += TROWS) { + if ((read_n + i) < params->O) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = src[i * src_ld + j]; + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = T(0); + } + } + } + } + } + + /* Iteration helper */ + METAL_FUNC void next() { + if (++weight_hw < (params->wS[1] * params->wS[0])) { + src += weight_step; + return; + } + + weight_hw = 0; + + src += BK - (params->wS[1] * params->wS[0] - 1) * weight_step; + } +}; + +template < + typename T, + short BM, + short BN, + short BK, + short tgp_size, + short tgp_padding = 0> +struct Conv3DInputBlockLoaderLargeFilter { + // Destination dimensions + STEEL_CONST short BROWS = BM; + STEEL_CONST short BCOLS = BK; + + // Read dimensions + STEEL_CONST short dst_ld = BCOLS + tgp_padding; + STEEL_CONST short vec_size = tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4; + + // Thread read shape + STEEL_CONST short TCOLS = BCOLS / vec_size; + STEEL_CONST short TROWS = tgp_size / TCOLS; + + // Rows / strided reads within the block + STEEL_CONST short n_rows = BROWS / TROWS; + + // Thread location indices + const short thread_idx; + const short bi; + const short bj; + + // threadgroup and device memory + threadgroup T* dst; + + const constant MLXConvParams<3>* params; + const constant ImplicitGemmConv3DParams* gemm_params; + + short weight_d; + short weight_h; + short weight_w; + + short kdil_d; + short kdil_h; + short kdil_w; + + const device T* src[n_rows]; + + int read_n[n_rows]; + int read_id[n_rows]; + int read_ih[n_rows]; + int read_iw[n_rows]; + + /* Constructor */ + METAL_FUNC Conv3DInputBlockLoaderLargeFilter( + const device T* src_, + threadgroup T* dst_, + const int2 offsets, + const constant MLXConvParams<3>* params_, + const constant ImplicitGemmConv3DParams* gemm_params_, + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) + : thread_idx(simd_group_id * 32 + simd_lane_id), + bi(thread_idx / TCOLS), + bj(vec_size * (thread_idx % TCOLS)), + dst(dst_ + bi * dst_ld + bj), + params(params_), + gemm_params(gemm_params_), + weight_d(0), + weight_h(0), + weight_w(0), + kdil_d(params_->flip ? -params_->kdil[0] : params_->kdil[0]), + kdil_h(params_->flip ? -params_->kdil[1] : params_->kdil[1]), + kdil_w(params_->flip ? -params_->kdil[2] : params_->kdil[2]) { + int out_n_pixels = params->oS[0] * params->oS[1] * params->oS[2]; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; ++i) { + int offset_ndhw = offsets.y + bi + i * TROWS; + int n = offset_ndhw / out_n_pixels; + int dhw = offset_ndhw % out_n_pixels; + int od = dhw / (params->oS[1] * params->oS[2]); + int hw = dhw % (params->oS[1] * params->oS[2]); + int oh = hw / params->oS[2]; + int ow = hw % params->oS[2]; + + int id = od * params->str[0] - params->pad[0]; + int ih = oh * params->str[1] - params->pad[1]; + int iw = ow * params->str[2] - params->pad[2]; + + read_n[i] = n; + + if (params->flip) { + read_id[i] = id + (params->wS[0] - 1) * params->kdil[0]; + read_ih[i] = ih + (params->wS[1] - 1) * params->kdil[1]; + read_iw[i] = iw + (params->wS[2] - 1) * params->kdil[2]; + } else { + read_id[i] = id; + read_ih[i] = ih; + read_iw[i] = iw; + } + + // Adjust for flip + if (params->flip) { + id += (params->wS[0] - 1) * params->kdil[0]; + ih += (params->wS[1] - 1) * params->kdil[1]; + iw += (params->wS[2] - 1) * params->kdil[2]; + } + + // Read from input if in bounds + src[i] = src_ + n * params->in_strides[0] + id * params->in_strides[1] + + ih * params->in_strides[2] + iw * params->in_strides[3] + bj; + } + } + + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void load_unsafe() const { + STEEL_PRAGMA_UNROLL + for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) { + // Find bounds + int n = read_n[i]; + int id = read_id[i] + weight_d * kdil_d; + int ih = read_ih[i] + weight_h * kdil_h; + int iw = read_iw[i] + weight_w * kdil_w; + + // Read from input if in bounds + if ((n < params->N) && (id >= 0 && id < params->iS[0]) && + (ih >= 0 && ih < params->iS[1]) && (iw >= 0 && iw < params->iS[2])) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; ++j) { + dst[is * dst_ld + j] = src[i][j]; + } + } + + // Zero pad otherwise + else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; ++j) { + dst[is * dst_ld + j] = T(0); + } + } + } + } + + /* Iteration helper */ + METAL_FUNC void next() { + if (++weight_w < params->wS[2]) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; i++) { + src[i] += gemm_params->inp_jump_w; + } + + return; + } + + weight_w = 0; + + if (++weight_h < params->wS[1]) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; i++) { + src[i] += gemm_params->inp_jump_h; + } + + return; + } + + weight_h = 0; + + if (++weight_d < params->wS[0]) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; i++) { + src[i] += gemm_params->inp_jump_d; + } + + return; + } + + weight_d = 0; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; i++) { + src[i] += gemm_params->inp_jump_c; + } + } +}; + +template < + typename T, + short BM, + short BN, + short BK, + short tgp_size, + short tgp_padding = 0> +struct Conv3DInputBlockLoaderSmallFilter { + // Destination dimensions + STEEL_CONST short BROWS = BM; + STEEL_CONST short BCOLS = BK; + + // Read dimensions + STEEL_CONST short dst_ld = BCOLS + tgp_padding; + STEEL_CONST short vec_size = tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4; + + // Thread read shape + STEEL_CONST short TCOLS = BCOLS / vec_size; + STEEL_CONST short TROWS = tgp_size / TCOLS; + + // Rows / strided reads within the block + STEEL_CONST short n_rows = BROWS / TROWS; + + using mask_t = short; + + // Thread location indices + const short thread_idx; + const short bi; + const short bj; + + // threadgroup and device memory + threadgroup T* dst; + + const constant MLXConvParams<3>* params; + const constant ImplicitGemmConv3DParams* gemm_params; + + short weight_d; + short weight_h; + short weight_w; + + const device T* src[n_rows]; + + mask_t mask_d[n_rows]; + mask_t mask_h[n_rows]; + mask_t mask_w[n_rows]; + + /* Constructor */ + METAL_FUNC Conv3DInputBlockLoaderSmallFilter( + const device T* src_, + threadgroup T* dst_, + const int2 offsets, + const constant MLXConvParams<3>* params_, + const constant ImplicitGemmConv3DParams* gemm_params_, + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) + : thread_idx(simd_group_id * 32 + simd_lane_id), + bi(thread_idx / TCOLS), + bj(vec_size * (thread_idx % TCOLS)), + dst(dst_ + bi * dst_ld + bj), + params(params_), + gemm_params(gemm_params_), + weight_d(0), + weight_h(0), + weight_w(0) { + int out_n_pixels = params->oS[0] * params->oS[1] * params->oS[2]; + + int read_n[n_rows]; + int read_id[n_rows]; + int read_ih[n_rows]; + int read_iw[n_rows]; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; ++i) { + int offset_ndhw = offsets.y + bi + i * TROWS; + int n = offset_ndhw / out_n_pixels; + int dhw = offset_ndhw % out_n_pixels; + int od = dhw / (params->oS[1] * params->oS[2]); + int hw = dhw % (params->oS[1] * params->oS[2]); + int oh = hw / params->oS[2]; + int ow = hw % params->oS[2]; + + int id = od * params->str[0] - params->pad[0]; + int ih = oh * params->str[1] - params->pad[1]; + int iw = ow * params->str[2] - params->pad[2]; + + read_n[i] = n; + read_id[i] = id; + read_ih[i] = ih; + read_iw[i] = iw; + + // Adjust for flip + if (params->flip) { + id += (params->wS[0] - 1) * params->kdil[0]; + ih += (params->wS[1] - 1) * params->kdil[1]; + iw += (params->wS[2] - 1) * params->kdil[2]; + } + + // Read from input if in bounds + src[i] = src_ + n * params->in_strides[0] + id * params->in_strides[1] + + ih * params->in_strides[2] + iw * params->in_strides[3] + bj; + } + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; ++i) { + mask_d[i] = 0; + mask_h[i] = 0; + mask_w[i] = 0; + } + + for (short kd = 0; kd < params->wS[0]; kd++) { + short flip_d = params->flip ? params->wS[0] - kd - 1 : kd; + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; ++i) { + int n = read_n[i]; + int id = read_id[i] + flip_d * params->kdil[0]; + + bool in_bounds = n < params->N && id >= 0 && id < params->iS[0]; + + mask_d[i] |= (in_bounds << kd); + } + } + + for (short kh = 0; kh < params->wS[1]; kh++) { + short flip_h = params->flip ? params->wS[1] - kh - 1 : kh; + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; ++i) { + int ih = read_ih[i] + flip_h * params->kdil[1]; + + bool in_bounds = ih >= 0 && ih < params->iS[1]; + + mask_h[i] |= (in_bounds << kh); + } + } + + for (short kw = 0; kw < params->wS[2]; kw++) { + short flip_w = params->flip ? params->wS[2] - kw - 1 : kw; + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; ++i) { + int iw = read_iw[i] + flip_w * params->kdil[2]; + + bool in_bounds = iw >= 0 && iw < params->iS[2]; + + mask_w[i] |= (in_bounds << kw); + } + } + } + + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void load_unsafe() const { + mask_t d_mask = mask_t(1) << weight_d; + mask_t h_mask = mask_t(1) << weight_h; + mask_t w_mask = mask_t(1) << weight_w; + + STEEL_PRAGMA_UNROLL + for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) { + // Read from input if in bounds + if ((mask_d[i] & d_mask) && (mask_h[i] & h_mask) && + (mask_w[i] & w_mask)) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; ++j) { + dst[is * dst_ld + j] = src[i][j]; + } + } + + // Zero pad otherwise + else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; ++j) { + dst[is * dst_ld + j] = T(0); + } + } + } + } + + /* Iteration helper */ + METAL_FUNC void next() { + if (++weight_w < params->wS[2]) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; i++) { + src[i] += gemm_params->inp_jump_w; + } + + return; + } + + weight_w = 0; + + if (++weight_h < params->wS[1]) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; i++) { + src[i] += gemm_params->inp_jump_h; + } + + return; + } + + weight_h = 0; + + if (++weight_d < params->wS[0]) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; i++) { + src[i] += gemm_params->inp_jump_d; + } + + return; + } + + weight_d = 0; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; i++) { + src[i] += gemm_params->inp_jump_c; + } + } +}; + +template < + typename T, + short BM, + short BN, + short BK, + short tgp_size, + short tgp_padding = 0> +struct Conv3DWeightBlockLoader { + // Destination dimensions + STEEL_CONST short BROWS = BN; + STEEL_CONST short BCOLS = BK; + + // Read dimensions + STEEL_CONST short dst_ld = BCOLS + tgp_padding; + STEEL_CONST short vec_size = + (BN == 8) ? 1 : (tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4); + + // Thread read shape + STEEL_CONST short TCOLS = BCOLS / vec_size; + STEEL_CONST short TROWS = tgp_size / TCOLS; + + // Rows / strided reads within the block + STEEL_CONST short n_rows = BROWS / TROWS; + + // Leading dimension for src + const int src_ld; + + // Thread location indices + const short thread_idx; + const short bi; + const short bj; + + // threadgroup and device memory + threadgroup T* dst; + const device T* src; + + const constant MLXConvParams<3>* params; + + int weight_dhw; + int weight_step; + + const int read_n; + const bool do_read; + + /* Constructor */ + METAL_FUNC Conv3DWeightBlockLoader( + const device T* src_, + threadgroup T* dst_, + const int2 offsets, + const constant MLXConvParams<3>* params_, + const constant ImplicitGemmConv3DParams* gemm_params_, + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(params_->wt_strides[0]), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(thread_idx / TCOLS), + bj(vec_size * (thread_idx % TCOLS)), + dst(dst_ + bi * dst_ld + bj), + src(src_ + bi * src_ld + bj), + params(params_), + weight_dhw(0), + weight_step(params->C / params->groups), + read_n(offsets.y + bi), + do_read(read_n + n_rows * TROWS <= gemm_params_->N) {} + + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void load_unsafe() const { + if (BN != 8 || do_read) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BN; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = src[i * src_ld + j]; + } + } + } else { + for (short i = 0; i < BN; i += TROWS) { + if ((read_n + i) < params->O) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = src[i * src_ld + j]; + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = T(0); + } + } + } + } + } + + /* Iteration helper */ + METAL_FUNC void next() { + if (++weight_dhw < (params->wS[0] * params->wS[1] * params->wS[2])) { + src += weight_step; + return; + } + + weight_dhw = 0; + + src += + BK - (params->wS[0] * params->wS[1] * params->wS[2] - 1) * weight_step; + } +}; + +} // namespace steel +} // namespace mlx diff --git a/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h new file mode 100644 index 00000000..2312e1ca --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h @@ -0,0 +1,319 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/backend/metal/kernels/steel/utils.h" + +#include "mlx/backend/metal/kernels/steel/conv/params.h" + +/////////////////////////////////////////////////////////////////////////////// +// Loading helper +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +template +struct ChannelHelper { + STEEL_CONST short n_channels = n_channels_; + STEEL_CONST short vec_size = n_channels_ <= 4 ? 4 : 8; + STEEL_CONST short excess = vec_size - n_channels_; +}; + +template <> +struct ChannelHelper<1> { + STEEL_CONST short n_channels = 1; + STEEL_CONST short vec_size = 1; + STEEL_CONST short excess = 0; +}; + +template <> +struct ChannelHelper<2> { + STEEL_CONST short n_channels = 2; + STEEL_CONST short vec_size = 2; + STEEL_CONST short excess = 0; +}; + +template <> +struct ChannelHelper<3> { + STEEL_CONST short n_channels = 3; + STEEL_CONST short vec_size = 4; + STEEL_CONST short excess = 1; +}; + +template <> +struct ChannelHelper<4> { + STEEL_CONST short n_channels = 4; + STEEL_CONST short vec_size = 4; + STEEL_CONST short excess = 0; +}; + +template < + typename T, + short BM, + short BN, + short BK, + short tgp_size, + short n_channels, + short tgp_padding = 0> +struct Conv2DInputBlockLoaderSmallChannels { + // Destination dimensions + STEEL_CONST short BROWS = BM; + STEEL_CONST short BCOLS = BK; + + // Read dimensions + STEEL_CONST short dst_ld = BCOLS + tgp_padding; + STEEL_CONST short vec_size = ChannelHelper::vec_size; + + // Thread read shape + STEEL_CONST short TCOLS = BCOLS / vec_size; + STEEL_CONST short TROWS = tgp_size / TCOLS; + + // Rows / strided reads within the block + STEEL_CONST short n_rows = BROWS / TROWS; + + // Thread location indices + const short thread_idx; + const short bi; + const short bj; + + // threadgroup and device memory + threadgroup T* dst; + + const constant MLXConvParams<2>* params; + const constant ImplicitGemmConv2DParams* gemm_params; + + int weight_hw; + + const device T* src[n_rows]; + + int read_n[n_rows]; + int read_ih[n_rows]; + int read_iw[n_rows]; + + /* Constructor */ + METAL_FUNC Conv2DInputBlockLoaderSmallChannels( + const device T* src_, + threadgroup T* dst_, + const int2 offsets, + const constant MLXConvParams<2>* params_, + const constant ImplicitGemmConv2DParams* gemm_params_, + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) + : thread_idx(simd_group_id * 32 + simd_lane_id), + bi(thread_idx / TCOLS), + bj(vec_size * (thread_idx % TCOLS)), + dst(dst_ + bi * dst_ld + bj), + params(params_), + gemm_params(gemm_params_), + weight_hw(thread_idx % TCOLS) { + int out_n_pixels = params->oS[0] * params->oS[1]; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; ++i) { + int offset_nhw = offsets.y + bi + i * TROWS; + int n = offset_nhw / out_n_pixels; + int hw = offset_nhw % out_n_pixels; + int oh = hw / params->oS[1]; + int ow = hw % params->oS[1]; + + int ih = oh * params->str[0] - params->pad[0]; + int iw = ow * params->str[1] - params->pad[1]; + + // Read from input if in bounds + src[i] = src_ + n * params->in_strides[0] + ih * params->in_strides[1] + + iw * params->in_strides[2]; + + read_n[i] = n; + read_ih[i] = ih; + read_iw[i] = iw; + } + } + + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void load_unsafe() const { + if (weight_hw >= params->wS[1] * params->wS[0]) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = T(0); + } + } + return; + } + + int wh = (weight_hw / params->wS[1]); + int ww = (weight_hw % params->wS[1]); + + int flip_h = params->flip ? params->wS[0] - wh - 1 : wh; + int flip_w = params->flip ? params->wS[1] - ww - 1 : ww; + + int weight_h = flip_h * params->kdil[0]; + int weight_w = flip_w * params->kdil[1]; + + STEEL_PRAGMA_UNROLL + for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) { + // Find bounds + int n = read_n[i]; + int ih = read_ih[i] + weight_h; + int iw = read_iw[i] + weight_w; + + // Read from input if in bounds + if ((n < params->N) && (ih >= 0 && ih < params->iS[0]) && + (iw >= 0 && iw < params->iS[1])) { + const device T* curr_src = src[i] + weight_h * params->in_strides[1] + + weight_w * params->in_strides[2]; + + STEEL_PRAGMA_UNROLL + for (short j = 0; j < n_channels; ++j) { + dst[is * dst_ld + j] = curr_src[j]; + } + + STEEL_PRAGMA_UNROLL + for (short j = n_channels; j < vec_size; ++j) { + dst[is * dst_ld + j] = T(0); + } + } + + // Zero pad otherwise + else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; ++j) { + dst[is * dst_ld + j] = T(0); + } + } + } + } + + /* Iteration helper */ + METAL_FUNC void next() { + weight_hw += TCOLS; + } +}; + +template < + typename T, + short BM, + short BN, + short BK, + short tgp_size, + short n_channels, + short tgp_padding = 0> +struct Conv2DWeightBlockLoaderSmallChannels { + // Destination dimensions + STEEL_CONST short BROWS = BN; + STEEL_CONST short BCOLS = BK; + + // Read dimensions + STEEL_CONST short dst_ld = BCOLS + tgp_padding; + STEEL_CONST short vec_size = ChannelHelper::vec_size; + + // Thread read shape + STEEL_CONST short TCOLS = BCOLS / vec_size; + STEEL_CONST short TROWS = tgp_size / TCOLS; + + // Rows / strided reads within the block + STEEL_CONST short n_rows = BROWS / TROWS; + + // Leading dimension for src + const int src_ld; + + // Thread location indices + const short thread_idx; + const short bi; + const short bj; + + // threadgroup and device memory + threadgroup T* dst; + const device T* src; + + const constant MLXConvParams<2>* params; + + int weight_hw; + + const int read_n; + const bool do_read; + + /* Constructor */ + METAL_FUNC Conv2DWeightBlockLoaderSmallChannels( + const device T* src_, + threadgroup T* dst_, + const int2 offsets, + const constant MLXConvParams<2>* params_, + const constant ImplicitGemmConv2DParams* gemm_params_, + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(params_->wt_strides[0]), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(thread_idx / TCOLS), + bj(vec_size * (thread_idx % TCOLS)), + dst(dst_ + bi * dst_ld + bj), + src(src_ + bi * src_ld), + params(params_), + weight_hw(thread_idx % TCOLS), + read_n(offsets.y + bi), + do_read(read_n + BN <= gemm_params_->N) {} + + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void load_unsafe() const { + if (bi >= BROWS || bj >= BCOLS) + return; + + if (read_n >= params->O || weight_hw >= params->wS[1] * params->wS[0]) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = T(0); + } + } + + return; + } + + const device T* curr_src = src + weight_hw * (params->C / params->groups); + + if (BN != 8 || do_read) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < n_channels; j++) { + dst[i * dst_ld + j] = curr_src[i * src_ld + j]; + } + + STEEL_PRAGMA_UNROLL + for (short j = n_channels; j < vec_size; j++) { + dst[i * dst_ld + j] = T(0); + } + } + } else { + for (short i = 0; i < BROWS; i += TROWS) { + if (((read_n + i) < params->O)) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < n_channels; j++) { + dst[i * dst_ld + j] = curr_src[i * src_ld + j]; + } + + STEEL_PRAGMA_UNROLL + for (short j = n_channels; j < vec_size; j++) { + dst[i * dst_ld + j] = T(0); + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = T(0); + } + } + } + } + } + + /* Iteration helper */ + METAL_FUNC void next() { + weight_hw += TCOLS; + } +}; + +} // namespace steel +} // namespace mlx diff --git a/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h new file mode 100644 index 00000000..9b7ddc2e --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h @@ -0,0 +1,381 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/backend/metal/kernels/steel/defines.h" + +/////////////////////////////////////////////////////////////////////////////// +// Loading helper +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +template < + typename T, + short BM, + short BN, + short BK, + short tgp_size, + short tgp_padding = 0> +struct Conv2DInputBlockLoaderGeneral { + // Destination dimensions + STEEL_CONST short BROWS = BM; + STEEL_CONST short BCOLS = BK; + + // Read dimensions + STEEL_CONST short dst_ld = BCOLS + tgp_padding; + STEEL_CONST short vec_size = tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4; + + // Thread read shape + STEEL_CONST short TCOLS = BCOLS / vec_size; + STEEL_CONST short TROWS = tgp_size / TCOLS; + + // Rows / strided reads within the block + STEEL_CONST short n_rows = BROWS / TROWS; + + // Thread location indices + const short thread_idx; + const short bi; + const short bj; + + // threadgroup and device memory + threadgroup T* dst; + + const constant MLXConvParams<2>* params; + const constant Conv2DGeneralJumpParams* jump_params; + + const short base_wh; + const short base_ww; + + short weight_h; + short weight_w; + + const device T* src[n_rows]; + + int read_n[n_rows]; + int read_ih[n_rows]; + int read_iw[n_rows]; + + /* Constructor */ + METAL_FUNC Conv2DInputBlockLoaderGeneral( + const device T* src_, + threadgroup T* dst_, + const int4 offsets, + const constant MLXConvParams<2>* params_, + const constant Conv2DGeneralJumpParams* jump_params_, + const short base_wh_, + const short base_ww_, + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) + : thread_idx(simd_group_id * 32 + simd_lane_id), + bi(thread_idx / TCOLS), + bj(vec_size * (thread_idx % TCOLS)), + dst(dst_ + bi * dst_ld + bj), + params(params_), + jump_params(jump_params_), + base_wh(base_wh_), + base_ww(base_ww_), + weight_h(base_wh_), + weight_w(base_ww_) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; ++i) { + int offset_nhw = offsets.y + bi + i * TROWS; + int n = offset_nhw / jump_params->adj_out_hw; + int hw = offset_nhw % jump_params->adj_out_hw; + int oh = + (hw / jump_params->adj_out_w) * jump_params->f_out_jump_h + offsets.z; + int ow = + (hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + offsets.w; + + int ih = oh * params->str[0] - params->pad[0]; + int iw = ow * params->str[1] - params->pad[1]; + + read_n[i] = n; + read_ih[i] = ih; + read_iw[i] = iw; + + // Read from input if in bounds + src[i] = src_ + n * params->in_strides[0] + bj; + } + } + + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void load_unsafe() const { + STEEL_PRAGMA_UNROLL + for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) { + // Find bounds + int n = read_n[i]; + + int h_flip = params->flip ? params->wS[0] - weight_h - 1 : weight_h; + int w_flip = params->flip ? params->wS[1] - weight_w - 1 : weight_w; + + int ih_dil = read_ih[i] + h_flip * params->kdil[0]; + int iw_dil = read_iw[i] + w_flip * params->kdil[1]; + + int ih = ih_dil / params->idil[0]; + int iw = iw_dil / params->idil[1]; + + size_t offset = ih * params->in_strides[1] + iw * params->in_strides[2]; + + // Read from input if in bounds + if ((n < params->N) && (ih_dil >= 0 && ih < params->iS[0]) && + (iw_dil >= 0 && iw < params->iS[1])) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; ++j) { + dst[is * dst_ld + j] = (src[i])[offset + j]; + } + } + + // Zero pad otherwise + else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; ++j) { + dst[is * dst_ld + j] = T(0); + } + } + } + } + + METAL_FUNC void load_safe(const short remaining_k) const { + STEEL_PRAGMA_UNROLL + for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) { + // Find bounds + int n = read_n[i]; + + int h_flip = params->flip ? params->wS[0] - weight_h - 1 : weight_h; + int w_flip = params->flip ? params->wS[1] - weight_w - 1 : weight_w; + + int ih_dil = read_ih[i] + h_flip * params->kdil[0]; + int iw_dil = read_iw[i] + w_flip * params->kdil[1]; + + int ih = ih_dil / params->idil[0]; + int iw = iw_dil / params->idil[1]; + + size_t offset = ih * params->in_strides[1] + iw * params->in_strides[2]; + + // Read from input if in bounds + if ((n < params->N) && (ih_dil >= 0 && ih < params->iS[0]) && + (iw_dil >= 0 && iw < params->iS[1])) { + if (bj + vec_size <= remaining_k) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; ++j) { + dst[is * dst_ld + j] = (src[i])[offset + j]; + } + } else { + for (short j = 0; j < vec_size; ++j) { + if (bj + j < remaining_k) { + dst[is * dst_ld + j] = (src[i])[offset + j]; + } else { + dst[is * dst_ld + j] = T(0); + } + } + } + } + + // Zero pad otherwise + else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; ++j) { + dst[is * dst_ld + j] = T(0); + } + } + } + } + + /* Iteration helper */ + METAL_FUNC void next() { + weight_w += jump_params->f_wgt_jump_w; + if (weight_w < params->wS[1]) { + return; + } + + weight_w = base_ww; + + weight_h += jump_params->f_wgt_jump_h; + if (weight_h < params->wS[0]) { + return; + } + + weight_h = base_wh; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; i++) { + src[i] += BK; + } + } +}; + +template < + typename T, + short BM, + short BN, + short BK, + short tgp_size, + short tgp_padding = 0> +struct Conv2DWeightBlockLoaderGeneral { + // Destination dimensions + STEEL_CONST short BROWS = BN; + STEEL_CONST short BCOLS = BK; + + // Read dimensions + STEEL_CONST short dst_ld = BCOLS + tgp_padding; + STEEL_CONST short vec_size = + (BN == 8) ? 1 : (tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4); + + // Thread read shape + STEEL_CONST short TCOLS = BCOLS / vec_size; + STEEL_CONST short TROWS = tgp_size / TCOLS; + + // Rows / strided reads within the block + STEEL_CONST short n_rows = BROWS / TROWS; + + // Leading dimension for src + const int src_ld; + + // Thread location indices + const short thread_idx; + const short bi; + const short bj; + + // threadgroup and device memory + threadgroup T* dst; + const device T* src; + + const constant MLXConvParams<2>* params; + const constant Conv2DGeneralJumpParams* jump_params; + + const short base_wh; + const short base_ww; + + short weight_h; + short weight_w; + + const int start_row; + + /* Constructor */ + METAL_FUNC Conv2DWeightBlockLoaderGeneral( + const device T* src_, + threadgroup T* dst_, + const int2 offsets, + const constant MLXConvParams<2>* params_, + const constant Conv2DGeneralJumpParams* jump_params_, + const short base_wh_, + const short base_ww_, + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(params_->wt_strides[0]), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(thread_idx / TCOLS), + bj(vec_size * (thread_idx % TCOLS)), + dst(dst_ + bi * dst_ld + bj), + src(src_ + bi * src_ld + bj), + params(params_), + jump_params(jump_params_), + base_wh(base_wh_), + base_ww(base_ww_), + weight_h(base_wh_), + weight_w(base_ww_), + start_row(offsets.y + bi) {} + + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void load_unsafe() const { + const device T* curr_src = src + weight_h * params->wt_strides[1] + + weight_w * params->wt_strides[2]; + + if ((start_row + BN <= params->O)) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BN; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = curr_src[i * src_ld + j]; + } + } + } else { + for (short i = 0; i < BN; i += TROWS) { + if ((start_row + i) < params->O) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = curr_src[i * src_ld + j]; + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = T(0); + } + } + } + } + } + + METAL_FUNC void load_safe(const short remaining_k) const { + const device T* curr_src = src + weight_h * params->wt_strides[1] + + weight_w * params->wt_strides[2]; + + if ((start_row + BN <= params->O)) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BN; i += TROWS) { + if (bj + vec_size <= remaining_k) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = curr_src[i * src_ld + j]; + } + } else { + for (short j = 0; j < vec_size; j++) { + if (bj + j < remaining_k) { + dst[i * dst_ld + j] = curr_src[i * src_ld + j]; + } else { + dst[i * dst_ld + j] = T(0); + } + } + } + } + } else { + for (short i = 0; i < BN; i += TROWS) { + if ((start_row + i) < params->O) { + if (bj + vec_size <= remaining_k) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = curr_src[i * src_ld + j]; + } + } else { + for (short j = 0; j < vec_size; j++) { + if (bj + j < remaining_k) { + dst[i * dst_ld + j] = curr_src[i * src_ld + j]; + } else { + dst[i * dst_ld + j] = T(0); + } + } + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = T(0); + } + } + } + } + } + + /* Iteration helper */ + METAL_FUNC void next() { + weight_w += jump_params->f_wgt_jump_w; + if (weight_w < params->wS[1]) { + return; + } + + weight_w = base_ww; + + weight_h += jump_params->f_wgt_jump_h; + if (weight_h < params->wS[0]) { + return; + } + + weight_h = base_wh; + + src += BK; + } +}; + +} // namespace steel +} // namespace mlx diff --git a/Source/Cmlx/mlx-generated/metal/steel/conv/params.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/conv/params.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/steel/conv/params.h rename to Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/conv/params.h diff --git a/Source/Cmlx/mlx-generated/metal/steel/defines.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/defines.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/steel/defines.h rename to Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/defines.h diff --git a/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h new file mode 100644 index 00000000..bbe1d96c --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h @@ -0,0 +1,295 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/backend/metal/kernels/steel/gemm/loader.h" +#include "mlx/backend/metal/kernels/steel/gemm/mma.h" +#include "mlx/backend/metal/kernels/steel/gemm/params.h" +#include "mlx/backend/metal/kernels/steel/gemm/transforms.h" +#include "mlx/backend/metal/kernels/steel/utils.h" + +using namespace metal; + +/////////////////////////////////////////////////////////////////////////////// +// GEMM kernel class +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +template +struct LoopAlignment {}; + +template < + typename T, + typename U, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + bool MN_aligned, + bool K_aligned, + typename AccumType = typename AccumHelper::accum_type, + typename Epilogue = TransformNone> +struct GEMMKernel { + STEEL_CONST short tgp_padding_a = 16 / sizeof(T); + STEEL_CONST short tgp_padding_b = 16 / sizeof(T); + STEEL_CONST short tgp_mem_size_a = + transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a); + STEEL_CONST short tgp_mem_size_b = + transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b); + STEEL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b; + + STEEL_CONST short tgp_size = WM * WN * 32; + + using loader_a_t = BlockLoader< + T, + transpose_a ? BK : BM, + transpose_a ? BM : BK, + transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a, + !transpose_a, + tgp_size>; + using loader_b_t = BlockLoader< + T, + transpose_b ? BN : BK, + transpose_b ? BK : BN, + transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b, + transpose_b, + tgp_size>; + using mma_t = BlockMMA< + T, + U, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a, + transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b, + AccumType, + Epilogue>; + + /* Main kernel function */ + template + static METAL_FUNC void gemm_loop( + threadgroup T* As [[threadgroup(0)]], + threadgroup T* Bs [[threadgroup(1)]], + const int gemm_k_iterations, + thread loader_a_t& loader_a, + thread loader_b_t& loader_b, + thread mma_t& mma_op, + thread const short& tgp_bm, + thread const short& tgp_bn, + thread const short& lbk, + LoopAlignment l = {}) { + // Appease the compiler + (void)l; + + short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm); + + short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK); + + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + if (M_aligned) { + loader_a.load_unsafe(); + } else { + loader_a.load_safe(tile_dims_A); + } + + if (N_aligned) { + loader_b.load_unsafe(); + } else { + loader_b.load_safe(tile_dims_B); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + if (!K_aligned_) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + short2 tile_dims_A_last = + transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm); + short2 tile_dims_B_last = + transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk); + + loader_a.load_safe(tile_dims_A_last); + loader_b.load_safe(tile_dims_B_last); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + mma_op.mma(As, Bs); + } + } + + /* Main kernel function */ + static METAL_FUNC void run( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + device U* D [[buffer(2)]], + const constant GEMMParams* params [[buffer(3)]], + threadgroup T* As [[threadgroup(0)]], + threadgroup T* Bs [[threadgroup(1)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + // Pacifying compiler + (void)lid; + + const int tid_y = ((tid.y) << params->swizzle_log) + + ((tid.x) & ((1 << params->swizzle_log) - 1)); + const int tid_x = (tid.x) >> params->swizzle_log; + + if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { + return; + } + + threadgroup_barrier(mem_flags::mem_none); + + // Find block in A, B, C + const int c_row = tid_y * BM; + const int c_col = tid_x * BN; + const size_t c_row_long = size_t(c_row); + const size_t c_col_long = size_t(c_col); + + A += transpose_a ? c_row_long : c_row_long * params->lda; + B += transpose_b ? c_col_long * params->ldb : c_col_long; + D += c_row_long * params->ldd + c_col_long; + + // Prepare threadgroup loading operations + thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); + thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); + + // Prepare threadgroup mma operation + thread mma_t mma_op(simd_group_id, simd_lane_id); + + int gemm_k_iterations = params->gemm_k_iterations_aligned; + + /////////////////////////////////////////////////////////////////////////////// + // MNK aligned loop + if (MN_aligned) { + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + threadgroup_barrier(mem_flags::mem_none); + + // Loop tail + if (!K_aligned) { + int lbk = params->K - params->gemm_k_iterations_aligned * BK; + short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM); + short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk); + + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + mma_op.mma(As, Bs); + } + + // Store results to device memory + mma_op.store_result(D, params->ldd); + return; + + } + /////////////////////////////////////////////////////////////////////////////// + // MN unaligned loop + else { // Loop over K - unaligned case + short tgp_bm = min(BM, params->M - c_row); + short tgp_bn = min(BN, params->N - c_col); + short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK; + + if (tgp_bm == BM && tgp_bn == BN) { + gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk); + + mma_op.store_result(D, params->ldd); + return; + + } else if (tgp_bn == BN) { + gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk); + + mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); + return; + + } else if (tgp_bm == BM) { + gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk); + + mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); + return; + + } else { + gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk); + + mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); + return; + } + } + } +}; + +} // namespace steel +} // namespace mlx \ No newline at end of file diff --git a/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h new file mode 100644 index 00000000..26645cfb --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h @@ -0,0 +1,157 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/metal/kernels/steel/gemm/nax.h" +#include "mlx/backend/metal/kernels/steel/gemm/params.h" +#include "mlx/backend/metal/kernels/steel/gemm/transforms.h" +#include "mlx/backend/metal/kernels/steel/utils.h" + +using namespace metal; + +namespace mlx::steel { + +template < + typename T, + short SM, + short SN, + short SK, + short BK, + bool transpose_a, + bool transpose_b, + bool kAlignedM, + bool kAlignedN, + bool kAlignedK, + short UM, + short UN, + short UK, + typename AccumType = float> +auto gemm_loop( + const device T* A, + const device T* B, + int lda, + int ldb, + int K, + int gemm_k_iterations_aligned, + const short sgp_sm, + const short sgp_sn) { + constexpr short TM = SM / UM; + constexpr short TN = SN / UN; + constexpr short TK = SK / UK; + + constexpr int RA = transpose_a ? TK : TM; + constexpr int CA = transpose_a ? TM : TK; + + constexpr int RB = transpose_b ? TN : TK; + constexpr int CB = transpose_b ? TK : TN; + + using DSubTile = NAXSubTile; + using ASubTile = + NAXSubTile; + using BSubTile = + NAXSubTile; + + NAXTile Dtile; + Dtile.clear(); + + int gemm_k_iterations_ = gemm_k_iterations_aligned; + + STEEL_PRAGMA_NO_UNROLL + for (int kk0 = 0; kk0 < gemm_k_iterations_; kk0++) { + threadgroup_barrier(mem_flags::mem_none); + + STEEL_PRAGMA_NO_UNROLL + for (int kk1 = 0; kk1 < BK; kk1 += SK) { + NAXTile Atile; + NAXTile Btile; + const int k = kk1; + + volatile int compiler_barrier; + + const int A_offset = transpose_a ? k * lda : k; + const int B_offset = transpose_b ? k : k * ldb; + + if constexpr (kAlignedM) { + Atile.load(A + A_offset, lda); + } else { + const short rmax = transpose_a ? SK : sgp_sm; + const short cmax = transpose_a ? sgp_sm : SK; + Atile.load_safe(A + A_offset, lda, short2(cmax, rmax)); + } + + if constexpr (kAlignedN) { + Btile.load(B + B_offset, ldb); + } else { + const short rmax = transpose_b ? sgp_sn : SK; + const short cmax = transpose_b ? SK : sgp_sn; + Btile.load_safe(B + B_offset, ldb, short2(cmax, rmax)); + } + + tile_matmad_nax( + Dtile, + Atile, + metal::bool_constant{}, + Btile, + metal::bool_constant{}); + + (void)compiler_barrier; + } + + A += transpose_a ? (BK * lda) : BK; + B += transpose_b ? BK : (BK * ldb); + } + + if constexpr (!kAlignedK) { + simdgroup_barrier(mem_flags::mem_none); + + const short rem_bk = K - gemm_k_iterations_ * BK; + + STEEL_PRAGMA_NO_UNROLL + for (int kk1 = 0; kk1 < rem_bk; kk1 += SK) { + NAXTile Atile; + NAXTile Btile; + + STEEL_PRAGMA_UNROLL + for (int mm = 0; mm < TM; mm++) { + STEEL_PRAGMA_UNROLL + for (int nn = 0; nn < TN; nn++) { + STEEL_PRAGMA_UNROLL + for (int kk = 0; kk < TK; kk++) { + const int m = mm * UM; + const int n = nn * UN; + const int k = kk1 + kk * UK; + const short psk = max(0, rem_bk - k); + + const int A_offset = transpose_a ? (m + k * lda) : (m * lda + k); + const int B_offset = transpose_b ? (k + n * ldb) : (k * ldb + n); + + { + const short psm = kAlignedM ? SM : max(0, sgp_sm - m); + const short rmax = transpose_a ? psk : psm; + const short cmax = transpose_a ? psm : psk; + Atile.load_safe(A + A_offset, lda, short2(cmax, rmax)); + } + + { + const short psn = kAlignedN ? SN : max(0, sgp_sn - n); + const short rmax = transpose_b ? psn : psk; + const short cmax = transpose_b ? psk : psn; + Btile.load_safe(B + B_offset, ldb, short2(cmax, rmax)); + } + + subtile_matmad_nax( + Dtile.subtile_at(mm, nn), + Atile.subtile_at(0, 0), + metal::bool_constant{}, + Btile.subtile_at(0, 0), + metal::bool_constant{}); + } + } + } + } + } + + return Dtile; +} + +} // namespace mlx::steel diff --git a/Source/Cmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_fused.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_fused.h rename to Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h diff --git a/Source/Cmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_fused_nax.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_fused_nax.h rename to Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h diff --git a/Source/Cmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_gather.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_gather.h rename to Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h diff --git a/Source/Cmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_gather_nax.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_gather_nax.h rename to Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h diff --git a/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h new file mode 100644 index 00000000..cc3ddd93 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h @@ -0,0 +1,719 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/backend/metal/kernels/steel/defines.h" +using namespace metal; +using namespace mlx::steel; + +/////////////////////////////////////////////////////////////////////////////// +// GEMM kernels +/////////////////////////////////////////////////////////////////////////////// + +struct _NoMask { + char x; + + constexpr METAL_FUNC operator bool() { + return true; + } + constexpr METAL_FUNC operator bool() const threadgroup { + return true; + } + constexpr METAL_FUNC operator bool() const device { + return true; + } + constexpr METAL_FUNC operator bool() const constant { + return true; + } +}; + +template +struct ScaleOp { + OutT scale; + + METAL_FUNC OutT apply(InT x) const { + return static_cast(x) * scale; + } +}; + +typedef struct _NoMask nomask_t; + +template < + typename T, + typename out_mask_t, + typename op_mask_t, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + bool MN_aligned, + bool K_aligned> +[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void +block_masked_gemm( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + device T* D [[buffer(3)]], + const constant GEMMParams* params [[buffer(4)]], + const constant int* batch_shape [[buffer(6)]], + const constant int64_t* batch_strides [[buffer(7)]], + const device out_mask_t* out_mask [[buffer(10)]], + const device op_mask_t* lhs_mask [[buffer(11)]], + const device op_mask_t* rhs_mask [[buffer(12)]], + const constant int* mask_strides [[buffer(13)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + // Appease the compiler + (void)lid; + + static_assert( + BM == BN, + "block_masked_gemm must have the same block M and block N size"); + static_assert(BM % BK == 0, "block_masked_gemm must have BM % BK == 0"); + + constexpr bool has_operand_mask = !metal::is_same_v; + constexpr bool has_output_mask = !metal::is_same_v; + + constexpr bool has_mul_operand_mask = + has_operand_mask && !metal::is_same_v; + constexpr bool has_mul_output_mask = + has_output_mask && !metal::is_same_v; + + constexpr short k_mask_factor = short(BM / BK); + + using gemm_kernel = GEMMKernel< + T, + T, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + MN_aligned, + K_aligned>; + + const int tid_y = ((tid.y) << params->swizzle_log) + + ((tid.x) & ((1 << params->swizzle_log) - 1)); + const int tid_x = (tid.x) >> params->swizzle_log; + + if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { + return; + } + + const constant auto* mask_batch_strides = + batch_strides + 2 * params->batch_ndim; + + if (params->batch_ndim > 1) { + if (has_output_mask) { + out_mask += elem_to_loc( + tid.z, batch_shape, mask_batch_strides, params->batch_ndim); + + mask_batch_strides += params->batch_ndim; + } + + if (has_operand_mask) { + const constant auto* mask_strides_lhs = mask_batch_strides; + const constant auto* mask_strides_rhs = + mask_strides_lhs + params->batch_ndim; + + ulong2 batch_offsets = elem_to_loc_broadcast( + tid.z, + batch_shape, + mask_strides_lhs, + mask_strides_rhs, + params->batch_ndim); + + lhs_mask += batch_offsets.x; + rhs_mask += batch_offsets.y; + } + } else { + if (has_output_mask) { + out_mask += tid.z * mask_batch_strides[0]; + mask_batch_strides += params->batch_ndim; + } + + if (has_operand_mask) { + lhs_mask += tid.z * mask_batch_strides[0]; + rhs_mask += tid.z * mask_batch_strides[params->batch_ndim]; + } + } + + // Adjust for batch + if (params->batch_ndim > 1) { + const constant auto* A_bstrides = batch_strides; + const constant auto* B_bstrides = batch_strides + params->batch_ndim; + + ulong2 batch_offsets = elem_to_loc_broadcast( + tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim); + + A += batch_offsets.x; + B += batch_offsets.y; + + } else { + A += params->batch_stride_a * tid.z; + B += params->batch_stride_b * tid.z; + } + + D += params->batch_stride_d * tid.z; + + // Find block in A, B, C + const int c_row = tid_y * BM; + const int c_col = tid_x * BN; + const size_t c_row_long = size_t(c_row); + const size_t c_col_long = size_t(c_col); + + A += transpose_a ? c_row_long : c_row_long * params->lda; + B += transpose_b ? c_col_long * params->ldb : c_col_long; + D += c_row_long * params->ldd + c_col_long; + + const constant int* out_mask_strides = mask_strides; + const constant int* lhs_mask_strides = + mask_strides + (has_output_mask ? 2 : 0); + const constant int* rhs_mask_strides = + lhs_mask_strides + (has_operand_mask ? 2 : 0); + + const int out_mask_offset = !has_output_mask + ? 0 + : tid_y * out_mask_strides[1] + tid_x * out_mask_strides[0]; + int lhs_mask_offset = !has_operand_mask ? 0 : tid_y * lhs_mask_strides[1]; + int rhs_mask_offset = !has_operand_mask ? 0 : tid_x * rhs_mask_strides[0]; + const int lhs_mask_step = !has_operand_mask ? 0 : lhs_mask_strides[0]; + const int rhs_mask_step = !has_operand_mask ? 0 : rhs_mask_strides[1]; + short k_factor_cnt = k_mask_factor; + + ScaleOp out_mask_op; + ScaleOp lhs_mask_op; + ScaleOp rhs_mask_op; + + if (has_output_mask) { + auto mask_out = out_mask[out_mask_offset]; + + if (has_mul_output_mask) { + out_mask_op.scale = float(mask_out); + } + + // Write zeros and return + if (!mask_out) { + constexpr short tgp_size = WM * WN * 32; + constexpr short vec_size = 4; + + // Tile threads in threadgroup + constexpr short TN = BN / vec_size; + constexpr short TM = tgp_size / TN; + + const short thread_idx = simd_group_id * 32 + simd_lane_id; + const short bi = thread_idx / TN; + const short bj = vec_size * (thread_idx % TN); + + D += bi * params->ldd + bj; + + short tgp_bm = min(BM, params->M - c_row); + short tgp_bn = min(BN, params->N - c_col); + + if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) { + for (short ti = 0; ti < BM; ti += TM) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + D[ti * params->ldd + j] = T(0.); + } + } + } else { + short jmax = tgp_bn - bj; + jmax = jmax < vec_size ? jmax : vec_size; + for (short ti = 0; (bi + ti) < tgp_bm; ti += TM) { + for (short j = 0; j < jmax; j++) { + D[ti * params->ldd + j] = T(0.); + } + } + } + + return; + } + } + + threadgroup_barrier(mem_flags::mem_none); + + // Prepare threadgroup mma operation + thread typename gemm_kernel::mma_t mma_op(simd_group_id, simd_lane_id); + + threadgroup T As[gemm_kernel::tgp_mem_size_a]; + threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; + + // Prepare threadgroup loading operations + thread typename gemm_kernel::loader_a_t loader_a( + A, params->lda, As, simd_group_id, simd_lane_id); + thread typename gemm_kernel::loader_b_t loader_b( + B, params->ldb, Bs, simd_group_id, simd_lane_id); + + // Prepare threadgroup bounds + const short tgp_bm = + MN_aligned ? short(BM) : short(min(BM, params->M - c_row)); + const short tgp_bn = + MN_aligned ? short(BN) : short(min(BN, params->N - c_col)); + + int gemm_k_iterations = params->gemm_k_iterations_aligned; + + /////////////////////////////////////////////////////////////////////////////// + // Do unaligned K iterations first + if (!K_aligned) { + const int k_last = params->gemm_k_iterations_aligned * BK; + const int mask_idx_last = k_last / BM; + + if (!has_operand_mask || + (bool(lhs_mask[lhs_mask_offset + mask_idx_last * lhs_mask_step]) && + bool(rhs_mask[rhs_mask_offset + mask_idx_last * rhs_mask_step]))) { + if (has_mul_operand_mask) { + lhs_mask_op.scale = + lhs_mask[lhs_mask_offset + mask_idx_last * lhs_mask_step]; + rhs_mask_op.scale = + rhs_mask[rhs_mask_offset + mask_idx_last * rhs_mask_step]; + } + + // Move loader source ahead to end + const int k_remain = params->K - k_last; + const size_t k_jump_a = + transpose_a ? params->lda * size_t(k_last) : size_t(k_last); + const size_t k_jump_b = + transpose_b ? size_t(k_last) : params->ldb * size_t(k_last); + + loader_a.src += k_jump_a; + loader_b.src += k_jump_b; + + // Load tile + const short2 tile_dims_A = + transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); + const short2 tile_dims_B = + transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + + if (has_mul_operand_mask) { + loader_a.apply_inplace_op(lhs_mask_op); + loader_b.apply_inplace_op(rhs_mask_op); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Do matmul + mma_op.mma(As, Bs); + + // Reset source back to start + loader_a.src -= k_jump_a; + loader_b.src -= k_jump_b; + } + } + + /////////////////////////////////////////////////////////////////////////////// + // MNK aligned loop + if (MN_aligned) { + for (; gemm_k_iterations > 0; gemm_k_iterations--) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (!has_operand_mask || + (bool(lhs_mask[lhs_mask_offset]) && + bool(rhs_mask[rhs_mask_offset]))) { + if (has_mul_operand_mask) { + lhs_mask_op.scale = lhs_mask[lhs_mask_offset]; + rhs_mask_op.scale = rhs_mask[rhs_mask_offset]; + } + + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + if (has_mul_operand_mask) { + loader_a.apply_inplace_op(lhs_mask_op); + loader_b.apply_inplace_op(rhs_mask_op); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + } + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + + k_factor_cnt--; + lhs_mask_offset += k_factor_cnt == 0 ? lhs_mask_step : 0; + rhs_mask_offset += k_factor_cnt == 0 ? rhs_mask_step : 0; + k_factor_cnt = k_factor_cnt == 0 ? k_mask_factor : k_factor_cnt; + } + + if (has_mul_output_mask) { + mma_op.apply_epilogue(out_mask_op); + } + + // Store results to device memory + mma_op.store_result(D, params->ldd); + return; + + } + /////////////////////////////////////////////////////////////////////////////// + // MN unaligned loop + else { + const bool M_aligned = (tgp_bm == BM); + const bool N_aligned = (tgp_bn == BN); + + const short2 tile_dims_A = + transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm); + const short2 tile_dims_B = + transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK); + + for (; gemm_k_iterations > 0; gemm_k_iterations--) { + threadgroup_barrier(mem_flags::mem_threadgroup); + if (!has_operand_mask || + (bool(lhs_mask[lhs_mask_offset]) && + bool(rhs_mask[rhs_mask_offset]))) { + if (has_mul_operand_mask) { + lhs_mask_op.scale = lhs_mask[lhs_mask_offset]; + rhs_mask_op.scale = rhs_mask[rhs_mask_offset]; + } + + // Load elements into threadgroup + if (M_aligned) { + loader_a.load_unsafe(); + } else { + loader_a.load_safe(tile_dims_A); + } + + if (N_aligned) { + loader_b.load_unsafe(); + } else { + loader_b.load_safe(tile_dims_B); + } + + if (has_mul_operand_mask) { + loader_a.apply_inplace_op(lhs_mask_op); + loader_b.apply_inplace_op(rhs_mask_op); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + } + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + + k_factor_cnt--; + lhs_mask_offset += k_factor_cnt == 0 ? lhs_mask_step : 0; + rhs_mask_offset += k_factor_cnt == 0 ? rhs_mask_step : 0; + k_factor_cnt = k_factor_cnt == 0 ? k_mask_factor : k_factor_cnt; + } + + if (has_mul_output_mask) { + mma_op.apply_epilogue(out_mask_op); + } + + if (M_aligned && N_aligned) { + mma_op.store_result(D, params->ldd); + } else { + mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); + } + } +} + +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + bool MN_aligned, + bool K_aligned, + bool has_operand_mask = false> +[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void +block_masked_gemm( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + device T* D [[buffer(3)]], + const constant GEMMParams* params [[buffer(4)]], + const constant int* batch_shape [[buffer(6)]], + const constant int64_t* batch_strides [[buffer(7)]], + const device bool* out_mask [[buffer(10)]], + const device bool* lhs_mask [[buffer(11)]], + const device bool* rhs_mask [[buffer(12)]], + const constant int* mask_strides [[buffer(13)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + // Appease the compiler + (void)lid; + + using gemm_kernel = GEMMKernel< + T, + T, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + MN_aligned, + K_aligned>; + + const int tid_y = ((tid.y) << params->swizzle_log) + + ((tid.x) & ((1 << params->swizzle_log) - 1)); + const int tid_x = (tid.x) >> params->swizzle_log; + + if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { + return; + } + + if (params->batch_ndim > 1) { + const constant auto* mask_batch_strides = + batch_strides + 2 * params->batch_ndim; + out_mask += + elem_to_loc(tid.z, batch_shape, mask_batch_strides, params->batch_ndim); + + if (has_operand_mask) { + const constant auto* mask_strides_lhs = + mask_batch_strides + params->batch_ndim; + const constant auto* mask_strides_rhs = + mask_strides_lhs + params->batch_ndim; + + ulong2 batch_offsets = elem_to_loc_broadcast( + tid.z, + batch_shape, + mask_strides_lhs, + mask_strides_rhs, + params->batch_ndim); + + lhs_mask += batch_offsets.x; + rhs_mask += batch_offsets.y; + } + } else { + out_mask += tid.z * batch_strides[2 * params->batch_ndim]; + if (has_operand_mask) { + lhs_mask += tid.z * batch_strides[3 * params->batch_ndim]; + rhs_mask += tid.z * batch_strides[4 * params->batch_ndim]; + } + } + + // Adjust for batch + if (params->batch_ndim > 1) { + const constant auto* A_bstrides = batch_strides; + const constant auto* B_bstrides = batch_strides + params->batch_ndim; + + ulong2 batch_offsets = elem_to_loc_broadcast( + tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim); + + A += batch_offsets.x; + B += batch_offsets.y; + + } else { + A += params->batch_stride_a * tid.z; + B += params->batch_stride_b * tid.z; + } + + D += params->batch_stride_d * tid.z; + + // Find block in A, B, C + const int c_row = tid_y * BM; + const int c_col = tid_x * BN; + const size_t c_row_long = size_t(c_row); + const size_t c_col_long = size_t(c_col); + + A += transpose_a ? c_row_long : c_row_long * params->lda; + B += transpose_b ? c_col_long * params->ldb : c_col_long; + D += c_row_long * params->ldd + c_col_long; + + bool mask_out = out_mask[tid_y * mask_strides[1] + tid_x * mask_strides[0]]; + + // Write zeros and return + if (!mask_out) { + constexpr short tgp_size = WM * WN * 32; + constexpr short vec_size = 4; + + // Tile threads in threadgroup + constexpr short TN = BN / vec_size; + constexpr short TM = tgp_size / TN; + + const short thread_idx = simd_group_id * 32 + simd_lane_id; + const short bi = thread_idx / TN; + const short bj = vec_size * (thread_idx % TN); + + D += bi * params->ldd + bj; + + short tgp_bm = min(BM, params->M - c_row); + short tgp_bn = min(BN, params->N - c_col); + + if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) { + for (short ti = 0; ti < BM; ti += TM) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + D[ti * params->ldd + j] = T(0.); + } + } + } else { + short jmax = tgp_bn - bj; + jmax = jmax < vec_size ? jmax : vec_size; + for (short ti = 0; (bi + ti) < tgp_bm; ti += TM) { + for (short j = 0; j < jmax; j++) { + D[ti * params->ldd + j] = T(0.); + } + } + } + + return; + } + + threadgroup_barrier(mem_flags::mem_none); + + // Prepare threadgroup mma operation + thread typename gemm_kernel::mma_t mma_op(simd_group_id, simd_lane_id); + + int gemm_k_iterations = params->gemm_k_iterations_aligned; + + threadgroup T As[gemm_kernel::tgp_mem_size_a]; + threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; + + // Prepare threadgroup loading operations + thread typename gemm_kernel::loader_a_t loader_a( + A, params->lda, As, simd_group_id, simd_lane_id); + thread typename gemm_kernel::loader_b_t loader_b( + B, params->ldb, Bs, simd_group_id, simd_lane_id); + + /////////////////////////////////////////////////////////////////////////////// + // MNK aligned loop + if (MN_aligned) { + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (!has_operand_mask || + (lhs_mask + [tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] && + rhs_mask + [((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) { + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + } + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + threadgroup_barrier(mem_flags::mem_none); + + // Loop tail + if (!K_aligned) { + if (!has_operand_mask || + (lhs_mask + [tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] && + rhs_mask + [(params->K / BM) * mask_strides[5] + + tid_x * mask_strides[4]])) { + int lbk = params->K - params->gemm_k_iterations_aligned * BK; + short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM); + short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk); + + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + mma_op.mma(As, Bs); + } + } + + // Store results to device memory + mma_op.store_result(D, params->ldd); + return; + + } + /////////////////////////////////////////////////////////////////////////////// + // MN unaligned loop + else { // Loop over K - unaligned case + short tgp_bm = min(BM, params->M - c_row); + short tgp_bn = min(BN, params->N - c_col); + short lbk = params->K - params->gemm_k_iterations_aligned * BK; + + bool M_aligned = (tgp_bm == BM); + bool N_aligned = (tgp_bn == BN); + + short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm); + short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK); + + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + if (!has_operand_mask || + (lhs_mask + [tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] && + rhs_mask + [((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) { + // Load elements into threadgroup + if (M_aligned) { + loader_a.load_unsafe(); + } else { + loader_a.load_safe(tile_dims_A); + } + + if (N_aligned) { + loader_b.load_unsafe(); + } else { + loader_b.load_safe(tile_dims_B); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + } + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + if (!K_aligned) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (!has_operand_mask || + (lhs_mask + [tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] && + rhs_mask + [(params->K / BM) * mask_strides[5] + + tid_x * mask_strides[4]])) { + short2 tile_dims_A_last = + transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm); + short2 tile_dims_B_last = + transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk); + + loader_a.load_safe(tile_dims_A_last); + loader_b.load_safe(tile_dims_B_last); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + mma_op.mma(As, Bs); + } + } + + if (M_aligned && N_aligned) { + mma_op.store_result(D, params->ldd); + } else { + mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); + } + } +} diff --git a/Source/Cmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_segmented.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_segmented.h rename to Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h diff --git a/Source/Cmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_splitk.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_splitk.h rename to Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h diff --git a/Source/Cmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_splitk_nax.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_splitk_nax.h rename to Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h diff --git a/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h new file mode 100644 index 00000000..d421b2d1 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h @@ -0,0 +1,137 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/backend/metal/kernels/steel/defines.h" + +/////////////////////////////////////////////////////////////////////////////// +// Loading helper +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +template < + typename T, + short BROWS, + short BCOLS, + short dst_ld, + short reduction_dim, + short tgp_size, + short alignment = 1, + short n_reads = (BCOLS * BROWS) / (tgp_size), + short TCOLS = BCOLS / n_reads, + short TROWS = tgp_size / TCOLS> +struct BlockLoader { + STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS; + STEEL_CONST short vec_size = n_reads; + + // Leading dimension for src + const int src_ld; + const int tile_stride; + + // Thread location indices + const short thread_idx; + const short bi; + const short bj; + + // threadgroup and device memory + threadgroup T* dst; + const device T* src; + + struct alignas(alignment * sizeof(T)) ReadVector { + uint8_t v[sizeof(T) * vec_size]; + }; + + /* Constructor */ + METAL_FUNC BlockLoader( + const device T* src_, + const int src_ld_, + threadgroup T* dst_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(src_ld_), + tile_stride(reduction_dim ? BCOLS : BROWS * src_ld), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(thread_idx / TCOLS), + bj(vec_size * (thread_idx % TCOLS)), + dst(dst_ + bi * dst_ld + bj), + src(src_ + bi * src_ld + bj) {} + + /* Apply operation to threadgroup without bound checking */ + template + METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = op.apply(dst[i * dst_ld + j]); + } + } + } + + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void load_unsafe() const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + *((threadgroup ReadVector*)(&dst[i * dst_ld])) = + *((const device ReadVector*)(&src[i * src_ld])); + } + } + + /* Load from device memory into threadgroup memory - with bound checking */ + METAL_FUNC void load_safe(short2 src_tile_dim) const { + src_tile_dim = src_tile_dim - short2(bj, bi); + + // Skip loading if thread has no valid reads + if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = T(0); + } + } + return; + } + + // Use fast thread memory for bound checks + bool tmp_idx[vec_size]; + T tmp_val[vec_size]; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + // Make sure tmp_idx only contains valid indices + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x); + } + + // Read valid indices into tmp_val + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)]; + } + + // Zero out unneeded values + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0); + } + + // Copy values to threadgroup memory + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = tmp_val[j]; + } + } + } + + /* Iteration helper */ + METAL_FUNC void next() { + src += tile_stride; + } +}; + +} // namespace steel +} // namespace mlx diff --git a/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h new file mode 100644 index 00000000..11319940 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h @@ -0,0 +1,1146 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include +#include +#include + +#include "mlx/backend/metal/kernels/steel/defines.h" +#include "mlx/backend/metal/kernels/steel/gemm/transforms.h" +#include "mlx/backend/metal/kernels/steel/utils/integral_constant.h" + +using namespace metal; + +/////////////////////////////////////////////////////////////////////////////// +// MMA helper +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +template +struct BaseMMAFrag { + static_assert( + kFragRows_ == 8, + "Only 8 x 8 fragment matrices are currently supported"); + static_assert( + kFragCols_ == 8, + "Only 8 x 8 fragment matrices are currently supported"); +}; + +template +struct BaseMMAFrag { + STEEL_CONST int kFragRows = 8; + STEEL_CONST int kFragCols = 8; + + STEEL_CONST int kElemsPerFrag = (kFragRows * kFragCols) / 32; + + STEEL_CONST int kElemRows = 1; + STEEL_CONST int kElemCols = 2; + + static_assert( + kElemRows * kElemCols == kElemsPerFrag, + "MMAFrag shape is not consistent with MMAFrag size"); + + typedef metal::simdgroup_matrix mat_type; + typedef metal::vec frag_type; + + METAL_FUNC static constexpr short2 get_coord( + ushort simd_lane_id [[thread_index_in_simdgroup]]) { + const short qid = simd_lane_id / 4; + const short fm = (qid & 4) + ((simd_lane_id / 2) % 4); + const short fn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; + return short2{fn, fm}; + } + + template + METAL_FUNC static constexpr void + load(thread frag_type& dst, SrcPtrType src, StrX str_x, StrY str_y) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[i * kElemCols + j] = static_cast(src[i * str_x + j * str_y]); + } + } + } + + template < + typename SrcPtrType, + typename StrX, + typename StrY, + typename LimX, + typename LimY, + typename OffX, + typename OffY> + METAL_FUNC static constexpr void load_safe( + thread frag_type& dst, + SrcPtrType src, + StrX str_x, + StrY str_y, + LimX lim_x, + LimY lim_y, + OffX off_x = Int<0>{}, + OffY off_y = Int<0>{}) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + if ((off_x + i) < lim_x && (off_y + j) < lim_y) { + dst[i * kElemCols + j] = + static_cast(src[(off_x + i) * str_x + (off_x + j) * str_y]); + } else { + dst[i * kElemCols + j] = T(0); + } + } + } + } + + template + METAL_FUNC static constexpr void + store(const thread frag_type& src, DstPtrType dst, StrX str_x, StrY str_y) { + using U = pointer_element_t; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[i * str_x + j * str_y] = static_cast(src[i * kElemCols + j]); + } + } + } + + template < + typename DstPtrType, + typename StrX, + typename StrY, + typename LimX, + typename LimY, + typename OffX, + typename OffY> + METAL_FUNC static constexpr void store_safe( + const thread frag_type& src, + DstPtrType dst, + StrX str_x, + StrY str_y, + LimX lim_x, + LimY lim_y, + OffX off_x = Int<0>{}, + OffY off_y = Int<0>{}) { + using U = pointer_element_t; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + if ((off_x + i) < lim_x && (off_y + j) < lim_y) { + dst[(off_x + i) * str_x + (off_y + j) * str_y] = + static_cast(src[i * kElemCols + j]); + } + } + } + } + + template < + typename DstPtrType, + typename StrX, + typename StrY, + typename StartX, + typename StopX, + typename StartY, + typename StopY, + typename OffX, + typename OffY> + METAL_FUNC static constexpr void store_slice( + const thread frag_type& src, + DstPtrType dst, + StrX str_x, + StrY str_y, + StartX start_x, + StopX stop_x, + StartY start_y, + StopY stop_y, + OffX off_x = Int<0>{}, + OffY off_y = Int<0>{}) { + using U = pointer_element_t; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + if ((off_x + i) < stop_x && (off_x + i) >= start_x && + (off_y + j) < stop_y && (off_y + j) >= start_y) { + dst[(off_x + i) * str_x + (off_y + j) * str_y] = + static_cast(src[i * kElemCols + j]); + } + } + } + } + + METAL_FUNC static constexpr void mma( + thread frag_type& D, + thread frag_type& A, + thread frag_type& B, + thread frag_type& C) { + mat_type D_mat; + mat_type A_mat; + mat_type B_mat; + mat_type C_mat; + + reinterpret_cast(A_mat.thread_elements()) = A; + reinterpret_cast(B_mat.thread_elements()) = B; + reinterpret_cast(C_mat.thread_elements()) = C; + + mma(D_mat, A_mat, B_mat, C_mat); + + D = reinterpret_cast(D_mat.thread_elements()); + } + + METAL_FUNC static constexpr void mma( + thread mat_type& D, + thread mat_type& A, + thread mat_type& B, + thread mat_type& C) { + simdgroup_multiply_accumulate(D, A, B, C); + } +}; + +template < + typename T, + int kTileRows_, + int kTileCols_, + class MMAFrag_ = BaseMMAFrag> +struct MMATile { + using MMAFrag_t = MMAFrag_; + using elem_type = T; + STEEL_CONST int kFragRows = MMAFrag_t::kFragRows; + STEEL_CONST int kFragCols = MMAFrag_t::kFragCols; + STEEL_CONST int kElemsPerFrag = MMAFrag_t::kElemsPerFrag; + + STEEL_CONST int kTileRows = kTileRows_; + STEEL_CONST int kTileCols = kTileCols_; + + STEEL_CONST int kRows = kTileRows * kFragRows; + STEEL_CONST int kCols = kTileCols * kFragCols; + + STEEL_CONST int kNumFrags = kTileRows * kTileCols; + STEEL_CONST int kElemsPerTile = kNumFrags * kElemsPerFrag; + + typedef typename MMAFrag_t::mat_type mat_type; + typedef typename MMAFrag_t::frag_type frag_type; + + frag_type val_frags[kNumFrags] = {frag_type(0)}; + + METAL_FUNC MMATile() thread {} + + METAL_FUNC constexpr void clear() { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kNumFrags; ++i) { + val_frags[i] = frag_type(0); + } + } + + METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) { + return val_frags[i * kTileCols + j]; + } + + METAL_FUNC constexpr const thread frag_type& frag_at( + const short i, + const short j) const { + return val_frags[i * kTileCols + j]; + } + + METAL_FUNC mat_type mat_at(const short i, const short j) { + mat_type val_mat; + STEEL_PRAGMA_UNROLL + for (short ii = 0; ii < kElemsPerFrag; ++ii) { + val_mat.thread_elements()[ii] = frag_at(i, j)[ii]; + } + return val_mat; + } + + METAL_FUNC thread elem_type* elems() { + return reinterpret_cast(val_frags); + } + + METAL_FUNC const thread elem_type* elems() const { + return reinterpret_cast(val_frags); + } + + template + METAL_FUNC void load(const threadgroup U* src) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::load( + frag_at(i, j), + &( + src[(i * kFragRows) * w_x * str_x + + (j * kFragCols) * w_y * str_y]), + Int{}, + Int{}); + } + } + } + + template + METAL_FUNC void store(threadgroup U* dst) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::store( + frag_at(i, j), + &( + dst[(i * kFragRows) * w_x * str_x + + (j * kFragCols) * w_y * str_y]), + Int{}, + Int{}); + } + } + } + + template + METAL_FUNC void load(const device U* src, const int ld) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::load( + frag_at(i, j), + &(src[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]), + ld, + Int<1>{}); + } + } + } + + template + METAL_FUNC void store(device U* dst, const int ld) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::store( + frag_at(i, j), + &(dst[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]), + ld, + Int<1>{}); + } + } + } + + template + METAL_FUNC void + load_safe(const device U* src, const int ld, const short2 src_tile_dims) { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kTileCols; ++j) { + MMAFrag_t::load_safe( + frag_at(i, j), + src, + ld, + Int<1>{}, + src_tile_dims.y, + src_tile_dims.x, + (i * kFragRows) * w_x, + (j * kFragCols) * w_y); + } + } + } + + template + METAL_FUNC void + store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kTileCols; ++j) { + MMAFrag_t::store_safe( + frag_at(i, j), + dst, + ld, + Int<1>{}, + dst_tile_dims.y, + dst_tile_dims.x, + (i * kFragRows) * w_x, + (j * kFragCols) * w_y); + } + } + } + + template + METAL_FUNC void store_slice( + device U* dst, + const int ld, + const short2 start, + const short2 stop) const { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kTileCols; ++j) { + MMAFrag_t::store_slice( + frag_at(i, j), + dst, + ld, + Int<1>{}, + start.y, + stop.y, + start.x, + stop.x, + (i * kFragRows) * w_x, + (j * kFragCols) * w_y); + } + } + } +}; + +template +METAL_FUNC void tile_matmad( + thread MMATile& D, + thread MMATile& A, + thread MMATile& B, + thread MMATile& C) { + STEEL_PRAGMA_UNROLL + for (short m = 0; m < M; ++m) { + STEEL_PRAGMA_UNROLL + for (short n = 0; n < N; ++n) { + short n_serp = (m % 2) ? (N - 1 - n) : n; + STEEL_PRAGMA_UNROLL + for (short k = 0; k < K; ++k) { + MMATile::MMAFrag_t::mma( + D.frag_at(m, n_serp), + A.frag_at(m, k), + B.frag_at(k, n_serp), + C.frag_at(m, n_serp)); + } + } + } +} + +template +struct TransformNone { + static METAL_FUNC complex64_t apply(complex64_t x) { + return x; + } + static METAL_FUNC complex64_t apply(complex64_t x, complex64_t) { + return x; + } +}; + +template < + typename T, + typename U, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + short lda_tgp, + short ldb_tgp, + typename AccumType = float, + typename Epilogue = TransformNone> +struct BlockMMA { + // MMAFrag size + STEEL_CONST short kFragSize = 8; + using MMAFrag_acc_t = BaseMMAFrag; + + // Warp tile simdgroup matrix strides along M + STEEL_CONST short TM_stride = kFragSize * WM; + // Warp tile simdgroup matrix strides along M + STEEL_CONST short TN_stride = kFragSize * WN; + + // Warp tile size along M + STEEL_CONST short TM = BM / (kFragSize * WM); + // Warp tile size along N + STEEL_CONST short TN = BN / (kFragSize * WN); + + // Threadgroup A strides + STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M + STEEL_CONST short A_str_k = transpose_a ? lda_tgp : 1; // K + + // Threadgroup B strides + STEEL_CONST short B_str_k = transpose_b ? 1 : ldb_tgp; // K + STEEL_CONST short B_str_n = transpose_b ? ldb_tgp : 1; // N + + // Threadgroup strides along K + STEEL_CONST short tile_stride_a = kFragSize * A_str_k; + STEEL_CONST short tile_stride_b = kFragSize * B_str_k; + + // Simdgroup matrices + MMATile Atile; + MMATile Btile; + MMATile Ctile; + + // Offsets within threadgroup + short sm; + short sn; + + short As_offset; + short Bs_offset; + + /* Constructor */ + METAL_FUNC BlockMMA( + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) { + // Determine thread position in simdgroup matrix + short tm = kFragSize * (simd_group_id / WN); + short tn = kFragSize * (simd_group_id % WN); + + short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id); + sm = simd_coord.y; + sn = simd_coord.x; + + // Determine thread and simdgroup offset + As_offset = (tm + sm) * A_str_m + (sn)*A_str_k; // M, K + Bs_offset = (sm)*B_str_k + (tn + sn) * B_str_n; // K, N + + sm += tm; + sn += tn; + } + + /* (BM, BK) X (BK, BN) multiply accumulate function */ + METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) { + // Adjust for simdgroup and thread location + As += As_offset; + Bs += Bs_offset; + + // Iterate over BK in blocks of kFragSize + STEEL_PRAGMA_UNROLL + for (short kk = 0; kk < BK; kk += kFragSize) { + simdgroup_barrier(mem_flags::mem_none); + + Atile.template load(As); + + simdgroup_barrier(mem_flags::mem_none); + + Btile.template load(Bs); + + simdgroup_barrier(mem_flags::mem_none); + + tile_matmad(Ctile, Atile, Btile, Ctile); + + // Progress to next simdgroup tile + As += tile_stride_a; + Bs += tile_stride_b; + } + } + + /* Store results from simdgroup_matrix results into device memory */ + METAL_FUNC void store_result(device U* D, const int ldd) { + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { + Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]); + } + + // Adjust for simdgroup and thread location + D += sm * ldd + sn; + + Ctile.template store(D, ldd); + } + + METAL_FUNC void + store_result_slice(device U* D, const int ldd, short2 start, short2 stop) { + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { + Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]); + } + + D += sm * ldd + sn; + start -= short2(sn, sm); + stop -= short2(sn, sm); + + // TODO: Check the start as well + if (stop.y <= 0 || stop.x <= 0) { + return; + } + + Ctile.template store_slice(D, ldd, start, stop); + } + + METAL_FUNC void + store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) { + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { + Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]); + } + + // Adjust for simdgroup and thread location + D += sm * ldd + sn; + dst_tile_dims -= short2(sn, sm); + + if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) + return; + + Ctile.template store_safe(D, ldd, dst_tile_dims); + } + + /* Apply epilogue */ + template + METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) { + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { + Ctile.elems()[i] = epilogue_op.apply(Ctile.elems()[i]); + } + } + + /* Apply epilogue */ + template + METAL_FUNC void apply_epilogue( + const device U* C, + const int ldc, + const int fdc, + thread const BinaryEpilogue& epilogue_op) { + // Adjust for simdgroup and thread location + C += (sm)*ldc + (sn)*fdc; + + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread auto& accum = Ctile.frag_at(i, j); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short k = 0; k < decltype(Ctile)::kElemsPerFrag; k++) { + accum[k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]); + } + } + } + } + + /* Apply epilogue */ + template + METAL_FUNC void apply_epilogue_safe( + const device U* C, + const int ldc, + const int fdc, + short2 dst_tile_dims, + thread const BinaryEpilogue& epilogue_op) { + // Adjust for simdgroup and thread location + C += (sm)*ldc + (sn)*fdc; + dst_tile_dims -= short2(sn, sm); + + if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) + return; + + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread auto& accum = Ctile.frag_at(i, j); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + + constexpr short kelems = decltype(Ctile)::kElemsPerFrag; + + // Read C + U c_elems[kelems] = {0}; + + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + if ((j * TN_stride + k) < dst_tile_dims.x) { + c_elems[k] = C[offset_c + k * fdc]; + } + } + + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + accum[k] = epilogue_op.apply(accum[k], c_elems[k]); + } + } + } + } + + /* Store results from simdgroup_matrix results into device memory */ + METAL_FUNC void store_result( + device U* D, + const int ldd, + const device U* C, + const int ldc, + const int fdc, + thread const Epilogue& epilogue_op) const { + // Adjust for simdgroup and thread location + C += (sm)*ldc + (sn)*fdc; + D += (sm)*ldd + sn; + + constexpr short kelems = decltype(Ctile)::kElemsPerFrag; + + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread const auto& accum = Ctile.frag_at(i, j); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + int offset_d = (i * TM_stride) * ldd + (j * TN_stride); + + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + D[offset_d + k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]); + } + } + } + } + + METAL_FUNC void store_result_safe( + device U* D, + const int ldd, + const device U* C, + const int ldc, + const int fdc, + short2 dst_tile_dims, + thread const Epilogue& epilogue_op) const { + // Adjust for simdgroup and thread location + C += (sm)*ldc + (sn)*fdc; + D += (sm)*ldd + sn; + dst_tile_dims -= short2(sn, sm); + + if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) + return; + + constexpr short kelems = decltype(Ctile)::kElemsPerFrag; + + STEEL_PRAGMA_UNROLL + for (int i = 0; i < TM; i++) { + if (i * TM_stride < dst_tile_dims.y) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread const auto& accum = Ctile.frag_at(i, j); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + int offset_d = (i * TM_stride) * ldd + (j * TN_stride); + + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + if ((j * TN_stride + k) < dst_tile_dims.x) { + D[offset_d + k] = + epilogue_op.apply(accum[k], C[offset_c + k * fdc]); + } + } + } + } + } + } +}; + +template < + typename U, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + short lda_tgp, + short ldb_tgp, + typename AccumType, + typename Epilogue> +struct BlockMMA< + complex64_t, + U, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + lda_tgp, + ldb_tgp, + AccumType, + Epilogue> { + static_assert( + metal::is_same_v, + "BlockMMA expects float accumulators"); + static_assert( + metal::is_same_v, + "For complex BlockMMA, U must be complex64_t; use a different epilogue for projections"); + // MMAFrag size + STEEL_CONST short kFragSize = 8; + using MMAFrag_acc_t = BaseMMAFrag; + + // Warp tile simdgroup matrix strides along M + STEEL_CONST short TM_stride = kFragSize * WM; + // Warp tile simdgroup matrix strides along M + STEEL_CONST short TN_stride = kFragSize * WN; + + // Warp tile size along M + STEEL_CONST short TM = BM / (kFragSize * WM); + // Warp tile size along N + STEEL_CONST short TN = BN / (kFragSize * WN); + + // Threadgroup A strides + STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M + STEEL_CONST short A_str_k = transpose_a ? lda_tgp : 1; // K + + // Threadgroup B strides + STEEL_CONST short B_str_k = transpose_b ? 1 : ldb_tgp; // K + STEEL_CONST short B_str_n = transpose_b ? ldb_tgp : 1; // N + + // Threadgroup strides along K + STEEL_CONST short tile_stride_a = kFragSize * A_str_k; + STEEL_CONST short tile_stride_b = kFragSize * B_str_k; + + // When indexing complex as float[2] + STEEL_CONST short A_str_m_f = A_str_m * 2; + STEEL_CONST short A_str_k_f = A_str_k * 2; + STEEL_CONST short B_str_k_f = B_str_k * 2; + STEEL_CONST short B_str_n_f = B_str_n * 2; + STEEL_CONST short tile_stride_a_f = tile_stride_a * 2; + STEEL_CONST short tile_stride_b_f = tile_stride_b * 2; + + // Accumulators (real/imag) + MMATile Ctile_r; + MMATile Ctile_i; + + // Offsets within threadgroup + short sm, sn; + short As_offset, Bs_offset; + + /* Constructor */ + METAL_FUNC BlockMMA( + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) { + // Determine thread position in simdgroup matrix + short tm = kFragSize * (simd_group_id / WN); + short tn = kFragSize * (simd_group_id % WN); + + short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id); + sm = simd_coord.y; + sn = simd_coord.x; + + // Determine thread and simdgroup offset + As_offset = (tm + sm) * A_str_m + (sn)*A_str_k; // (M,K) + Bs_offset = (sm)*B_str_k + (tn + sn) * B_str_n; // (K,N) + + sm += tm; + sn += tn; + } + + /* Karatsuba MMA: 3 real MMAs per K-chunk */ + METAL_FUNC void mma( + const threadgroup complex64_t* As, + const threadgroup complex64_t* Bs) { + // Adjust for simdgroup and thread location + As += As_offset; + Bs += Bs_offset; + threadgroup const float* As_f = + reinterpret_cast(As); + threadgroup const float* Bs_f = + reinterpret_cast(Bs); + + // Iterate over BK in blocks of kFragSize + STEEL_PRAGMA_UNROLL + for (short kk = 0; kk < BK; kk += kFragSize) { + simdgroup_barrier(mem_flags::mem_none); + + MMATile Ar, Ai; + Ar.template load(As_f + 0); + Ai.template load(As_f + 1); + + simdgroup_barrier(mem_flags::mem_none); + + MMATile Br, Bi; + Br.template load(Bs_f + 0); + Bi.template load(Bs_f + 1); + + simdgroup_barrier(mem_flags::mem_none); + + // P = Ar*Br ; Q = Ai*Bi ; R = (Ar+Ai)*(Br+Bi) + MMATile P, Q, R; + + tile_matmad(P, Ar, Br, P); + tile_matmad(Q, Ai, Bi, Q); + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < decltype(Ar)::kElemsPerTile; ++i) + Ar.elems()[i] += Ai.elems()[i]; + STEEL_PRAGMA_UNROLL + for (short i = 0; i < decltype(Br)::kElemsPerTile; ++i) + Br.elems()[i] += Bi.elems()[i]; + + tile_matmad(R, Ar, Br, R); + + // C_r += P - Q ; C_i -= Q + STEEL_PRAGMA_UNROLL + for (short i = 0; i < decltype(Ctile_r)::kElemsPerTile; ++i) { + const auto p = P.elems()[i]; + const auto q = Q.elems()[i]; + const auto r = R.elems()[i]; + Ctile_r.elems()[i] += (p - q); + Ctile_i.elems()[i] += (r - p - q); + } + + // Progress to next simdgroup tile + As_f += tile_stride_a_f; + Bs_f += tile_stride_b_f; + } + } + + /* Store results from simdgroup_matrix results into device memory */ + METAL_FUNC void store_result(device U* D, const int ldd) { + // Adjust for simdgroup and thread location + D += sm * ldd + sn; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + thread const auto& r = Ctile_r.frag_at(i, j); + thread const auto& im = Ctile_i.frag_at(i, j); + int off = (i * TM_stride) * ldd + (j * TN_stride); + STEEL_PRAGMA_UNROLL + for (short k = 0; k < decltype(Ctile_r)::kElemsPerFrag; k++) { + D[off + k] = Epilogue::apply(complex64_t(r[k], im[k])); + } + } + } + } + + METAL_FUNC void + store_result_slice(device U* D, const int ldd, short2 start, short2 stop) { + D += sm * ldd + sn; + start -= short2(sn, sm); + stop -= short2(sn, sm); + + if (stop.y <= 0 || stop.x <= 0) + return; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; ++i) { + const int row = i * TM_stride; + if (row >= start.y && row < stop.y) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; ++j) { + const int off = row * ldd + (j * TN_stride); + thread const auto& r = Ctile_r.frag_at(i, j); + thread const auto& im = Ctile_i.frag_at(i, j); + + STEEL_PRAGMA_UNROLL + for (short k = 0; k < decltype(Ctile_r)::kElemsPerFrag; ++k) { + const int col = j * TN_stride + k; + if (col >= start.x && col < stop.x) { + D[off + k] = Epilogue::apply(complex64_t(r[k], im[k])); + } + } + } + } + } + } + + METAL_FUNC void + store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) { + D += sm * ldd + sn; + dst_tile_dims -= short2(sn, sm); + if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) + return; + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + if (i * TM_stride < dst_tile_dims.y) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + int off = (i * TM_stride) * ldd + (j * TN_stride); + thread const auto& r = Ctile_r.frag_at(i, j); + thread const auto& im = Ctile_i.frag_at(i, j); + STEEL_PRAGMA_UNROLL + for (short k = 0; k < decltype(Ctile_r)::kElemsPerFrag; k++) { + if ((j * TN_stride + k) < dst_tile_dims.x) { + D[off + k] = Epilogue::apply(complex64_t(r[k], im[k])); + } + } + } + } + } + } + + /* Apply epilogue */ + template + METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < decltype(Ctile_r)::kElemsPerTile; i++) { + complex64_t out = epilogue_op.apply( + complex64_t(Ctile_r.elems()[i], Ctile_i.elems()[i])); + Ctile_r.elems()[i] = out.real; + Ctile_i.elems()[i] = out.imag; + } + } + + /* Apply epilogue */ + template + METAL_FUNC void apply_epilogue( + const device U* C, + const int ldc, + const int fdc, + thread const BinaryEpilogue& epilogue_op) { + // Adjust for simdgroup and thread location + C += (sm)*ldc + (sn)*fdc; + + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in Cr, Ci + thread auto& r = Ctile_r.frag_at(i, j); + thread auto& im = Ctile_i.frag_at(i, j); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + + STEEL_PRAGMA_UNROLL + for (short k = 0; k < decltype(Ctile_r)::kElemsPerFrag; k++) { + complex64_t out = epilogue_op.apply( + complex64_t(r[k], im[k]), C[offset_c + k * fdc]); + r[k] = out.real; + im[k] = out.imag; + } + } + } + } + + /* Apply epilogue */ + template + METAL_FUNC void apply_epilogue_safe( + const device U* C, + const int ldc, + const int fdc, + short2 dst_tile_dims, + thread const BinaryEpilogue& epilogue_op) { + // Adjust for simdgroup and thread location + C += (sm)*ldc + (sn)*fdc; + dst_tile_dims -= short2(sn, sm); + + if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) + return; + + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in Cr, Ci + thread auto& r = Ctile_r.frag_at(i, j); + thread auto& im = Ctile_i.frag_at(i, j); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + + constexpr short kelems = decltype(Ctile_r)::kElemsPerFrag; + complex64_t tmp[kelems]; + + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + if ((j * TN_stride + k) < dst_tile_dims.x && + (i * TM_stride) < dst_tile_dims.y) { + tmp[k] = C[offset_c + k * fdc]; + } else { + tmp[k] = complex64_t(0.0f, 0.0f); + } + } + + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + complex64_t out = epilogue_op.apply(complex64_t(r[k], im[k]), tmp[k]); + r[k] = out.real; + im[k] = out.imag; + } + } + } + } + + /* Store results from simdgroup_matrix results into device memory */ + METAL_FUNC void store_result( + device U* D, + const int ldd, + const device U* C, + const int ldc, + const int fdc, + thread const Epilogue& epilogue_op) const { + // Adjust for simdgroup and thread location + C += (sm)*ldc + (sn)*fdc; + D += (sm)*ldd + sn; + + constexpr short kelems = decltype(Ctile_r)::kElemsPerFrag; + + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in Cr, Ci + thread const auto& r = Ctile_r.frag_at(i, j); + thread const auto& im = Ctile_i.frag_at(i, j); + int off_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + int off_d = (i * TM_stride) * ldd + (j * TN_stride); + + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + D[off_d + k] = + epilogue_op.apply(complex64_t(r[k], im[k]), C[off_c + k * fdc]); + } + } + } + } + + METAL_FUNC void store_result_safe( + device U* D, + const int ldd, + const device U* C, + const int ldc, + const int fdc, + short2 dst_tile_dims, + thread const Epilogue& epilogue_op) const { + // Adjust for simdgroup and thread location + C += (sm)*ldc + (sn)*fdc; + D += (sm)*ldd + sn; + dst_tile_dims -= short2(sn, sm); + + if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) + return; + + constexpr short kelems = decltype(Ctile_r)::kElemsPerFrag; + + STEEL_PRAGMA_UNROLL + for (int i = 0; i < TM; i++) { + if (i * TM_stride < dst_tile_dims.y) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < TN; j++) { + // Get accumulated result and associated offset in Cr, Ci + thread const auto& r = Ctile_r.frag_at(i, j); + thread const auto& im = Ctile_i.frag_at(i, j); + int off_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + int off_d = (i * TM_stride) * ldd + (j * TN_stride); + + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + if ((j * TN_stride + k) < dst_tile_dims.x) { + D[off_d + k] = epilogue_op.apply( + complex64_t(r[k], im[k]), C[off_c + k * fdc]); + } + } + } + } + } + } +}; + +} // namespace steel +} // namespace mlx diff --git a/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h new file mode 100644 index 00000000..5839176c --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h @@ -0,0 +1,1084 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include +#include + +#include "mlx/backend/metal/kernels/steel/defines.h" +#include "mlx/backend/metal/kernels/steel/gemm/transforms.h" +#include "mlx/backend/metal/kernels/steel/utils/integral_constant.h" + +#include + +using namespace metal; + +/////////////////////////////////////////////////////////////////////////////// +// MMA helper +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +/////////////////////////////////////////////////////////////////////////////// +// NAX Steel with new tiles +/////////////////////////////////////////////////////////////////////////////// + +struct BaseNAXFrag { + STEEL_CONST short kFragRows = 16; + STEEL_CONST short kFragCols = 16; + + STEEL_CONST short kElemsPerFrag = (kFragRows * kFragCols) / 32; + + STEEL_CONST short kElemRows = 2; + STEEL_CONST short kElemCols = 4; + + STEEL_CONST short kElemRowsJump = 8; + + static_assert( + kElemRows * kElemCols == kElemsPerFrag, + "MMAFrag shape is not consistent with MMAFrag size"); + + template + using dtype_frag_t = typename metal::vec; + + METAL_FUNC static short2 get_coord() { + const ushort simd_lane_id = __metal_get_thread_index_in_simdgroup(ushort()); + const short qid = simd_lane_id >> 2; + const short fm = ((qid & 4) | ((simd_lane_id >> 1) & 3)); + const short fn = ((qid & 2) | (simd_lane_id & 1)) * 4; + return short2{fn, fm}; + } + + METAL_FUNC static short2 get_coord(short idx) { + const ushort simd_lane_id = __metal_get_thread_index_in_simdgroup(ushort()); + const short qid = simd_lane_id >> 2; + const short fm = ((qid & 4) | ((simd_lane_id >> 1) & 3)) + (idx >> 2) * 8; + const short fn = ((qid & 2) | (simd_lane_id & 1)) * 4 + idx % 4; + return short2{fn, fm}; + } + + template < + typename T, + typename SrcPtrType, + typename StrX, + typename StrY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC static constexpr void load( + thread dtype_frag_t& dst, + SrcPtrType src, + StrX str_x, + StrY str_y, + OffX off_x = {}, + OffY off_y = {}) { + const short2 sc = get_coord(); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + const auto r = off_x + i * kElemRowsJump + sc.y; + const auto c = off_y + sc.x; + + if constexpr (metal::is_same_v>) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[i * kElemCols + j] = static_cast(src[r * str_x + c + j]); + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[i * kElemCols + j] = + static_cast(src[r * str_x + (c + j) * str_y]); + } + } + } + } + + template < + typename T, + typename SrcPtrType, + typename StrX, + typename StrY, + typename LimX, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC static constexpr void load_rows( + thread dtype_frag_t& dst, + SrcPtrType src, + StrX str_x, + StrY str_y, + LimX lim_x, + OffX off_x = {}, + OffY off_y = {}) { + const short2 sc = get_coord(); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + const auto r = off_x + i * kElemRowsJump + sc.y; + const auto c = off_y + sc.x; + + if (r < lim_x) { + if constexpr (metal::is_same_v>) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[i * kElemCols + j] = static_cast(src[r * str_x + (c + j)]); + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[i * kElemCols + j] = + static_cast(src[r * str_x + (c + j) * str_y]); + } + } + + } else { + dst = dtype_frag_t(0); + } + } + } + + template < + typename T, + typename SrcPtrType, + typename StrX, + typename StrY, + typename LimX, + typename LimY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC static constexpr void load_safe( + thread dtype_frag_t& dst, + SrcPtrType src, + StrX str_x, + StrY str_y, + LimX lim_x, + LimY lim_y, + OffX off_x = {}, + OffY off_y = {}) { + const short2 sc = get_coord(); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + const auto r = off_x + i * kElemRowsJump + sc.y; + const auto c = off_y + sc.x; + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + if (r < lim_x && (c + j) < lim_y) { + dst[i * kElemCols + j] = + static_cast(src[r * str_x + (c + j) * str_y]); + } else { + dst[i * kElemCols + j] = T(0); + } + } + } + } + + template < + typename T, + typename DstPtrType, + typename StrX, + typename StrY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC static constexpr void store( + const thread dtype_frag_t& src, + DstPtrType dst, + StrX str_x, + StrY str_y, + OffX off_x = {}, + OffY off_y = {}) { + using U = pointer_element_t; + + const short2 sc = get_coord(); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + const auto r = off_x + i * kElemRowsJump + sc.y; + const auto c = off_y + sc.x; + + if constexpr (metal::is_same_v>) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[r * str_x + c + j] = static_cast(src[i * kElemCols + j]); + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[r * str_x + (c + j) * str_y] = + static_cast(src[i * kElemCols + j]); + } + } + } + } + + template < + typename T, + typename DstPtrType, + typename StrX, + typename StrY, + typename LimX, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC static constexpr void store_rows( + const thread dtype_frag_t& src, + DstPtrType dst, + StrX str_x, + StrY str_y, + LimX lim_x, + OffX off_x = {}, + OffY off_y = {}) { + using U = pointer_element_t; + + const short2 sc = get_coord(); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + const auto r = off_x + i * kElemRowsJump + sc.y; + const auto c = off_y + sc.x; + + if (r < lim_x) { + if constexpr (metal::is_same_v>) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[r * str_x + c + j] = static_cast(src[i * kElemCols + j]); + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[r * str_x + (c + j) * str_y] = + static_cast(src[i * kElemCols + j]); + } + } + } + } + } + + template < + typename T, + typename DstPtrType, + typename StrX, + typename StrY, + typename LimX, + typename LimY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC static constexpr void store_safe( + const thread dtype_frag_t& src, + DstPtrType dst, + StrX str_x, + StrY str_y, + LimX lim_x, + LimY lim_y, + OffX off_x = {}, + OffY off_y = {}) { + using U = pointer_element_t; + + const short2 sc = get_coord(); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + const auto r = off_x + i * kElemRowsJump + sc.y; + const auto c = off_y + sc.x; + + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + if (r < lim_x && (c + j) < lim_y) { + dst[r * str_x + (c + j) * str_y] = + static_cast(src[i * kElemCols + j]); + } + } + } + } + + template < + typename T, + typename DstPtrType, + typename StrX, + typename StrY, + typename StartX, + typename StopX, + typename StartY, + typename StopY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC static constexpr void store_slice( + const thread dtype_frag_t& src, + DstPtrType dst, + StrX str_x, + StrY str_y, + StartX start_x, + StopX stop_x, + StartY start_y, + StopY stop_y, + OffX off_x = Int<0>{}, + OffY off_y = Int<0>{}) { + using U = pointer_element_t; + + const short2 sc = get_coord(); + + const_for_loop<0, kElemRows, 1>([&](auto idx_row) { + const auto r = off_x + idx_row * Int{}; + if (r >= stop_x - sc.y || r < start_x - sc.y) { + return; + } + + const_for_loop<0, kElemCols, 1>([&](auto idx_col) { + const auto c = off_y + idx_col; + if (c >= stop_y - sc.x || c < start_y - sc.x) { + return; + } + + const auto src_idx = idx_row * Int{} + idx_col; + dst[(r + sc.y) * str_x + (c + sc.x) * str_y] = + static_cast(src[src_idx]); + }); + }); + } + + template + METAL_FUNC static constexpr void row_reduce( + thread const dtype_frag_t& inp_vals, + thread T* reduced_vals) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + T thr_reduce = Op::apply( + Op::apply(inp_vals[i * kElemCols + 0], inp_vals[i * kElemCols + 1]), + Op::apply(inp_vals[i * kElemCols + 2], inp_vals[i * kElemCols + 3])); + + T qgr_reduce = simd_shuffle_xor(thr_reduce, ushort(1)); + qgr_reduce = Op::apply(thr_reduce, qgr_reduce); + + T sgr_reduce = simd_shuffle_xor(qgr_reduce, ushort(8)); + sgr_reduce = Op::apply(qgr_reduce, sgr_reduce); + + reduced_vals[i] = Op::apply(reduced_vals[i], sgr_reduce); + } + } + + template + METAL_FUNC static constexpr void row_bin_op( + thread dtype_frag_t& inp_vals, + thread T* row_vals) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + inp_vals[i * kElemCols + j] = + Op::apply(inp_vals[i * kElemCols + j], row_vals[i]); + } + } + } +}; + +template < + typename T, + short kRows_, + short kCols_, + typename NAXFrag_t = BaseNAXFrag> +struct NAXSubTile { + STEEL_CONST short kRows = kRows_; + STEEL_CONST short kCols = kCols_; + + STEEL_CONST short kFragRows = NAXFrag_t::kFragRows; + STEEL_CONST short kFragCols = NAXFrag_t::kFragCols; + STEEL_CONST short kElemsPerFrag = NAXFrag_t::kElemsPerFrag; + + STEEL_CONST short kSubTileRows = kRows / kFragRows; + STEEL_CONST short kSubTileCols = kCols / kFragCols; + + STEEL_CONST short kNumFrags = kSubTileRows * kSubTileCols; + STEEL_CONST short kElemsPerSubTile = kNumFrags * kElemsPerFrag; + + STEEL_CONST int kRowsPerThread = kSubTileRows * NAXFrag_t::kElemRows; + STEEL_CONST int kColsPerThread = kSubTileCols * NAXFrag_t::kElemCols; + + STEEL_CONST short kFragThrRows = NAXFrag_t::kElemRows; + STEEL_CONST short kFragThrCols = NAXFrag_t::kElemCols; + STEEL_CONST short kFragRowsJump = NAXFrag_t::kElemRowsJump; + + using frag_type = typename NAXFrag_t::template dtype_frag_t; + + frag_type val_frags[kNumFrags]; + + METAL_FUNC constexpr void clear() { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kNumFrags; ++i) { + val_frags[i] = frag_type(0); + } + } + + METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) { + return val_frags[i * kSubTileCols + j]; + } + + METAL_FUNC constexpr const thread frag_type& frag_at( + const short i, + const short j) const { + return val_frags[i * kSubTileCols + j]; + } + + template + METAL_FUNC constexpr thread frag_type& frag_at() { + return val_frags[i * kSubTileCols + j]; + } + + template + METAL_FUNC constexpr const thread frag_type& frag_at() const { + return val_frags[i * kSubTileCols + j]; + } + + METAL_FUNC thread T* elems() { + return reinterpret_cast(val_frags); + } + + METAL_FUNC const thread T* elems() const { + return reinterpret_cast(val_frags); + } + + template + METAL_FUNC void row_reduce(thread metal::vec& vals) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::template row_reduce( + frag_at(i, j), &vals[i * kFragThrRows]); + } + } + } + + template + METAL_FUNC void row_bin_op(thread metal::vec& vals) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::template row_bin_op( + frag_at(i, j), &vals[i * kFragThrRows]); + } + } + } + + template < + typename SrcPtrType, + typename StrX, + typename StrY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC constexpr void load( + SrcPtrType src, + StrX str_x, + StrY str_y, + OffX off_x = {}, + OffY off_y = {}) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::load( + frag_at(i, j), + src, + str_x, + str_y, + off_x + i * kFragRows, + off_y + j * kFragCols); + } + } + } + + template < + typename DstPtrType, + typename StrX, + typename StrY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC constexpr void store( + DstPtrType dst, + StrX str_x, + StrY str_y, + OffX off_x = {}, + OffY off_y = {}) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::store( + frag_at(i, j), + dst, + str_x, + str_y, + off_x + i * kFragRows, + off_y + j * kFragCols); + } + } + } + + template < + typename SrcPtrType, + typename StrX, + typename StrY, + typename LimX, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC constexpr void load_rows( + SrcPtrType src, + StrX str_x, + StrY str_y, + LimX lim_x, + OffX off_x = {}, + OffY off_y = {}) { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::load_rows( + frag_at(i, j), + src, + str_x, + str_y, + lim_x, + off_x + (i * kFragRows), + off_y + (j * kFragCols)); + } + } + } + + template < + typename SrcPtrType, + typename StrX, + typename StrY, + typename LimX, + typename LimY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC constexpr void load_safe( + SrcPtrType src, + StrX str_x, + StrY str_y, + LimX lim_x, + LimY lim_y, + OffX off_x = {}, + OffY off_y = {}) { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::load_safe( + frag_at(i, j), + src, + str_x, + str_y, + lim_x, + lim_y, + off_x + (i * kFragRows), + off_y + (j * kFragCols)); + } + } + } + + template < + typename DstPtrType, + typename StrX, + typename StrY, + typename LimX, + typename LimY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC constexpr void store_safe( + DstPtrType dst, + StrX str_x, + StrY str_y, + LimX lim_x, + LimY lim_y, + OffX off_x = {}, + OffY off_y = {}) const { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::store_safe( + frag_at(i, j), + dst, + str_x, + str_y, + lim_x, + lim_y, + off_x + (i * kFragRows), + off_y + (j * kFragCols)); + } + } + } + + template < + typename DstPtrType, + typename StrX, + typename StrY, + typename LimX, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC constexpr void store_rows( + DstPtrType dst, + StrX str_x, + StrY str_y, + LimX lim_x, + OffX off_x = {}, + OffY off_y = {}) const { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::store_safe( + frag_at(i, j), + dst, + str_x, + str_y, + lim_x, + off_x + (i * kFragRows), + off_y + (j * kFragCols)); + } + } + } + + template < + typename DstPtrType, + typename StrX, + typename StrY, + typename StartX, + typename StopX, + typename StartY, + typename StopY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC constexpr void store_slice( + DstPtrType dst, + StrX str_x, + StrY str_y, + StartX start_x, + StopX stop_x, + StartY start_y, + StopY stop_y, + OffX off_x = Int<0>{}, + OffY off_y = Int<0>{}) const { + const_for_loop<0, kSubTileRows, 1>([&](auto idx_row) { + const_for_loop<0, kSubTileCols, 1>([&](auto idx_col) { + NAXFrag_t::store_slice( + frag_at(), + dst, + str_x, + str_y, + start_x, + stop_x, + start_y, + stop_y, + off_x + idx_row * Int{}, + off_y + idx_col * Int{}); + }); + }); + } +}; + +template < + short RC, + short CC, + short RA, + short CA, + short RB, + short CB, + typename CType, + typename AType, + typename BType, + bool transpose_a, + bool transpose_b, + typename NAXFrag_t = BaseNAXFrag> +METAL_FUNC void subtile_matmad_nax( + thread NAXSubTile& C, + thread NAXSubTile& A, + metal::bool_constant, + thread NAXSubTile& B, + metal::bool_constant) { + // Static checks + constexpr short FMa = transpose_a ? CA : RA; + constexpr short FMc = RC; + static_assert(FMa == FMc, "NAX matmul: M dimensions do not match"); + + constexpr short FNb = transpose_b ? RB : CB; + constexpr short FNc = CC; + static_assert(FNb == FNc, "NAX matmul: N dimensions do not match"); + + constexpr short FKa = transpose_a ? RA : CA; + constexpr short FKb = transpose_b ? CB : RB; + static_assert(FKa == FKb, "NAX matmul: N dimensions do not match"); + + constexpr short FM = FMc; + constexpr short FN = FNc; + constexpr short FK = FKa; + + constexpr int TM = FM / 16; + constexpr int TN = FN / 16; + constexpr int TK = FK / 16; + + // Create Matmul descriptor + constexpr auto desc = mpp::tensor_ops::matmul2d_descriptor( + FM, + FN, + FK, + transpose_a, + transpose_b, + true, + mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate); + + // Create matmul op + mpp::tensor_ops::matmul2d gemm_op; + + // Create matmul operands in registers + auto ct_a = + gemm_op.template get_left_input_cooperative_tensor(); + auto ct_b = + gemm_op + .template get_right_input_cooperative_tensor(); + + // Create matmul output in register + auto ct_c = gemm_op.template get_destination_cooperative_tensor< + decltype(ct_a), + decltype(ct_b), + CType>(); + + // Load A in to left operand registers + STEEL_PRAGMA_UNROLL + for (short mm = 0; mm < TM; mm++) { + STEEL_PRAGMA_UNROLL + for (short kk = 0; kk < TK; kk++) { + const short fi = transpose_a ? kk : mm; + const short fj = transpose_a ? mm : kk; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < 8; i++) { + ct_a[(TK * mm + kk) * 8 + i] = A.frag_at(fi, fj)[i]; + } + } + } + + // Load B into right operand registers + STEEL_PRAGMA_UNROLL + for (short nn = 0; nn < TN; nn++) { + STEEL_PRAGMA_UNROLL + for (short kk = 0; kk < TK; kk++) { + const short fi = transpose_b ? nn : kk; + const short fj = transpose_b ? kk : nn; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < 8; i++) { + ct_b[(TN * kk + nn) * 8 + i] = B.frag_at(fi, fj)[i]; + } + } + } + + // Load C into output registers (op handles accumulation) + STEEL_PRAGMA_UNROLL + for (short i = 0; i < ct_c.get_capacity(); i++) { + ct_c[i] = C.elems()[i]; + } + + // Do matmul + gemm_op.run(ct_a, ct_b, ct_c); + + // Copy out results + STEEL_PRAGMA_UNROLL + for (short i = 0; i < ct_c.get_capacity(); i++) { + C.elems()[i] = ct_c[i]; + } +} + +template +struct NAXTile { + using NAXSubTile_t = NAXSubTile_; + using elem_type = T; + STEEL_CONST short kSubTileRows = NAXSubTile_t::kRows; + STEEL_CONST short kSubTileCols = NAXSubTile_t::kCols; + STEEL_CONST short kElemsPerSubTile = NAXSubTile_t::kElemsPerSubTile; + + STEEL_CONST short kTileRows = kTileRows_; + STEEL_CONST short kTileCols = kTileCols_; + + STEEL_CONST short kRows = kTileRows * kSubTileRows; + STEEL_CONST short kCols = kTileCols * kSubTileCols; + + STEEL_CONST short kSubTiles = kTileRows * kTileCols; + STEEL_CONST short kElemsPerTile = kSubTiles * kElemsPerSubTile; + + STEEL_CONST short kRowsPerThread = kTileRows * NAXSubTile_t::kRowsPerThread; + STEEL_CONST short kColsPerThread = kTileCols * NAXSubTile_t::kColsPerThread; + + STEEL_CONST short kSubTileThrRows = NAXSubTile_t::kRowsPerThread; + STEEL_CONST short kSubTileThrCols = NAXSubTile_t::kColsPerThread; + + NAXSubTile_t val_subtiles[kSubTiles]; + + METAL_FUNC NAXTile() thread {} + + METAL_FUNC constexpr void clear() { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kSubTiles; ++i) { + val_subtiles[i].clear(); + } + } + + METAL_FUNC constexpr thread NAXSubTile_t& subtile_at( + const short i, + const short j) { + return val_subtiles[i * kTileCols + j]; + } + + METAL_FUNC constexpr const thread NAXSubTile_t& subtile_at( + const short i, + const short j) const { + return val_subtiles[i * kTileCols + j]; + } + + template + METAL_FUNC constexpr const thread NAXSubTile_t& subtile_at() const { + return val_subtiles[i * kTileCols + j]; + } + + METAL_FUNC thread elem_type* elems() { + return reinterpret_cast(val_subtiles[0].elems()); + } + + METAL_FUNC const thread elem_type* elems() const { + return reinterpret_cast(val_subtiles[0].elems()); + } + + template + METAL_FUNC void row_reduce(thread metal::vec& vals) const { + auto sub_rows = (thread metal::vec*)(&vals); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + subtile_at(i, j).template row_reduce(sub_rows[i]); + } + } + } + + template + METAL_FUNC void row_bin_op(thread metal::vec& vals) { + auto sub_rows = (thread metal::vec*)(&vals); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + subtile_at(i, j).template row_bin_op(sub_rows[i]); + } + } + } + + template + METAL_FUNC void load(const threadgroup U* src) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + subtile_at(i, j).load( + src, + Int{}, + Int{}, + i * kSubTileRows, + j * kSubTileCols); + } + } + } + + template + METAL_FUNC void store(threadgroup U* dst) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + subtile_at(i, j).store( + dst, + Int{}, + Int{}, + i * kSubTileRows, + j * kSubTileCols); + } + } + } + + template + METAL_FUNC void load(const device U* src, const int ld) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + subtile_at(i, j).load( + &src[(i * kSubTileRows * ld + j * kSubTileCols)], ld, Int<1>{}); + } + } + } + + template + METAL_FUNC void store(device U* dst, const int ld) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + subtile_at(i, j).store( + &dst[(i * kSubTileRows * ld + j * kSubTileCols)], ld, Int<1>{}); + } + } + } + + template + METAL_FUNC void + load_rows(const device U* src, const int ld, const short n_rows) { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kTileCols; ++j) { + subtile_at(i, j).load_rows( + &src[(i * kSubTileRows) * ld + (j * kSubTileCols)], + ld, + Int<1>{}, + n_rows - i * kSubTileRows); + } + } + } + + template + METAL_FUNC void + load_safe(const device U* src, const int ld, const short2 src_tile_dims) { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kTileCols; ++j) { + subtile_at(i, j).load_safe( + src, + ld, + Int<1>{}, + src_tile_dims.y, + src_tile_dims.x, + i * kSubTileRows, + j * kSubTileCols); + } + } + } + + template + METAL_FUNC void store_rows(device U* dst, const int ld, const short n_rows) + const { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kTileCols; ++j) { + subtile_at(i, j).store_rows( + &dst[(i * kSubTileRows) * ld + (j * kSubTileCols)], + ld, + Int<1>{}, + n_rows - i * kSubTileRows); + } + } + } + + template + METAL_FUNC void + store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kTileCols; ++j) { + subtile_at(i, j).store_safe( + dst, + ld, + Int<1>{}, + dst_tile_dims.y, + dst_tile_dims.x, + i * kSubTileRows, + j * kSubTileCols); + } + } + } + + template + METAL_FUNC void store_slice( + device U* dst, + const int ld, + const short2 start, + const short2 stop) const { + const_for_loop<0, kTileRows, 1>([&](auto idx_row) { + const_for_loop<0, kTileCols, 1>([&](auto idx_col) { + subtile_at().store_slice( + dst, + ld, + Int<1>{}, + start.y, + stop.y, + start.x, + stop.x, + idx_row * Int{}, + idx_col * Int{}); + }); + }); + } +}; + +template < + class CTile, + class ATile, + class BTile, + bool transpose_a, + bool transpose_b> +METAL_FUNC void tile_matmad_nax( + thread CTile& C, + thread ATile& A, + metal::bool_constant, + thread BTile& B, + metal::bool_constant) { + // Static checks + constexpr short TMa = transpose_a ? ATile::kTileCols : ATile::kTileRows; + constexpr short TMc = CTile::kTileRows; + static_assert(TMa == TMc, "NAX tile matmul: M dimensions do not match"); + + constexpr short FMa = transpose_a ? ATile::kSubTileCols : ATile::kSubTileRows; + constexpr short FMc = CTile::kSubTileRows; + static_assert(FMa == FMc, "NAX subtile matmul: M dimensions do not match"); + + constexpr short TNb = transpose_b ? BTile::kTileRows : BTile::kTileCols; + constexpr short TNc = CTile::kTileCols; + static_assert(TNb == TNc, "NAX tile matmul: N dimensions do not match"); + + constexpr short FNb = transpose_b ? BTile::kSubTileRows : BTile::kSubTileCols; + constexpr short FNc = CTile::kSubTileCols; + static_assert(FNb == FNc, "NAX subtile matmul: N dimensions do not match"); + + constexpr short TKa = transpose_a ? ATile::kTileRows : ATile::kTileCols; + constexpr short TKb = transpose_b ? BTile::kTileCols : BTile::kTileRows; + static_assert(TKa == TKb, "NAX tile matmul: K dimensions do not match"); + + constexpr short FKa = transpose_a ? ATile::kSubTileRows : ATile::kSubTileCols; + constexpr short FKb = transpose_b ? BTile::kSubTileCols : BTile::kSubTileRows; + static_assert(FKa == FKb, "NAX subtile matmul: K dimensions do not match"); + + constexpr short TM = TMc; + constexpr short TN = TNc; + constexpr short TK = TKa; + + // Do matmul here + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; ++j) { + STEEL_PRAGMA_UNROLL + for (short k = 0; k < TK; ++k) { + const short ra = transpose_a ? k : i; + const short ca = transpose_a ? i : k; + const short rb = transpose_b ? j : k; + const short cb = transpose_b ? k : j; + + subtile_matmad_nax( + C.subtile_at(i, j), + A.subtile_at(ra, ca), + metal::bool_constant{}, + B.subtile_at(rb, cb), + metal::bool_constant{}); + } + } + } +} + +} // namespace steel +} // namespace mlx diff --git a/Source/Cmlx/mlx-generated/metal/steel/gemm/params.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/gemm/params.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/steel/gemm/params.h rename to Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/gemm/params.h diff --git a/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h new file mode 100644 index 00000000..0282a122 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h @@ -0,0 +1,72 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/backend/metal/kernels/steel/utils.h" + +/////////////////////////////////////////////////////////////////////////////// +// Transforms and Epilogues +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +template +struct TransformNone { + static METAL_FUNC OutT apply(InT x) { + return static_cast(x); + } + + static METAL_FUNC OutT apply(InT x, OutT) { + return static_cast(x); + } +}; + +template +struct TransformAdd { + TransformAdd(const float, const float) {} + + static METAL_FUNC OutT apply(InT x) { + return static_cast(x); + } + + static METAL_FUNC OutT apply(InT x, OutT c) { + return static_cast(x) + c; + } +}; + +template +struct TransformAxpby { + const float alpha; + const float beta; + + TransformAxpby(const float alpha_, const float beta_) + : alpha(alpha_), beta(beta_) {} + + static METAL_FUNC OutT apply(InT x) { + return static_cast(x); + } + + METAL_FUNC OutT apply(InT x, OutT c) const { + return static_cast( + x * static_cast(alpha) + (static_cast(beta) * c)); + } +}; + +template +struct AccumHelper { + typedef float accum_type; +}; + +struct BlockSwizzle { + static METAL_FUNC int2 + swizzle(uint3 tid [[threadgroup_position_in_grid]], const int swizzle_log) { + const int tid_x = (tid.x) >> swizzle_log; + const int tid_y = + ((tid.y) << swizzle_log) + ((tid.x) & ((1 << swizzle_log) - 1)); + return int2(tid_x, tid_y); + } +}; + +} // namespace steel +} // namespace mlx \ No newline at end of file diff --git a/Source/Cmlx/mlx-generated/metal/steel/utils.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/utils.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/steel/utils.h rename to Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/utils.h diff --git a/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h new file mode 100644 index 00000000..526f561e --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h @@ -0,0 +1,134 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include +#include "mlx/backend/metal/kernels/steel/utils/type_traits.h" + +#pragma METAL internals : enable + +namespace mlx { +namespace steel { + +/////////////////////////////////////////////////////////////////////////////// +// Integral constant with casting +/////////////////////////////////////////////////////////////////////////////// + +template +struct integral_constant { + static constexpr constant T value = v; + using value_type = T; + using type = integral_constant; + + METAL_FUNC constexpr operator value_type() const noexcept { + return value; + } + + // METAL_FUNC constexpr value_type operator()() const noexcept { + // return value; + // } +}; + +template +using bool_constant = integral_constant; +using true_type = bool_constant; +using false_type = bool_constant; + +template +struct is_integral : bool_constant::value> {}; + +template +struct is_integral> + : bool_constant::value> {}; + +template +constexpr constant bool is_integral_v = is_integral::value; + +template +using Int = integral_constant; + +/////////////////////////////////////////////////////////////////////////////// +// Binary Operators on Integral constants +/////////////////////////////////////////////////////////////////////////////// + +#define integral_const_binop(__op__, __operator__) \ + template \ + METAL_FUNC constexpr auto __operator__( \ + integral_constant, integral_constant) { \ + constexpr auto res = tv __op__ uv; \ + return integral_constant{}; \ + } + +integral_const_binop(+, operator+); +integral_const_binop(-, operator-); +integral_const_binop(*, operator*); +integral_const_binop(/, operator/); + +integral_const_binop(==, operator==); +integral_const_binop(!=, operator!=); +integral_const_binop(<, operator<); +integral_const_binop(>, operator>); +integral_const_binop(<=, operator<=); +integral_const_binop(>=, operator>=); + +integral_const_binop(&&, operator&&); +integral_const_binop(||, operator||); + +template >> +METAL_FUNC constexpr auto operator||(true_type, T) { + return true_type{}; +} +template >> +METAL_FUNC constexpr auto operator||(T, true_type) { + return true_type{}; +} + +template >> +METAL_FUNC constexpr auto operator&&(false_type, T) { + return false_type{}; +} + +template >> +METAL_FUNC constexpr auto operator&&(T, false_type) { + return false_type{}; +} + +// Dispatch utilities +template +void dispatch_bool(bool v, F f) { + if (v) { + f(true_type{}); + } else { + f(false_type{}); + } +} + +template +constexpr void const_for_loop(F f) { + if constexpr (start < stop) { + constexpr auto idx = Int{}; + f(idx); + const_for_loop(f); + } +} + +#undef integral_const_binop + +/////////////////////////////////////////////////////////////////////////////// +// Reduction operators +/////////////////////////////////////////////////////////////////////////////// + +template +METAL_FUNC constexpr T sum(T x) { + return x; +} + +template +METAL_FUNC constexpr auto sum(T x, Us... us) { + return x + sum(us...); +} + +} // namespace steel +} // namespace mlx + +#pragma METAL internals : disable \ No newline at end of file diff --git a/Source/Cmlx/mlx-generated/metal/steel/utils/type_traits.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/steel/utils/type_traits.h rename to Source/Cxxmlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h diff --git a/Source/Cmlx/mlx-generated/metal/ternary.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/ternary.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/ternary.h rename to Source/Cxxmlx/include/mlx/backend/metal/kernels/ternary.h diff --git a/Source/Cmlx/mlx-generated/metal/ternary_ops.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/ternary_ops.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/ternary_ops.h rename to Source/Cxxmlx/include/mlx/backend/metal/kernels/ternary_ops.h diff --git a/Source/Cmlx/mlx-generated/metal/unary.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/unary.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/unary.h rename to Source/Cxxmlx/include/mlx/backend/metal/kernels/unary.h diff --git a/Source/Cxxmlx/include/mlx/backend/metal/kernels/unary_ops.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/unary_ops.h new file mode 100644 index 00000000..327bb5a9 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/metal/kernels/unary_ops.h @@ -0,0 +1,454 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include +#include + +#include "mlx/backend/metal/kernels/cexpf.h" +#include "mlx/backend/metal/kernels/erf.h" +#include "mlx/backend/metal/kernels/expm1f.h" +#include "mlx/backend/metal/kernels/fp8.h" + +namespace { +constant float inf = metal::numeric_limits::infinity(); +} + +struct Abs { + template + T operator()(T x) { + return metal::abs(x); + }; + uint8_t operator()(uint8_t x) { + return x; + }; + uint16_t operator()(uint16_t x) { + return x; + }; + uint32_t operator()(uint32_t x) { + return x; + }; + uint64_t operator()(uint64_t x) { + return x; + }; + bool operator()(bool x) { + return x; + }; + complex64_t operator()(complex64_t x) { + return {metal::precise::sqrt(x.real * x.real + x.imag * x.imag), 0}; + }; +}; + +struct ArcCos { + template + T operator()(T x) { + return metal::precise::acos(x); + }; + + complex64_t operator()(complex64_t x); +}; + +struct ArcCosh { + template + T operator()(T x) { + return metal::precise::acosh(x); + }; +}; + +struct ArcSin { + template + T operator()(T x) { + return metal::precise::asin(x); + }; + + complex64_t operator()(complex64_t x); +}; + +struct ArcSinh { + template + T operator()(T x) { + return metal::precise::asinh(x); + }; +}; + +struct ArcTan { + template + T operator()(T x) { + return metal::precise::atan(x); + }; + + complex64_t operator()(complex64_t x); +}; + +struct ArcTanh { + template + T operator()(T x) { + return metal::precise::atanh(x); + }; +}; + +struct BitwiseInvert { + template + T operator()(T x) { + return ~x; + }; +}; + +struct Ceil { + template + T operator()(T x) { + return metal::ceil(x); + }; + int8_t operator()(int8_t x) { + return x; + }; + int16_t operator()(int16_t x) { + return x; + }; + int32_t operator()(int32_t x) { + return x; + }; + int64_t operator()(int64_t x) { + return x; + }; + uint8_t operator()(uint8_t x) { + return x; + }; + uint16_t operator()(uint16_t x) { + return x; + }; + uint32_t operator()(uint32_t x) { + return x; + }; + uint64_t operator()(uint64_t x) { + return x; + }; + bool operator()(bool x) { + return x; + }; +}; + +struct Cos { + template + T operator()(T x) { + return metal::precise::cos(x); + }; + + complex64_t operator()(complex64_t x) { + return { + metal::precise::cos(x.real) * metal::precise::cosh(x.imag), + -metal::precise::sin(x.real) * metal::precise::sinh(x.imag)}; + }; +}; + +struct Cosh { + template + T operator()(T x) { + return metal::precise::cosh(x); + }; + + complex64_t operator()(complex64_t x) { + return { + metal::precise::cosh(x.real) * metal::precise::cos(x.imag), + metal::precise::sinh(x.real) * metal::precise::sin(x.imag)}; + }; +}; + +struct Conjugate { + complex64_t operator()(complex64_t x) { + return complex64_t{x.real, -x.imag}; + } +}; + +struct Erf { + template + T operator()(T x) { + return static_cast(erf(static_cast(x))); + }; +}; + +struct ErfInv { + template + T operator()(T x) { + return static_cast(erfinv(static_cast(x))); + }; +}; + +struct Exp { + template + T operator()(T x) { + return metal::precise::exp(x); + }; + complex64_t operator()(complex64_t x) { + return cexpf(x); + } +}; + +struct Expm1 { + template + T operator()(T x) { + return static_cast(expm1f(static_cast(x))); + }; +}; + +struct Floor { + template + T operator()(T x) { + return metal::floor(x); + }; + int8_t operator()(int8_t x) { + return x; + }; + int16_t operator()(int16_t x) { + return x; + }; + int32_t operator()(int32_t x) { + return x; + }; + int64_t operator()(int64_t x) { + return x; + }; + uint8_t operator()(uint8_t x) { + return x; + }; + uint16_t operator()(uint16_t x) { + return x; + }; + uint32_t operator()(uint32_t x) { + return x; + }; + uint64_t operator()(uint64_t x) { + return x; + }; + bool operator()(bool x) { + return x; + }; +}; + +struct Imag { + float operator()(complex64_t x) { + return x.imag; + }; +}; + +struct Log { + template + T operator()(T x) { + return metal::precise::log(x); + }; + + complex64_t operator()(complex64_t x) { + auto r = metal::precise::log(Abs{}(x).real); + auto i = metal::precise::atan2(x.imag, x.real); + return {r, i}; + }; +}; + +struct Log2 { + template + T operator()(T x) { + return metal::precise::log2(x); + }; + + complex64_t operator()(complex64_t x) { + auto y = Log{}(x); + return {y.real / M_LN2_F, y.imag / M_LN2_F}; + }; +}; + +struct Log10 { + template + T operator()(T x) { + return metal::precise::log10(x); + }; + + complex64_t operator()(complex64_t x) { + auto y = Log{}(x); + return {y.real / M_LN10_F, y.imag / M_LN10_F}; + }; +}; + +struct Log1p { + template + T operator()(T x) { + return log1p(x); + }; +}; + +struct LogicalNot { + template + T operator()(T x) { + return !x; + }; +}; + +struct Negative { + template + T operator()(T x) { + return -x; + }; +}; + +struct Real { + float operator()(complex64_t x) { + return x.real; + }; +}; + +struct Round { + template + T operator()(T x) { + return metal::rint(x); + }; + complex64_t operator()(complex64_t x) { + return {metal::rint(x.real), metal::rint(x.imag)}; + }; +}; + +struct Sigmoid { + template + T operator()(T x) { + auto y = 1 / (1 + metal::exp(metal::abs(x))); + return (x < 0) ? y : 1 - y; + } +}; + +struct Sign { + template + T operator()(T x) { + return (x > T(0)) - (x < T(0)); + }; + uint32_t operator()(uint32_t x) { + return x != 0; + }; + complex64_t operator()(complex64_t x) { + if (x == complex64_t(0)) { + return x; + } + return x / + (complex64_t)metal::precise::sqrt(x.real * x.real + x.imag * x.imag); + }; +}; + +struct Sin { + template + T operator()(T x) { + return metal::precise::sin(x); + }; + + complex64_t operator()(complex64_t x) { + return { + metal::precise::sin(x.real) * metal::precise::cosh(x.imag), + metal::precise::cos(x.real) * metal::precise::sinh(x.imag)}; + }; +}; + +struct Sinh { + template + T operator()(T x) { + return metal::precise::sinh(x); + }; + + complex64_t operator()(complex64_t x) { + return { + metal::precise::sinh(x.real) * metal::precise::cos(x.imag), + metal::precise::cosh(x.real) * metal::precise::sin(x.imag)}; + }; +}; + +struct Square { + template + T operator()(T x) { + return x * x; + }; +}; + +struct Sqrt { + template + T operator()(T x) { + return metal::precise::sqrt(x); + }; + + complex64_t operator()(complex64_t x) { + if (x.real == 0.0 && x.imag == 0.0) { + return {0.0, 0.0}; + } + auto r = Abs{}(x).real; + auto a = metal::precise::sqrt((r + x.real) / 2.0); + auto b_abs = metal::precise::sqrt((r - x.real) / 2.0); + auto b = metal::copysign(b_abs, x.imag); + return {a, b}; + } +}; + +struct Rsqrt { + template + T operator()(T x) { + return metal::precise::rsqrt(x); + }; + + complex64_t operator()(complex64_t x) { + return 1.0 / Sqrt{}(x); + } +}; + +struct Tan { + template + T operator()(T x) { + return metal::precise::tan(x); + }; + + complex64_t operator()(complex64_t x) { + float tan_a = metal::precise::tan(x.real); + float tanh_b = metal::precise::tanh(x.imag); + float t1 = tan_a * tanh_b; + float denom = 1. + t1 * t1; + return {(tan_a - tanh_b * t1) / denom, (tanh_b + tan_a * t1) / denom}; + }; +}; + +struct Tanh { + template + T operator()(T x) { + return metal::precise::tanh(x); + }; + + complex64_t operator()(complex64_t x) { + float tanh_a = metal::precise::tanh(x.real); + float tan_b = metal::precise::tan(x.imag); + float t1 = tanh_a * tan_b; + float denom = 1. + t1 * t1; + return {(tanh_a + tan_b * t1) / denom, (tan_b - tanh_a * t1) / denom}; + }; +}; + +complex64_t ArcCos::operator()(complex64_t x) { + auto i = complex64_t{0.0, 1.0}; + auto y = Log{}(x + i * Sqrt{}(1.0 - x * x)); + return {y.imag, -y.real}; +}; + +complex64_t ArcSin::operator()(complex64_t x) { + auto i = complex64_t{0.0, 1.0}; + auto y = Log{}(i * x + Sqrt{}(1.0 - x * x)); + return {y.imag, -y.real}; +}; + +complex64_t ArcTan::operator()(complex64_t x) { + auto i = complex64_t{0.0, 1.0}; + auto ix = i * x; + return (1.0 / complex64_t{0.0, 2.0}) * Log{}((1.0 + ix) / (1.0 - ix)); +}; + +struct ToFP8 { + template + uint8_t operator()(T f) { + return fp8_e4m3(f).bits; + } +}; + +struct FromFP8 { + float operator()(uint8_t x) { + return float(*(thread fp8_e4m3*)(&x)); + } +}; diff --git a/Source/Cxxmlx/include/mlx/backend/metal/kernels/utils.h b/Source/Cxxmlx/include/mlx/backend/metal/kernels/utils.h new file mode 100644 index 00000000..d3564501 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/metal/kernels/utils.h @@ -0,0 +1,445 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include + +#include "mlx/backend/metal/kernels/bf16.h" +#include "mlx/backend/metal/kernels/bf16_math.h" +#include "mlx/backend/metal/kernels/complex.h" +#include "mlx/backend/metal/kernels/defines.h" +#include "mlx/backend/metal/kernels/logging.h" + +typedef half float16_t; + +// Work per thread values for different types. The values here are expected to +// match get_work_per_thread in mlx/backend/metal/utils.h +template +struct WorkPerThread { + static_assert(sizeof(U) <= 8, "Type too large"); + static constexpr int constant n = 8 / sizeof(U); +}; + +/////////////////////////////////////////////////////////////////////////////// +// Type limits utils +/////////////////////////////////////////////////////////////////////////////// + +template +struct Limits { + static const constant U max = metal::numeric_limits::max(); + static const constant U min = metal::numeric_limits::min(); + static const constant U finite_max = metal::numeric_limits::max(); + static const constant U finite_min = metal::numeric_limits::min(); +}; + +#define instantiate_default_limit(type) \ + template <> \ + struct Limits { \ + static constexpr constant type max = metal::numeric_limits::max(); \ + static constexpr constant type min = metal::numeric_limits::min(); \ + static constexpr constant type finite_max = \ + metal::numeric_limits::max(); \ + static constexpr constant type finite_min = \ + metal::numeric_limits::min(); \ + }; + +instantiate_default_limit(uint8_t); +instantiate_default_limit(uint16_t); +instantiate_default_limit(uint32_t); +instantiate_default_limit(uint64_t); +instantiate_default_limit(int8_t); +instantiate_default_limit(int16_t); +instantiate_default_limit(int32_t); +instantiate_default_limit(int64_t); + +#define instantiate_float_limit(type) \ + template <> \ + struct Limits { \ + static constexpr constant type max = \ + metal::numeric_limits::infinity(); \ + static constexpr constant type min = \ + -metal::numeric_limits::infinity(); \ + static constexpr constant type finite_max = \ + metal::numeric_limits::max(); \ + static constexpr constant type finite_min = \ + -metal::numeric_limits::max(); \ + }; + +instantiate_float_limit(half); +instantiate_float_limit(float); +instantiate_float_limit(bfloat16_t); + +template <> +struct Limits { + static constexpr constant bool max = true; + static constexpr constant bool min = false; +}; + +template <> +struct Limits { + static constexpr constant complex64_t max = complex64_t( + metal::numeric_limits::infinity(), + metal::numeric_limits::infinity()); + static constexpr constant complex64_t min = complex64_t( + -metal::numeric_limits::infinity(), + -metal::numeric_limits::infinity()); +}; + +/////////////////////////////////////////////////////////////////////////////// +// Indexing utils +/////////////////////////////////////////////////////////////////////////////// + +#define MLX_MTL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)") + +/////////////////////////////////////////////////////////////////////////////// +// Single Array with generic dims + +template +METAL_FUNC IdxT elem_to_loc( + IdxT elem, + constant const int* shape, + constant const int64_t* strides, + int ndim) { + IdxT loc = 0; + for (int i = ndim - 1; i >= 0 && elem > 0; --i) { + loc += (elem % shape[i]) * IdxT(strides[i]); + elem /= shape[i]; + } + return loc; +} + +// Non templated version to handle arbitrary dims +template +METAL_FUNC IdxT elem_to_loc( + uint3 elem, + constant const int* shape, + constant const int64_t* strides, + int ndim) { + IdxT loc = + elem.x * IdxT(strides[ndim - 1]) + elem.y * IdxT(strides[ndim - 2]); + for (int d = ndim - 3; d >= 0; --d) { + loc += (elem.z % shape[d]) * IdxT(strides[d]); + elem.z /= shape[d]; + } + return loc; +} + +/////////////////////////////////////////////////////////////////////////////// +// Single Array with fixed N dims + +template +METAL_FUNC IdxT elem_to_loc_1(uint elem, constant const int64_t& stride) { + return elem * IdxT(stride); +} + +template +METAL_FUNC IdxT elem_to_loc_2(uint2 elem, constant const int64_t strides[2]) { + return elem.x * IdxT(strides[1]) + elem.y * IdxT(strides[0]); +} + +template +METAL_FUNC IdxT elem_to_loc_3(uint3 elem, constant const int64_t strides[3]) { + return elem.x * IdxT(strides[2]) + elem.y * IdxT(strides[1]) + + elem.z * IdxT(strides[0]); +} + +/////////////////////////////////////////////////////////////////////////////// +// Multiple Arrays with generic dims + +template +METAL_FUNC vec elem_to_loc_2_nd( + uint3 elem, + constant const int* shape, + constant const int64_t* a_strides, + constant const int64_t* b_strides, + int ndim) { + vec loc = { + IdxT( + elem.x * IdxT(a_strides[ndim - 1]) + + IdxT(elem.y) * IdxT(a_strides[ndim - 2])), + IdxT( + elem.x * IdxT(b_strides[ndim - 1]) + + elem.y * IdxT(b_strides[ndim - 2]))}; + for (int d = ndim - 3; d >= 0; --d) { + uint l = elem.z % shape[d]; + loc.x += l * IdxT(a_strides[d]); + loc.y += l * IdxT(b_strides[d]); + elem.z /= shape[d]; + } + return loc; +} + +template +METAL_FUNC vec elem_to_loc_3_nd( + uint3 elem, + constant const int* shape, + constant const int64_t* a_strides, + constant const int64_t* b_strides, + constant const int64_t* c_strides, + int ndim) { + vec loc = { + IdxT(elem.x * IdxT(a_strides[ndim - 1])) + + IdxT(elem.y * IdxT(a_strides[ndim - 2])), + IdxT(elem.x * IdxT(b_strides[ndim - 1])) + + IdxT(elem.y * IdxT(b_strides[ndim - 2])), + IdxT(elem.x * IdxT(c_strides[ndim - 1])) + + IdxT(elem.y * IdxT(c_strides[ndim - 2]))}; + for (int d = ndim - 3; d >= 0; --d) { + uint l = elem.z % shape[d]; + loc.x += l * IdxT(a_strides[d]); + loc.y += l * IdxT(b_strides[d]); + loc.z += l * IdxT(c_strides[d]); + elem.z /= shape[d]; + } + return loc; +} + +/////////////////////////////////////////////////////////////////////////////// +// Elem to loc in a loop utils +/////////////////////////////////////////////////////////////////////////////// + +template +struct LoopedElemToLoc { + int dim; + LoopedElemToLoc inner_looper; + OffsetT offset{0}; + int index{0}; + + LoopedElemToLoc(int dim) : dim(dim), inner_looper(dim - 1) {} + + void next(const constant int* shape, const constant int64_t* strides) { + if (dim == 0) { + return; + } + index++; + offset += OffsetT(strides[dim - 1]); + if (index >= shape[dim - 1]) { + index = 0; + inner_looper.next(shape, strides); + offset = inner_looper.offset; + } + } + + void next(int n, const constant int* shape, const constant int64_t* strides) { + if (dim == 0) { + return; + } + index += n; + offset += n * OffsetT(strides[dim - 1]); + + if (index >= shape[dim - 1]) { + int extra = index - shape[dim - 1]; + if (extra >= shape[dim - 1]) { + inner_looper.next(1 + extra / shape[dim - 1], shape, strides); + extra = extra % shape[dim - 1]; + } else { + inner_looper.next(shape, strides); + } + index = 0; + offset = inner_looper.offset; + if (extra > 0) { + next(extra, shape, strides); + } + } + } + + OffsetT location() { + return offset; + } +}; + +template +struct LoopedElemToLoc<1, OffsetT, true> { + int dim; + OffsetT offset{0}; + uint index{0}; + + LoopedElemToLoc(int dim) : dim(dim) {} + + void next(const constant int* shape, const constant int64_t* strides) { + index++; + if (dim > 1) { + offset = elem_to_loc(index, shape, strides, dim); + } else { + offset += OffsetT(strides[0]); + } + } + + void next(int n, const constant int* shape, const constant int64_t* strides) { + index += n; + if (dim > 1) { + offset = elem_to_loc(index, shape, strides, dim); + } else { + offset = index * OffsetT(strides[0]); + } + } + + OffsetT location() { + return offset; + } +}; + +template +struct LoopedElemToLoc<1, OffsetT, false> { + OffsetT offset{0}; + + LoopedElemToLoc(int) {} + + void next(const constant int*, const constant int64_t* strides) { + offset += OffsetT(strides[0]); + } + + void next(int n, const constant int*, const constant int64_t* strides) { + offset += n * OffsetT(strides[0]); + } + + OffsetT location() { + return offset; + } +}; + +/////////////////////////////////////////////////////////////////////////////// +// Calculation utils +/////////////////////////////////////////////////////////////////////////////// + +/** Compute ceil((float)N/(float)M) */ +template +inline T ceildiv(T N, U M) { + return (N + M - 1) / M; +} + +// https://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html#1202 +inline float log1p(float x) { + float xp1 = 1.0f + x; + if (xp1 == Limits::max) { + return Limits::max; + } + if (xp1 == 1.0f) { + return x; + } + + return x * (metal::log(xp1) / (xp1 - 1.0f)); +} + +inline bfloat16_t log1p(bfloat16_t x) { + float xp1 = 1.0f + static_cast(x); + if (xp1 == Limits::max) { + return Limits::max; + } + if (xp1 == 1.0f) { + return x; + } + + return bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f))); +} + +inline complex64_t log1p(complex64_t in) { + float x = in.real; + float y = in.imag; + float zabs = metal::precise::sqrt(x * x + y * y); + float theta = metal::atan2(y, x + 1); + if (zabs < 0.5f) { + float r = x * (2 + x) + y * y; + if (r == 0) { // handle underflow + return {x, theta}; + } + return {0.5f * log1p(r), theta}; + } else { + auto z0 = metal::sqrt((x + 1) * (x + 1) + y * y); + return {metal::log(z0), theta}; + } +} + +/////////////////////////////////////////////////////////////////////////////// +// SIMD shuffle ops +/////////////////////////////////////////////////////////////////////////////// + +inline uint64_t simd_shuffle_down(uint64_t data, uint16_t delta) { + return as_type( + metal::simd_shuffle_down(as_type(data), delta)); +} + +inline int64_t simd_shuffle_down(int64_t data, uint16_t delta) { + return as_type( + metal::simd_shuffle_down(as_type(data), delta)); +} + +inline bool simd_shuffle_down(bool data, uint16_t delta) { + return simd_shuffle_down(static_cast(data), delta); +} + +inline complex64_t simd_shuffle_down(complex64_t data, uint16_t delta) { + return complex64_t( + simd_shuffle_down(data.real, delta), simd_shuffle_down(data.imag, delta)); +} + +inline uint64_t simd_shuffle_up(uint64_t data, uint16_t delta) { + return as_type(metal::simd_shuffle_up(as_type(data), delta)); +} + +inline int64_t simd_shuffle_up(int64_t data, uint16_t delta) { + return as_type(metal::simd_shuffle_up(as_type(data), delta)); +} + +inline bool simd_shuffle_up(bool data, uint16_t delta) { + return simd_shuffle_up(static_cast(data), delta); +} + +inline complex64_t simd_shuffle_up(complex64_t data, uint16_t delta) { + return complex64_t( + simd_shuffle_up(data.real, delta), simd_shuffle_up(data.imag, delta)); +} + +inline uint64_t +simd_shuffle_and_fill_up(uint64_t data, uint64_t filling, uint16_t delta) { + return as_type(metal::simd_shuffle_and_fill_up( + as_type(data), as_type(filling), delta)); +} + +inline int64_t +simd_shuffle_and_fill_up(int64_t data, int64_t filling, uint16_t delta) { + return as_type(metal::simd_shuffle_and_fill_up( + as_type(data), as_type(filling), delta)); +} + +inline bool simd_shuffle_and_fill_up(bool data, bool filling, uint16_t delta) { + return simd_shuffle_and_fill_up( + static_cast(data), static_cast(filling), delta); +} + +inline complex64_t simd_shuffle_and_fill_up( + complex64_t data, + complex64_t filling, + uint16_t delta) { + return complex64_t( + simd_shuffle_and_fill_up(data.real, filling.real, delta), + simd_shuffle_and_fill_up(data.imag, filling.imag, delta)); +} + +inline uint64_t simd_shuffle(uint64_t data, uint16_t lane) { + return as_type(metal::simd_shuffle(as_type(data), lane)); +} + +inline int64_t simd_shuffle(int64_t data, uint16_t lane) { + return as_type(metal::simd_shuffle(as_type(data), lane)); +} + +inline bool simd_shuffle(bool data, uint16_t lane) { + return simd_shuffle(static_cast(data), lane); +} + +inline complex64_t simd_shuffle(complex64_t data, uint16_t lane) { + return complex64_t( + simd_shuffle(data.real, lane), simd_shuffle(data.imag, lane)); +} + +// std::conditional is not included with Metal +template +struct ConditionalType { + using type = U; +}; + +template +struct ConditionalType { + using type = T; +}; diff --git a/Source/Cxxmlx/include/mlx/backend/metal/matmul.h b/Source/Cxxmlx/include/mlx/backend/metal/matmul.h new file mode 100644 index 00000000..218664b1 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/metal/matmul.h @@ -0,0 +1,144 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include "mlx/backend/metal/device.h" + +namespace mlx::core { + +template +void steel_matmul_regular_axpby( + const Stream& s, + metal::Device& d, + const array& a, + const array& b, + const array& c, + array& out, + int M, + int N, + int K, + int batch_size_out, + int lda, + int ldb, + int ldd, + bool transpose_a, + bool transpose_b, + std::vector& copies, + Shape batch_shape, + Strides batch_strides, + int64_t A_batch_stride, + int64_t B_batch_stride, + int64_t matrix_stride_out, + int64_t C_batch_stride = 0, + float alpha = 1.0f, + float beta = 0.0f); + +inline void steel_matmul_regular( + const Stream& s, + metal::Device& d, + const array& a, + const array& b, + array& out, + int M, + int N, + int K, + int batch_size_out, + int lda, + int ldb, + int ldd, + bool transpose_a, + bool transpose_b, + std::vector& copies, + Shape batch_shape, + Strides batch_strides, + int64_t A_batch_stride, + int64_t B_batch_stride, + int64_t matrix_stride_out) { + return steel_matmul_regular_axpby( + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& a = */ a, + /* const array& b = */ b, + /* const array& c = */ b, + /* array& out = */ out, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* int batch_size_out = */ batch_size_out, + /* int lda = */ lda, + /* int ldb = */ ldb, + /* int ldd = */ ldd, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* std::vector& copies = */ copies, + /* Shape batch_shape = */ batch_shape, + /* Strides batch_strides = */ batch_strides, + /* int64_t A_batch_stride = */ A_batch_stride, + /* int64_t B_batch_stride = */ B_batch_stride, + /* int64_t matrix_stride_out = */ matrix_stride_out); +} + +template +void steel_matmul_axpby( + const Stream& s, + metal::Device& d, + const array& a, + const array& b, + const array& c, + array& out, + int M, + int N, + int K, + int batch_size_out, + int lda, + int ldb, + bool transpose_a, + bool transpose_b, + std::vector& copies, + Shape batch_shape = {}, + Strides A_batch_stride = {}, + Strides B_batch_stride = {}, + Strides C_batch_stride = {}, + float alpha = 1.0f, + float beta = 0.0f); + +inline void steel_matmul( + const Stream& s, + metal::Device& d, + const array& a, + const array& b, + array& out, + int M, + int N, + int K, + int batch_size_out, + int lda, + int ldb, + bool transpose_a, + bool transpose_b, + std::vector& copies, + Shape batch_shape = {}, + Strides A_batch_stride = {}, + Strides B_batch_stride = {}) { + return steel_matmul_axpby( + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& a = */ a, + /* const array& b = */ b, + /* const array& c = */ b, + /* array& out = */ out, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* int batch_size_out = */ batch_size_out, + /* int lda = */ lda, + /* int ldb = */ ldb, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* std::vector& copies = */ copies, + /* Shape batch_shape = */ batch_shape, + /* Strides A_batch_stride = */ A_batch_stride, + /* Strides B_batch_stride = */ B_batch_stride); +} + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/backend/metal/metal.h b/Source/Cxxmlx/include/mlx/backend/metal/metal.h new file mode 100644 index 00000000..6662e21e --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/metal/metal.h @@ -0,0 +1,25 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include +#include +#include + +#include "mlx/api.h" + +namespace mlx::core::metal { + +/* Check if the Metal backend is available. */ +MLX_API bool is_available(); + +/** Capture a GPU trace, saving it to an absolute file `path` */ +MLX_API void start_capture(std::string path = ""); +MLX_API void stop_capture(); + +/** Get information about the GPU and system settings. */ +MLX_API const + std::unordered_map>& + device_info(); + +} // namespace mlx::core::metal diff --git a/Source/Cxxmlx/include/mlx/backend/metal/reduce.h b/Source/Cxxmlx/include/mlx/backend/metal/reduce.h new file mode 100644 index 00000000..a997d7e2 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/metal/reduce.h @@ -0,0 +1,41 @@ +// Copyright @ 2023 - 2024 Apple Inc. + +#pragma once + +#include "mlx/backend/common/reduce.h" +#include "mlx/backend/metal/device.h" +#include "mlx/stream.h" + +namespace mlx::core { + +using metal::CommandEncoder; + +void all_reduce_dispatch( + const array& in, + array& out, + const std::string& op_name, + CommandEncoder& compute_encoder, + metal::Device& d, + const Stream& s); + +void row_reduce_general_dispatch( + const array& in, + array& out, + const std::string& op_name, + const ReductionPlan& plan, + const std::vector& axes, + CommandEncoder& compute_encoder, + metal::Device& d, + const Stream& s); + +void strided_reduce_general_dispatch( + const array& in, + array& out, + const std::string& op_name, + const ReductionPlan& plan, + const std::vector& axes, + CommandEncoder& compute_encoder, + metal::Device& d, + const Stream& s); + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/backend/metal/resident.h b/Source/Cxxmlx/include/mlx/backend/metal/resident.h new file mode 100644 index 00000000..5db55828 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/metal/resident.h @@ -0,0 +1,32 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/backend/metal/device.h" + +namespace mlx::core::metal { + +class ResidencySet { + public: + ResidencySet(MTL::Device* d); + ~ResidencySet(); + + ResidencySet(const ResidencySet&) = delete; + ResidencySet& operator=(const ResidencySet&) = delete; + + const MTL::ResidencySet* mtl_residency_set() { + return wired_set_; + } + + void insert(MTL::Allocation* buf); + void erase(MTL::Allocation* buf); + + void resize(size_t size); + + private: + MTL::ResidencySet* wired_set_{nullptr}; + std::unordered_set unwired_set_; + size_t capacity_{0}; +}; + +} // namespace mlx::core::metal diff --git a/Source/Cxxmlx/include/mlx/backend/metal/scan.h b/Source/Cxxmlx/include/mlx/backend/metal/scan.h new file mode 100644 index 00000000..dab79c50 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/metal/scan.h @@ -0,0 +1,17 @@ +#pragma once + +#include "mlx/array.h" +#include "mlx/primitives.h" + +namespace mlx::core { + +void scan_gpu_inplace( + array in, + array& out, + Scan::ReduceType reduce_type, + int axis, + bool reverse, + bool inclusive, + const Stream& s); + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/backend/metal/ternary.h b/Source/Cxxmlx/include/mlx/backend/metal/ternary.h new file mode 100644 index 00000000..91c6fbbe --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/metal/ternary.h @@ -0,0 +1,21 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/array.h" + +namespace mlx::core { + +void ternary_op_gpu( + const std::vector& inputs, + array& out, + const char* op, + const Stream& s); + +void ternary_op_gpu_inplace( + const std::vector& inputs, + array& out, + const char* op, + const Stream& s); + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/backend/metal/unary.h b/Source/Cxxmlx/include/mlx/backend/metal/unary.h new file mode 100644 index 00000000..1d6ecf02 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/metal/unary.h @@ -0,0 +1,21 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/array.h" + +namespace mlx::core { + +void unary_op_gpu( + const std::vector& inputs, + array& out, + const char* op, + const Stream& s); + +void unary_op_gpu_inplace( + const std::vector& inputs, + array& out, + const char* op, + const Stream& s); + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/backend/metal/utils.h b/Source/Cxxmlx/include/mlx/backend/metal/utils.h new file mode 100644 index 00000000..c4cef8cb --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/metal/utils.h @@ -0,0 +1,99 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include + +#include "mlx/array.h" +#include "mlx/backend/metal/device.h" +#include "mlx/primitives.h" + +namespace mlx::core { + +MLX_API std::string type_to_name(const Dtype& t); +MLX_API std::string type_to_name(const array& a); + +// Compute the grid and block dimensions, check backend/common/utils.h for docs. +MTL::Size get_block_dims(int dim0, int dim1, int dim2, int pow2 = 10); +MTL::Size get_2d_grid_dims(const Shape& shape, const Strides& strides); +MTL::Size +get_2d_grid_dims(const Shape& shape, const Strides& strides, size_t divisor); + +inline NS::String* make_string(std::ostringstream& os) { + std::string string = os.str(); + return NS::String::string(string.c_str(), NS::UTF8StringEncoding); +} + +inline void debug_set_stream_queue_label(MTL::CommandQueue* queue, int index) { +#ifdef MLX_METAL_DEBUG + std::ostringstream label; + label << "Stream " << index; + queue->setLabel(make_string(label)); +#endif +} + +inline void debug_set_primitive_buffer_label( + MTL::CommandBuffer* command_buffer, + Primitive& primitive) { +#ifdef MLX_METAL_DEBUG + std::ostringstream label; + if (auto cbuf_label = command_buffer->label(); cbuf_label) { + label << cbuf_label->utf8String(); + } + label << primitive.name(); + command_buffer->setLabel(make_string(label)); +#endif +} + +template +constexpr bool is_numeric_except_char = std::is_arithmetic_v && + !std::is_same_v && !std::is_same_v && + !std::is_same_v && !std::is_same_v; + +template +void concatenate(std::string& acc, T first) { + if constexpr (is_numeric_except_char) { + acc += std::to_string(first); + } else { + acc += first; + } +} + +template +void concatenate(std::string& acc, T first, Args... args) { + if constexpr (is_numeric_except_char) { + acc += std::to_string(first); + } else { + acc += first; + } + concatenate(acc, args...); +} + +inline int get_work_per_thread(Dtype dtype) { + return std::max(1, 8 / dtype.size()); +} +inline int get_work_per_thread(Dtype dtype, size_t size) { + constexpr size_t wpt_threshold = 1 << 16; + return size < wpt_threshold ? 1 : std::max(1, 8 / dtype.size()); +} + +inline size_t ceildiv(size_t n, size_t m) { + return (n + m - 1) / m; +} + +inline void check_kernel_threadgroup_size( + const MTL::ComputePipelineState* kernel, + MTL::Size group_dims, + const std::string& name) { + auto max_size = kernel->maxTotalThreadsPerThreadgroup(); + auto requested_size = group_dims.width * group_dims.height * group_dims.depth; + + if (max_size < requested_size) { + std::ostringstream msg; + msg << "Maximum threads per threadgroup is " << max_size + << " but requested " << requested_size << " for kernel " << name << "."; + throw std::runtime_error(msg.str()); + } +} + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/backend/no_gpu/apple_memory.h b/Source/Cxxmlx/include/mlx/backend/no_gpu/apple_memory.h new file mode 100644 index 00000000..7fdc5301 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/no_gpu/apple_memory.h @@ -0,0 +1,16 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include + +namespace { + +size_t get_memory_size() { + size_t memsize = 0; + size_t length = sizeof(memsize); + sysctlbyname("hw.memsize", &memsize, &length, NULL, 0); + return memsize; +} + +} // namespace diff --git a/Source/Cxxmlx/include/mlx/backend/no_gpu/linux_memory.h b/Source/Cxxmlx/include/mlx/backend/no_gpu/linux_memory.h new file mode 100644 index 00000000..f909edcd --- /dev/null +++ b/Source/Cxxmlx/include/mlx/backend/no_gpu/linux_memory.h @@ -0,0 +1,22 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include + +namespace { + +size_t get_memory_size() { + struct sysinfo info; + + if (sysinfo(&info) != 0) { + return 0; + } + + size_t total_ram = info.totalram; + total_ram *= info.mem_unit; + + return total_ram; +} + +} // namespace diff --git a/Source/Cxxmlx/include/mlx/compile.h b/Source/Cxxmlx/include/mlx/compile.h new file mode 100644 index 00000000..eba0983e --- /dev/null +++ b/Source/Cxxmlx/include/mlx/compile.h @@ -0,0 +1,45 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include "mlx/api.h" +#include "mlx/array.h" + +namespace mlx::core { + +enum class CompileMode { disabled, no_simplify, no_fuse, enabled }; + +/** Compile takes a function and returns a compiled function. */ +MLX_API std::function(const std::vector&)> compile( + std::function(const std::vector&)> fun, + bool shapeless = false); + +MLX_API std::function(const std::vector&)> compile( + std::vector (*fun)(const std::vector&), + bool shapeless = false); + +// Convert capture-less lambdas to function pointers. +template < + typename F, + typename = std::enable_if_t< + std::is_convertible_v())>>> +std::function(const std::vector&)> compile( + F&& f, + bool shapeless = false) { + return compile(+f, shapeless); +} + +/** Globally disable compilation. + * Setting the environment variable ``MLX_DISABLE_COMPILE`` can also + * be used to disable compilation. + */ +MLX_API void disable_compile(); + +/** Globally enable compilation. + * This will override the environment variable ``MLX_DISABLE_COMPILE``. + */ +MLX_API void enable_compile(); + +/** Set the compiler mode to the given value. */ +MLX_API void set_compile_mode(CompileMode mode); +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/compile_impl.h b/Source/Cxxmlx/include/mlx/compile_impl.h new file mode 100644 index 00000000..238a8b94 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/compile_impl.h @@ -0,0 +1,70 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include + +#include "mlx/api.h" +#include "mlx/array.h" + +namespace mlx::core::detail { + +using ArraysAndExtra = std::pair, std::shared_ptr>; +using ArrayFnWithExtra = + std::function&)>; + +// This is not part of the general C++ API as calling with a bad id is a bad +// idea. +MLX_API std::function(const std::vector&)> compile( + std::function(const std::vector&)> fun, + std::uintptr_t fun_id, + bool shapeless = false, + std::vector constants = {}); + +MLX_API ArrayFnWithExtra compile( + ArrayFnWithExtra fun, + std::uintptr_t fun_id, + bool shapeless, + std::vector constants); + +// Erase cached compile functions +MLX_API void compile_erase(std::uintptr_t fun_id); + +// Clear the compiler cache causing a recompilation of all compiled functions +// when called again. +MLX_API void compile_clear_cache(); + +bool compile_available_for_device(const Device& device); + +std::tuple, std::vector, std::shared_ptr> +compile_trace( + const ArrayFnWithExtra& fun, + const std::vector& inputs, + bool shapeless); + +using ParentsMap = + std::unordered_map>>; + +// Traverses the graph to build a tape and a map of array ids to their parents +std::pair, ParentsMap> compile_dfs( + const std::vector& inputs, + std::vector& outputs, + const std::vector& original_inputs); + +// Simplify the tape. +void compile_simplify( + std::vector& tape, + ParentsMap& parents_map, + std::vector& outputs, + int passes); + +std::vector compile_replace( + const std::vector& tape, + const std::vector& trace_inputs, + const std::vector& trace_outputs, + const std::vector& inputs, + bool shapeless); + +void compile_validate_shapeless(const std::vector& tape); + +} // namespace mlx::core::detail diff --git a/Source/Cxxmlx/include/mlx/device.h b/Source/Cxxmlx/include/mlx/device.h new file mode 100644 index 00000000..f89ad189 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/device.h @@ -0,0 +1,56 @@ +// Copyright © 2023-2025 Apple Inc. + +#pragma once + +#include "mlx/api.h" + +#include +#include +#include + +namespace mlx::core { + +struct MLX_API Device { + enum class DeviceType { + cpu, + gpu, + }; + + static constexpr DeviceType cpu = DeviceType::cpu; + static constexpr DeviceType gpu = DeviceType::gpu; + + Device(DeviceType type, int index = 0) : type(type), index(index) {} + + DeviceType type; + int index; +}; + +MLX_API const Device& default_device(); + +MLX_API void set_default_device(const Device& d); + +MLX_API bool operator==(const Device& lhs, const Device& rhs); +MLX_API bool operator!=(const Device& lhs, const Device& rhs); + +MLX_API bool is_available(const Device& d); + +/** Get the number of available devices for the given device type. */ +MLX_API int device_count(Device::DeviceType type); + +/** + * Get information about a device. + * + * Returns a map of device properties. Keys vary by backend: + * - device_name (string): Device name + * - architecture (string): Architecture identifier + * - total_memory/memory_size (size_t): Total device memory + * - free_memory (size_t): Available memory (CUDA only) + * - uuid (string): Device UUID (CUDA only) + * - pci_bus_id (string): PCI bus ID (CUDA only) + * - compute_capability_major/minor (size_t): Compute capability (CUDA only) + */ +MLX_API const + std::unordered_map>& + device_info(const Device& d = default_device()); + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/distributed/distributed.h b/Source/Cxxmlx/include/mlx/distributed/distributed.h new file mode 100644 index 00000000..00c7a80e --- /dev/null +++ b/Source/Cxxmlx/include/mlx/distributed/distributed.h @@ -0,0 +1,61 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include + +#include "mlx/api.h" +#include "mlx/array.h" +#include "mlx/utils.h" + +namespace mlx::core::distributed { + +// Forward declaration of the base group implementation. +namespace detail { +class GroupImpl; +}; + +/* Check if a communication backend is available */ +MLX_API bool is_available(); +MLX_API bool is_available(const std::string& bk); + +/** + * A distributed::Group represents a group of independent mlx processes that + * can communicate. We must also be able to create sub-groups from a group in + * order to define more granular communication. + */ +struct MLX_API Group { + Group(std::shared_ptr group) : group_(std::move(group)) {} + + int rank() const; + int size() const; + + /** + * Split the group according to the provided color. Namely processes that use + * the same color will go to the same group. + * + * The key defines the rank of the processes in the new group. The smaller + * the key the smaller the rank. If the provided key is negative, then the + * rank in the current group is used. + */ + Group split(int color, int key = -1) const; + + const std::shared_ptr& raw_group() const { + return group_; + } + + private: + std::shared_ptr group_{nullptr}; +}; + +/** + * Initialize the distributed backend and return the group containing all + * discoverable processes. + * + * If strict is true then throw an error if we couldn't initialize the + * distributed subsystem. Otherwise simply return a singleton group which will + * render communication operations as no-op. + */ +MLX_API Group init(bool strict = false, const std::string& bk = "any"); + +} // namespace mlx::core::distributed diff --git a/Source/Cxxmlx/include/mlx/distributed/distributed_impl.h b/Source/Cxxmlx/include/mlx/distributed/distributed_impl.h new file mode 100644 index 00000000..d889587a --- /dev/null +++ b/Source/Cxxmlx/include/mlx/distributed/distributed_impl.h @@ -0,0 +1,59 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/distributed/distributed.h" + +namespace mlx::core::distributed::detail { + +/** + * Abstract base class of a distributed group implementation. + */ +class GroupImpl { + public: + virtual ~GroupImpl() {} + + // Choose the stream this communication group can operate on + virtual Stream communication_stream(StreamOrDevice s = {}) = 0; + + // Group operations + virtual int rank() = 0; + virtual int size() = 0; + virtual std::shared_ptr split(int color, int key = -1) = 0; + + // Actual communication operations + virtual void all_sum(const array& input, array& output, Stream stream) = 0; + virtual void all_gather(const array& input, array& output, Stream stream) = 0; + virtual void send(const array& input, int dst, Stream stream) = 0; + virtual void recv(array& out, int src, Stream stream) = 0; + virtual void all_max(const array& input, array& output, Stream stream) = 0; + virtual void all_min(const array& input, array& output, Stream stream) = 0; + virtual void + sum_scatter(const array& input, array& output, Stream stream) = 0; +}; + +/* Define the MLX stream that the communication should happen in. */ +Stream communication_stream(Group group, StreamOrDevice s = {}); + +/* Perform an all reduce sum operation */ +void all_sum(Group group, const array& input, array& output, Stream stream); + +/* Perform an all gather operation */ +void all_gather(Group group, const array& input, array& output, Stream stream); + +/** Send an array to the dst rank */ +void send(Group group, const array& input, int dst, Stream stream); + +/** Recv an array from the src rank */ +void recv(Group group, array& out, int src, Stream stream); + +/** Max reduction */ +void all_max(Group group, const array& input, array& output, Stream stream); + +/** Min reduction */ +void all_min(Group group, const array& input, array& output, Stream stream); + +/** Reduce scatter with average operation */ +void sum_scatter(Group group, const array& input, array& output, Stream stream); + +} // namespace mlx::core::distributed::detail diff --git a/Source/Cxxmlx/include/mlx/distributed/jaccl/jaccl.h b/Source/Cxxmlx/include/mlx/distributed/jaccl/jaccl.h new file mode 100644 index 00000000..d07f9ccc --- /dev/null +++ b/Source/Cxxmlx/include/mlx/distributed/jaccl/jaccl.h @@ -0,0 +1,12 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/distributed/distributed.h" + +namespace mlx::core::distributed::jaccl { + +using GroupImpl = mlx::core::distributed::detail::GroupImpl; + +bool is_available(); +std::shared_ptr init(bool strict = false); + +} // namespace mlx::core::distributed::jaccl diff --git a/Source/Cxxmlx/include/mlx/distributed/jaccl/mesh.h b/Source/Cxxmlx/include/mlx/distributed/jaccl/mesh.h new file mode 100644 index 00000000..ed51361a --- /dev/null +++ b/Source/Cxxmlx/include/mlx/distributed/jaccl/mesh.h @@ -0,0 +1,89 @@ +// Copyright © 2026 Apple Inc. + +#pragma once + +#include "mlx/distributed/distributed_impl.h" +#include "mlx/distributed/jaccl/mesh_impl.h" +#include "mlx/distributed/jaccl/ring_impl.h" +#include "mlx/distributed/jaccl/utils.h" + +using GroupImpl = mlx::core::distributed::detail::GroupImpl; + +namespace mlx::core::distributed::jaccl { + +/** + * The JACCL communication group for a fully connected mesh. We expect one + * connection per peer and it should be the lowest latency communication group + * for small to medium size messages. + * + * Like all JACCL groups it uses a side channel to exchange the necessary + * information and then configure the connections to be ready for RDMA + * operations. + */ +class MeshGroup : public GroupImpl { + public: + MeshGroup( + int rank, + const std::vector& device_names, + const char* coordinator_addr); + + Stream communication_stream(StreamOrDevice s) override { + return to_stream(s, Device::cpu); + } + + int rank() override { + return rank_; + } + + int size() override { + return size_; + } + + void all_sum(const array& input, array& output, Stream stream) override; + void all_max(const array& input, array& output, Stream stream) override; + void all_min(const array& input, array& output, Stream stream) override; + void all_gather(const array& input, array& output, Stream stream) override; + void send(const array& input, int dst, Stream stream) override; + void recv(array& out, int src, Stream stream) override; + + void sum_scatter(const array& input, array& output, Stream stream) override { + throw std::runtime_error("[jaccl] sum_scatter not supported."); + } + + std::shared_ptr split(int color, int key = -1) override { + throw std::runtime_error("[jaccl] Group split not supported."); + } + + private: + template + void all_reduce( + const array& input, + array& output, + Stream stream, + ReduceOp reduce_op); + + /** + * Performs the connection initialization. Namely, after this call all + * Connection objects should have a queue pair in RTS state and all buffers + * should have been allocated. + */ + void initialize(); + + /** + * Allocate all the buffers that we will use in the communication group. + */ + void allocate_buffers(); + + int rank_; + int size_; + SideChannel side_channel_; + std::vector connections_; + std::vector buffers_; + std::vector ring_send_buffers_; + std::vector ring_recv_buffers_; + + MeshImpl mesh_; + RingImpl ring_; +}; + +} // namespace mlx::core::distributed::jaccl diff --git a/Source/Cxxmlx/include/mlx/distributed/jaccl/mesh_impl.h b/Source/Cxxmlx/include/mlx/distributed/jaccl/mesh_impl.h new file mode 100644 index 00000000..fc486a39 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/distributed/jaccl/mesh_impl.h @@ -0,0 +1,358 @@ +// Copyright © 2026 Apple Inc. + +#pragma once + +#include + +#include "mlx/distributed/jaccl/utils.h" + +constexpr int MESH_MAX_PEERS = 8; + +namespace mlx::core::distributed::jaccl { + +class MeshImpl { + public: + MeshImpl( + int rank, + int size, + std::vector& conns, + std::vector& buffers) + : rank_(rank), size_(size), connections_(conns), buffers_(buffers) {} + + MeshImpl() : rank_(0), size_(1) {} + + template + void + all_reduce(const T* in_ptr, T* out_ptr, int64_t size, ReduceOp reduce_op) { + // If not inplace all reduce then copy the input to the output first + if (in_ptr != out_ptr) { + std::memcpy(out_ptr, in_ptr, size * sizeof(T)); + } + + // Fully connected all reduce + T* data = out_ptr; + auto [sz, buffer_size] = buffer_size_from_message(size * sizeof(T)); + int64_t N = buffer_size / sizeof(T); + constexpr int PIPELINE = 2; + constexpr int WC_NUM = PIPELINE * MESH_MAX_PEERS * 2; + int64_t total = static_cast(size); + int num_peers = size_ - 1; + + // Counters to maintain the state of transfers + int in_flight = 0; + int64_t read_offset = 0; + int completed_send_count[PIPELINE] = {0}; + int completed_recv_begin[MESH_MAX_PEERS] = {0}; + int completed_recv_end[MESH_MAX_PEERS] = {0}; + + // Prefill the pipeline + int buff = 0; + while (read_offset < total && buff < PIPELINE) { + post_recv_all(sz, buff); + std::copy( + data + read_offset, + data + std::min(read_offset + N, total), + send_buffer(sz, buff).begin()); + post_send_all(sz, buff); + + buff++; + in_flight += 2 * num_peers; + read_offset += N; + } + + // Main loop + // + // Keep going until we have no longer data in flight. + while (in_flight > 0) { + // Poll the hardware for completions. + // + // If a send was completed mark how many completions we have received + // for that buffer. If we have sent the buffer to all peers we can + // reuse the buffer so copy the next chunk of data and send it to all. + // + // If a receive is completed then advance the pointer of completed + // receives. + ibv_wc wc[WC_NUM]; + int n = poll(connections_, WC_NUM, wc); + for (int i = 0; i < n; i++) { + int work_type = wc[i].wr_id >> 16; + int buff = (wc[i].wr_id >> 8) & 0xff; + int rank = wc[i].wr_id & 0xff; + + in_flight--; + + if (work_type == SEND_WR && read_offset < total) { + completed_send_count[buff]++; + if (completed_send_count[buff] == num_peers) { + std::copy( + data + read_offset, + data + std::min(read_offset + N, total), + send_buffer(sz, buff).begin()); + post_send_all(sz, buff); + + completed_send_count[buff] = 0; + in_flight += num_peers; + read_offset += N; + } + } + + else if (work_type == RECV_WR) { + completed_recv_end[rank]++; + } + } + + // Process the completed recv + // + // For each rank we have a range of completed recv defined by a begin + // and end inclusive and exlusive in standard C++ fashion. + // + // When there is an unprocessed receive we first check if we have + // finished sending the write location. If so then we reduce in-place + // and then check if there is more to be received and post a recv. + for (int r = 0; r < size_; r++) { + int s = completed_recv_begin[r]; + int e = completed_recv_end[r]; + int w = s * N; + while (w < read_offset && e - s > 0) { + int buff = s % PIPELINE; + reduce_op( + recv_buffer(sz, buff, r).begin(), + data + w, + std::min(N, total - w)); + w += N; + s++; + if (w + (PIPELINE - 1) * N < total) { + recv_from(sz, r, buff); + in_flight++; + } + } + completed_recv_begin[r] = s; + } + } + } + + void all_gather(const char* in_ptr, char* out_ptr, int64_t n_bytes) { + // Copy our data to the appropriate place + std::memcpy(out_ptr + rank_ * n_bytes, in_ptr, n_bytes); + + // Fully connected all gather + char* data = out_ptr; + char* our_data = out_ptr + rank_ * n_bytes; + auto [sz, N] = buffer_size_from_message(n_bytes); + constexpr int PIPELINE = 2; + constexpr int WC_NUM = PIPELINE * MESH_MAX_PEERS * 2; + int64_t total = static_cast(n_bytes); + int num_peers = size_ - 1; + + // Counters to maintain the state of transfers + int in_flight = 0; + int read_offset = 0; + int completed_send_count[PIPELINE] = {0}; + int write_offset[MESH_MAX_PEERS] = {0}; + + // Prefill the pipeline + int buff = 0; + while (read_offset < total && buff < PIPELINE) { + post_recv_all(sz, buff); + std::copy( + our_data + read_offset, + our_data + std::min(read_offset + N, total), + send_buffer(sz, buff).begin()); + post_send_all(sz, buff); + + buff++; + in_flight += 2 * num_peers; + read_offset += N; + } + + // Main loop + // + // Keep going until we have no longer data in flight. + while (in_flight > 0) { + ibv_wc wc[WC_NUM]; + int n = poll(connections_, WC_NUM, wc); + for (int i = 0; i < n; i++) { + int work_type = wc[i].wr_id >> 16; + int buff = (wc[i].wr_id >> 8) & 0xff; + int rank = wc[i].wr_id & 0xff; + + in_flight--; + + // Send completed. If all sends completed then send the next chunk. + if (work_type == SEND_WR && read_offset < total) { + completed_send_count[buff]++; + if (completed_send_count[buff] == num_peers) { + std::copy( + our_data + read_offset, + our_data + std::min(read_offset + N, total), + send_buffer(sz, buff).begin()); + post_send_all(sz, buff); + + completed_send_count[buff] = 0; + in_flight += num_peers; + read_offset += N; + } + } + + // Recv completed. If we have more chunks then post another recv. + else if (work_type == RECV_WR) { + std::copy( + recv_buffer(sz, buff, rank).begin(), + recv_buffer(sz, buff, rank).begin() + + std::min(N, total - write_offset[rank]), + data + rank * n_bytes + write_offset[rank]); + write_offset[rank] += N; + if (write_offset[rank] + N * (PIPELINE - 1) < total) { + recv_from(sz, rank, buff); + in_flight++; + } + } + } + } + } + + void send(const char* in_ptr, int64_t n_bytes, int dst) { + constexpr int PIPELINE = 2; + constexpr int WC_NUM = PIPELINE; + auto [sz, N] = buffer_size_from_message(n_bytes); + + int in_flight = 0; + int64_t read_offset = 0; + + // Prefill the pipeline + int buff = 0; + while (read_offset < n_bytes && buff < PIPELINE) { + std::copy( + in_ptr + read_offset, + in_ptr + std::min(read_offset + N, n_bytes), + send_buffer(sz, buff).begin()); + send_to(sz, dst, buff); + + buff++; + read_offset += N; + in_flight++; + } + + // Main loop + while (in_flight > 0) { + // Poll the hardware for completions. + // + // If a send was completed and we have more data to send then go ahead + // and send them. + ibv_wc wc[WC_NUM]; + int n = connections_[dst].poll(WC_NUM, wc); + for (int i = 0; i < n; i++) { + int buff = (wc[i].wr_id >> 8) & 0xff; + int rank = wc[i].wr_id & 0xff; + + in_flight--; + + if (read_offset < n_bytes) { + std::copy( + in_ptr + read_offset, + in_ptr + std::min(read_offset + N, n_bytes), + send_buffer(sz, buff).begin()); + send_to(sz, dst, buff); + + read_offset += N; + in_flight++; + } + } + } + } + + void recv(char* out_ptr, int64_t n_bytes, int src) { + constexpr int PIPELINE = 2; + constexpr int WC_NUM = PIPELINE; + auto [sz, N] = buffer_size_from_message(n_bytes); + + int in_flight = 0; + int64_t write_offset = 0; + + // Prefill the pipeline + int buff = 0; + while (N * buff < n_bytes && buff < PIPELINE) { + recv_from(sz, src, buff); + + in_flight++; + buff++; + } + + // Main loop + while (in_flight > 0) { + // Poll the hardware for completions. + // + // If a recv was completed copy it to the output and if we have more + // data to fetch post another recv. + ibv_wc wc[WC_NUM]; + int n = connections_[src].poll(WC_NUM, wc); + for (int i = 0; i < n; i++) { + int buff = (wc[i].wr_id >> 8) & 0xff; + int rank = wc[i].wr_id & 0xff; + + in_flight--; + + std::copy( + recv_buffer(sz, buff, src).begin(), + recv_buffer(sz, buff, src).begin() + + std::min(n_bytes - write_offset, static_cast(N)), + out_ptr + write_offset); + write_offset += N; + + if (write_offset + (PIPELINE - 1) * N < n_bytes) { + recv_from(sz, src, buff); + + in_flight++; + } + } + } + } + + private: + void send_to(int sz, int rank, int buff) { + connections_[rank].post_send( + send_buffer(sz, buff), SEND_WR << 16 | buff << 8 | rank); + } + + void recv_from(int sz, int rank, int buff) { + connections_[rank].post_recv( + recv_buffer(sz, buff, rank), RECV_WR << 16 | buff << 8 | rank); + } + + SharedBuffer& send_buffer(int sz, int buff) { + return buffers_[sz * NUM_BUFFERS * size_ + buff * size_ + rank_]; + } + + SharedBuffer& recv_buffer(int sz, int buff, int rank) { + return buffers_[sz * NUM_BUFFERS * size_ + buff * size_ + rank]; + } + + void post_send_all(int sz, int buff) { + auto& b = send_buffer(sz, buff); + int wr_id = SEND_WR << 16 | buff << 8; + for (int i = 0; i < size_; i++) { + if (i == rank_) { + continue; + } + connections_[i].post_send(b, wr_id | i); + } + } + + void post_recv_all(int sz, int buff) { + int b = sz * NUM_BUFFERS * size_ + buff * size_; + int wr_id = RECV_WR << 16 | buff << 8; + for (int i = 0; i < size_; i++) { + if (i == rank_) { + continue; + } + connections_[i].post_recv(buffers_[b + i], wr_id | i); + } + } + + int rank_; + int size_; + std::span connections_; + std::span buffers_; +}; + +} // namespace mlx::core::distributed::jaccl diff --git a/Source/Cxxmlx/include/mlx/distributed/jaccl/ring.h b/Source/Cxxmlx/include/mlx/distributed/jaccl/ring.h new file mode 100644 index 00000000..b3ce2f7b --- /dev/null +++ b/Source/Cxxmlx/include/mlx/distributed/jaccl/ring.h @@ -0,0 +1,89 @@ +// Copyright © 2026 Apple Inc. + +#pragma once + +#include "mlx/distributed/distributed_impl.h" +#include "mlx/distributed/jaccl/ring_impl.h" +#include "mlx/distributed/jaccl/utils.h" + +using GroupImpl = mlx::core::distributed::detail::GroupImpl; + +namespace mlx::core::distributed::jaccl { + +/** + * The JACCL communication group for a ring where each node is connected to its + * two neighboring nodes. It should be the highest bandwidth communication + * group for large messages when many connections per peer are used. + * + * Like all JACCL groups it uses a side channel to exchange the necessary + * information and then configure the connections to be ready for RDMA + * operations. + */ +class RingGroup : public GroupImpl { + public: + RingGroup( + int rank, + int size, + const std::vector& left_devices, + const std::vector& right_devices, + const char* coordinator_addr); + + Stream communication_stream(StreamOrDevice s) override { + return to_stream(s, Device::cpu); + } + + int rank() override { + return rank_; + } + + int size() override { + return size_; + } + + void all_sum(const array& input, array& output, Stream stream) override; + void all_max(const array& input, array& output, Stream stream) override; + void all_min(const array& input, array& output, Stream stream) override; + void all_gather(const array& input, array& output, Stream stream) override; + void send(const array& input, int dst, Stream stream) override; + void recv(array& out, int src, Stream stream) override; + + void sum_scatter(const array& input, array& output, Stream stream) override { + throw std::runtime_error("[jaccl] sum_scatter not supported."); + } + + std::shared_ptr split(int color, int key = -1) override { + throw std::runtime_error("[jaccl] Group split not supported."); + } + + private: + template + void all_reduce( + const array& input, + array& output, + Stream stream, + ReduceOp reduce_op); + + /** + * Performs the connection initialization. Namely, after this call all + * Connection objects should have a queue pair in RTS state and all buffers + * should have been allocated. + */ + void initialize(); + + /** + * Allocate all the buffers that we will use in the communication group. + */ + void allocate_buffers(); + + int rank_; + int size_; + int n_conns_; + SideChannel side_channel_; + std::vector left_; + std::vector right_; + std::vector send_buffers_; + std::vector recv_buffers_; + RingImpl ring_; +}; + +} // namespace mlx::core::distributed::jaccl diff --git a/Source/Cxxmlx/include/mlx/distributed/jaccl/ring_impl.h b/Source/Cxxmlx/include/mlx/distributed/jaccl/ring_impl.h new file mode 100644 index 00000000..ce883d1f --- /dev/null +++ b/Source/Cxxmlx/include/mlx/distributed/jaccl/ring_impl.h @@ -0,0 +1,631 @@ +// Copyright © 2026 Apple Inc. + +#pragma once + +#include + +#include "mlx/distributed/jaccl/utils.h" + +constexpr int RING_MAX_CONNS = 4; + +namespace mlx::core::distributed::jaccl { + +class RingImpl { + public: + RingImpl( + int rank, + int size, + std::vector& left, + std::vector& right, + std::vector& send_buffers, + std::vector& recv_buffers) + : rank_(rank), + size_(size), + n_conns_(left.size()), + left_(left), + right_(right), + send_buffers_(send_buffers), + recv_buffers_(recv_buffers) {} + + RingImpl( + int rank, + int size, + Connection* left_begin, + Connection* right_begin, + size_t n_conns, + std::vector& send_buffers, + std::vector& recv_buffers) + : rank_(rank), + size_(size), + n_conns_(n_conns), + left_(left_begin, n_conns), + right_(right_begin, n_conns), + send_buffers_(send_buffers), + recv_buffers_(recv_buffers) {} + + RingImpl() : rank_(0), size_(1), n_conns_(0) {} + + template + void all_reduce( + const T* in_ptr, + T* out_ptr, + int64_t size, + int n_wires, + ReduceOp reduce_op) { + // If not inplace all reduce then copy the input to the output first + if (in_ptr != out_ptr) { + std::memcpy(out_ptr, in_ptr, size * sizeof(T)); + } + + constexpr int PIPELINE = 2; + constexpr int WC_NUM = PIPELINE * RING_MAX_CONNS * 2 * MAX_DIR; + int64_t chunk_size = (size + size_ - 1) / size_; + int64_t size_per_wire = + (chunk_size + (MAX_DIR * n_wires) - 1) / (MAX_DIR * n_wires); + auto [sz, N] = buffer_size_from_message(size_per_wire * sizeof(T)); + N /= sizeof(T); + int64_t n_steps = (size_per_wire + N - 1) / N; + + // Counters to maintain the state of transfers + int in_flight = 0; + int64_t chunk_multiple_size = size_ * chunk_size; + int64_t send_offset[MAX_DIR]; + int64_t recv_offset[MAX_DIR]; + int64_t send_limits[MAX_DIR]; + int64_t recv_limits[MAX_DIR]; + int send_count[MAX_DIR * RING_MAX_CONNS] = {0}; + int recv_count[MAX_DIR * RING_MAX_CONNS] = {0}; + send_offset[0] = rank_ * chunk_size; + recv_offset[0] = ((rank_ + size_ - 1) % size_) * chunk_size; + if constexpr (MAX_DIR == 2) { + send_offset[1] = rank_ * chunk_size; + recv_offset[1] = ((rank_ + 1) % size_) * chunk_size; + send_limits[0] = std::min( + n_wires * size_per_wire, std::max(0, size - send_offset[0])); + send_limits[1] = + std::min(chunk_size, std::max(0, size - send_offset[1])); + recv_limits[0] = std::min( + n_wires * size_per_wire, std::max(0, size - recv_offset[0])); + recv_limits[1] = + std::min(chunk_size, std::max(0, size - recv_offset[1])); + } else { + send_limits[0] = + std::min(chunk_size, std::max(0, size - send_offset[0])); + recv_limits[0] = + std::min(chunk_size, std::max(0, size - recv_offset[0])); + } + + // First reduce scatter + // + // Possible perf improvement by not syncing at every step but running ahead + // as needed. + for (int k = 0; k < size_ - 1; k++) { + // Prefill the pipeline + int buff = 0; + while (buff < n_steps && buff < PIPELINE) { + post_recv_all(sz, buff, n_wires); + for (int lr = 0; lr < MAX_DIR; lr++) { + for (int lw = 0; lw < n_wires; lw++) { + int64_t offset = lw * N + + send_count[lr * RING_MAX_CONNS + lw] * n_wires * N + + lr * n_wires * size_per_wire; + std::copy( + out_ptr + send_offset[lr] + offset, + out_ptr + send_offset[lr] + + std::max(offset, std::min(offset + N, send_limits[lr])), + send_buffer(sz, buff, lr, lw).begin()); + send_count[lr * RING_MAX_CONNS + lw]++; + } + } + post_send_all(sz, buff, n_wires); + + buff++; + in_flight += 2 * MAX_DIR * n_wires; + } + + // Main loop + // + // Keep going until we have no longer data in flight. + while (in_flight > 0) { + ibv_wc wc[WC_NUM]; + int n = poll(left_, right_, WC_NUM, wc); + for (int i = 0; i < n; i++) { + int work_type = wc[i].wr_id >> 16; + int buff = (wc[i].wr_id >> 8) & 0xff; + int wire = wc[i].wr_id & 0xff; + int lr = wire / RING_MAX_CONNS; + int lw = wire % RING_MAX_CONNS; + + in_flight--; + + if (work_type == SEND_WR && send_count[wire] < n_steps) { + int64_t offset = lw * N + send_count[wire] * n_wires * N + + lr * n_wires * size_per_wire; + std::copy( + out_ptr + send_offset[lr] + offset, + out_ptr + send_offset[lr] + + std::max(offset, std::min(offset + N, send_limits[lr])), + send_buffer(sz, buff, lr, lw).begin()); + send_to(sz, buff, lr, lw); + in_flight++; + send_count[wire]++; + } + + else if (work_type == RECV_WR) { + int64_t offset = lw * N + recv_count[wire] * n_wires * N + + lr * n_wires * size_per_wire; + reduce_op( + recv_buffer(sz, buff, lr, lw).begin(), + out_ptr + recv_offset[lr] + offset, + std::max(0, std::min(N, recv_limits[lr] - offset))); + recv_count[wire]++; + if (recv_count[wire] + (PIPELINE - 1) < n_steps) { + recv_from(sz, buff, lr, lw); + in_flight++; + } + } + } + } + + send_offset[0] = (send_offset[0] + chunk_multiple_size - chunk_size) % + chunk_multiple_size; + recv_offset[0] = (recv_offset[0] + chunk_multiple_size - chunk_size) % + chunk_multiple_size; + if constexpr (MAX_DIR == 2) { + send_offset[1] = (send_offset[1] + chunk_size) % chunk_multiple_size; + recv_offset[1] = (recv_offset[1] + chunk_size) % chunk_multiple_size; + send_limits[0] = std::min( + n_wires * size_per_wire, + std::max(0, size - send_offset[0])); + send_limits[1] = + std::min(chunk_size, std::max(0, size - send_offset[1])); + recv_limits[0] = std::min( + n_wires * size_per_wire, + std::max(0, size - recv_offset[0])); + recv_limits[1] = + std::min(chunk_size, std::max(0, size - recv_offset[1])); + } else { + send_limits[0] = + std::min(chunk_size, std::max(0, size - send_offset[0])); + recv_limits[0] = + std::min(chunk_size, std::max(0, size - recv_offset[0])); + } + for (int i = 0; i < MAX_DIR * RING_MAX_CONNS; i++) { + send_count[i] = recv_count[i] = 0; + } + } + + // Secondly all gather + // + // The offsets are correct from the scatter reduce + for (int k = 0; k < size_ - 1; k++) { + // Prefill the pipeline + int buff = 0; + while (buff < n_steps && buff < PIPELINE) { + post_recv_all(sz, buff, n_wires); + for (int lr = 0; lr < MAX_DIR; lr++) { + for (int lw = 0; lw < n_wires; lw++) { + int64_t offset = lw * N + + send_count[lr * RING_MAX_CONNS + lw] * n_wires * N + + lr * n_wires * size_per_wire; + std::copy( + out_ptr + send_offset[lr] + offset, + out_ptr + send_offset[lr] + + std::max(offset, std::min(offset + N, send_limits[lr])), + send_buffer(sz, buff, lr, lw).begin()); + send_count[lr * RING_MAX_CONNS + lw]++; + } + } + post_send_all(sz, buff, n_wires); + + buff++; + in_flight += 2 * MAX_DIR * n_wires; + } + + // Main loop + // + // Keep going until we have no longer data in flight. + while (in_flight > 0) { + ibv_wc wc[WC_NUM]; + int n = poll(left_, right_, WC_NUM, wc); + for (int i = 0; i < n; i++) { + int work_type = wc[i].wr_id >> 16; + int buff = (wc[i].wr_id >> 8) & 0xff; + int wire = wc[i].wr_id & 0xff; + int lr = wire / RING_MAX_CONNS; + int lw = wire % RING_MAX_CONNS; + + in_flight--; + + if (work_type == SEND_WR && send_count[wire] < n_steps) { + int64_t offset = lw * N + send_count[wire] * n_wires * N + + lr * n_wires * size_per_wire; + std::copy( + out_ptr + send_offset[lr] + offset, + out_ptr + send_offset[lr] + + std::max(offset, std::min(offset + N, send_limits[lr])), + send_buffer(sz, buff, lr, lw).begin()); + send_to(sz, buff, lr, lw); + in_flight++; + send_count[wire]++; + } + + else if (work_type == RECV_WR) { + int64_t offset = lw * N + recv_count[wire] * n_wires * N + + lr * n_wires * size_per_wire; + std::copy( + recv_buffer(sz, buff, lr, lw).begin(), + recv_buffer(sz, buff, lr, lw).begin() + + std::max(0, std::min(N, recv_limits[lr] - offset)), + out_ptr + recv_offset[lr] + offset); + recv_count[wire]++; + if (recv_count[wire] + (PIPELINE - 1) < n_steps) { + recv_from(sz, buff, lr, lw); + in_flight++; + } + } + } + } + + send_offset[0] = (send_offset[0] + chunk_multiple_size - chunk_size) % + chunk_multiple_size; + recv_offset[0] = (recv_offset[0] + chunk_multiple_size - chunk_size) % + chunk_multiple_size; + if constexpr (MAX_DIR == 2) { + send_offset[1] = (send_offset[1] + chunk_size) % chunk_multiple_size; + recv_offset[1] = (recv_offset[1] + chunk_size) % chunk_multiple_size; + send_limits[0] = std::min( + n_wires * size_per_wire, + std::max(0, size - send_offset[0])); + send_limits[1] = + std::min(chunk_size, std::max(0, size - send_offset[1])); + recv_limits[0] = std::min( + n_wires * size_per_wire, + std::max(0, size - recv_offset[0])); + recv_limits[1] = + std::min(chunk_size, std::max(0, size - recv_offset[1])); + } else { + send_limits[0] = + std::min(chunk_size, std::max(0, size - send_offset[0])); + recv_limits[0] = + std::min(chunk_size, std::max(0, size - recv_offset[0])); + } + for (int i = 0; i < MAX_DIR * RING_MAX_CONNS; i++) { + send_count[i] = recv_count[i] = 0; + } + } + } + + void + all_gather(const char* in_ptr, char* out_ptr, int64_t n_bytes, int n_wires) { + // Copy our data to the appropriate place + std::memcpy(out_ptr + rank_ * n_bytes, in_ptr, n_bytes); + + constexpr int PIPELINE = 2; + constexpr int WC_NUM = PIPELINE * RING_MAX_CONNS * 2 * 2; + size_t n_bytes_per_wire = (n_bytes + (2 * n_wires) - 1) / (2 * n_wires); + size_t out_bytes = n_bytes * size_; + auto [sz, N] = buffer_size_from_message(n_bytes_per_wire); + int n_steps = (n_bytes_per_wire + N - 1) / N; + + // Counters to maintain the state of transfers + int in_flight = 0; + int64_t send_offset[2]; + int64_t recv_offset[2]; + int64_t limits[2]; + int send_count[2 * RING_MAX_CONNS] = {0}; + int recv_count[2 * RING_MAX_CONNS] = {0}; + send_offset[0] = send_offset[1] = rank_ * n_bytes; + recv_offset[0] = ((rank_ + size_ - 1) % size_) * n_bytes; + recv_offset[1] = ((rank_ + 1) % size_) * n_bytes; + limits[0] = n_wires * n_bytes_per_wire; + limits[1] = n_bytes; + + // Possible perf improvement by not syncing at every step but running ahead + // as needed. + for (int k = 0; k < size_ - 1; k++) { + // Prefill the pipeline + int buff = 0; + while (buff < n_steps && buff < PIPELINE) { + post_recv_all(sz, buff); + for (int lr = 0; lr < 2; lr++) { + for (int lw = 0; lw < n_wires; lw++) { + int64_t offset = lw * N + + send_count[lr * RING_MAX_CONNS + lw] * n_wires * N + + lr * n_wires * n_bytes_per_wire; + std::copy( + out_ptr + send_offset[lr] + offset, + out_ptr + send_offset[lr] + + std::max(offset, std::min(offset + N, limits[lr])), + send_buffer(sz, buff, lr, lw).begin()); + send_count[lr * RING_MAX_CONNS + lw]++; + } + } + post_send_all(sz, buff); + + buff++; + in_flight += 2 * 2 * n_wires; + } + + // Main loop + // + // Keep going until we have no longer data in flight. + while (in_flight > 0) { + ibv_wc wc[WC_NUM]; + int n = poll(left_, right_, WC_NUM, wc); + for (int i = 0; i < n; i++) { + int work_type = wc[i].wr_id >> 16; + int buff = (wc[i].wr_id >> 8) & 0xff; + int wire = wc[i].wr_id & 0xff; + int lr = wire / RING_MAX_CONNS; + int lw = wire % RING_MAX_CONNS; + + in_flight--; + + if (work_type == SEND_WR && send_count[wire] < n_steps) { + int64_t offset = lw * N + send_count[wire] * n_wires * N + + lr * n_wires * n_bytes_per_wire; + std::copy( + out_ptr + send_offset[lr] + offset, + out_ptr + send_offset[lr] + + std::max(offset, std::min(offset + N, limits[lr])), + send_buffer(sz, buff, lr, lw).begin()); + send_to(sz, buff, lr, lw); + in_flight++; + send_count[wire]++; + } + + else if (work_type == RECV_WR) { + int64_t offset = lw * N + recv_count[wire] * n_wires * N + + lr * n_wires * n_bytes_per_wire; + std::copy( + recv_buffer(sz, buff, lr, lw).begin(), + recv_buffer(sz, buff, lr, lw).begin() + + std::max(0, std::min(N, limits[lr] - offset)), + out_ptr + recv_offset[lr] + offset); + recv_count[wire]++; + if (recv_count[wire] + (PIPELINE - 1) < n_steps) { + recv_from(sz, buff, lr, lw); + in_flight++; + } + } + } + } + + send_offset[0] = (send_offset[0] + out_bytes - n_bytes) % out_bytes; + recv_offset[0] = (recv_offset[0] + out_bytes - n_bytes) % out_bytes; + send_offset[1] = (send_offset[1] + n_bytes) % out_bytes; + recv_offset[1] = (recv_offset[1] + n_bytes) % out_bytes; + for (int i = 0; i < 2 * RING_MAX_CONNS; i++) { + send_count[i] = recv_count[i] = 0; + } + } + } + + void send(const char* in_ptr, int64_t n_bytes, int dst, int n_wires) { + int left = (rank_ + size_ - 1) % size_; + + // In the case that size_ == 2 then left == right so we bias send towards + // left and recv towards right so that the selections will be correct for + // the 2 node case. + auto& conns = (dst == left) ? left_ : right_; + int dir = dst == left; + + constexpr int PIPELINE = 2; + constexpr int WC_NUM = PIPELINE * RING_MAX_CONNS; + + int64_t bytes_per_wire = (n_bytes + n_wires - 1) / n_wires; + auto [sz, N] = buffer_size_from_message(bytes_per_wire); + + int in_flight = 0; + int64_t read_offset[RING_MAX_CONNS]; + int64_t limits[RING_MAX_CONNS]; + for (int lw = 0; lw < n_wires; lw++) { + read_offset[lw] = std::min(lw * bytes_per_wire, n_bytes); + limits[lw] = std::min((lw + 1) * bytes_per_wire, n_bytes); + } + + // Prefill the pipeline + for (int lw = 0; lw < n_wires; lw++) { + int buff = 0; + while (read_offset[lw] < limits[lw] && buff < PIPELINE) { + std::copy( + in_ptr + read_offset[lw], + in_ptr + std::min(read_offset[lw] + N, limits[lw]), + send_buffer(sz, buff, dir, lw).begin()); + send_to(sz, buff, dir, lw); + + buff++; + read_offset[lw] += N; + in_flight++; + } + } + + // Main loop + while (in_flight > 0) { + // Poll the hardware for completions. + // + // If a send was completed and we have more data to send then go ahead + // and send them. + ibv_wc wc[WC_NUM]; + int n = poll(conns, WC_NUM, wc); + for (int i = 0; i < n; i++) { + int buff = (wc[i].wr_id >> 8) & 0xff; + int wire = wc[i].wr_id & 0xff; + int lw = wire % RING_MAX_CONNS; + + in_flight--; + + if (read_offset[lw] < limits[lw]) { + std::copy( + in_ptr + read_offset[lw], + in_ptr + std::min(read_offset[lw] + N, limits[lw]), + send_buffer(sz, buff, dir, lw).begin()); + send_to(sz, buff, dir, lw); + + read_offset[lw] += N; + in_flight++; + } + } + } + } + + void recv(char* out_ptr, int64_t n_bytes, int src, int n_wires) { + int right = (rank_ + 1) % size_; + + // In the case that size_ == 2 then left == right so we bias send towards + // left and recv towards right so that the selections will be correct for + // the 2 node case. + auto& conns = (src == right) ? right_ : left_; + int dir = src == right; + + constexpr int PIPELINE = 2; + constexpr int WC_NUM = PIPELINE * RING_MAX_CONNS; + + int64_t bytes_per_wire = (n_bytes + n_wires - 1) / n_wires; + auto [sz, N] = buffer_size_from_message(bytes_per_wire); + + int in_flight = 0; + int64_t write_offset[RING_MAX_CONNS]; + int64_t limits[RING_MAX_CONNS]; + for (int lw = 0; lw < n_wires; lw++) { + write_offset[lw] = std::min(lw * bytes_per_wire, n_bytes); + limits[lw] = std::min((lw + 1) * bytes_per_wire, n_bytes); + } + + // Prefill the pipeline + for (int lw = 0; lw < n_wires; lw++) { + int buff = 0; + while (N * buff < limits[lw] && buff < PIPELINE) { + recv_from(sz, buff, dir, lw); + + buff++; + in_flight++; + } + } + + // Main loop + while (in_flight > 0) { + // Poll the hardware for completions. + // + // If a recv was completed copy it to the output and if we have more + // data to fetch post another recv. + ibv_wc wc[WC_NUM]; + int n = poll(conns, WC_NUM, wc); + for (int i = 0; i < n; i++) { + int buff = (wc[i].wr_id >> 8) & 0xff; + int wire = wc[i].wr_id & 0xff; + int lw = wire % RING_MAX_CONNS; + + in_flight--; + + std::copy( + recv_buffer(sz, buff, dir, lw).begin(), + recv_buffer(sz, buff, dir, lw).begin() + + std::max( + 0, std::min(limits[lw] - write_offset[lw], N)), + out_ptr + write_offset[lw]); + write_offset[lw] += N; + + if (write_offset[lw] + (PIPELINE - 1) * N < limits[lw]) { + recv_from(sz, buff, dir, lw); + + in_flight++; + } + } + } + } + + private: + void send_to(int sz, int buff, int left_right, int wire) { + if (left_right) { + left_[wire].post_send( + send_buffer_left(sz, buff, wire), + SEND_WR << 16 | buff << 8 | (RING_MAX_CONNS + wire)); + } else { + right_[wire].post_send( + send_buffer_right(sz, buff, wire), SEND_WR << 16 | buff << 8 | wire); + } + } + + void recv_from(int sz, int buff, int left_right, int wire) { + if (left_right) { + right_[wire].post_recv( + recv_buffer_right(sz, buff, wire), + RECV_WR << 16 | buff << 8 | (RING_MAX_CONNS + wire)); + } else { + left_[wire].post_recv( + recv_buffer_left(sz, buff, wire), RECV_WR << 16 | buff << 8 | wire); + } + } + + SharedBuffer& send_buffer_right(int sz, int buff, int wire) { + return send_buffers_ + [sz * NUM_BUFFERS * n_conns_ * 2 + buff * n_conns_ * 2 + wire]; + } + + SharedBuffer& send_buffer_left(int sz, int buff, int wire) { + return send_buffers_ + [sz * NUM_BUFFERS * n_conns_ * 2 + buff * n_conns_ * 2 + n_conns_ + + wire]; + } + + SharedBuffer& send_buffer(int sz, int buff, int left_right, int wire) { + return send_buffers_ + [sz * NUM_BUFFERS * n_conns_ * 2 + buff * n_conns_ * 2 + + left_right * n_conns_ + wire]; + } + + SharedBuffer& recv_buffer_left(int sz, int buff, int wire) { + return recv_buffers_ + [sz * NUM_BUFFERS * n_conns_ * 2 + buff * n_conns_ * 2 + wire]; + } + + SharedBuffer& recv_buffer_right(int sz, int buff, int wire) { + return recv_buffers_ + [sz * NUM_BUFFERS * n_conns_ * 2 + buff * n_conns_ * 2 + n_conns_ + + wire]; + } + + SharedBuffer& recv_buffer(int sz, int buff, int left_right, int wire) { + return recv_buffers_ + [sz * NUM_BUFFERS * n_conns_ * 2 + buff * n_conns_ * 2 + + left_right * n_conns_ + wire]; + } + + template + void post_recv_all(int sz, int buff, int n_wires) { + for (int lr = 0; lr < MAX_DIR; lr++) { + for (int lw = 0; lw < n_wires; lw++) { + recv_from(sz, buff, lr, lw); + } + } + } + + void post_recv_all(int sz, int buff) { + post_recv_all<2>(sz, buff, n_conns_); + } + + template + void post_send_all(int sz, int buff, int n_wires) { + for (int lr = 0; lr < MAX_DIR; lr++) { + for (int lw = 0; lw < n_wires; lw++) { + send_to(sz, buff, lr, lw); + } + } + } + + void post_send_all(int sz, int buff) { + post_send_all<2>(sz, buff, n_conns_); + } + + int rank_; + int size_; + int n_conns_; + std::span left_; + std::span right_; + std::span send_buffers_; + std::span recv_buffers_; +}; + +} // namespace mlx::core::distributed::jaccl diff --git a/Source/Cxxmlx/include/mlx/distributed/jaccl/utils.h b/Source/Cxxmlx/include/mlx/distributed/jaccl/utils.h new file mode 100644 index 00000000..8faa7740 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/distributed/jaccl/utils.h @@ -0,0 +1,343 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include + +#include +#include +#include + +#include "mlx/distributed/utils.h" + +constexpr const char* IBV_TAG = "[jaccl]"; +constexpr int SEND_WR = 1; +constexpr int RECV_WR = 2; +constexpr int MAX_SEND_WR = 32; +constexpr int MAX_RECV_WR = 32; +constexpr int BUFFER_SIZES = 8; +constexpr int NUM_BUFFERS = 2; +constexpr int FRAME_SIZE = 4096; + +namespace detail = mlx::core::distributed::detail; + +namespace { + +template +struct is_container : std::false_type {}; + +template +struct is_container< + T, + std::void_t> + : std::true_type {}; + +inline std::pair buffer_size_from_message(int64_t msg) { + if (__builtin_available(macOS 26.3, iOS 26.3, tvOS 26.3, visionOS 26.3, *)) { + for (int k = BUFFER_SIZES - 1; k > 0; k--) { + if (msg >= FRAME_SIZE * (1 << k)) { + return {k, FRAME_SIZE * (1 << k)}; + } + } + } + return {0, FRAME_SIZE}; +} + +} // namespace + +namespace mlx::core::distributed::jaccl { + +/** + * Wrapper for the ibverbs API. + */ +struct IBVWrapper { + IBVWrapper(); + bool is_available() { + return librdma_handle_ != nullptr; + } + + // API + ibv_device** (*get_device_list)(int*); + const char* (*get_device_name)(ibv_device*); + ibv_context* (*open_device)(ibv_device*); + void (*free_device_list)(ibv_device**); + int (*close_device)(ibv_context*); + + ibv_pd* (*alloc_pd)(ibv_context*); + ibv_qp* (*create_qp)(ibv_pd*, ibv_qp_init_attr*); + ibv_cq* (*create_cq)(ibv_context*, int, void*, ibv_comp_channel*, int); + int (*destroy_cq)(ibv_cq*); + int (*destroy_qp)(ibv_qp*); + int (*dealloc_pd)(ibv_pd*); + + int (*query_port)(ibv_context*, uint8_t, ibv_port_attr*); + int (*query_gid)(ibv_context*, uint8_t, int, ibv_gid*); + int (*modify_qp)(ibv_qp*, ibv_qp_attr*, int); + ibv_mr* (*reg_mr)(ibv_pd*, void*, size_t, int); + int (*dereg_mr)(ibv_mr*); + + private: + void* librdma_handle_; +}; + +IBVWrapper& ibv(); + +/** + * Contains the information that defines a destination to a remote device. + * Basically we can compute our own destination and share it with remote hosts + * over the side channel. + */ +struct Destination { + int local_id; + int queue_pair_number; + int packet_sequence_number; + ibv_gid global_identifier; +}; + +/** + * A buffer that can be registered to a number of protection domains. + */ +class SharedBuffer { + public: + SharedBuffer(size_t num_bytes); + SharedBuffer(SharedBuffer&& b); + ~SharedBuffer(); + + SharedBuffer(const SharedBuffer&) = delete; + SharedBuffer& operator=(const SharedBuffer&) = delete; + + void register_to_protection_domain(ibv_pd* protection_domain); + + size_t size() const { + return num_bytes_; + } + + uint32_t local_key(ibv_pd* protection_domain) const { + return memory_regions_.at(protection_domain)->lkey; + } + + ibv_sge to_scatter_gather_entry(ibv_pd* protection_domain) const { + ibv_sge entry; + entry.addr = reinterpret_cast(data_); + entry.length = size(); + entry.lkey = local_key(protection_domain); + return entry; + } + + template + T* data() { + return static_cast(data_); + } + + template + T* begin() { + return static_cast(data_); + } + + template + T* end() { + return static_cast(data_) + size() / sizeof(T); + } + + private: + void* data_; + size_t num_bytes_; + std::unordered_map memory_regions_; +}; + +/** + * Manipulates an RDMA connection. Enables (among other things) + * + * - Creating a queue pair + * - Sending and receiving + * - Checking completion + */ +struct Connection { + ibv_context* ctx; + ibv_pd* protection_domain; + ibv_cq* completion_queue; + ibv_qp* queue_pair; + Destination src; // holds the local information + + Connection(ibv_context* ctx_); + Connection(Connection&& c); + + Connection(const Connection&) = delete; + Connection& operator=(Connection&) = delete; + + ~Connection(); + void allocate_protection_domain(); + void create_completion_queue(int num_entries); + void create_queue_pair(); + + const Destination& info(); + void queue_pair_init(); + void queue_pair_rtr(const Destination& dst); + void queue_pair_rts(); + + void post_send(const SharedBuffer& buff, uint64_t work_request_id) { + ibv_send_wr work_request, *bad_work_request; + + auto entry = buff.to_scatter_gather_entry(protection_domain); + work_request.wr_id = work_request_id; + work_request.sg_list = &entry; + work_request.num_sge = 1; + work_request.opcode = IBV_WR_SEND; + work_request.send_flags = IBV_SEND_SIGNALED; + work_request.next = nullptr; + + if (int status = + ibv_post_send(queue_pair, &work_request, &bad_work_request); + status != 0) { + std::ostringstream msg; + msg << "[jaccl] Send failed with error code " << status; + throw std::invalid_argument(msg.str()); + } + } + + void post_recv(const SharedBuffer& buff, uint64_t work_request_id) { + ibv_recv_wr work_request, *bad_work_request; + + auto entry = buff.to_scatter_gather_entry(protection_domain); + work_request.wr_id = work_request_id; + work_request.sg_list = &entry; + work_request.num_sge = 1; + work_request.next = nullptr; + + if (int status = + ibv_post_recv(queue_pair, &work_request, &bad_work_request); + status != 0) { + std::ostringstream msg; + msg << "[jaccl] Recv failed with error code " << status; + throw std::invalid_argument(msg.str()); + } + } + + int poll(int num_completions, ibv_wc* work_completions) { + return ibv_poll_cq(completion_queue, num_completions, work_completions); + } +}; + +std::vector create_connections( + const std::vector& device_names); + +inline int poll( + std::span connections, + int num_completions, + ibv_wc* work_completions) { + int completions = 0; + for (auto& c : connections) { + if (c.ctx == nullptr) { + continue; + } + if (completions >= num_completions) { + return completions; + } + + int n = ibv_poll_cq( + c.completion_queue, + num_completions - completions, + work_completions + completions); + + completions += n; + } + return completions; +} + +inline int poll( + std::span connections_1, + std::span connections_2, + int num_completions, + ibv_wc* work_completions) { + int completions = 0; + completions += poll(connections_1, num_completions, work_completions); + completions += poll( + connections_2, + num_completions - completions, + work_completions + completions); + return completions; +} + +/** + * Implement a TCP side channel to exchange information about the RDMA + * connections. + * + * Implements a simple all gather where every node sends to rank 0 and rank 0 + * broadcasts to every node. + */ +class SideChannel { + public: + SideChannel(int rank, int size, const char* addr); + SideChannel(SideChannel&& sc); + + SideChannel(const SideChannel&) = delete; + SideChannel& operator=(const SideChannel&) = delete; + + template + std::vector all_gather(const T& v) { + std::vector result(size_); + + // T is a container of stuff like std::vector or std::string + if constexpr (is_container::value) { + using U = typename T::value_type; + + // Share the lengths first and set the communication size to be the + // maximum length of the containers. + auto lengths = all_gather(v.size()); + auto max_len = *std::max_element(lengths.begin(), lengths.end()); + for (auto& s : result) { + s.resize(max_len); + } + + // All gather of length max_len + if (rank_ == 0) { + std::copy(v.begin(), v.end(), result[rank_].begin()); + for (int i = 1; i < size_; i++) { + sockets_[i - 1].recv(IBV_TAG, result[i].data(), sizeof(U) * max_len); + } + for (int i = 1; i < size_; i++) { + for (int j = 0; j < size_; j++) { + sockets_[i - 1].send( + IBV_TAG, result[j].data(), sizeof(U) * max_len); + } + } + } else { + std::copy(v.begin(), v.end(), result[rank_].begin()); + sockets_[0].send(IBV_TAG, result[rank_].data(), sizeof(U) * max_len); + for (int i = 0; i < size_; i++) { + sockets_[0].recv(IBV_TAG, result[i].data(), sizeof(U) * max_len); + } + } + + // Resize the outputs back to the original length + for (int i = 0; i < size_; i++) { + result[i].resize(lengths[i]); + } + } + + // T is a scalar + else { + if (rank_ == 0) { + result[rank_] = v; + for (int i = 1; i < size_; i++) { + sockets_[i - 1].recv(IBV_TAG, &result[i], sizeof(T)); + } + for (int i = 1; i < size_; i++) { + sockets_[i - 1].send(IBV_TAG, result.data(), size_ * sizeof(T)); + } + } else { + sockets_[0].send(IBV_TAG, &v, sizeof(T)); + sockets_[0].recv(IBV_TAG, result.data(), size_ * sizeof(T)); + } + } + + return result; + } + + private: + int rank_; + int size_; + std::vector sockets_; +}; + +} // namespace mlx::core::distributed::jaccl diff --git a/Source/Cxxmlx/include/mlx/distributed/mpi/mpi.h b/Source/Cxxmlx/include/mlx/distributed/mpi/mpi.h new file mode 100644 index 00000000..cd11a478 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/distributed/mpi/mpi.h @@ -0,0 +1,12 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/distributed/distributed.h" + +namespace mlx::core::distributed::mpi { + +using GroupImpl = mlx::core::distributed::detail::GroupImpl; + +bool is_available(); +std::shared_ptr init(bool strict = false); + +} // namespace mlx::core::distributed::mpi diff --git a/Source/Cxxmlx/include/mlx/distributed/mpi/mpi_declarations.h b/Source/Cxxmlx/include/mlx/distributed/mpi/mpi_declarations.h new file mode 100644 index 00000000..99c1a9cb --- /dev/null +++ b/Source/Cxxmlx/include/mlx/distributed/mpi/mpi_declarations.h @@ -0,0 +1,28 @@ +// Copyright © 2024 Apple Inc. + +// Constants + +#define MPI_SUCCESS 0 +#define MPI_ANY_SOURCE -1 +#define MPI_ANY_TAG -1 +#define MPI_IN_PLACE ((void*)1) +#define MPI_MAX_LIBRARY_VERSION_STRING 256 + +// Define all the types that we use so that we don't include which +// causes linker errors on some platforms. +// +// NOTE: We define everything for openmpi. + +typedef void* MPI_Comm; +typedef void* MPI_Datatype; +typedef void* MPI_Op; + +typedef void(MPI_User_function)(void*, void*, int*, MPI_Datatype*); + +typedef struct ompi_status_public_t { + int MPI_SOURCE; + int MPI_TAG; + int MPI_ERROR; + int _cancelled; + size_t _ucount; +} MPI_Status; diff --git a/Source/Cxxmlx/include/mlx/distributed/nccl/nccl.h b/Source/Cxxmlx/include/mlx/distributed/nccl/nccl.h new file mode 100644 index 00000000..5370d2da --- /dev/null +++ b/Source/Cxxmlx/include/mlx/distributed/nccl/nccl.h @@ -0,0 +1,12 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/distributed/distributed.h" + +namespace mlx::core::distributed::nccl { + +using GroupImpl = mlx::core::distributed::detail::GroupImpl; + +bool is_available(); +std::shared_ptr init(bool strict = false); + +} // namespace mlx::core::distributed::nccl diff --git a/Source/Cxxmlx/include/mlx/distributed/ops.h b/Source/Cxxmlx/include/mlx/distributed/ops.h new file mode 100644 index 00000000..e223c5be --- /dev/null +++ b/Source/Cxxmlx/include/mlx/distributed/ops.h @@ -0,0 +1,57 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include + +#include "mlx/api.h" +#include "mlx/distributed/distributed.h" +#include "mlx/utils.h" + +namespace mlx::core::distributed { + +MLX_API array all_sum( + const array& x, + std::optional group = std::nullopt, + StreamOrDevice s = {}); + +MLX_API array all_gather( + const array& x, + std::optional group = std::nullopt, + StreamOrDevice S = {}); + +MLX_API array send( + const array& x, + int dst, + std::optional group = std::nullopt, + StreamOrDevice s = {}); + +MLX_API array recv( + Shape shape, + Dtype dtype, + int src, + std::optional group = std::nullopt, + StreamOrDevice s = {}); + +MLX_API array recv_like( + const array& x, + int src, + std::optional group = std::nullopt, + StreamOrDevice s = {}); + +MLX_API array all_max( + const array& x, + std::optional group = std::nullopt, + StreamOrDevice s = {}); + +MLX_API array all_min( + const array& x, + std::optional group = std::nullopt, + StreamOrDevice s = {}); + +MLX_API array sum_scatter( + const array& x, + std::optional group = std::nullopt, + StreamOrDevice s = {}); + +} // namespace mlx::core::distributed diff --git a/Source/Cxxmlx/include/mlx/distributed/primitives.h b/Source/Cxxmlx/include/mlx/distributed/primitives.h new file mode 100644 index 00000000..18a0d65f --- /dev/null +++ b/Source/Cxxmlx/include/mlx/distributed/primitives.h @@ -0,0 +1,156 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/distributed/distributed.h" +#include "mlx/distributed/distributed_impl.h" +#include "mlx/primitives.h" + +namespace mlx::core::distributed { + +class DistPrimitive : public Primitive { + public: + DistPrimitive(Stream stream, Group group) + : Primitive(stream), group_(group) {} + + const Group& group() const { + return group_; + } + + private: + Group group_; +}; + +class AllReduce : public DistPrimitive { + public: + enum ReduceType { And, Or, Sum, Prod, Min, Max }; + + AllReduce(Stream stream, Group group, ReduceType reduce_type) + : DistPrimitive(stream, group), reduce_type_(reduce_type) {} + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + std::pair, std::vector> vmap( + const std::vector& inputs, + const std::vector& axes) override; + std::vector jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) override; + std::vector vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector& outputs) override; + + const char* name() const override { + switch (reduce_type_) { + case And: + return "And AllReduce"; + case Or: + return "Or AllReduce"; + case Sum: + return "Sum AllReduce"; + case Prod: + return "Prod AllReduce"; + case Min: + return "Min AllReduce"; + case Max: + return "Max AllReduce"; + } + return ""; + } + + private: + ReduceType reduce_type_; +}; + +class AllGather : public DistPrimitive { + public: + AllGather(Stream stream, Group group) : DistPrimitive(stream, group) {} + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + std::pair, std::vector> vmap( + const std::vector& inputs, + const std::vector& axes) override; + std::vector jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) override; + std::vector vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector& outputs) override; + + DEFINE_NAME(AllGather); +}; + +class Send : public DistPrimitive { + public: + Send(Stream stream, Group group, int dst) + : DistPrimitive(stream, group), dst_(dst) {} + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + std::pair, std::vector> vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_NAME(Send); + + private: + int dst_; +}; + +class Recv : public DistPrimitive { + public: + Recv(Stream stream, Group group, int src) + : DistPrimitive(stream, group), src_(src) {} + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + DEFINE_NAME(Recv); + + private: + int src_; +}; + +class ReduceScatter : public DistPrimitive { + public: + enum ReduceType { Sum, Min, Max }; + ReduceScatter(Stream stream, Group group, ReduceType reduce_type) + : DistPrimitive(stream, group), reduce_type_(reduce_type) {} + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + const char* name() const override { + switch (reduce_type_) { + case Sum: + return "Sum ReduceScatter"; + case Min: + return "Min ReduceScatter"; + case Max: + return "Max ReduceScatter"; + } + return ""; + } + + private: + ReduceType reduce_type_; +}; +} // namespace mlx::core::distributed diff --git a/Source/Cxxmlx/include/mlx/distributed/reduction_ops.h b/Source/Cxxmlx/include/mlx/distributed/reduction_ops.h new file mode 100644 index 00000000..02777be3 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/distributed/reduction_ops.h @@ -0,0 +1,38 @@ +// Copyright © 2025 Apple Inc. + +namespace mlx::core::distributed::detail { + +template +struct SumOp { + void operator()(const T* input, T* output, size_t N) const { + while (N-- > 0) { + *output += *input; + input++; + output++; + } + } +}; + +template +struct MaxOp { + void operator()(const T* input, T* output, size_t N) const { + while (N-- > 0) { + *output = std::max(*output, *input); + input++; + output++; + } + } +}; + +template +struct MinOp { + void operator()(const T* input, T* output, size_t N) const { + while (N-- > 0) { + *output = std::min(*output, *input); + input++; + output++; + } + } +}; + +} // namespace mlx::core::distributed::detail diff --git a/Source/Cxxmlx/include/mlx/distributed/ring/ring.h b/Source/Cxxmlx/include/mlx/distributed/ring/ring.h new file mode 100644 index 00000000..e0b3fd09 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/distributed/ring/ring.h @@ -0,0 +1,12 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/distributed/distributed.h" + +namespace mlx::core::distributed::ring { + +using GroupImpl = mlx::core::distributed::detail::GroupImpl; + +bool is_available(); +std::shared_ptr init(bool strict = false); + +} // namespace mlx::core::distributed::ring diff --git a/Source/Cxxmlx/include/mlx/distributed/utils.h b/Source/Cxxmlx/include/mlx/distributed/utils.h new file mode 100644 index 00000000..213dd59a --- /dev/null +++ b/Source/Cxxmlx/include/mlx/distributed/utils.h @@ -0,0 +1,67 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include +#include + +namespace mlx::core::distributed::detail { + +struct address_t { + sockaddr_storage addr; + socklen_t len; + + const sockaddr* get() const { + return (struct sockaddr*)&addr; + } +}; + +/** + * Parse a sockaddr from an ip and port provided as strings. + */ +address_t parse_address(const std::string& ip, const std::string& port); + +/** + * Parse a sockaddr provided as an : string. + */ +address_t parse_address(const std::string& ip_port); + +/** + * Small wrapper over a TCP socket to simplify initiating connections. + */ +class TCPSocket { + public: + TCPSocket(const char* tag); + TCPSocket(const TCPSocket&) = delete; + TCPSocket& operator=(const TCPSocket&) = delete; + TCPSocket(TCPSocket&& s); + TCPSocket& operator=(TCPSocket&&); + ~TCPSocket(); + + void listen(const char* tag, const address_t& addr); + TCPSocket accept(const char* tag); + + void send(const char* tag, const void* data, size_t len); + void recv(const char* tag, void* data, size_t len); + + int detach(); + + operator int() const { + return sock_; + } + + static TCPSocket connect( + const char* tag, + const address_t& addr, + int num_retries = 1, + int wait = 0, + std::function cb = nullptr); + + private: + TCPSocket(int sock); + + int sock_; +}; + +} // namespace mlx::core::distributed::detail diff --git a/Source/Cxxmlx/include/mlx/dtype.h b/Source/Cxxmlx/include/mlx/dtype.h new file mode 100644 index 00000000..744ca587 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/dtype.h @@ -0,0 +1,116 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include +#include + +#include "mlx/api.h" +#include "mlx/types/complex.h" +#include "mlx/types/half_types.h" + +namespace mlx::core { + +struct Dtype { + enum class Val { + bool_, + uint8, + uint16, + uint32, + uint64, + int8, + int16, + int32, + int64, + float16, + float32, + float64, + bfloat16, + complex64, + }; + + enum class Kind { + b, /* bool */ + u, /* unsigned int */ + i, /* signed int */ + f, /* float */ + c, /* complex */ + V, /* void - used for brain float */ + }; + + enum class Category { + complexfloating, + floating, + inexact, + signedinteger, + unsignedinteger, + integer, + number, + generic + }; + + constexpr explicit Dtype(Val val, uint8_t size) : val_(val), size_(size) {} + + constexpr operator Val() const { + return val_; + } + constexpr Val val() const { + return val_; + } + constexpr uint8_t size() const { + return size_; + } + + private: + Val val_; + uint8_t size_; +}; + +inline constexpr Dtype bool_{Dtype::Val::bool_, sizeof(bool)}; + +inline constexpr Dtype uint8{Dtype::Val::uint8, sizeof(uint8_t)}; +inline constexpr Dtype uint16{Dtype::Val::uint16, sizeof(uint16_t)}; +inline constexpr Dtype uint32{Dtype::Val::uint32, sizeof(uint32_t)}; +inline constexpr Dtype uint64{Dtype::Val::uint64, sizeof(uint64_t)}; + +inline constexpr Dtype int8{Dtype::Val::int8, sizeof(int8_t)}; +inline constexpr Dtype int16{Dtype::Val::int16, sizeof(int16_t)}; +inline constexpr Dtype int32{Dtype::Val::int32, sizeof(int32_t)}; +inline constexpr Dtype int64{Dtype::Val::int64, sizeof(int64_t)}; + +inline constexpr Dtype float16{Dtype::Val::float16, sizeof(uint16_t)}; +inline constexpr Dtype float32{Dtype::Val::float32, sizeof(float)}; +inline constexpr Dtype float64{Dtype::Val::float64, sizeof(double)}; +inline constexpr Dtype bfloat16{Dtype::Val::bfloat16, sizeof(uint16_t)}; +inline constexpr Dtype complex64{Dtype::Val::complex64, sizeof(complex64_t)}; + +inline constexpr Dtype::Category complexfloating = + Dtype::Category::complexfloating; +inline constexpr Dtype::Category floating = Dtype::Category::floating; +inline constexpr Dtype::Category inexact = Dtype::Category::inexact; +inline constexpr Dtype::Category signedinteger = Dtype::Category::signedinteger; +inline constexpr Dtype::Category unsignedinteger = + Dtype::Category::unsignedinteger; +inline constexpr Dtype::Category integer = Dtype::Category::integer; +inline constexpr Dtype::Category number = Dtype::Category::number; +inline constexpr Dtype::Category generic = Dtype::Category::generic; + +MLX_API bool issubdtype(const Dtype& a, const Dtype& b); +MLX_API bool issubdtype(const Dtype::Category& a, const Dtype& b); +MLX_API bool issubdtype(const Dtype& a, const Dtype::Category& b); +MLX_API bool issubdtype(const Dtype::Category& a, const Dtype::Category& b); + +MLX_API Dtype promote_types(const Dtype& t1, const Dtype& t2); + +inline uint8_t size_of(const Dtype& t) { + return t.size(); +} + +MLX_API Dtype::Kind kindof(const Dtype& t); + +template +struct MLX_API TypeToDtype { + operator Dtype(); +}; + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/dtype_utils.h b/Source/Cxxmlx/include/mlx/dtype_utils.h new file mode 100644 index 00000000..47c6ed66 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/dtype_utils.h @@ -0,0 +1,119 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include + +#include "mlx/dtype.h" +#include "mlx/utils.h" + +namespace mlx::core { + +// Return string representation of dtype. +const char* dtype_to_string(Dtype arg); + +#define MLX_INTERNAL_DTYPE_SWITCH_CASE(DTYPE, TYPE) \ + case DTYPE: \ + f(type_identity{}); \ + break + +#define MLX_INTERNAL_DTYPE_SWITCH_INTS() \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(int8, int8_t); \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(int16, int16_t); \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(int32, int32_t); \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(int64, int64_t); \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(uint8, uint8_t); \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(uint16, uint16_t); \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(uint32, uint32_t); \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(uint64, uint64_t) + +#define MLX_INTERNAL_DTYPE_SWITCH_FLOATS() \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(float16, float16_t); \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(bfloat16, bfloat16_t); \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(float32, float); \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(float64, double) + +// This already exists in C++20 but in C++20 we can also just use templated +// lambdas which will make this so much nicer. +template +struct type_identity { + using type = T; +}; + +#define MLX_GET_TYPE(x) typename decltype(x)::type +#define MLX_GET_VALUE(x) decltype(x)::value + +template +void dispatch_all_types(Dtype dt, F&& f) { + switch (dt) { + MLX_INTERNAL_DTYPE_SWITCH_CASE(bool_, bool); + MLX_INTERNAL_DTYPE_SWITCH_INTS(); + MLX_INTERNAL_DTYPE_SWITCH_FLOATS(); + MLX_INTERNAL_DTYPE_SWITCH_CASE(complex64, complex64_t); + } +} + +template +void dispatch_int_types(Dtype dt, std::string_view tag, F&& f) { + switch (dt) { + MLX_INTERNAL_DTYPE_SWITCH_INTS(); + default: + std::ostringstream msg; + msg << tag << " Only integer types supported but " << dt + << " was provided"; + throw std::invalid_argument(msg.str()); + } +} + +template +void dispatch_float_types(Dtype dt, std::string_view tag, F&& f) { + switch (dt) { + MLX_INTERNAL_DTYPE_SWITCH_FLOATS(); + default: + std::ostringstream msg; + msg << tag << " Only float types supported but " << dt << " was provided"; + throw std::invalid_argument(msg.str()); + } +} + +template +void dispatch_inexact_types(Dtype dt, std::string_view tag, F&& f) { + switch (dt) { + MLX_INTERNAL_DTYPE_SWITCH_FLOATS(); + MLX_INTERNAL_DTYPE_SWITCH_CASE(complex64, complex64_t); + default: + std::ostringstream msg; + msg << tag << " Only inexact (float/complex) types supported but " << dt + << " was provided"; + throw std::invalid_argument(msg.str()); + } +} + +template +void dispatch_int_float_types(Dtype dt, std::string_view tag, F&& f) { + switch (dt) { + MLX_INTERNAL_DTYPE_SWITCH_INTS(); + MLX_INTERNAL_DTYPE_SWITCH_FLOATS(); + default: + std::ostringstream msg; + msg << tag << " Only integer and float types supported but " << dt + << " was provided"; + throw std::invalid_argument(msg.str()); + } +} + +template +void dispatch_real_types(Dtype dt, std::string_view tag, F&& f) { + switch (dt) { + MLX_INTERNAL_DTYPE_SWITCH_CASE(bool_, bool); + MLX_INTERNAL_DTYPE_SWITCH_INTS(); + MLX_INTERNAL_DTYPE_SWITCH_FLOATS(); + default: + std::ostringstream msg; + msg << tag << " Only real numbers supported but " << dt + << " was provided"; + throw std::invalid_argument(msg.str()); + } +} + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/einsum.h b/Source/Cxxmlx/include/mlx/einsum.h new file mode 100644 index 00000000..05588f88 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/einsum.h @@ -0,0 +1,23 @@ +// Copyright © 2024 Apple Inc. +#pragma once + +#include +#include +#include + +#include "mlx/api.h" +#include "mlx/array.h" +#include "mlx/utils.h" + +namespace mlx::core { + +MLX_API std::pair>, std::string> einsum_path( + const std::string& subscripts, + const std::vector& operands); + +MLX_API array einsum( + const std::string& subscripts, + const std::vector& operands, + StreamOrDevice s = {}); + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/event.h b/Source/Cxxmlx/include/mlx/event.h new file mode 100644 index 00000000..66a6a75d --- /dev/null +++ b/Source/Cxxmlx/include/mlx/event.h @@ -0,0 +1,58 @@ +// Copyright © 2024 Apple Inc. +#pragma once + +#include +#include +#include + +#include "mlx/stream.h" + +namespace mlx::core { + +class Event { + public: + Event() {}; + explicit Event(Stream stream); + + // Wait for the event to be signaled at its current value + void wait(); + + // Wait in the given stream for the event to be signaled at its current value + void wait(Stream stream); + + // Signal the event at its current value in the given stream + void signal(Stream stream); + + // Check if the event has been signaled at its current value + bool is_signaled() const; + + // Check if the event is valid + bool valid() const { + return event_ != nullptr; + } + + uint64_t value() const { + return value_; + } + + void set_value(uint64_t v) { + value_ = v; + } + + const Stream& stream() const { + if (!valid()) { + throw std::runtime_error( + "[Event::stream] Cannot access stream on invalid event."); + } + return stream_; + } + + private: + // Default constructed stream should never be used + // since the event is not yet valid + Stream stream_{0, Device::cpu}; + std::shared_ptr event_{nullptr}; + uint64_t value_{0}; +}; + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/export.h b/Source/Cxxmlx/include/mlx/export.h new file mode 100644 index 00000000..5532f7c8 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/export.h @@ -0,0 +1,137 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include +#include +#include +#include +#include "mlx/api.h" +#include "mlx/array.h" + +namespace mlx::core { + +using Args = std::vector; +using Kwargs = std::unordered_map; + +// Possible types for a Primitive's state +using StateT = std::variant< + bool, + int, + size_t, + float, + double, + Dtype, + Shape, + Strides, + std::vector, + std::vector, + std::vector>, + std::vector>, + std::optional, + std::string>; + +using ExportCallbackInput = std::unordered_map< + std::string, + std::variant< + std::vector>, + std::vector>, + std::vector>, + std::vector, + std::string>>; +using ExportCallback = std::function; + +struct FunctionExporter; + +/** + * Make an exporter to save multiple traces of a given function to + * the same file. + */ +MLX_API FunctionExporter exporter( + const std::string& file, + const std::function(const Args&)>& fun, + bool shapeless = false); + +MLX_API FunctionExporter exporter( + const std::string& file, + const std::function(const Kwargs&)>& fun, + bool shapeless = false); + +MLX_API FunctionExporter exporter( + const std::string& path, + const std::function(const Args&, const Kwargs&)>& fun, + bool shapeless = false); + +/** + * Export a function to a file. + */ +MLX_API void export_function( + const std::string& file, + const std::function(const Args&)>& fun, + const Args& args, + bool shapeless = false); + +MLX_API void export_function( + const std::string& file, + const std::function(const Kwargs&)>& fun, + const Kwargs& kwargs, + bool shapeless = false); + +MLX_API void export_function( + const std::string& file, + const std::function(const Args&, const Kwargs&)>& fun, + const Args& args, + const Kwargs& kwargs, + bool shapeless = false); + +struct ImportedFunction; + +/** + * Import a function from a file. + */ +MLX_API ImportedFunction import_function(const std::string& file); + +/** + * Make an exporter to export multiple traces of a given function with the same + * callback. + */ +MLX_API FunctionExporter exporter( + const ExportCallback& callback, + const std::function(const Args&)>& fun, + bool shapeless = false); + +MLX_API FunctionExporter exporter( + const ExportCallback& callback, + const std::function(const Kwargs&)>& fun, + bool shapeless = false); + +MLX_API FunctionExporter exporter( + const ExportCallback& callback, + const std::function(const Args&, const Kwargs&)>& fun, + bool shapeless = false); + +/** + * Export a function with a callback. + */ +MLX_API void export_function( + const ExportCallback& callback, + const std::function(const Args&)>& fun, + const Args& args, + bool shapeless = false); + +MLX_API void export_function( + const ExportCallback& callback, + const std::function(const Kwargs&)>& fun, + const Kwargs& kwargs, + bool shapeless = false); + +MLX_API void export_function( + const ExportCallback& callback, + const std::function(const Args&, const Kwargs&)>& fun, + const Args& args, + const Kwargs& kwargs, + bool shapeless = false); + +} // namespace mlx::core + +#include "mlx/export_impl.h" diff --git a/Source/Cxxmlx/include/mlx/export_impl.h b/Source/Cxxmlx/include/mlx/export_impl.h new file mode 100644 index 00000000..467a5f0d --- /dev/null +++ b/Source/Cxxmlx/include/mlx/export_impl.h @@ -0,0 +1,99 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/api.h" +#include "mlx/io/load.h" + +#pragma once + +namespace mlx::core { + +struct FunctionTable; + +struct MLX_API FunctionExporter { + void operator()(const std::initializer_list& args) { + this->operator()(Args(args)); + } + void operator()(const Args& args); + void operator()(const Kwargs& kwargs); + void operator()(const Args& args, const Kwargs& kwargs); + + void close(); + + FunctionExporter(const FunctionExporter&) = delete; + FunctionExporter& operator=(const FunctionExporter&) = delete; + FunctionExporter(FunctionExporter&& other) = default; + + private: + friend MLX_API FunctionExporter exporter( + const std::string&, + const std::function(const Args&)>&, + bool shapeless); + + friend MLX_API FunctionExporter exporter( + const std::string&, + const std::function(const Kwargs&)>&, + bool shapeless); + + friend MLX_API FunctionExporter exporter( + const std::string&, + const std::function(const Args&, const Kwargs&)>&, + bool shapeless); + + friend MLX_API FunctionExporter exporter( + const ExportCallback&, + const std::function(const Args&)>&, + bool shapeless); + + friend MLX_API FunctionExporter exporter( + const ExportCallback&, + const std::function(const Kwargs&)>&, + bool shapeless); + + friend MLX_API FunctionExporter exporter( + const ExportCallback&, + const std::function(const Args&, const Kwargs&)>&, + bool shapeless); + + FunctionExporter( + const std::string& file, + std::function(const Args&, const Kwargs&)> fun, + bool shapeless); + + FunctionExporter( + const ExportCallback& callback, + std::function(const Args&, const Kwargs&)> fun, + bool shapeless); + + io::FileWriter os; + ExportCallback callback; + std::function(const Args&, const Kwargs& kwargs)> fun; + void export_function(const Args& args, const Kwargs& kwargs); + void export_with_callback( + const std::vector& inputs, + const std::vector& outputs, + const std::vector& tape, + const std::vector& kwarg_keys); + std::unordered_map constants; + int count{0}; + bool closed{false}; + std::shared_ptr ftable; +}; + +struct MLX_API ImportedFunction { + std::vector operator()( + const std::initializer_list& args) const { + return this->operator()(Args(args)); + } + std::vector operator()(const Args& args) const; + std::vector operator()(const Kwargs& kwargs) const; + std::vector operator()(const Args& args, const Kwargs& kwargs) const; + + private: + ImportedFunction(const std::string& file); + friend MLX_API ImportedFunction import_function(const std::string&); + ImportedFunction(); + + std::shared_ptr ftable; +}; + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/fast.h b/Source/Cxxmlx/include/mlx/fast.h new file mode 100644 index 00000000..1183aba8 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/fast.h @@ -0,0 +1,103 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include +#include + +#include "mlx/api.h" +#include "mlx/utils.h" + +namespace mlx::core::fast { + +MLX_API array rms_norm( + const array& x, + const std::optional& weight, + float eps, + StreamOrDevice s = {}); + +MLX_API array layer_norm( + const array& x, + const std::optional& weight, + const std::optional& bias, + float eps, + StreamOrDevice s = {}); + +MLX_API array rope( + const array& x, + int dims, + bool traditional, + std::optional base, + float scale, + int offset, + const std::optional& freqs = std::nullopt, + StreamOrDevice s = {}); + +MLX_API array rope( + const array& x, + int dims, + bool traditional, + std::optional base, + float scale, + const array& offset, + const std::optional& freqs = std::nullopt, + StreamOrDevice s = {}); + +/** Computes: O = softmax(Q @ K.T) @ V **/ +MLX_API array scaled_dot_product_attention( + const array& queries, + const array& keys, + const array& values, + const float scale, + const std::string& mask_mode = "", + std::optional mask_arr = {}, + const std::optional& sinks = {}, + StreamOrDevice s = {}); + +using TemplateArg = std::variant; +using ScalarArg = std::variant; + +using CustomKernelFunction = std::function( + const std::vector&, + const std::vector&, + const std::vector&, + std::tuple, + std::tuple, + std::vector>, + std::optional, + bool, + StreamOrDevice)>; + +MLX_API CustomKernelFunction metal_kernel( + const std::string& name, + const std::vector& input_names, + const std::vector& output_names, + const std::string& source, + const std::string& header = "", + bool ensure_row_contiguous = true, + bool atomic_outputs = false); + +MLX_API CustomKernelFunction cuda_kernel( + const std::string& name, + const std::vector& input_names, + const std::vector& output_names, + const std::string& source, + const std::string& header = "", + bool ensure_row_contiguous = true, + int shared_memory = 0); + +MLX_API std::vector precompiled_cuda_kernel( + const std::string& name, + const std::string& compiled_source, + const std::vector& inputs, + const std::vector& output_shapes, + const std::vector& output_dtypes, + const std::vector& scalars, + std::tuple grid, + std::tuple threadgroup, + int shared_memory = 0, + std::optional init_value = std::nullopt, + bool ensure_row_contiguous = false, + StreamOrDevice s = {}); + +} // namespace mlx::core::fast diff --git a/Source/Cxxmlx/include/mlx/fast_primitives.h b/Source/Cxxmlx/include/mlx/fast_primitives.h new file mode 100644 index 00000000..44348308 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/fast_primitives.h @@ -0,0 +1,427 @@ +// Copyright © 2024 Apple Inc. + +#include +#include + +#include "mlx/primitives.h" + +namespace mlx::core::fast { + +// Custom primitive accepts a fallback function which it uses for +// transformations. Transformations are virtual so that derived classes may +// override the default behavior. +class Custom : public Primitive { + public: + explicit Custom( + Stream stream, + std::function(std::vector)> fallback) + : Primitive(stream), fallback_(std::move(fallback)) {} + + virtual std::pair, std::vector> vmap( + const std::vector& inputs, + const std::vector& axes) override; + + virtual std::vector jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) override; + + virtual std::vector vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector& outputs) override; + + protected: + std::function(std::vector)> fallback_; +}; + +class RMSNorm : public Custom { + public: + RMSNorm( + Stream stream, + std::function(std::vector)> fallback, + float eps) + : Custom(stream, std::move(fallback)), eps_(eps) {} + + static bool use_fallback(Stream stream); + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override { + throw std::runtime_error("NYI"); + } + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + std::vector vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector& outputs) override; + + DEFINE_NAME(RMSNorm) + bool is_equivalent(const Primitive& other) const override; + DEFINE_INPUT_OUTPUT_SHAPE() + + auto state() const { + return std::make_pair(nullptr, eps_); + } + + private: + float eps_; +}; + +class RMSNormVJP : public Custom { + public: + RMSNormVJP( + Stream stream, + std::function(std::vector)> fallback, + float eps) + : Custom(stream, std::move(fallback)), eps_(eps) {} + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override { + throw std::runtime_error("NYI"); + } + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + DEFINE_NAME(RMSNormVJP) + bool is_equivalent(const Primitive& other) const override; + auto state() const { + return std::make_pair(nullptr, eps_); + } + + private: + float eps_; +}; + +class LayerNorm : public Custom { + public: + LayerNorm( + Stream stream, + std::function(std::vector)> fallback, + float eps) + : Custom(stream, std::move(fallback)), eps_(eps) {} + + static bool use_fallback(Stream s); + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override { + throw std::runtime_error("NYI"); + } + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + std::vector vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector& outputs) override; + + DEFINE_NAME(LayerNorm) + bool is_equivalent(const Primitive& other) const override; + DEFINE_INPUT_OUTPUT_SHAPE() + auto state() const { + return std::make_pair(nullptr, eps_); + } + + private: + float eps_; +}; + +class LayerNormVJP : public Custom { + public: + LayerNormVJP( + Stream stream, + std::function(std::vector)> fallback, + float eps) + : Custom(stream, std::move(fallback)), eps_(eps) {} + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override { + throw std::runtime_error("NYI"); + } + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + DEFINE_NAME(LayerNormVJP) + bool is_equivalent(const Primitive& other) const override; + auto state() const { + return std::make_pair(nullptr, eps_); + } + + private: + float eps_; +}; + +class RoPE : public Custom { + public: + RoPE( + Stream stream, + std::function(std::vector)> fallback, + int dims, + bool traditional, + float base, + float scale, + bool forward) + : Custom(stream, std::move(fallback)), + dims_(dims), + traditional_(traditional), + base_(base), + scale_(scale), + forward_(forward) {} + + static bool use_fallback(Stream s); + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override { + throw std::runtime_error("NYI"); + } + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + std::vector vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector& outputs) override; + + DEFINE_NAME(RoPE) + bool is_equivalent(const Primitive& other) const override; + DEFINE_INPUT_OUTPUT_SHAPE() + auto state() const { + return std::make_tuple( + nullptr, dims_, traditional_, base_, scale_, forward_); + } + + private: + int dims_; + bool traditional_; + float base_; + float scale_; + bool forward_; +}; + +class ScaledDotProductAttention : public Custom { + public: + ScaledDotProductAttention( + Stream stream, + std::function(std::vector)> fallback, + float scale, + bool do_causal, + bool has_sinks, + bool output_logsumexp) + : Custom(stream, std::move(fallback)), + scale_(scale), + do_causal_(do_causal), + has_sinks_(has_sinks), + output_logsumexp_(output_logsumexp) {} + + static bool use_fallback( + const array& q, + const array& k, + const array& v, + bool has_mask, + bool has_arr_mask, + bool do_causal, + bool is_training, + bool output_logsumexp, + Stream s); + static bool supports_bool_mask(); + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override { + throw std::runtime_error("NYI"); + } + + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + std::vector vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector& outputs) override; + + bool is_equivalent(const Primitive& other) const override; + + DEFINE_NAME(ScaledDotProductAttention); + DEFINE_INPUT_OUTPUT_SHAPE() + auto state() const { + return std::make_tuple( + nullptr, scale_, do_causal_, has_sinks_, output_logsumexp_); + } + + private: + float scale_; + bool do_causal_; + bool has_sinks_; + bool output_logsumexp_; +}; + +class ScaledDotProductAttentionVJP : public Custom { + public: + ScaledDotProductAttentionVJP( + Stream stream, + std::function(std::vector)> fallback, + float scale, + bool do_causal, + bool has_sinks) + : Custom(stream, std::move(fallback)), + scale_(scale), + do_causal_(do_causal), + has_sinks_(has_sinks) {} + + static bool use_fallback(const array& q, Stream s); + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override { + throw std::runtime_error("NYI"); + } + + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + DEFINE_NAME(ScaledDotProductAttentionVJP); + bool is_equivalent(const Primitive& other) const override; + auto state() const { + return std::make_tuple(nullptr, scale_, do_causal_, has_sinks_); + } + + private: + float scale_; + bool do_causal_; + bool has_sinks_; +}; + +class ConvertFP8 : public Primitive { + public: + explicit ConvertFP8(Stream stream, bool to_fp8) + : Primitive(stream), to_fp8_(to_fp8) {} + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + const char* name() const override { + if (to_fp8_) { + return "ToFP8"; + } else { + return "FromFP8"; + } + } + bool state() const { + return to_fp8_; + }; + + bool is_equivalent(const Primitive& other) const override; + DEFINE_INPUT_OUTPUT_SHAPE(); + + private: + bool to_fp8_; +}; + +class Quantize : public Custom { + public: + explicit Quantize( + Stream stream, + std::function(std::vector)> fallback, + int group_size, + int bits, + QuantizationMode mode, + bool dequantize) + : Custom(stream, std::move(fallback)), + group_size_(group_size), + bits_(bits), + mode_(mode), + dequantize_(dequantize) {} + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + DEFINE_NAME(Quantize); + + bool is_equivalent(const Primitive& other) const override; + std::vector output_shapes(const std::vector& inputs) override; + auto state() const { + return std::make_tuple(nullptr, group_size_, bits_, mode_, dequantize_); + } + + private: + int group_size_; + int bits_; + QuantizationMode mode_; + bool dequantize_; +}; + +using ScalarArg = std::variant; + +class CustomKernel : public Primitive { + public: + CustomKernel( + Stream stream, + std::string name, + std::string source, + std::tuple grid, + std::tuple threadgroup, + std::vector> shape_infos, + bool ensure_row_contiguous, + std::optional init_value, + std::vector scalar_arguments, + bool is_precompiled, + int shared_memory) + : Primitive(stream), + name_(std::move(name)), + source_(std::move(source)), + grid_(grid), + threadgroup_(threadgroup), + shape_infos_(std::move(shape_infos)), + ensure_row_contiguous_(ensure_row_contiguous), + init_value_(init_value), + scalar_arguments_(std::move(scalar_arguments)), + is_precompiled_(is_precompiled), + shared_memory_(shared_memory) {} + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override { + throw std::runtime_error("Custom kernels only run on GPU."); + } + + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + DEFINE_NAME(CustomKernel); + auto state() const { + return std::make_tuple( + name_, + source_, + grid_, + threadgroup_, + shape_infos_, + ensure_row_contiguous_, + init_value_, + scalar_arguments_, + is_precompiled_, + shared_memory_); + } + + private: + std::string name_; + std::string source_; + std::tuple grid_; + std::tuple threadgroup_; + std::vector> shape_infos_; + bool ensure_row_contiguous_; + std::optional init_value_; + std::vector scalar_arguments_; + bool is_precompiled_; + int shared_memory_; +}; + +} // namespace mlx::core::fast diff --git a/Source/Cxxmlx/include/mlx/fence.h b/Source/Cxxmlx/include/mlx/fence.h new file mode 100644 index 00000000..0ececdb6 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/fence.h @@ -0,0 +1,39 @@ +// Copyright © 2024 Apple Inc. + +#include + +#include "mlx/array.h" + +namespace mlx::core { + +/* A fence to be used for synchronizing work between streams. + * + * Calls to `wait` wait in the given stream until all previous calls to update + * are complete on their given stream. + * + * The array passed to `update` is computed and visible after the call to + * `wait` returns. The array passed to `wait` will not be read until all + * previous calls to `update` have completed. + * + * Note, calls to `update` should always be from the same thread or explicitly + * synchronized so that they occur in sequence. Calls to `wait` can be on any + * thread. + * + * For the Metal back-end the fence supports slow (default) and fast mode. + * Fast mode requires setting the environment variable + * `MLX_METAL_FAST_SYNCH=1`. Fast mode also requires Metal 3.2+ (macOS 15+, + * iOS 18+). + */ +class Fence { + public: + Fence() {}; + explicit Fence(Stream stream); + + void update(Stream stream, const array& x, bool cross_device); + void wait(Stream stream, const array& x); + + private: + std::shared_ptr fence_{nullptr}; +}; + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/fft.h b/Source/Cxxmlx/include/mlx/fft.h new file mode 100644 index 00000000..9abf2b18 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/fft.h @@ -0,0 +1,159 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include + +#include "array.h" +#include "device.h" +#include "mlx/api.h" +#include "utils.h" + +namespace mlx::core::fft { + +/** Compute the n-dimensional Fourier Transform. */ +MLX_API array fftn( + const array& a, + const Shape& n, + const std::vector& axes, + StreamOrDevice s = {}); +MLX_API array +fftn(const array& a, const std::vector& axes, StreamOrDevice s = {}); +MLX_API array fftn(const array& a, StreamOrDevice s = {}); + +/** Compute the n-dimensional inverse Fourier Transform. */ +MLX_API array ifftn( + const array& a, + const Shape& n, + const std::vector& axes, + StreamOrDevice s = {}); +MLX_API array +ifftn(const array& a, const std::vector& axes, StreamOrDevice s = {}); +MLX_API array ifftn(const array& a, StreamOrDevice s = {}); + +/** Compute the one-dimensional Fourier Transform. */ +inline array fft(const array& a, int n, int axis, StreamOrDevice s = {}) { + return fftn(a, {n}, {axis}, s); +} +inline array fft(const array& a, int axis = -1, StreamOrDevice s = {}) { + return fftn(a, {axis}, s); +} + +/** Compute the one-dimensional inverse Fourier Transform. */ +inline array ifft(const array& a, int n, int axis, StreamOrDevice s = {}) { + return ifftn(a, {n}, {axis}, s); +} +inline array ifft(const array& a, int axis = -1, StreamOrDevice s = {}) { + return ifftn(a, {axis}, s); +} + +/** Compute the two-dimensional Fourier Transform. */ +inline array fft2( + const array& a, + const Shape& n, + const std::vector& axes, + StreamOrDevice s = {}) { + return fftn(a, n, axes, s); +} +inline array fft2( + const array& a, + const std::vector& axes = {-2, -1}, + StreamOrDevice s = {}) { + return fftn(a, axes, s); +} + +/** Compute the two-dimensional inverse Fourier Transform. */ +inline array ifft2( + const array& a, + const Shape& n, + const std::vector& axes, + StreamOrDevice s = {}) { + return ifftn(a, n, axes, s); +} +inline array ifft2( + const array& a, + const std::vector& axes = {-2, -1}, + StreamOrDevice s = {}) { + return ifftn(a, axes, s); +} + +/** Compute the n-dimensional Fourier Transform on a real input. */ +MLX_API array rfftn( + const array& a, + const Shape& n, + const std::vector& axes, + StreamOrDevice s = {}); +MLX_API array +rfftn(const array& a, const std::vector& axes, StreamOrDevice s = {}); +MLX_API array rfftn(const array& a, StreamOrDevice s = {}); + +/** Compute the n-dimensional inverse of `rfftn`. */ +MLX_API array irfftn( + const array& a, + const Shape& n, + const std::vector& axes, + StreamOrDevice s = {}); +MLX_API array +irfftn(const array& a, const std::vector& axes, StreamOrDevice s = {}); +MLX_API array irfftn(const array& a, StreamOrDevice s = {}); + +/** Compute the one-dimensional Fourier Transform on a real input. */ +inline array rfft(const array& a, int n, int axis, StreamOrDevice s = {}) { + return rfftn(a, {n}, {axis}, s); +} +inline array rfft(const array& a, int axis = -1, StreamOrDevice s = {}) { + return rfftn(a, {axis}, s); +} +/** Compute the one-dimensional inverse of `rfft`. */ +inline array irfft(const array& a, int n, int axis, StreamOrDevice s = {}) { + return irfftn(a, {n}, {axis}, s); +} +inline array irfft(const array& a, int axis = -1, StreamOrDevice s = {}) { + return irfftn(a, {axis}, s); +} + +/** Compute the two-dimensional Fourier Transform on a real input. */ +inline array rfft2( + const array& a, + const Shape& n, + const std::vector& axes, + StreamOrDevice s = {}) { + return rfftn(a, n, axes, s); +} +inline array rfft2( + const array& a, + const std::vector& axes = {-2, -1}, + StreamOrDevice s = {}) { + return rfftn(a, axes, s); +} + +/** Compute the two-dimensional inverse of `rfft2`. */ +inline array irfft2( + const array& a, + const Shape& n, + const std::vector& axes, + StreamOrDevice s = {}) { + return irfftn(a, n, axes, s); +} +inline array irfft2( + const array& a, + const std::vector& axes = {-2, -1}, + StreamOrDevice s = {}) { + return irfftn(a, axes, s); +} +/** Shift the zero-frequency component to the center of the spectrum. */ +MLX_API array fftshift(const array& a, StreamOrDevice s = {}); + +/** Shift the zero-frequency component to the center of the spectrum along + * specified axes. */ +MLX_API array +fftshift(const array& a, const std::vector& axes, StreamOrDevice s = {}); + +/** The inverse of fftshift. */ +MLX_API array ifftshift(const array& a, StreamOrDevice s = {}); + +/** The inverse of fftshift along specified axes. */ +MLX_API array +ifftshift(const array& a, const std::vector& axes, StreamOrDevice s = {}); + +} // namespace mlx::core::fft diff --git a/Source/Cxxmlx/include/mlx/graph_utils.h b/Source/Cxxmlx/include/mlx/graph_utils.h new file mode 100644 index 00000000..54297c2c --- /dev/null +++ b/Source/Cxxmlx/include/mlx/graph_utils.h @@ -0,0 +1,67 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include + +#include "mlx/api.h" +#include "mlx/array.h" + +namespace mlx::core { + +struct MLX_API NodeNamer { + std::unordered_map names; + + const std::string& get_name(const array& x); + void set_name(const array& x, std::string n); +}; + +MLX_API void print_graph( + std::ostream& os, + NodeNamer namer, + const std::vector& outputs); + +inline void print_graph(std::ostream& os, const std::vector& outputs) { + print_graph(os, NodeNamer{}, outputs); +} + +template > +inline void print_graph(std::ostream& os, Arrays&&... outputs) { + print_graph( + os, NodeNamer{}, std::vector{std::forward(outputs)...}); +} + +template > +inline void +print_graph(std::ostream& os, NodeNamer namer, Arrays&&... outputs) { + print_graph( + os, + std::move(namer), + std::vector{std::forward(outputs)...}); +} + +MLX_API void export_to_dot( + std::ostream& os, + NodeNamer namer, + const std::vector& outputs); + +inline void export_to_dot(std::ostream& os, const std::vector& outputs) { + export_to_dot(os, NodeNamer{}, outputs); +} + +template > +inline void export_to_dot(std::ostream& os, Arrays&&... outputs) { + export_to_dot( + os, NodeNamer{}, std::vector{std::forward(outputs)...}); +} + +template > +inline void +export_to_dot(std::ostream& os, NodeNamer namer, Arrays&&... outputs) { + export_to_dot( + os, + std::move(namer), + std::vector{std::forward(outputs)...}); +} + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/io.h b/Source/Cxxmlx/include/mlx/io.h new file mode 100644 index 00000000..760f2985 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/io.h @@ -0,0 +1,61 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include +#include + +#include "mlx/api.h" +#include "mlx/array.h" +#include "mlx/io/load.h" +#include "mlx/stream.h" +#include "mlx/utils.h" + +namespace mlx::core { +using GGUFMetaData = + std::variant>; +using GGUFLoad = std::pair< + std::unordered_map, + std::unordered_map>; +using SafetensorsLoad = std::pair< + std::unordered_map, + std::unordered_map>; + +/** Save array to out stream in .npy format */ +MLX_API void save(std::shared_ptr out_stream, array a); + +/** Save array to file in .npy format */ +MLX_API void save(std::string file, array a); + +/** Load array from reader in .npy format */ +MLX_API array +load(std::shared_ptr in_stream, StreamOrDevice s = {}); + +/** Load array from file in .npy format */ +MLX_API array load(std::string file, StreamOrDevice s = {}); + +/** Load array map from .safetensors file format */ +MLX_API SafetensorsLoad +load_safetensors(std::shared_ptr in_stream, StreamOrDevice s = {}); +MLX_API SafetensorsLoad +load_safetensors(const std::string& file, StreamOrDevice s = {}); + +MLX_API void save_safetensors( + std::shared_ptr in_stream, + std::unordered_map, + std::unordered_map metadata = {}); +MLX_API void save_safetensors( + std::string file, + std::unordered_map, + std::unordered_map metadata = {}); + +/** Load array map and metadata from .gguf file format */ + +MLX_API GGUFLoad load_gguf(const std::string& file, StreamOrDevice s = {}); + +MLX_API void save_gguf( + std::string file, + std::unordered_map array_map, + std::unordered_map meta_data = {}); + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/io/gguf.h b/Source/Cxxmlx/include/mlx/io/gguf.h new file mode 100644 index 00000000..fa5bc458 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/io/gguf.h @@ -0,0 +1,20 @@ +// Copyright © 2023-2024 Apple Inc. +#pragma once + +#include "mlx/io.h" +#include "mlx/primitives.h" +#include "mlx/transforms.h" +#include "mlx/utils.h" + +extern "C" { +#include +} + +namespace mlx::core { + +Shape get_shape(const gguf_tensor& tensor); +void gguf_load_quantized( + std::unordered_map& a, + const gguf_tensor& tensor); + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/io/load.h b/Source/Cxxmlx/include/mlx/io/load.h new file mode 100644 index 00000000..0efcb367 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/io/load.h @@ -0,0 +1,175 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include +#include + +#include +#ifdef _MSC_VER +#include +#else +#include +#include +#endif + +#include "mlx/threadpool.h" + +// Strictly we need to operate on files in binary mode (to avoid \r getting +// automatically inserted), but every modern system except for Windows no +// longer differentiates between binary and text files and for them define +// the flag as no-op. +#ifndef O_BINARY +#define O_BINARY 0 +#endif + +namespace mlx::core { + +namespace io { + +ThreadPool& thread_pool(); + +class Reader { + public: + virtual bool is_open() const = 0; + virtual bool good() const = 0; + virtual size_t tell() = 0; // tellp is non-const in iostream + virtual void seek( + int64_t off, + std::ios_base::seekdir way = std::ios_base::beg) = 0; + virtual void read(char* data, size_t n) = 0; + virtual void read(char* data, size_t n, size_t offset) = 0; + virtual std::string label() const = 0; + virtual ~Reader() = default; +}; + +class Writer { + public: + virtual bool is_open() const = 0; + virtual bool good() const = 0; + virtual size_t tell() = 0; + virtual void seek( + int64_t off, + std::ios_base::seekdir way = std::ios_base::beg) = 0; + virtual void write(const char* data, size_t n) = 0; + virtual std::string label() const = 0; + virtual ~Writer() = default; +}; + +class ParallelFileReader : public Reader { + public: + explicit ParallelFileReader(std::string file_path) + : fd_(open(file_path.c_str(), O_RDONLY | O_BINARY)), + label_(std::move(file_path)) {} + + ~ParallelFileReader() override { + close(fd_); + } + + bool is_open() const override { + return fd_ > 0; + } + + bool good() const override { + return is_open(); + } + + size_t tell() override { + return lseek(fd_, 0, SEEK_CUR); + } + + // Warning: do not use this function from multiple threads as + // it advances the file descriptor + void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg) + override { + if (way == std::ios_base::beg) { + lseek(fd_, off, 0); + } else { + lseek(fd_, off, SEEK_CUR); + } + } + + // Warning: do not use this function from multiple threads as + // it advances the file descriptor + void read(char* data, size_t n) override; + + void read(char* data, size_t n, size_t offset) override; + + std::string label() const override { + return "file " + label_; + } + + private: + static constexpr size_t batch_size_ = 1 << 25; + static ThreadPool& thread_pool(); + int fd_; + std::string label_; +}; + +class FileWriter : public Writer { + public: + explicit FileWriter() {} + explicit FileWriter(std::string file_path) + : fd_(open( + file_path.c_str(), + O_CREAT | O_WRONLY | O_TRUNC | O_BINARY, + 0644)), + label_(std::move(file_path)) {} + + FileWriter(const FileWriter&) = delete; + FileWriter& operator=(const FileWriter&) = delete; + FileWriter(FileWriter&& other) { + std::swap(fd_, other.fd_); + } + + ~FileWriter() override { + if (fd_ != 0) { + close(fd_); + } + } + + bool is_open() const override { + return fd_ >= 0; + } + + bool good() const override { + return is_open(); + } + + size_t tell() override { + return lseek(fd_, 0, SEEK_CUR); + } + + void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg) + override { + if (way == std::ios_base::beg) { + lseek(fd_, off, 0); + } else { + lseek(fd_, off, SEEK_CUR); + } + } + + void write(const char* data, size_t n) override { + while (n != 0) { + auto m = ::write(fd_, data, std::min(n, static_cast(INT32_MAX))); + if (m <= 0) { + std::ostringstream msg; + msg << "[write] Unable to write " << n << " bytes to file."; + throw std::runtime_error(msg.str()); + } + data += m; + n -= m; + } + } + + std::string label() const override { + return "file " + label_; + } + + private: + int fd_{0}; + std::string label_; +}; + +} // namespace io +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/linalg.h b/Source/Cxxmlx/include/mlx/linalg.h new file mode 100644 index 00000000..fe3f83c2 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/linalg.h @@ -0,0 +1,115 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include + +#include "mlx/api.h" +#include "mlx/array.h" +#include "mlx/device.h" +#include "mlx/ops.h" +#include "mlx/stream.h" + +namespace mlx::core::linalg { + +/** + * Compute vector or matrix norms. + * + * - If axis and ord are both unspecified, computes the 2-norm of flatten(x). + * - If axis is not provided but ord is, then x must be either 1D or 2D. + * - If axis is provided, but ord is not, then the 2-norm (or Frobenius norm + * for matrices) is computed along the given axes. At most 2 axes can be + * specified. + * - If both axis and ord are provided, then the corresponding matrix or vector + * norm is computed. At most 2 axes can be specified. + */ +MLX_API array norm( + const array& a, + const double ord, + const std::optional>& axis = std::nullopt, + bool keepdims = false, + StreamOrDevice s = {}); +inline array norm( + const array& a, + const double ord, + int axis, + bool keepdims = false, + StreamOrDevice s = {}) { + return norm(a, ord, std::vector{axis}, keepdims, s); +} +MLX_API array norm( + const array& a, + const std::string& ord, + const std::optional>& axis = std::nullopt, + bool keepdims = false, + StreamOrDevice s = {}); +inline array norm( + const array& a, + const std::string& ord, + int axis, + bool keepdims = false, + StreamOrDevice s = {}) { + return norm(a, ord, std::vector{axis}, keepdims, s); +} +MLX_API array norm( + const array& a, + const std::optional>& axis = std::nullopt, + bool keepdims = false, + StreamOrDevice s = {}); +inline array +norm(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {}) { + return norm(a, std::vector{axis}, keepdims, s); +} + +MLX_API std::pair qr(const array& a, StreamOrDevice s = {}); + +MLX_API std::vector +svd(const array& a, bool compute_uv, StreamOrDevice s /* = {} */); +inline std::vector svd(const array& a, StreamOrDevice s = {}) { + return svd(a, true, s); +} + +MLX_API array inv(const array& a, StreamOrDevice s = {}); + +MLX_API array +tri_inv(const array& a, bool upper = false, StreamOrDevice s = {}); + +MLX_API array +cholesky(const array& a, bool upper = false, StreamOrDevice s = {}); + +MLX_API array pinv(const array& a, StreamOrDevice s = {}); + +MLX_API array +cholesky_inv(const array& a, bool upper = false, StreamOrDevice s = {}); + +MLX_API std::vector lu(const array& a, StreamOrDevice s = {}); + +MLX_API std::pair lu_factor( + const array& a, + StreamOrDevice s = {}); + +MLX_API array solve(const array& a, const array& b, StreamOrDevice s = {}); + +MLX_API array solve_triangular( + const array& a, + const array& b, + bool upper = false, + StreamOrDevice s = {}); + +/** + * Compute the cross product of two arrays along the given axis. + */ +MLX_API array +cross(const array& a, const array& b, int axis = -1, StreamOrDevice s = {}); + +MLX_API std::pair eig(const array& a, StreamOrDevice s = {}); + +MLX_API array eigvals(const array& a, StreamOrDevice s = {}); + +MLX_API array +eigvalsh(const array& a, std::string UPLO = "L", StreamOrDevice s = {}); + +MLX_API std::pair +eigh(const array& a, std::string UPLO = "L", StreamOrDevice s = {}); + +} // namespace mlx::core::linalg diff --git a/Source/Cxxmlx/include/mlx/memory.h b/Source/Cxxmlx/include/mlx/memory.h new file mode 100644 index 00000000..f4eabc99 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/memory.h @@ -0,0 +1,80 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include + +#include "mlx/api.h" + +namespace mlx::core { + +/* Get the actively used memory in bytes. + * + * Note, this will not always match memory use reported by the system because + * it does not include cached memory buffers. + * */ +MLX_API size_t get_active_memory(); + +/* Get the peak amount of used memory in bytes. + * + * The maximum memory used recorded from the beginning of the program + * execution or since the last call to reset_peak_memory. + * */ +MLX_API size_t get_peak_memory(); + +/* Reset the peak memory to zero. + * */ +MLX_API void reset_peak_memory(); + +/* Get the cache size in bytes. + * + * The cache includes memory not currently used that has not been returned + * to the system allocator. + * */ +MLX_API size_t get_cache_memory(); + +/* Set the memory limit. + * The memory limit is a guideline for the maximum amount of memory to use + * during graph evaluation. If the memory limit is exceeded and there is no + * more RAM (including swap when available) allocations will result in an + * exception. + * + * When Metal is available the memory limit defaults to 1.5 times the maximum + * recommended working set size reported by the device. + * + * Returns the previous memory limit. + * */ +MLX_API size_t set_memory_limit(size_t limit); + +/* Get the current memory limit. */ +MLX_API size_t get_memory_limit(); + +/* Set the cache limit. + * If using more than the given limit, free memory will be reclaimed + * from the cache on the next allocation. To disable the cache, + * set the limit to 0. + * + * The cache limit defaults to the memory limit. + * + * Returns the previous cache limit. + * */ +MLX_API size_t set_cache_limit(size_t limit); + +/* Clear the memory cache. */ +MLX_API void clear_cache(); + +/* Set the wired size limit. + * + * Note, this function is only useful when using the Metal backend with + * macOS 15.0 or higher. + * + * The wired limit is the total size in bytes of memory that will be kept + * resident. The default value is ``0``. + * + * Setting a wired limit larger than system wired limit is an error. + * + * Returns the previous wired limit. + * */ +MLX_API size_t set_wired_limit(size_t limit); + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/mlx.h b/Source/Cxxmlx/include/mlx/mlx.h new file mode 100644 index 00000000..eda7333d --- /dev/null +++ b/Source/Cxxmlx/include/mlx/mlx.h @@ -0,0 +1,25 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include "mlx/array.h" +#include "mlx/backend/cuda/cuda.h" +#include "mlx/backend/gpu/device_info.h" +#include "mlx/backend/metal/metal.h" +#include "mlx/compile.h" +#include "mlx/device.h" +#include "mlx/distributed/distributed.h" +#include "mlx/distributed/ops.h" +#include "mlx/einsum.h" +#include "mlx/export.h" +#include "mlx/fast.h" +#include "mlx/fft.h" +#include "mlx/io.h" +#include "mlx/linalg.h" +#include "mlx/memory.h" +#include "mlx/ops.h" +#include "mlx/random.h" +#include "mlx/stream.h" +#include "mlx/transforms.h" +#include "mlx/utils.h" +#include "mlx/version.h" diff --git a/Source/Cxxmlx/include/mlx/ops.h b/Source/Cxxmlx/include/mlx/ops.h new file mode 100644 index 00000000..74032c01 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/ops.h @@ -0,0 +1,1626 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include + +#include "mlx/api.h" +#include "mlx/array.h" +#include "mlx/device.h" +#include "mlx/stream.h" +#include "mlx/utils.h" + +namespace mlx::core { + +/** + * \defgroup ops Core array operations + * @{ + */ + +/** + * A 1D array of numbers starting at `start` (optional), + * stopping at stop, stepping by `step` (optional). */ +MLX_API array arange( + double start, + double stop, + double step, + Dtype dtype, + StreamOrDevice s = {}); +MLX_API array +arange(double start, double stop, double step, StreamOrDevice s = {}); +MLX_API array +arange(double start, double stop, Dtype dtype, StreamOrDevice s = {}); +MLX_API array arange(double start, double stop, StreamOrDevice s = {}); +MLX_API array arange(double stop, Dtype dtype, StreamOrDevice s = {}); +MLX_API array arange(double stop, StreamOrDevice s = {}); + +MLX_API array arange(int start, int stop, int step, StreamOrDevice s = {}); +MLX_API array arange(int start, int stop, StreamOrDevice s = {}); +MLX_API array arange(int stop, StreamOrDevice s = {}); + +/** A 1D array of `num` evenly spaced numbers in the range `[start, stop]` */ +MLX_API array linspace( + double start, + double stop, + int num = 50, + Dtype dtype = float32, + StreamOrDevice s = {}); + +/** Convert an array to the given data type. */ +MLX_API array astype(array a, Dtype dtype, StreamOrDevice s = {}); + +/** Create a view of an array with the given shape and strides. */ +MLX_API array as_strided( + array a, + Shape shape, + Strides strides, + size_t offset, + StreamOrDevice s = {}); + +/** Copy another array. */ +MLX_API array copy(array a, StreamOrDevice s = {}); + +/** Fill an array of the given shape with the given value(s). */ +MLX_API array full(Shape shape, array vals, Dtype dtype, StreamOrDevice s = {}); +MLX_API array full(Shape shape, array vals, StreamOrDevice s = {}); +template +array full(Shape shape, T val, Dtype dtype, StreamOrDevice s = {}) { + return full(std::move(shape), array(val, dtype), to_stream(s)); +} +template +array full(Shape shape, T val, StreamOrDevice s = {}) { + return full(std::move(shape), array(val), to_stream(s)); +} + +MLX_API array +full_like(const array& a, array vals, Dtype dtype, StreamOrDevice s = {}); +MLX_API array full_like(const array& a, array vals, StreamOrDevice s = {}); +template +array full_like(const array& a, T val, Dtype dtype, StreamOrDevice s = {}) { + return full_like(a, array(val, dtype), dtype, to_stream(s)); +} +template +array full_like(const array& a, T val, StreamOrDevice s = {}) { + return full_like(a, array(val, a.dtype()), to_stream(s)); +} + +/** Fill an array of the given shape with zeros. */ +MLX_API array zeros(const Shape& shape, Dtype dtype, StreamOrDevice s = {}); +inline array zeros(const Shape& shape, StreamOrDevice s = {}) { + return zeros(shape, float32, s); +} +MLX_API array zeros_like(const array& a, StreamOrDevice s = {}); + +/** Fill an array of the given shape with ones. */ +MLX_API array ones(const Shape& shape, Dtype dtype, StreamOrDevice s = {}); +inline array ones(const Shape& shape, StreamOrDevice s = {}) { + return ones(shape, float32, s); +} +MLX_API array ones_like(const array& a, StreamOrDevice s = {}); + +/** Fill an array of the given shape (n,m) with ones in the specified diagonal + * k, and zeros everywhere else. */ +MLX_API array eye(int n, int m, int k, Dtype dtype, StreamOrDevice s = {}); +inline array eye(int n, Dtype dtype, StreamOrDevice s = {}) { + return eye(n, n, 0, dtype, s); +} +inline array eye(int n, int m, StreamOrDevice s = {}) { + return eye(n, m, 0, float32, s); +} +inline array eye(int n, int m, int k, StreamOrDevice s = {}) { + return eye(n, m, k, float32, s); +} +inline array eye(int n, StreamOrDevice s = {}) { + return eye(n, n, 0, float32, s); +} + +/** Create a square matrix of shape (n,n) of zeros, and ones in the major + * diagonal. */ +MLX_API array identity(int n, Dtype dtype, StreamOrDevice s = {}); +inline array identity(int n, StreamOrDevice s = {}) { + return identity(n, float32, s); +} + +MLX_API array tri(int n, int m, int k, Dtype type, StreamOrDevice s = {}); +inline array tri(int n, Dtype type, StreamOrDevice s = {}) { + return tri(n, n, 0, type, s); +} + +MLX_API array tril(array x, int k = 0, StreamOrDevice s = {}); +MLX_API array triu(array x, int k = 0, StreamOrDevice s = {}); + +/** Reshape an array to the given shape. */ +MLX_API array reshape(const array& a, Shape shape, StreamOrDevice s = {}); + +/** Unflatten the axis to the given shape. */ +MLX_API array +unflatten(const array& a, int axis, Shape shape, StreamOrDevice s = {}); + +/** Flatten the dimensions in the range `[start_axis, end_axis]` . */ +MLX_API array flatten( + const array& a, + int start_axis, + int end_axis = -1, + StreamOrDevice s = {}); + +/** Flatten the array to 1D. */ +MLX_API array flatten(const array& a, StreamOrDevice s = {}); + +/** Multiply the array by the Hadamard matrix of corresponding size. */ +MLX_API array hadamard_transform( + const array& a, + std::optional scale = std::nullopt, + StreamOrDevice s = {}); + +/** Remove singleton dimensions at the given axes. */ +MLX_API array +squeeze(const array& a, const std::vector& axes, StreamOrDevice s = {}); + +/** Remove singleton dimensions at the given axis. */ +MLX_API array squeeze(const array& a, int axis, StreamOrDevice s = {}); + +/** Remove all singleton dimensions. */ +MLX_API array squeeze(const array& a, StreamOrDevice s = {}); + +/** Add a singleton dimension at the given axes. */ +MLX_API array expand_dims( + const array& a, + const std::vector& axes, + StreamOrDevice s = {}); + +/** Add a singleton dimension at the given axis. */ +MLX_API array expand_dims(const array& a, int axis, StreamOrDevice s = {}); + +/** Slice an array. */ +MLX_API array slice( + const array& a, + Shape start, + Shape stop, + Shape strides, + StreamOrDevice s = {}); +inline array slice( + const array& a, + std::initializer_list start, + Shape stop, + Shape strides, + StreamOrDevice s = {}) { + return slice(a, Shape(start), std::move(stop), std::move(strides), s); +} + +/** Slice an array with a stride of 1 in each dimension. */ +MLX_API array +slice(const array& a, Shape start, Shape stop, StreamOrDevice s = {}); + +/** Slice an array with dynamic starting indices. */ +MLX_API array slice( + const array& a, + const array& start, + std::vector axes, + Shape slice_size, + StreamOrDevice s = {}); + +/** Update a slice from the source array. */ +MLX_API array slice_update( + const array& src, + const array& update, + Shape start, + Shape stop, + Shape strides, + StreamOrDevice s = {}); + +/** Update a slice from the source array with stride 1 in each dimension. */ +MLX_API array slice_update( + const array& src, + const array& update, + Shape start, + Shape stop, + StreamOrDevice s = {}); + +/** Update a slice from the source array with dynamic starting indices. */ +MLX_API array slice_update( + const array& src, + const array& update, + const array& start, + std::vector axes, + StreamOrDevice s = {}); + +/** Split an array into sub-arrays along a given axis. */ +MLX_API std::vector +split(const array& a, int num_splits, int axis, StreamOrDevice s = {}); +MLX_API std::vector +split(const array& a, int num_splits, StreamOrDevice s = {}); +MLX_API std::vector +split(const array& a, const Shape& indices, int axis, StreamOrDevice s = {}); +MLX_API std::vector +split(const array& a, const Shape& indices, StreamOrDevice s = {}); + +/** A vector of coordinate arrays from coordinate vectors. */ +MLX_API std::vector meshgrid( + const std::vector& arrays, + bool sparse = false, + const std::string& indexing = "xy", + StreamOrDevice s = {}); + +/** + * Clip (limit) the values in an array. + */ +MLX_API array clip( + const array& a, + const std::optional& a_min = std::nullopt, + const std::optional& a_max = std::nullopt, + StreamOrDevice s = {}); + +/** Concatenate arrays along a given axis. */ +MLX_API array +concatenate(std::vector arrays, int axis, StreamOrDevice s = {}); +MLX_API array concatenate(std::vector arrays, StreamOrDevice s = {}); + +/** Stack arrays along a new axis. */ +MLX_API array +stack(const std::vector& arrays, int axis, StreamOrDevice s = {}); +MLX_API array stack(const std::vector& arrays, StreamOrDevice s = {}); + +/** Repeat an array along an axis. */ +MLX_API array +repeat(const array& arr, int repeats, int axis, StreamOrDevice s = {}); +MLX_API array repeat(const array& arr, int repeats, StreamOrDevice s = {}); + +MLX_API array +tile(const array& arr, std::vector reps, StreamOrDevice s = {}); + +/** Permutes the dimensions according to the given axes. */ +MLX_API array +transpose(const array& a, std::vector axes, StreamOrDevice s = {}); +inline array transpose( + const array& a, + std::initializer_list axes, + StreamOrDevice s = {}) { + return transpose(a, std::vector(axes), s); +} + +/** Swap two axes of an array. */ +MLX_API array +swapaxes(const array& a, int axis1, int axis2, StreamOrDevice s = {}); + +/** Move an axis of an array. */ +MLX_API array +moveaxis(const array& a, int source, int destination, StreamOrDevice s = {}); + +/** Pad an array with a constant value */ +MLX_API array +pad(const array& a, + const std::vector& axes, + const Shape& low_pad_size, + const Shape& high_pad_size, + const array& pad_value = array(0), + const std::string& mode = "constant", + StreamOrDevice s = {}); + +/** Pad an array with a constant value along all axes */ +MLX_API array +pad(const array& a, + const std::vector>& pad_width, + const array& pad_value = array(0), + const std::string& mode = "constant", + StreamOrDevice s = {}); +MLX_API array +pad(const array& a, + const std::pair& pad_width, + const array& pad_value = array(0), + const std::string& mode = "constant", + StreamOrDevice s = {}); +MLX_API array +pad(const array& a, + int pad_width, + const array& pad_value = array(0), + const std::string& mode = "constant", + StreamOrDevice s = {}); + +/** Permutes the dimensions in reverse order. */ +MLX_API array transpose(const array& a, StreamOrDevice s = {}); + +/** Broadcast an array to a given shape. */ +MLX_API array +broadcast_to(const array& a, const Shape& shape, StreamOrDevice s = {}); + +/** Broadcast a vector of arrays against one another. */ +MLX_API std::vector broadcast_arrays( + const std::vector& inputs, + StreamOrDevice s = {}); + +/** Returns the bool array with (a == b) element-wise. */ +MLX_API array equal(const array& a, const array& b, StreamOrDevice s = {}); +inline array operator==(const array& a, const array& b) { + return equal(a, b); +} +template +array operator==(T a, const array& b) { + return equal(array(a), b); +} +template +array operator==(const array& a, T b) { + return equal(a, array(b)); +} + +/** Returns the bool array with (a != b) element-wise. */ +MLX_API array not_equal(const array& a, const array& b, StreamOrDevice s = {}); +inline array operator!=(const array& a, const array& b) { + return not_equal(a, b); +} +template +array operator!=(T a, const array& b) { + return not_equal(array(a), b); +} +template +array operator!=(const array& a, T b) { + return not_equal(a, array(b)); +} + +/** Returns bool array with (a > b) element-wise. */ +MLX_API array greater(const array& a, const array& b, StreamOrDevice s = {}); +inline array operator>(const array& a, const array& b) { + return greater(a, b); +} +template +array operator>(T a, const array& b) { + return greater(array(a), b); +} +template +array operator>(const array& a, T b) { + return greater(a, array(b)); +} + +/** Returns bool array with (a >= b) element-wise. */ +MLX_API array +greater_equal(const array& a, const array& b, StreamOrDevice s = {}); +inline array operator>=(const array& a, const array& b) { + return greater_equal(a, b); +} +template +array operator>=(T a, const array& b) { + return greater_equal(array(a), b); +} +template +array operator>=(const array& a, T b) { + return greater_equal(a, array(b)); +} + +/** Returns bool array with (a < b) element-wise. */ +MLX_API array less(const array& a, const array& b, StreamOrDevice s = {}); +inline array operator<(const array& a, const array& b) { + return less(a, b); +} +template +array operator<(T a, const array& b) { + return less(array(a), b); +} +template +array operator<(const array& a, T b) { + return less(a, array(b)); +} + +/** Returns bool array with (a <= b) element-wise. */ +MLX_API array less_equal(const array& a, const array& b, StreamOrDevice s = {}); +inline array operator<=(const array& a, const array& b) { + return less_equal(a, b); +} +template +array operator<=(T a, const array& b) { + return less_equal(array(a), b); +} +template +array operator<=(const array& a, T b) { + return less_equal(a, array(b)); +} + +/** True if two arrays have the same shape and elements. */ +MLX_API array array_equal( + const array& a, + const array& b, + bool equal_nan, + StreamOrDevice s = {}); +inline array +array_equal(const array& a, const array& b, StreamOrDevice s = {}) { + return array_equal(a, b, false, s); +} + +MLX_API array isnan(const array& a, StreamOrDevice s = {}); + +MLX_API array isinf(const array& a, StreamOrDevice s = {}); + +MLX_API array isfinite(const array& a, StreamOrDevice s = {}); + +MLX_API array isposinf(const array& a, StreamOrDevice s = {}); + +MLX_API array isneginf(const array& a, StreamOrDevice s = {}); + +/** Select from x or y depending on condition. */ +MLX_API array where( + const array& condition, + const array& x, + const array& y, + StreamOrDevice s = {}); + +/** Replace NaN and infinities with finite numbers. */ +MLX_API array nan_to_num( + const array& a, + float nan = 0.0f, + const std::optional posinf = std::nullopt, + const std::optional neginf = std::nullopt, + StreamOrDevice s = {}); + +/** True if all elements in the array are true (or non-zero). **/ +MLX_API array all(const array& a, bool keepdims, StreamOrDevice s = {}); +inline array all(const array& a, StreamOrDevice s = {}) { + return all(a, false, to_stream(s)); +} + +/** True if the two arrays are equal within the specified tolerance. */ +MLX_API array allclose( + const array& a, + const array& b, + double rtol = 1e-5, + double atol = 1e-8, + bool equal_nan = false, + StreamOrDevice s = {}); + +/** Returns a boolean array where two arrays are element-wise equal within the + * specified tolerance. */ +MLX_API array isclose( + const array& a, + const array& b, + double rtol = 1e-5, + double atol = 1e-8, + bool equal_nan = false, + StreamOrDevice s = {}); + +/** + * Reduces the input along the given axes. An output value is true + * if all the corresponding inputs are true. + **/ +MLX_API array +all(const array& a, + const std::vector& axes, + bool keepdims = false, + StreamOrDevice s = {}); + +/** + * Reduces the input along the given axis. An output value is true + * if all the corresponding inputs are true. + **/ +MLX_API array +all(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {}); + +/** True if any elements in the array are true (or non-zero). **/ +MLX_API array any(const array& a, bool keepdims, StreamOrDevice s = {}); +inline array any(const array& a, StreamOrDevice s = {}) { + return any(a, false, to_stream(s)); +} + +/** + * Reduces the input along the given axes. An output value is true + * if any of the corresponding inputs are true. + **/ +MLX_API array +any(const array& a, + const std::vector& axes, + bool keepdims = false, + StreamOrDevice s = {}); + +/** + * Reduces the input along the given axis. An output value is true + * if any of the corresponding inputs are true. + **/ +MLX_API array +any(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {}); + +/** Sums the elements of an array. */ +MLX_API array sum(const array& a, bool keepdims, StreamOrDevice s = {}); +inline array sum(const array& a, StreamOrDevice s = {}) { + return sum(a, false, to_stream(s)); +} + +/** Sums the elements of an array along the given axes. */ +MLX_API array +sum(const array& a, + const std::vector& axes, + bool keepdims = false, + StreamOrDevice s = {}); + +/** Sums the elements of an array along the given axis. */ +MLX_API array +sum(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {}); + +/** Computes the mean of the elements of an array. */ +MLX_API array mean(const array& a, bool keepdims, StreamOrDevice s = {}); +inline array mean(const array& a, StreamOrDevice s = {}) { + return mean(a, false, to_stream(s)); +} + +/** Computes the mean of the elements of an array along the given axes */ +MLX_API array mean( + const array& a, + const std::vector& axes, + bool keepdims = false, + StreamOrDevice s = {}); + +/** Computes the mean of the elements of an array along the given axis */ +MLX_API array +mean(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {}); + +/** Computes the median of the elements of an array. */ +MLX_API array median(const array& a, bool keepdims, StreamOrDevice s = {}); +inline array median(const array& a, StreamOrDevice s = {}) { + return median(a, false, to_stream(s)); +} + +/** Computes the median of the elements of an array along the given axes */ +MLX_API array median( + const array& a, + const std::vector& axes, + bool keepdims = false, + StreamOrDevice s = {}); + +/** Computes the median of the elements of an array along the given axis */ +MLX_API array +median(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {}); + +/** Computes the variance of the elements of an array. */ +MLX_API array +var(const array& a, bool keepdims, int ddof = 0, StreamOrDevice s = {}); +inline array var(const array& a, StreamOrDevice s = {}) { + return var(a, false, 0, to_stream(s)); +} + +/** Computes the variance of the elements of an array along the given + * axes */ +MLX_API array +var(const array& a, + const std::vector& axes, + bool keepdims = false, + int ddof = 0, + StreamOrDevice s = {}); + +/** Computes the variance of the elements of an array along the given + * axis */ +MLX_API array +var(const array& a, + int axis, + bool keepdims = false, + int ddof = 0, + StreamOrDevice s = {}); + +/** Computes the standard deviation of the elements of an array. */ +MLX_API array +std(const array& a, bool keepdims, int ddof = 0, StreamOrDevice s = {}); +inline array std(const array& a, StreamOrDevice s = {}) { + return std(a, false, 0, to_stream(s)); +} + +/** Computes the standard deviation of the elements of an array along the given + * axes */ +MLX_API array +std(const array& a, + const std::vector& axes, + bool keepdims = false, + int ddof = 0, + StreamOrDevice s = {}); + +/** Computes the standard deviation of the elements of an array along the given + * axis */ +MLX_API array +std(const array& a, + int axis, + bool keepdims = false, + int ddof = 0, + StreamOrDevice s = {}); + +/** The product of all elements of the array. */ +MLX_API array prod(const array& a, bool keepdims, StreamOrDevice s = {}); +inline array prod(const array& a, StreamOrDevice s = {}) { + return prod(a, false, to_stream(s)); +} + +/** The product of the elements of an array along the given axes. */ +MLX_API array prod( + const array& a, + const std::vector& axes, + bool keepdims = false, + StreamOrDevice s = {}); + +/** The product of the elements of an array along the given axis. */ +MLX_API array +prod(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {}); + +/** The maximum of all elements of the array. */ +MLX_API array max(const array& a, bool keepdims, StreamOrDevice s = {}); +inline array max(const array& a, StreamOrDevice s = {}) { + return max(a, false, to_stream(s)); +} + +/** The maximum of the elements of an array along the given axes. */ +MLX_API array +max(const array& a, + const std::vector& axes, + bool keepdims = false, + StreamOrDevice s = {}); + +/** The maximum of the elements of an array along the given axis. */ +MLX_API array +max(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {}); + +/** The minimum of all elements of the array. */ +MLX_API array min(const array& a, bool keepdims, StreamOrDevice s = {}); +inline array min(const array& a, StreamOrDevice s = {}) { + return min(a, false, to_stream(s)); +} + +/** The minimum of the elements of an array along the given axes. */ +MLX_API array +min(const array& a, + const std::vector& axes, + bool keepdims = false, + StreamOrDevice s = {}); + +/** The minimum of the elements of an array along the given axis. */ +MLX_API array +min(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {}); + +/** Returns the Hanning window of size M. */ +MLX_API array hanning(int M, StreamOrDevice s = {}); + +/** Returns the Hamming window of size M. */ +MLX_API array hamming(int M, StreamOrDevice s = {}); + +/** Returns the bartlett window of size M. */ +MLX_API array bartlett(int M, StreamOrDevice s = {}); + +/** Returns the Blackmann window of size M. */ +MLX_API array blackman(int M, StreamOrDevice s = {}); + +/** Returns the index of the minimum value in the array. */ +MLX_API array argmin(const array& a, bool keepdims, StreamOrDevice s = {}); +inline array argmin(const array& a, StreamOrDevice s = {}) { + return argmin(a, false, s); +} + +/** Returns the indices of the minimum values along a given axis. */ +MLX_API array +argmin(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {}); + +/** Returns the index of the maximum value in the array. */ +MLX_API array argmax(const array& a, bool keepdims, StreamOrDevice s = {}); +inline array argmax(const array& a, StreamOrDevice s = {}) { + return argmax(a, false, s); +} + +/** Returns the indices of the maximum values along a given axis. */ +MLX_API array +argmax(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {}); + +/** Returns a sorted copy of the flattened array. */ +MLX_API array sort(const array& a, StreamOrDevice s = {}); + +/** Returns a sorted copy of the array along a given axis. */ +MLX_API array sort(const array& a, int axis, StreamOrDevice s = {}); + +/** Returns indices that sort the flattened array. */ +MLX_API array argsort(const array& a, StreamOrDevice s = {}); + +/** Returns indices that sort the array along a given axis. */ +MLX_API array argsort(const array& a, int axis, StreamOrDevice s = {}); + +/** + * Returns a partitioned copy of the flattened array + * such that the smaller kth elements are first. + **/ +MLX_API array partition(const array& a, int kth, StreamOrDevice s = {}); + +/** + * Returns a partitioned copy of the array along a given axis + * such that the smaller kth elements are first. + **/ +MLX_API array +partition(const array& a, int kth, int axis, StreamOrDevice s = {}); + +/** + * Returns indices that partition the flattened array + * such that the smaller kth elements are first. + **/ +MLX_API array argpartition(const array& a, int kth, StreamOrDevice s = {}); + +/** + * Returns indices that partition the array along a given axis + * such that the smaller kth elements are first. + **/ +MLX_API array +argpartition(const array& a, int kth, int axis, StreamOrDevice s = {}); + +/** Returns topk elements of the flattened array. */ +MLX_API array topk(const array& a, int k, StreamOrDevice s = {}); + +/** Returns topk elements of the array along a given axis. */ +MLX_API array topk(const array& a, int k, int axis, StreamOrDevice s = {}); + +/** Cumulative logsumexp of an array. */ +MLX_API array logcumsumexp( + const array& a, + bool reverse = false, + bool inclusive = true, + StreamOrDevice s = {}); + +/** Cumulative logsumexp of an array along the given axis. */ +MLX_API array logcumsumexp( + const array& a, + int axis, + bool reverse = false, + bool inclusive = true, + StreamOrDevice s = {}); + +/** The logsumexp of all elements of the array. */ +MLX_API array logsumexp(const array& a, bool keepdims, StreamOrDevice s = {}); +inline array logsumexp(const array& a, StreamOrDevice s = {}) { + return logsumexp(a, false, to_stream(s)); +} + +/** The logsumexp of the elements of an array along the given axes. */ +MLX_API array logsumexp( + const array& a, + const std::vector& axes, + bool keepdims = false, + StreamOrDevice s = {}); + +/** The logsumexp of the elements of an array along the given axis. */ +MLX_API array logsumexp( + const array& a, + int axis, + bool keepdims = false, + StreamOrDevice s = {}); + +/** Absolute value of elements in an array. */ +MLX_API array abs(const array& a, StreamOrDevice s = {}); + +/** Negate an array. */ +MLX_API array negative(const array& a, StreamOrDevice s = {}); +MLX_API array operator-(const array& a); + +/** The sign of the elements in an array. */ +MLX_API array sign(const array& a, StreamOrDevice s = {}); + +/** Logical not of an array */ +MLX_API array logical_not(const array& a, StreamOrDevice s = {}); + +/** Logical and of two arrays */ +MLX_API array +logical_and(const array& a, const array& b, StreamOrDevice s = {}); +MLX_API array operator&&(const array& a, const array& b); + +/** Logical or of two arrays */ +MLX_API array logical_or(const array& a, const array& b, StreamOrDevice s = {}); +MLX_API array operator||(const array& a, const array& b); + +/** The reciprocal (1/x) of the elements in an array. */ +MLX_API array reciprocal(const array& a, StreamOrDevice s = {}); + +/** Add two arrays. */ +MLX_API array add(const array& a, const array& b, StreamOrDevice s = {}); +MLX_API array operator+(const array& a, const array& b); +template +array operator+(T a, const array& b) { + return add(array(a), b); +} +template +array operator+(const array& a, T b) { + return add(a, array(b)); +} + +/** Subtract two arrays. */ +MLX_API array subtract(const array& a, const array& b, StreamOrDevice s = {}); +MLX_API array operator-(const array& a, const array& b); +template +array operator-(T a, const array& b) { + return subtract(array(a), b); +} +template +array operator-(const array& a, T b) { + return subtract(a, array(b)); +} + +/** Multiply two arrays. */ +MLX_API array multiply(const array& a, const array& b, StreamOrDevice s = {}); +MLX_API array operator*(const array& a, const array& b); +template +array operator*(T a, const array& b) { + return multiply(array(a), b); +} +template +array operator*(const array& a, T b) { + return multiply(a, array(b)); +} + +/** Divide two arrays. */ +MLX_API array divide(const array& a, const array& b, StreamOrDevice s = {}); +MLX_API array operator/(const array& a, const array& b); +MLX_API array operator/(double a, const array& b); +MLX_API array operator/(const array& a, double b); + +/** Compute the element-wise quotient and remainder. */ +MLX_API std::vector +divmod(const array& a, const array& b, StreamOrDevice s = {}); + +/** Compute integer division. Equivalent to doing floor(a / x). */ +MLX_API array +floor_divide(const array& a, const array& b, StreamOrDevice s = {}); + +/** Compute the element-wise remainder of division */ +MLX_API array remainder(const array& a, const array& b, StreamOrDevice s = {}); +MLX_API array operator%(const array& a, const array& b); +template +array operator%(T a, const array& b) { + return remainder(array(a), b); +} +template +array operator%(const array& a, T b) { + return remainder(a, array(b)); +} + +/** Element-wise maximum between two arrays. */ +MLX_API array maximum(const array& a, const array& b, StreamOrDevice s = {}); + +/** Element-wise minimum between two arrays. */ +MLX_API array minimum(const array& a, const array& b, StreamOrDevice s = {}); + +/** Floor the element of an array. **/ +MLX_API array floor(const array& a, StreamOrDevice s = {}); + +/** Ceil the element of an array. **/ +MLX_API array ceil(const array& a, StreamOrDevice s = {}); + +/** Square the elements of an array. */ +MLX_API array square(const array& a, StreamOrDevice s = {}); + +/** Exponential of the elements of an array. */ +MLX_API array exp(const array& a, StreamOrDevice s = {}); + +/** Sine of the elements of an array */ +MLX_API array sin(const array& a, StreamOrDevice s = {}); + +/** Cosine of the elements of an array */ +MLX_API array cos(const array& a, StreamOrDevice s = {}); + +/** Tangent of the elements of an array */ +MLX_API array tan(const array& a, StreamOrDevice s = {}); + +/** Arc Sine of the elements of an array */ +MLX_API array arcsin(const array& a, StreamOrDevice s = {}); + +/** Arc Cosine of the elements of an array */ +MLX_API array arccos(const array& a, StreamOrDevice s = {}); + +/** Arc Tangent of the elements of an array */ +MLX_API array arctan(const array& a, StreamOrDevice s = {}); + +/** Inverse tangent of the ratio of two arrays */ +MLX_API array arctan2(const array& a, const array& b, StreamOrDevice s = {}); + +/** Hyperbolic Sine of the elements of an array */ +MLX_API array sinh(const array& a, StreamOrDevice s = {}); + +/** Hyperbolic Cosine of the elements of an array */ +MLX_API array cosh(const array& a, StreamOrDevice s = {}); + +/** Hyperbolic Tangent of the elements of an array */ +MLX_API array tanh(const array& a, StreamOrDevice s = {}); + +/** Inverse Hyperbolic Sine of the elements of an array */ +MLX_API array arcsinh(const array& a, StreamOrDevice s = {}); + +/** Inverse Hyperbolic Cosine of the elements of an array */ +MLX_API array arccosh(const array& a, StreamOrDevice s = {}); + +/** Inverse Hyperbolic Tangent of the elements of an array */ +MLX_API array arctanh(const array& a, StreamOrDevice s = {}); + +/** Convert the elements of an array from Radians to Degrees **/ +MLX_API array degrees(const array& a, StreamOrDevice s = {}); + +/** Convert the elements of an array from Degrees to Radians **/ +MLX_API array radians(const array& a, StreamOrDevice s = {}); + +/** Natural logarithm of the elements of an array. */ +MLX_API array log(const array& a, StreamOrDevice s = {}); + +/** Log base 2 of the elements of an array. */ +MLX_API array log2(const array& a, StreamOrDevice s = {}); + +/** Log base 10 of the elements of an array. */ +MLX_API array log10(const array& a, StreamOrDevice s = {}); + +/** Natural logarithm of one plus elements in the array: `log(1 + a)`. */ +MLX_API array log1p(const array& a, StreamOrDevice s = {}); + +/** Log-add-exp of one elements in the array: `log(exp(a) + exp(b))`. */ +MLX_API array logaddexp(const array& a, const array& b, StreamOrDevice s = {}); + +/** Element-wise logistic sigmoid of the array: `1 / (1 + exp(-x)`. */ +MLX_API array sigmoid(const array& a, StreamOrDevice s = {}); + +/** Computes the error function of the elements of an array. */ +MLX_API array erf(const array& a, StreamOrDevice s = {}); + +/** Computes the inverse error function of the elements of an array. */ +MLX_API array erfinv(const array& a, StreamOrDevice s = {}); + +/** Computes the expm1 function of the elements of an array. */ +MLX_API array expm1(const array& a, StreamOrDevice s = {}); + +/** Stop the flow of gradients. */ +MLX_API array stop_gradient(const array& a, StreamOrDevice s = {}); + +/** Round a floating point number */ +MLX_API array round(const array& a, int decimals, StreamOrDevice s = {}); +inline array round(const array& a, StreamOrDevice s = {}) { + return round(a, 0, s); +} + +/** Matrix-matrix multiplication. */ +MLX_API array matmul(const array& a, const array& b, StreamOrDevice s = {}); + +/** Gather array entries given indices and slices */ +MLX_API array gather( + const array& a, + const std::vector& indices, + const std::vector& axes, + const Shape& slice_sizes, + StreamOrDevice s = {}); +inline array gather( + const array& a, + const array& indices, + int axis, + const Shape& slice_sizes, + StreamOrDevice s = {}) { + return gather(a, {indices}, std::vector{axis}, slice_sizes, s); +} + +/** Compute the Kronecker product of two arrays. */ +MLX_API array kron(const array& a, const array& b, StreamOrDevice s = {}); + +/** Take array slices at the given indices of the specified axis. */ +MLX_API array +take(const array& a, const array& indices, int axis, StreamOrDevice s = {}); +MLX_API array take(const array& a, int index, int axis, StreamOrDevice s = {}); + +/** Take array entries at the given indices treating the array as flattened. */ +MLX_API array take(const array& a, const array& indices, StreamOrDevice s = {}); +MLX_API array take(const array& a, int index, StreamOrDevice s = {}); + +/** Take array entries given indices along the axis */ +MLX_API array take_along_axis( + const array& a, + const array& indices, + int axis, + StreamOrDevice s = {}); + +/** Put the values into the array at the given indices along the axis */ +MLX_API array put_along_axis( + const array& a, + const array& indices, + const array& values, + int axis, + StreamOrDevice s = {}); + +/** Add the values into the array at the given indices along the axis */ +MLX_API array scatter_add_axis( + const array& a, + const array& indices, + const array& values, + int axis, + StreamOrDevice s = {}); + +/** Scatter updates to the given indices. + * + * The parameters ``indices`` and ``axes`` determine the locations of ``a`` + * that are updated with the values in ``updates``. Assuming 1-d ``indices`` + * for simplicity, ``indices[i]`` are the indices on axis ``axes[i]`` to which + * the values in ``updates`` will be applied. Note each array in + * ``indices`` is assigned to a corresponding axis and hence ``indices.size() == + * axes.size()``. If an index/axis pair is not provided then indices along that + * axis are assumed to be zero. + * + * Note the rank of ``updates`` must be equal to the sum of the rank of the + * broadcasted ``indices`` and the rank of ``a``. In other words, assuming the + * arrays in ``indices`` have the same shape, ``updates.ndim() == + * indices[0].ndim() + a.ndim()``. The leading dimensions of ``updates`` + * correspond to the indices, and the remaining ``a.ndim()`` dimensions are the + * values that will be applied to the given location in ``a``. + * + * For example: + * + * @code + * auto in = zeros({4, 4}, float32); + * auto indices = array({2}); + * auto updates = reshape(arange(1, 3, float32), {1, 1, 2}); + * std::vector axes{0}; + * + * auto out = scatter(in, {indices}, updates, axes); + * @endcode + * + * will produce: + * + * @code + * array([[0, 0, 0, 0], + * [0, 0, 0, 0], + * [1, 2, 0, 0], + * [0, 0, 0, 0]], dtype=float32) + * @endcode + * + * This scatters the two-element row vector ``[1, 2]`` starting at the ``(2, + * 0)`` position of ``a``. + * + * Adding another element to ``indices`` will scatter into another location of + * ``a``. We also have to add an another update for the new index: + * + * @code + * auto in = zeros({4, 4}, float32); + * auto indices = array({2, 0}); + * auto updates = reshape(arange(1, 5, float32), {2, 1, 2}); + * std::vector axes{0}; + * + * auto out = scatter(in, {indices}, updates, axes): + * @endcode + * + * will produce: + * + * @code + * array([[3, 4, 0, 0], + * [0, 0, 0, 0], + * [1, 2, 0, 0], + * [0, 0, 0, 0]], dtype=float32) + * @endcode + * + * To control the scatter location on an additional axis, add another index + * array to ``indices`` and another axis to ``axes``: + * + * @code + * auto in = zeros({4, 4}, float32); + * auto indices = std::vector{array({2, 0}), array({1, 2})}; + * auto updates = reshape(arange(1, 5, float32), {2, 1, 2}); + * std::vector axes{0, 1}; + * + * auto out = scatter(in, indices, updates, axes); + * @endcode + * + * will produce: + * + * @code + * array([[0, 0, 3, 4], + * [0, 0, 0, 0], + * [0, 1, 2, 0], + * [0, 0, 0, 0]], dtype=float32) + * @endcode + * + * Items in indices are broadcasted together. This means: + * + * @code + * auto indices = std::vector{array({2, 0}), array({1})}; + * @endcode + * + * is equivalent to: + * + * @code + * auto indices = std::vector{array({2, 0}), array({1, 1})}; + * @endcode + * + * Note, ``scatter`` does not perform bounds checking on the indices and + * updates. Out-of-bounds accesses on ``a`` are undefined and typically result + * in unintended or invalid memory writes. + */ +MLX_API array scatter( + const array& a, + const std::vector& indices, + const array& updates, + const std::vector& axes, + StreamOrDevice s = {}); +inline array scatter( + const array& a, + const array& indices, + const array& updates, + int axis, + StreamOrDevice s = {}) { + return scatter(a, {indices}, updates, std::vector{axis}, s); +} + +/** Scatter and add updates to given indices */ +MLX_API array scatter_add( + const array& a, + const std::vector& indices, + const array& updates, + const std::vector& axes, + StreamOrDevice s = {}); +inline array scatter_add( + const array& a, + const array& indices, + const array& updates, + int axis, + StreamOrDevice s = {}) { + return scatter_add(a, {indices}, updates, std::vector{axis}, s); +} + +/** Scatter and prod updates to given indices */ +MLX_API array scatter_prod( + const array& a, + const std::vector& indices, + const array& updates, + const std::vector& axes, + StreamOrDevice s = {}); +inline array scatter_prod( + const array& a, + const array& indices, + const array& updates, + int axis, + StreamOrDevice s = {}) { + return scatter_prod(a, {indices}, updates, std::vector{axis}, s); +} + +/** Scatter and max updates to given linear indices */ +MLX_API array scatter_max( + const array& a, + const std::vector& indices, + const array& updates, + const std::vector& axes, + StreamOrDevice s = {}); +inline array scatter_max( + const array& a, + const array& indices, + const array& updates, + int axis, + StreamOrDevice s = {}) { + return scatter_max(a, {indices}, updates, std::vector{axis}, s); +} +/** Scatter and min updates to given linear indices */ +MLX_API array scatter_min( + const array& a, + const std::vector& indices, + const array& updates, + const std::vector& axes, + StreamOrDevice s = {}); +inline array scatter_min( + const array& a, + const array& indices, + const array& updates, + int axis, + StreamOrDevice s = {}) { + return scatter_min(a, {indices}, updates, std::vector{axis}, s); +} + +MLX_API array masked_scatter( + const array& a, + const array& mask, + const array& src, + StreamOrDevice s = {}); + +/** Square root the elements of an array. */ +MLX_API array sqrt(const array& a, StreamOrDevice s = {}); + +/** Square root and reciprocal the elements of an array. */ +MLX_API array rsqrt(const array& a, StreamOrDevice s = {}); + +/** Softmax of an array. */ +MLX_API array softmax( + const array& a, + const std::vector& axes, + bool precise = false, + StreamOrDevice s = {}); + +/** Softmax of an array. */ +MLX_API array +softmax(const array& a, bool precise = false, StreamOrDevice s = {}); + +/** Softmax of an array. */ +inline array +softmax(const array& a, int axis, bool precise = false, StreamOrDevice s = {}) { + return softmax(a, std::vector{axis}, precise, s); +} + +/** Raise elements of a to the power of b element-wise */ +MLX_API array power(const array& a, const array& b, StreamOrDevice s = {}); + +/** Cumulative sum of an array. */ +MLX_API array cumsum( + const array& a, + bool reverse = false, + bool inclusive = true, + StreamOrDevice s = {}); + +/** Cumulative sum of an array along the given axis. */ +MLX_API array cumsum( + const array& a, + int axis, + bool reverse = false, + bool inclusive = true, + StreamOrDevice s = {}); + +/** Cumulative product of an array. */ +MLX_API array cumprod( + const array& a, + bool reverse = false, + bool inclusive = true, + StreamOrDevice s = {}); + +/** Cumulative product of an array along the given axis. */ +MLX_API array cumprod( + const array& a, + int axis, + bool reverse = false, + bool inclusive = true, + StreamOrDevice s = {}); + +/** Cumulative max of an array. */ +MLX_API array cummax( + const array& a, + bool reverse = false, + bool inclusive = true, + StreamOrDevice s = {}); + +/** Cumulative max of an array along the given axis. */ +MLX_API array cummax( + const array& a, + int axis, + bool reverse = false, + bool inclusive = true, + StreamOrDevice s = {}); + +/** Cumulative min of an array. */ +MLX_API array cummin( + const array& a, + bool reverse = false, + bool inclusive = true, + StreamOrDevice s = {}); + +/** Cumulative min of an array along the given axis. */ +MLX_API array cummin( + const array& a, + int axis, + bool reverse = false, + bool inclusive = true, + StreamOrDevice s = {}); + +/** General convolution with a filter */ +MLX_API array conv_general( + array input, + array weight, + std::vector stride = {}, + std::vector padding_lo = {}, + std::vector padding_hi = {}, + std::vector kernel_dilation = {}, + std::vector input_dilation = {}, + int groups = 1, + bool flip = false, + StreamOrDevice s = {}); + +/** General convolution with a filter */ +inline array conv_general( + const array& input, + const array& weight, + std::vector stride = {}, + std::vector padding = {}, + std::vector kernel_dilation = {}, + std::vector input_dilation = {}, + int groups = 1, + bool flip = false, + StreamOrDevice s = {}) { + return conv_general( + /* const array& input = */ input, + /* const array& weight = */ weight, + /* std::vector stride = */ stride, + /* std::vector padding_lo = */ padding, + /* std::vector padding_hi = */ padding, + /* std::vector kernel_dilation = */ kernel_dilation, + /* std::vector input_dilation = */ input_dilation, + /* int groups = */ groups, + /* bool flip = */ flip, + /* StreamOrDevice s = */ s); +} + +/** 1D convolution with a filter */ +MLX_API array conv1d( + const array& input, + const array& weight, + int stride = 1, + int padding = 0, + int dilation = 1, + int groups = 1, + StreamOrDevice s = {}); + +/** 2D convolution with a filter */ +MLX_API array conv2d( + const array& input, + const array& weight, + const std::pair& stride = {1, 1}, + const std::pair& padding = {0, 0}, + const std::pair& dilation = {1, 1}, + int groups = 1, + StreamOrDevice s = {}); + +/** 3D convolution with a filter */ +MLX_API array conv3d( + const array& input, + const array& weight, + const std::tuple& stride = {1, 1, 1}, + const std::tuple& padding = {0, 0, 0}, + const std::tuple& dilation = {1, 1, 1}, + int groups = 1, + StreamOrDevice s = {}); + +/** 1D transposed convolution with a filter */ +MLX_API array conv_transpose1d( + const array& input, + const array& weight, + int stride = 1, + int padding = 0, + int dilation = 1, + int output_padding = 0, + int groups = 1, + StreamOrDevice s = {}); + +/** 2D transposed convolution with a filter */ +MLX_API array conv_transpose2d( + const array& input, + const array& weight, + const std::pair& stride = {1, 1}, + const std::pair& padding = {0, 0}, + const std::pair& dilation = {1, 1}, + const std::pair& output_padding = {0, 0}, + int groups = 1, + StreamOrDevice s = {}); + +/** 3D transposed convolution with a filter */ +MLX_API array conv_transpose3d( + const array& input, + const array& weight, + const std::tuple& stride = {1, 1, 1}, + const std::tuple& padding = {0, 0, 0}, + const std::tuple& dilation = {1, 1, 1}, + const std::tuple& output_padding = {0, 0, 0}, + int groups = 1, + StreamOrDevice s = {}); + +/** Quantized matmul multiplies x with a quantized matrix w*/ +MLX_API array quantized_matmul( + array x, + array w, + array scales, + std::optional biases = std::nullopt, + bool transpose = true, + std::optional group_size = std::nullopt, + std::optional bits = std::nullopt, + const std::string& mode = "affine", + StreamOrDevice s = {}); + +/** Quantize a matrix along its last axis */ +MLX_API std::vector quantize( + const array& w, + std::optional group_size = std::nullopt, + std::optional bits = std::nullopt, + const std::string& mode = "affine", + const std::optional& global_scale = std::nullopt, + StreamOrDevice s = {}); + +/** Dequantize a matrix produced by quantize() */ +MLX_API array dequantize( + const array& w, + const array& scales, + const std::optional& biases = std::nullopt, + std::optional group_size = std::nullopt, + std::optional bits = std::nullopt, + const std::string& mode = "affine", + const std::optional& global_scale = std::nullopt, + std::optional dtype = std::nullopt, + StreamOrDevice s = {}); + +MLX_API array qqmm( + array x, // input activations + array w, // maybe quantized weights + const std::optional w_scales = std::nullopt, // optional scales if w + // is quantized + std::optional group_size = std::nullopt, + std::optional bits = std::nullopt, + const std::string& mode = "nvfp4", + const std::optional global_scale_x = std::nullopt, + const std::optional global_scale_w = std::nullopt, + StreamOrDevice s = {}); + +/** Convert an E4M3 float8 to the given floating point dtype. */ +MLX_API array from_fp8(array x, Dtype dtype, StreamOrDevice s = {}); + +/** Convert a floating point matrix to E4M3 float8. */ +MLX_API array to_fp8(array x, StreamOrDevice s = {}); + +/** Compute matrix products with matrix-level gather. */ +MLX_API array gather_qmm( + const array& x, + const array& w, + const array& scales, + const std::optional& biases = std::nullopt, + std::optional lhs_indices = std::nullopt, + std::optional rhs_indices = std::nullopt, + bool transpose = true, + std::optional group_size = std::nullopt, + std::optional bits = std::nullopt, + const std::string& mode = "affine", + bool sorted_indices = false, + StreamOrDevice s = {}); + +/** Returns a contraction of a and b over multiple dimensions. */ +MLX_API array tensordot( + const array& a, + const array& b, + const int axis = 2, + StreamOrDevice s = {}); + +MLX_API array tensordot( + const array& a, + const array& b, + const std::vector& axes_a, + const std::vector& axes_b, + StreamOrDevice s = {}); + +/** Compute the outer product of two vectors. */ +MLX_API array outer(const array& a, const array& b, StreamOrDevice s = {}); + +/** Compute the inner product of two vectors. */ +MLX_API array inner(const array& a, const array& b, StreamOrDevice s = {}); + +/** Compute D = beta * C + alpha * (A @ B) */ +MLX_API array addmm( + array c, + array a, + array b, + const float& alpha = 1.f, + const float& beta = 1.f, + StreamOrDevice s = {}); + +/** Compute matrix product with block masking */ +MLX_API array block_masked_mm( + array a, + array b, + int block_size, + std::optional mask_out = std::nullopt, + std::optional mask_lhs = std::nullopt, + std::optional mask_rhs = std::nullopt, + StreamOrDevice s = {}); + +/** Compute matrix product with matrix-level gather */ +MLX_API array gather_mm( + array a, + array b, + std::optional lhs_indices = std::nullopt, + std::optional rhs_indices = std::nullopt, + bool sorted_indices = false, + StreamOrDevice s = {}); + +/** + * Compute a matrix product but segment the inner dimension and write the + * result separately for each segment. + */ +MLX_API array +segmented_mm(array a, array b, array segments, StreamOrDevice s = {}); + +/** Extract a diagonal or construct a diagonal array */ +MLX_API array diagonal( + const array& a, + int offset = 0, + int axis1 = 0, + int axis2 = 1, + StreamOrDevice s = {}); + +/** Extract diagonal from a 2d array or create a diagonal matrix. */ +MLX_API array diag(const array& a, int k = 0, StreamOrDevice s = {}); + +/** Return the sum along a specified diagonal in the given array. */ +MLX_API array trace( + const array& a, + int offset, + int axis1, + int axis2, + Dtype dtype, + StreamOrDevice s = {}); +MLX_API array +trace(const array& a, int offset, int axis1, int axis2, StreamOrDevice s = {}); +MLX_API array trace(const array& a, StreamOrDevice s = {}); + +/** + * Implements the identity function but allows injecting dependencies to other + * arrays. This ensures that these other arrays will have been computed + * when the outputs of this function are computed. + */ +MLX_API std::vector depends( + const std::vector& inputs, + const std::vector& dependencies); + +/** convert an array to an atleast ndim array */ +MLX_API array atleast_1d(const array& a, StreamOrDevice s = {}); +MLX_API std::vector atleast_1d( + const std::vector& a, + StreamOrDevice s = {}); +MLX_API array atleast_2d(const array& a, StreamOrDevice s = {}); +MLX_API std::vector atleast_2d( + const std::vector& a, + StreamOrDevice s = {}); +MLX_API array atleast_3d(const array& a, StreamOrDevice s = {}); +MLX_API std::vector atleast_3d( + const std::vector& a, + StreamOrDevice s = {}); + +/** + * Extract the number of elements along some axes as a scalar array. Used to + * allow shape dependent shapeless compilation (pun intended). + */ +MLX_API array number_of_elements( + const array& a, + std::vector axes, + bool inverted, + Dtype dtype = int32, + StreamOrDevice s = {}); + +MLX_API array conjugate(const array& a, StreamOrDevice s = {}); + +/** Bitwise and. */ +MLX_API array +bitwise_and(const array& a, const array& b, StreamOrDevice s = {}); +MLX_API array operator&(const array& a, const array& b); + +/** Bitwise inclusive or. */ +MLX_API array bitwise_or(const array& a, const array& b, StreamOrDevice s = {}); +MLX_API array operator|(const array& a, const array& b); + +/** Bitwise exclusive or. */ +MLX_API array +bitwise_xor(const array& a, const array& b, StreamOrDevice s = {}); +MLX_API array operator^(const array& a, const array& b); + +/** Shift bits to the left. */ +MLX_API array left_shift(const array& a, const array& b, StreamOrDevice s = {}); +MLX_API array operator<<(const array& a, const array& b); + +/** Shift bits to the right. */ +MLX_API array +right_shift(const array& a, const array& b, StreamOrDevice s = {}); +MLX_API array operator>>(const array& a, const array& b); + +/** Invert the bits. */ +MLX_API array bitwise_invert(const array& a, StreamOrDevice s = {}); +MLX_API array operator~(const array& a); + +MLX_API array view(const array& a, const Dtype& dtype, StreamOrDevice s = {}); + +/** Roll elements along an axis and introduce them on the other side */ +MLX_API array roll(const array& a, int shift, StreamOrDevice s = {}); +MLX_API array roll(const array& a, const Shape& shift, StreamOrDevice s = {}); +MLX_API array roll(const array& a, int shift, int axis, StreamOrDevice s = {}); +MLX_API array roll( + const array& a, + int shift, + const std::vector& axes, + StreamOrDevice s = {}); +MLX_API array +roll(const array& a, const Shape& shift, int axis, StreamOrDevice s = {}); +MLX_API array roll( + const array& a, + const Shape& shift, + const std::vector& axes, + StreamOrDevice s = {}); + +/* The real part of a complex array. */ +MLX_API array real(const array& a, StreamOrDevice s = {}); + +/* The imaginary part of a complex array. */ +MLX_API array imag(const array& a, StreamOrDevice s = {}); + +/* Ensure the array's underlying memory is contiguous. */ +MLX_API array +contiguous(const array& a, bool allow_col_major = false, StreamOrDevice s = {}); + +/** @} */ + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/primitives.h b/Source/Cxxmlx/include/mlx/primitives.h new file mode 100644 index 00000000..4091aafc --- /dev/null +++ b/Source/Cxxmlx/include/mlx/primitives.h @@ -0,0 +1,2525 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include + +#include "mlx/api.h" +#include "mlx/array.h" +#include "mlx/device.h" +#include "mlx/io/load.h" +#include "mlx/stream.h" + +#define DEFINE_VMAP() \ + virtual std::pair, std::vector> vmap( \ + const std::vector& inputs, const std::vector& axes) \ + override; + +#define DEFINE_GRADS() \ + std::vector jvp( \ + const std::vector& primals, \ + const std::vector& tangents, \ + const std::vector& argnums) override; \ + \ + std::vector vjp( \ + const std::vector& primals, \ + const std::vector& cotangents, \ + const std::vector& argnums, \ + const std::vector& outputs) override; + +#define DEFINE_NAME(PRIMITIVE) \ + const char* name() const override { \ + return #PRIMITIVE; \ + } + +#define DEFINE_DEFAULT_IS_EQUIVALENT() \ + bool is_equivalent(const Primitive& other) const override { \ + return true; \ + } + +#define DEFINE_INPUT_OUTPUT_SHAPE() \ + std::vector output_shapes(const std::vector& inputs) \ + override { \ + return {inputs[0].shape()}; \ + } + +namespace mlx::core { + +// Abstract base class +class MLX_API Primitive { + public: + explicit Primitive(Stream stream) : stream_(stream) {} + + /** The device the primitive will run on. */ + const Device& device() { + return stream().device; + } + + /** The stream the primitive will run on. */ + const Stream& stream() { + return stream_; + } + + /** + * A primitive must know how to evaluate itself on + * the CPU/GPU for the given inputs and populate the output arrays. + * + * To avoid unnecessary allocations, the evaluation function + * is responsible for allocating space for the array. + */ + virtual void eval_cpu( + const std::vector& inputs, + std::vector& outputs) = 0; + virtual void eval_gpu( + const std::vector& inputs, + std::vector& outputs) = 0; + + /** + * The Jacobian-vector product. + */ + virtual std::vector jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums); + + /** + * The vector-Jacobian product. + */ + virtual std::vector vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector& outputs); + + /** + * The primitive must know how to vectorize itself across + * the given axes. The output is a pair containing the output arrays + * representing the vectorized computation and the axes which + * corresponds to the vectorized dimensions of each output. + */ + virtual std::pair, std::vector> vmap( + const std::vector& inputs, + const std::vector& axes); + + /** Get the name of primitive. */ + virtual const char* name() const = 0; + + /** Equivalence check defaults to false unless overridden by the primitive */ + virtual bool is_equivalent(const Primitive& other) const { + return false; + } + + /** Get the output shapes of the primitive. This is not required to be + * implemented by derived classes, in which case it will throw. */ + virtual std::vector output_shapes(const std::vector& inputs); + + virtual ~Primitive() = default; + Primitive(const Primitive& other) = delete; + Primitive(Primitive&& other) = delete; + Primitive& operator=(const Primitive& other) = delete; + Primitive& operator=(Primitive&& other) = delete; + + private: + // Every primitive stores the stream it should run in + Stream stream_; +}; + +class MLX_API UnaryPrimitive : public Primitive { + /** + * An abstract base class for a primitive with a single output. + */ + public: + explicit UnaryPrimitive(Stream stream) : Primitive(stream) {} + + virtual void eval_cpu(const std::vector& inputs, array& output) = 0; + virtual void eval_gpu(const std::vector& inputs, array& output) = 0; + + inline void eval_cpu( + const std::vector& inputs, + std::vector& outputs) override { + eval_cpu(inputs, outputs[0]); + } + inline void eval_gpu( + const std::vector& inputs, + std::vector& outputs) override { + eval_gpu(inputs, outputs[0]); + } + + virtual ~UnaryPrimitive() = default; + UnaryPrimitive(const UnaryPrimitive& other) = delete; + UnaryPrimitive(UnaryPrimitive&& other) = delete; + UnaryPrimitive& operator=(const UnaryPrimitive& other) = delete; + UnaryPrimitive& operator=(UnaryPrimitive&& other) = delete; +}; + +enum class QuantizationMode { Affine, Mxfp4, Mxfp8, Nvfp4 }; + +std::string quantization_mode_to_string(QuantizationMode mode); +QuantizationMode string_to_quantization_mode( + const std::string& mode, + std::string_view error_tag = ""); + +class Abs : public UnaryPrimitive { + public: + explicit Abs(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Abs) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class MLX_API Add : public UnaryPrimitive { + public: + explicit Add(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Add) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class AddMM : public UnaryPrimitive { + public: + explicit AddMM(Stream stream, float alpha, float beta) + : UnaryPrimitive(stream), alpha_(alpha), beta_(beta) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_GRADS() + DEFINE_VMAP() + DEFINE_NAME(AddMM) + + bool is_equivalent(const Primitive& other) const override; + std::pair state() const { + return {alpha_, beta_}; + }; + + private: + const float alpha_; + const float beta_; +}; + +class Arange : public UnaryPrimitive { + public: + explicit Arange(Stream stream, double start, double stop, double step) + : UnaryPrimitive(stream), start_(start), stop_(stop), step_(step) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_NAME(Arange) + bool is_equivalent(const Primitive& other) const override; + std::vector output_shapes(const std::vector& inputs) override; + std::tuple state() const { + return {start_, stop_, step_}; + }; + + private: + double start_; + double stop_; + double step_; +}; + +class ArcCos : public UnaryPrimitive { + public: + explicit ArcCos(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(ArcCos) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class ArcCosh : public UnaryPrimitive { + public: + explicit ArcCosh(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(ArcCosh) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class ArcSin : public UnaryPrimitive { + public: + explicit ArcSin(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(ArcSin) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class ArcSinh : public UnaryPrimitive { + public: + explicit ArcSinh(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(ArcSinh) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class ArcTan : public UnaryPrimitive { + public: + explicit ArcTan(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(ArcTan) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class ArcTan2 : public UnaryPrimitive { + public: + explicit ArcTan2(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(ArcTan2) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class ArcTanh : public UnaryPrimitive { + public: + explicit ArcTanh(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(ArcTanh) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class ArgPartition : public UnaryPrimitive { + public: + explicit ArgPartition(Stream stream, int kth, int axis) + : UnaryPrimitive(stream), kth_(kth), axis_(axis) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(ArgPartition) + DEFINE_INPUT_OUTPUT_SHAPE() + bool is_equivalent(const Primitive& other) const override; + std::pair state() const { + return {kth_, axis_}; + }; + + private: + int kth_; + int axis_; +}; + +class MLX_API ArgReduce : public UnaryPrimitive { + public: + enum ReduceType { + ArgMin, + ArgMax, + }; + + explicit ArgReduce(Stream stream, ReduceType reduce_type, int axis) + : UnaryPrimitive(stream), reduce_type_(reduce_type), axis_(axis) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(ArgReduce) + bool is_equivalent(const Primitive& other) const override; + std::vector output_shapes(const std::vector& inputs) override; + std::pair state() const { + return {reduce_type_, axis_}; + }; + + private: + ReduceType reduce_type_; + int axis_; +}; + +class ArgSort : public UnaryPrimitive { + public: + explicit ArgSort(Stream stream, int axis) + : UnaryPrimitive(stream), axis_(axis) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(ArgSort) + DEFINE_INPUT_OUTPUT_SHAPE() + bool is_equivalent(const Primitive& other) const override; + int state() const { + return axis_; + }; + + private: + int axis_; +}; + +class AsType : public UnaryPrimitive { + public: + explicit AsType(Stream stream, Dtype dtype) + : UnaryPrimitive(stream), dtype_(dtype) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(AsType) + DEFINE_INPUT_OUTPUT_SHAPE() + bool is_equivalent(const Primitive& other) const override; + Dtype state() const { + return dtype_; + }; + + private: + Dtype dtype_; +}; + +class AsStrided : public UnaryPrimitive { + public: + explicit AsStrided(Stream stream, Shape shape, Strides strides, size_t offset) + : UnaryPrimitive(stream), + shape_(std::move(shape)), + strides_(std::move(strides)), + offset_(offset) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_GRADS() + DEFINE_NAME(AsStrided) + bool is_equivalent(const Primitive& other) const override; + auto state() const { + return std::make_tuple(shape_, strides_, offset_); + } + + private: + Shape shape_; + Strides strides_; + size_t offset_; + + void eval(const std::vector& inputs, array& out); +}; + +class BitwiseBinary : public UnaryPrimitive { + public: + enum Op { And, Or, Xor, LeftShift, RightShift }; + + explicit BitwiseBinary(Stream stream, Op op) + : UnaryPrimitive(stream), op_(op) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + + const char* name() const override { + switch (op_) { + case BitwiseBinary::And: + return "BitwiseAnd"; + case BitwiseBinary::Or: + return "BitwiseOr"; + case BitwiseBinary::Xor: + return "BitwiseXor"; + case BitwiseBinary::LeftShift: + return "LeftShift"; + case BitwiseBinary::RightShift: + return "RightShift"; + } + return ""; + } + + bool is_equivalent(const Primitive& other) const override; + DEFINE_INPUT_OUTPUT_SHAPE() + auto state() const { + return op_; + } + + private: + Op op_; +}; + +class BitwiseInvert : public UnaryPrimitive { + public: + explicit BitwiseInvert(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_NAME(BitwiseInvert) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class BlockMaskedMM : public UnaryPrimitive { + public: + explicit BlockMaskedMM(Stream stream, int block_size) + : UnaryPrimitive(stream), block_size_(block_size) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::vector vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector& outputs) override; + + DEFINE_NAME(BlockMaskedMM) + bool is_equivalent(const Primitive& other) const override; + auto state() const { + return block_size_; + } + + private: + int block_size_; +}; + +class GatherMM : public UnaryPrimitive { + public: + explicit GatherMM( + Stream stream, + bool left_sorted = false, + bool right_sorted = false) + : UnaryPrimitive(stream), + left_sorted_(left_sorted), + right_sorted_(right_sorted) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::vector vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector& outputs) override; + + DEFINE_NAME(GatherMM) + bool is_equivalent(const Primitive& other) const override; + auto state() const { + return std::make_pair(left_sorted_, right_sorted_); + } + + private: + bool left_sorted_; + bool right_sorted_; +}; + +class SegmentedMM : public UnaryPrimitive { + public: + explicit SegmentedMM(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_NAME(SegmentedMM) +}; + +class BroadcastAxes : public UnaryPrimitive { + public: + explicit BroadcastAxes(Stream stream, std::vector ignore_axes = {}) + : UnaryPrimitive(stream), ignore_axes_(std::move(ignore_axes)) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(BroadcastAxes) + bool is_equivalent(const Primitive& other) const override; + static Shape output_shape( + const std::vector& inputs, + const std::vector& ignore_axes); + std::vector output_shapes(const std::vector& inputs) override; + auto state() const { + return ignore_axes_; + } + + private: + void eval(const std::vector& inputs, array& out); + std::vector ignore_axes_; +}; + +class Broadcast : public UnaryPrimitive { + public: + explicit Broadcast(Stream stream, const Shape& shape) + : UnaryPrimitive(stream), shape_(shape) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Broadcast) + static Shape output_shape(const std::vector& inputs); + std::vector output_shapes(const std::vector& inputs) override; + bool is_equivalent(const Primitive& other) const override; + Shape state() const { + return shape_; + }; + + private: + Shape shape_; + + void eval(const std::vector& inputs, array& out); +}; + +class Ceil : public UnaryPrimitive { + public: + explicit Ceil(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Ceil) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class MLX_API Compiled : public Primitive { + public: + /* + * The inputs, outputs and tape are either tracers or constants. + * - The tape should not contain the inputs, but it should contain the + * outputs. + * - The tape should also have only one array per primitive for multi-output + * primitives. + * - The constant_ids contains ids of arrays in the input list that are safe + * to treat as scalar constants. + */ + explicit Compiled( + Stream stream, + std::vector inputs, + std::vector outputs, + std::vector tape, + std::unordered_set constant_ids); + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + DEFINE_VMAP() + DEFINE_GRADS() + const char* name() const override; + std::vector output_shapes(const std::vector& inputs) override; + bool is_equivalent(const Primitive& other) const override; + + std::string lib_name() const { + return kernel_lib_; + } + + private: + const std::vector inputs_; + const std::vector outputs_; + const std::vector tape_; + const std::unordered_set constant_ids_; + const std::function is_constant_; + + mutable std::string name_; + std::string kernel_lib_; +}; + +class Concatenate : public UnaryPrimitive { + public: + explicit Concatenate(Stream stream, int axis) + : UnaryPrimitive(stream), axis_(axis) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Concatenate) + bool is_equivalent(const Primitive& other) const override; + std::vector output_shapes(const std::vector& inputs) override; + auto state() const { + return axis_; + } + + private: + int axis_; +}; + +class Conjugate : public UnaryPrimitive { + public: + explicit Conjugate(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_NAME(Conjugate) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class Contiguous : public UnaryPrimitive { + public: + explicit Contiguous(Stream stream, bool allow_col_major) + : UnaryPrimitive(stream), allow_col_major_(allow_col_major) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Contiguous) + DEFINE_INPUT_OUTPUT_SHAPE() + + bool is_equivalent(const Primitive& other) const override; + + private: + bool allow_col_major_; +}; + +class Convolution : public UnaryPrimitive { + public: + explicit Convolution( + Stream stream, + const std::vector& kernel_strides, + const std::vector& padding_lo, + const std::vector& padding_hi, + const std::vector& kernel_dilation, + const std::vector& input_dilation, + const int groups = 1, + const bool flip = false) + : UnaryPrimitive(stream), + padding_lo_(padding_lo), + padding_hi_(padding_hi), + kernel_strides_(kernel_strides), + kernel_dilation_(kernel_dilation), + input_dilation_(input_dilation), + groups_(groups), + flip_(flip) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::vector vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector& outputs) override; + + DEFINE_VMAP() + DEFINE_NAME(Convolution) + bool is_equivalent(const Primitive& other) const override; + std::vector output_shapes(const std::vector& inputs) override; + auto state() const { + return std::make_tuple( + kernel_strides_, + padding_lo_, + padding_hi_, + kernel_dilation_, + input_dilation_, + groups_, + flip_); + } + + static Shape conv_out_shape( + const Shape& in_shape, + const Shape& wt_shape, + const std::vector& strides, + const std::vector& pads_lo, + const std::vector& pads_hi, + const std::vector& kernel_dilation, + const std::vector& input_dilation); + + private: + std::vector padding_lo_; + std::vector padding_hi_; + std::vector kernel_strides_; + std::vector kernel_dilation_; + std::vector input_dilation_; + int groups_; + bool flip_; +}; + +class Copy : public UnaryPrimitive { + public: + explicit Copy(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Copy) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class Cos : public UnaryPrimitive { + public: + explicit Cos(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Cos) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class Cosh : public UnaryPrimitive { + public: + explicit Cosh(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Cosh) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class CustomTransforms : public Primitive { + public: + explicit CustomTransforms( + Stream stream, + int num_outputs, + std::function( + const std::vector&, + const std::vector&, + const std::vector&)> vjp, + std::function( + const std::vector&, + const std::vector&, + const std::vector&)> jvp, + std::function, std::vector>( + const std::vector&, + const std::vector&)> vmap) + : Primitive(stream), + num_outputs_(num_outputs), + vjp_fun_(std::move(vjp)), + jvp_fun_(std::move(jvp)), + vmap_fun_(std::move(vmap)) {} + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + DEFINE_GRADS(); + DEFINE_VMAP(); + DEFINE_NAME(CustomTransforms); + + private: + void eval(const std::vector& inputs, std::vector& outputs); + + int num_outputs_; + + std::function( + const std::vector&, + const std::vector&, + const std::vector&)> + vjp_fun_; + std::function( + const std::vector&, + const std::vector&, + const std::vector&)> + jvp_fun_; + std::function, std::vector>( + const std::vector&, + const std::vector&)> + vmap_fun_; +}; + +class Depends : public Primitive { + public: + explicit Depends(Stream stream) : Primitive(stream) {} + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + std::vector vjp( + const std::vector& primals, + const std::vector& cotan, + const std::vector& argnums, + const std::vector& outputs) override; + + DEFINE_NAME(Depends); + + private: + void eval(const std::vector& inputs, std::vector& outputs); +}; + +class Divide : public UnaryPrimitive { + public: + explicit Divide(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Divide) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class DivMod : public Primitive { + public: + explicit DivMod(Stream stream) : Primitive(stream) {} + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(DivMod) + DEFINE_DEFAULT_IS_EQUIVALENT() + std::vector output_shapes(const std::vector& inputs) override { + return std::vector{inputs[0].shape(), inputs[0].shape()}; + } +}; + +class Select : public UnaryPrimitive { + public: + explicit Select(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Select) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class Remainder : public UnaryPrimitive { + public: + explicit Remainder(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Remainder) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class Equal : public UnaryPrimitive { + public: + explicit Equal(Stream stream, bool equal_nan = false) + : UnaryPrimitive(stream), equal_nan_(equal_nan) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() + + const char* name() const override { + if (equal_nan_) { + return "NaNEqual"; + } else { + return "Equal"; + } + } + auto state() const { + return equal_nan_; + }; + + private: + bool equal_nan_; +}; + +class Erf : public UnaryPrimitive { + public: + explicit Erf(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Erf) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class ErfInv : public UnaryPrimitive { + public: + explicit ErfInv(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(ErfInv) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class MLX_API Exp : public UnaryPrimitive { + public: + explicit Exp(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Exp) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class Expm1 : public UnaryPrimitive { + public: + explicit Expm1(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Expm1) + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class ExpandDims : public UnaryPrimitive { + public: + explicit ExpandDims(Stream stream, std::vector axes) + : UnaryPrimitive(stream), axes_(std::move(axes)) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(ExpandDims) + + std::vector output_shapes(const std::vector& inputs) override; + bool is_equivalent(const Primitive& other) const override; + + static Shape output_shape(const array& input, const std::vector& axes); + auto state() const { + return axes_; + } + + private: + void eval(const std::vector& inputs, array& out); + std::vector axes_; +}; + +class FFT : public UnaryPrimitive { + public: + explicit FFT( + Stream stream, + const std::vector& axes, + bool inverse, + bool real) + : UnaryPrimitive(stream), axes_(axes), inverse_(inverse), real_(real) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(FFT) + + bool is_equivalent(const Primitive& other) const override; + auto state() const { + return std::make_tuple(axes_, inverse_, real_); + } + + private: + std::vector axes_; + bool inverse_; + bool real_; +}; + +class Flatten : public UnaryPrimitive { + public: + explicit Flatten(Stream stream, int start_axis, int end_axis) + : UnaryPrimitive(stream), start_axis_(start_axis), end_axis_(end_axis) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Flatten) + std::vector output_shapes(const std::vector& inputs) override; + bool is_equivalent(const Primitive& other) const override; + + static Shape output_shape(const array& input, int start_axis, int end_axis); + auto state() const { + return std::make_pair(start_axis_, end_axis_); + } + + private: + int start_axis_; + int end_axis_; + void eval(const std::vector& inputs, array& out); +}; + +class Floor : public UnaryPrimitive { + public: + explicit Floor(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Floor) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class Full : public UnaryPrimitive { + public: + explicit Full(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Full) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class Gather : public UnaryPrimitive { + public: + explicit Gather(Stream stream, std::vector axes, Shape slice_sizes) + : UnaryPrimitive(stream), + axes_(std::move(axes)), + slice_sizes_(std::move(slice_sizes)) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Gather) + bool is_equivalent(const Primitive& other) const override; + std::vector output_shapes(const std::vector& inputs) override; + std::pair, Shape> state() const { + return {axes_, slice_sizes_}; + } + + private: + std::vector axes_; + Shape slice_sizes_; +}; + +class GatherAxis : public UnaryPrimitive { + public: + explicit GatherAxis(Stream stream, int axis) + : UnaryPrimitive(stream), axis_(axis) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(GatherAxis) + bool is_equivalent(const Primitive& other) const override; + std::vector output_shapes(const std::vector& inputs) override; + auto state() const { + return axis_; + } + + private: + int axis_; +}; + +class Greater : public UnaryPrimitive { + public: + explicit Greater(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Greater) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class GreaterEqual : public UnaryPrimitive { + public: + explicit GreaterEqual(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(GreaterEqual) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class Hadamard : public UnaryPrimitive { + public: + explicit Hadamard(Stream stream, float scale) + : UnaryPrimitive(stream), scale_(scale) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Hadamard) + DEFINE_INPUT_OUTPUT_SHAPE() + + bool is_equivalent(const Primitive& other) const override; + auto state() const { + return scale_; + } + + private: + float scale_; +}; + +class Imag : public UnaryPrimitive { + public: + explicit Imag(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Imag) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class Less : public UnaryPrimitive { + public: + explicit Less(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Less) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class LessEqual : public UnaryPrimitive { + public: + explicit LessEqual(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(LessEqual) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class Load : public UnaryPrimitive { + public: + explicit Load( + Stream stream, + std::shared_ptr reader, + size_t offset, + bool swap_endianness = false) + : UnaryPrimitive(stream), + reader_(std::move(reader)), + offset_(offset), + swap_endianness_(swap_endianness) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_NAME(Load) + + private: + std::shared_ptr reader_; + size_t offset_; + bool swap_endianness_; +}; + +class Log : public UnaryPrimitive { + public: + enum Base { two, ten, e }; + + explicit Log(Stream stream, Base base) + : UnaryPrimitive(stream), base_(base) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() + + Base state() const { + return base_; + }; + + const char* name() const override { + switch (base_) { + case e: + return "Log"; + case two: + return "Log2"; + case ten: + return "Log10"; + } + return ""; + } + + private: + Base base_; +}; + +class Log1p : public UnaryPrimitive { + public: + explicit Log1p(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Log1p) + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class LogicalNot : public UnaryPrimitive { + public: + explicit LogicalNot(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(LogicalNot) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class LogicalAnd : public UnaryPrimitive { + public: + explicit LogicalAnd(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(LogicalAnd) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class LogicalOr : public UnaryPrimitive { + public: + explicit LogicalOr(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(LogicalOr) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class LogAddExp : public UnaryPrimitive { + public: + explicit LogAddExp(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(LogAddExp) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class LogSumExp : public UnaryPrimitive { + public: + explicit LogSumExp(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(LogSumExp) + DEFINE_DEFAULT_IS_EQUIVALENT() + std::vector output_shapes(const std::vector& inputs) override; +}; + +class Matmul : public UnaryPrimitive { + public: + explicit Matmul(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_GRADS() + DEFINE_VMAP() + DEFINE_NAME(Matmul) + DEFINE_DEFAULT_IS_EQUIVALENT() + std::vector output_shapes(const std::vector& inputs) override; +}; + +class Maximum : public UnaryPrimitive { + public: + explicit Maximum(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Maximum) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class Minimum : public UnaryPrimitive { + public: + explicit Minimum(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Minimum) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class Multiply : public UnaryPrimitive { + public: + explicit Multiply(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Multiply) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class Negative : public UnaryPrimitive { + public: + explicit Negative(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Negative) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class NotEqual : public UnaryPrimitive { + public: + explicit NotEqual(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(NotEqual) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class NumberOfElements : public UnaryPrimitive { + public: + explicit NumberOfElements( + Stream stream, + std::vector axes, + bool inverted, + Dtype dtype) + : UnaryPrimitive(stream), + axes_(std::move(axes)), + inverted_(inverted), + dtype_(dtype) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_NAME(NumberOfElements) + bool is_equivalent(const Primitive& other) const override; + std::vector output_shapes(const std::vector& inputs) override { + return {{}}; + } + std::tuple, bool, Dtype> state() const { + return {axes_, inverted_, dtype_}; + } + + private: + std::vector axes_; + bool inverted_; + Dtype dtype_; + + void eval(const std::vector& inputs, array& out); +}; + +class Pad : public UnaryPrimitive { + public: + explicit Pad( + Stream stream, + const std::vector& axes, + const Shape& low_pad_size, + const Shape& high_pad_size) + : UnaryPrimitive(stream), + axes_(axes), + low_pad_size_(low_pad_size), + high_pad_size_(high_pad_size) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Pad) + bool is_equivalent(const Primitive& other) const override; + auto state() const { + return std::make_tuple(axes_, low_pad_size_, high_pad_size_); + } + + private: + std::vector axes_; + Shape low_pad_size_; + Shape high_pad_size_; +}; + +class Partition : public UnaryPrimitive { + public: + explicit Partition(Stream stream, int kth, int axis) + : UnaryPrimitive(stream), kth_(kth), axis_(axis) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Partition) + DEFINE_INPUT_OUTPUT_SHAPE() + bool is_equivalent(const Primitive& other) const override; + auto state() const { + return std::make_pair(kth_, axis_); + }; + + private: + int kth_; + int axis_; +}; + +class Power : public UnaryPrimitive { + public: + explicit Power(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Power) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class QuantizedMatmul : public UnaryPrimitive { + public: + explicit QuantizedMatmul( + Stream stream, + int group_size, + int bits, + QuantizationMode mode, + bool transpose) + : UnaryPrimitive(stream), + group_size_(group_size), + bits_(bits), + mode_(mode), + transpose_(transpose) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(QuantizedMatmul) + bool is_equivalent(const Primitive& other) const override; + std::vector output_shapes(const std::vector& inputs) override; + auto state() const { + return std::make_tuple(group_size_, bits_, mode_, transpose_); + } + + private: + int group_size_; + int bits_; + QuantizationMode mode_; + bool transpose_; +}; + +class QQMatmul : public UnaryPrimitive { + public: + explicit QQMatmul( + Stream stream, + int group_size, + int bits, + QuantizationMode mode) + : UnaryPrimitive(stream), + group_size_(group_size), + bits_(bits), + mode_(mode) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + // DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(QQMatmul) + bool is_equivalent(const Primitive& other) const override; + std::vector output_shapes(const std::vector& inputs) override; + auto state() const { + return std::make_tuple(group_size_, bits_, mode_); + } + + private: + int group_size_; + int bits_; + QuantizationMode mode_; +}; + +class GatherQMM : public UnaryPrimitive { + public: + explicit GatherQMM( + Stream stream, + int group_size, + int bits, + QuantizationMode mode, + bool transpose, + bool left_sorted = false, + bool right_sorted = false) + : UnaryPrimitive(stream), + group_size_(group_size), + bits_(bits), + mode_(mode), + transpose_(transpose), + left_sorted_(left_sorted), + right_sorted_(right_sorted) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(GatherQMM) + bool is_equivalent(const Primitive& other) const override; + auto state() const { + return std::make_tuple( + group_size_, bits_, mode_, transpose_, left_sorted_, right_sorted_); + } + + private: + int group_size_; + int bits_; + QuantizationMode mode_; + bool transpose_; + bool left_sorted_; + bool right_sorted_; +}; + +class RandomBits : public UnaryPrimitive { + public: + explicit RandomBits(Stream stream, const Shape& shape, int width) + : UnaryPrimitive(stream), shape_(shape), width_(width) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_NAME(RandomBits) + bool is_equivalent(const Primitive& other) const override; + std::pair state() const { + return {shape_, width_}; + }; + + private: + Shape shape_; + int width_; +}; + +class Real : public UnaryPrimitive { + public: + explicit Real(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Real) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class Reshape : public UnaryPrimitive { + public: + explicit Reshape(Stream stream, const Shape& shape) + : UnaryPrimitive(stream), shape_(shape) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Reshape) + bool is_equivalent(const Primitive& other) const override; + Shape state() const { + return shape_; + }; + static Shape output_shape(const array& input, Shape shape); + std::vector output_shapes(const std::vector& inputs) override; + + private: + Shape shape_; +}; + +class MLX_API Reduce : public UnaryPrimitive { + public: + enum ReduceType { And, Or, Sum, Prod, Min, Max }; + + explicit Reduce( + Stream stream, + ReduceType reduce_type, + const std::vector& axes) + : UnaryPrimitive(stream), reduce_type_(reduce_type), axes_(axes) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS(); + + std::vector output_shapes(const std::vector& inputs) override; + + const char* name() const override { + switch (reduce_type_) { + case And: + return "And"; + case Or: + return "Or"; + case Sum: + return "Sum"; + case Prod: + return "Prod"; + case Min: + return "Min"; + case Max: + return "Max"; + } + return ""; + } + + bool is_equivalent(const Primitive& other) const override; + std::pair> state() const { + return {reduce_type_, axes_}; + }; + + private: + ReduceType reduce_type_; + std::vector axes_; +}; + +class Round : public UnaryPrimitive { + public: + explicit Round(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Round) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class Scan : public UnaryPrimitive { + public: + enum ReduceType { Max, Min, Sum, Prod, LogAddExp }; + + explicit Scan( + Stream stream, + ReduceType reduce_type, + int axis, + bool reverse, + bool inclusive) + : UnaryPrimitive(stream), + reduce_type_(reduce_type), + axis_(axis), + reverse_(reverse), + inclusive_(inclusive) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS(); + + const char* name() const override { + switch (reduce_type_) { + case Sum: + return "CumSum"; + case Prod: + return "CumProd"; + case Min: + return "CumMin"; + case Max: + return "CumMax"; + case LogAddExp: + return "CumLogAddExp"; + } + return ""; + } + + bool is_equivalent(const Primitive& other) const override; + auto state() const { + return std::make_tuple(reduce_type_, axis_, reverse_, inclusive_); + } + + private: + ReduceType reduce_type_; + int axis_; + bool reverse_; + bool inclusive_; +}; + +class Scatter : public UnaryPrimitive { + public: + enum ReduceType { Max, Min, Sum, Prod, None }; + + explicit Scatter( + Stream stream, + ReduceType reduce_type, + const std::vector& axes) + : UnaryPrimitive(stream), reduce_type_(reduce_type), axes_(axes) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP(); + DEFINE_GRADS(); + + const char* name() const override { + switch (reduce_type_) { + case Sum: + return "Scatter Sum"; + case Prod: + return "Scatter Prod"; + case Min: + return "Scatter Min"; + case Max: + return "Scatter Max"; + case None: + return "Scatter"; + } + return ""; + } + + bool is_equivalent(const Primitive& other) const override; + std::pair> state() const { + return {reduce_type_, axes_}; + }; + + private: + ReduceType reduce_type_; + std::vector axes_; +}; + +class ScatterAxis : public UnaryPrimitive { + public: + enum ReduceType { Sum, None }; + + explicit ScatterAxis(Stream stream, ReduceType reduce_type, int axis) + : UnaryPrimitive(stream), reduce_type_(reduce_type), axis_(axis) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + + const char* name() const override { + switch (reduce_type_) { + case Sum: + return "ScatterAxis Sum"; + case None: + return "ScatterAxis"; + } + return ""; + } + + bool is_equivalent(const Primitive& other) const override; + std::vector output_shapes(const std::vector& inputs) override; + std::pair state() const { + return {reduce_type_, axis_}; + } + + private: + ReduceType reduce_type_; + int axis_; +}; + +class MaskedScatter : public UnaryPrimitive { + public: + explicit MaskedScatter(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP(); + DEFINE_GRADS(); + DEFINE_NAME(MaskedScatter); + DEFINE_DEFAULT_IS_EQUIVALENT(); + DEFINE_INPUT_OUTPUT_SHAPE(); +}; + +class Sigmoid : public UnaryPrimitive { + public: + explicit Sigmoid(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Sigmoid) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class Sign : public UnaryPrimitive { + public: + explicit Sign(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Sign) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class Sin : public UnaryPrimitive { + public: + explicit Sin(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Sin) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class Sinh : public UnaryPrimitive { + public: + explicit Sinh(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Sinh) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class Slice : public UnaryPrimitive { + public: + explicit Slice( + Stream stream, + const Shape& start_indices, + const Shape& end_indices, + const Shape& strides) + : UnaryPrimitive(stream), + start_indices_(start_indices), + end_indices_(end_indices), + strides_(strides) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Slice) + bool is_equivalent(const Primitive& other) const override; + auto state() const { + return std::make_tuple(start_indices_, end_indices_, strides_); + } + + private: + Shape start_indices_; + Shape end_indices_; + Shape strides_; +}; + +class SliceUpdate : public UnaryPrimitive { + public: + explicit SliceUpdate( + Stream stream, + const Shape& start_indices, + const Shape& end_indices, + const Shape& strides) + : UnaryPrimitive(stream), + start_indices_(start_indices), + end_indices_(end_indices), + strides_(strides) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(SliceUpdate) + bool is_equivalent(const Primitive& other) const override; + DEFINE_INPUT_OUTPUT_SHAPE() + auto state() const { + return std::make_tuple(start_indices_, end_indices_, strides_); + } + + private: + Shape start_indices_; + Shape end_indices_; + Shape strides_; +}; + +class DynamicSlice : public UnaryPrimitive { + public: + explicit DynamicSlice(Stream stream, std::vector axes, Shape slice_size) + : UnaryPrimitive(stream), + axes_(std::move(axes)), + slice_size_(std::move(slice_size)) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(DynamicSlice) + bool is_equivalent(const Primitive& other) const override; + std::vector output_shapes(const std::vector& inputs) override; + auto state() const { + return std::make_pair(axes_, slice_size_); + } + + private: + std::vector axes_; + Shape slice_size_; +}; + +class DynamicSliceUpdate : public UnaryPrimitive { + public: + explicit DynamicSliceUpdate(Stream stream, std::vector axes) + : UnaryPrimitive(stream), axes_(std::move(axes)) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(DynamicSliceUpdate) + bool is_equivalent(const Primitive& other) const override; + DEFINE_INPUT_OUTPUT_SHAPE() + auto state() const { + return axes_; + } + + private: + std::vector axes_; +}; + +class Softmax : public UnaryPrimitive { + public: + explicit Softmax(Stream stream, bool precise) + : UnaryPrimitive(stream), precise_(precise) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Softmax) + DEFINE_INPUT_OUTPUT_SHAPE() + + bool is_equivalent(const Primitive& other) const override; + auto state() const { + return precise_; + }; + + private: + bool precise_; +}; + +class Sort : public UnaryPrimitive { + public: + explicit Sort(Stream stream, int axis) + : UnaryPrimitive(stream), axis_(axis) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Sort) + DEFINE_INPUT_OUTPUT_SHAPE() + bool is_equivalent(const Primitive& other) const override; + auto state() const { + return axis_; + } + + private: + int axis_; +}; + +class Split : public Primitive { + public: + explicit Split(Stream stream, const Shape& indices, int axis) + : Primitive(stream), indices_(indices), axis_(axis) {} + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Split) + bool is_equivalent(const Primitive& other) const override; + std::pair state() const { + return {indices_, axis_}; + }; + + private: + void eval(const std::vector& inputs, std::vector& outputs); + + Shape indices_; + int axis_; +}; + +class Square : public UnaryPrimitive { + public: + explicit Square(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Square) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class Sqrt : public UnaryPrimitive { + public: + explicit Sqrt(Stream stream, bool recip = false) + : UnaryPrimitive(stream), recip_(recip) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_INPUT_OUTPUT_SHAPE() + bool is_equivalent(const Primitive& other) const override; + auto state() const { + return recip_; + } + + const char* name() const override { + if (recip_) { + return "Rsqrt"; + } else { + return "Sqrt"; + } + } + + private: + bool recip_; +}; + +class StopGradient : public UnaryPrimitive { + public: + explicit StopGradient(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_NAME(StopGradient) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class Subtract : public UnaryPrimitive { + public: + explicit Subtract(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Subtract) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class Squeeze : public UnaryPrimitive { + public: + explicit Squeeze(Stream stream, std::vector axes) + : UnaryPrimitive(stream), axes_(std::move(axes)) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Squeeze) + + std::vector output_shapes(const std::vector& inputs) override; + bool is_equivalent(const Primitive& other) const override; + + static Shape output_shape(const array& input, const std::vector& axes); + auto state() const { + return axes_; + }; + + private: + void eval(const std::vector& inputs, array& out); + std::vector axes_; +}; + +class Tan : public UnaryPrimitive { + public: + explicit Tan(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Tan) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class Tanh : public UnaryPrimitive { + public: + explicit Tanh(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Tanh) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class Unflatten : public UnaryPrimitive { + public: + explicit Unflatten(Stream stream, int axis, Shape shape) + : UnaryPrimitive(stream), axis_(axis), shape_(std::move(shape)) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Unflatten) + std::vector output_shapes(const std::vector& inputs) override; + bool is_equivalent(const Primitive& other) const override; + + static Shape output_shape(const array& input, int axis, const Shape& shape); + auto state() const { + return std::make_pair(axis_, shape_); + } + + private: + int axis_; + Shape shape_; + void eval(const std::vector& inputs, array& out); +}; + +class View : public UnaryPrimitive { + public: + explicit View(Stream stream, Dtype dtype) + : UnaryPrimitive(stream), dtype_(dtype) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + const char* name() const override; + bool is_equivalent(const Primitive& other) const override; + auto state() const { + return dtype_; + } + + private: + Dtype dtype_; + mutable std::string name_; +}; + +class Transpose : public UnaryPrimitive { + public: + explicit Transpose(Stream stream, const std::vector& axes) + : UnaryPrimitive(stream), axes_(axes) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Transpose) + bool is_equivalent(const Primitive& other) const override; + std::vector output_shapes(const std::vector& inputs) override; + std::vector state() const { + return axes_; + }; + + private: + std::vector axes_; + + void eval(const std::vector& inputs, array& out); +}; + +/* QR Factorization primitive. */ +class QRF : public Primitive { + public: + explicit QRF(Stream stream) : Primitive(stream) {} + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + DEFINE_NAME(QRF) +}; + +/* SVD primitive. */ +class SVD : public Primitive { + public: + explicit SVD(Stream stream, bool compute_uv) + : Primitive(stream), compute_uv_(compute_uv) {} + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + DEFINE_VMAP() + DEFINE_NAME(SVD) + auto state() const { + return compute_uv_; + } + + private: + bool compute_uv_; +}; + +/* Matrix inversion primitive. */ +class Inverse : public UnaryPrimitive { + public: + explicit Inverse(Stream stream, bool tri, bool upper) + : UnaryPrimitive(stream), tri_(tri), upper_(upper) {} + + void eval_cpu(const std::vector& inputs, array& output) override; + void eval_gpu(const std::vector& inputs, array& output) override; + + DEFINE_VMAP() + DEFINE_NAME(Inverse) + auto state() const { + return std::make_pair(tri_, upper_); + } + + private: + bool tri_; + bool upper_; +}; + +class Cholesky : public UnaryPrimitive { + public: + explicit Cholesky(Stream stream, bool upper) + : UnaryPrimitive(stream), upper_(upper) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + auto state() const { + return upper_; + } + + DEFINE_VMAP() + DEFINE_NAME(Cholesky) + + private: + bool upper_; +}; + +class Eig : public Primitive { + public: + explicit Eig(Stream stream, bool compute_eigenvectors) + : Primitive(stream), compute_eigenvectors_(compute_eigenvectors) {} + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + DEFINE_VMAP() + DEFINE_NAME(Eig) + + std::vector output_shapes(const std::vector& inputs) override; + + bool is_equivalent(const Primitive& other) const override; + auto state() const { + return compute_eigenvectors_; + } + + private: + bool compute_eigenvectors_; +}; + +class Eigh : public Primitive { + public: + explicit Eigh(Stream stream, std::string uplo, bool compute_eigenvectors) + : Primitive(stream), + uplo_(std::move(uplo)), + compute_eigenvectors_(compute_eigenvectors) {} + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + DEFINE_VMAP() + DEFINE_NAME(Eigh) + + std::vector output_shapes(const std::vector& inputs) override; + + bool is_equivalent(const Primitive& other) const override; + auto state() const { + return std::make_pair(uplo_, compute_eigenvectors_); + } + + private: + std::string uplo_; + bool compute_eigenvectors_; +}; + +/* LU Factorization primitive. */ +class LUF : public Primitive { + public: + explicit LUF(Stream stream) : Primitive(stream) {} + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + DEFINE_NAME(LUF) +}; + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/random.h b/Source/Cxxmlx/include/mlx/random.h new file mode 100644 index 00000000..a23c2557 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/random.h @@ -0,0 +1,283 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include +#include + +#include "mlx/api.h" +#include "mlx/array.h" +#include "mlx/stream.h" +#include "mlx/utils.h" + +namespace mlx::core::random { + +class MLX_API KeySequence { + public: + explicit KeySequence(uint64_t seed); + + void seed(uint64_t seed); + array next(); + + // static default + static KeySequence& default_() { + static KeySequence ks(get_current_time_seed()); + return ks; + } + + private: + array key_; + static uint64_t get_current_time_seed() { + auto now = std::chrono::system_clock::now(); + return std::chrono::duration_cast( + now.time_since_epoch()) + .count(); + } +}; + +/** Get a PRNG key from a seed. */ +MLX_API array key(uint64_t seed); + +/** Seed the default PRNG key. */ +MLX_API void seed(uint64_t seed); + +/** Generate an array with type uint32 filled with random bits. */ +MLX_API array bits( + const Shape& shape, + int width, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); +inline array bits( + const Shape& shape, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}) { + return bits(shape, 4, key, s); +} + +/** Split the rng key into a pair of keys. */ +MLX_API std::pair split(const array& key, StreamOrDevice s = {}); + +/** Split the rng key into `num` keys. */ +MLX_API array split(const array& key, int num, StreamOrDevice s = {}); + +/** Generate uniform random numbers between low and high. */ +MLX_API array uniform( + const array& low, + const array& high, + const Shape& shape, + Dtype dtype = float32, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); + +template +array uniform( + T low, + U high, + const Shape& shape, + Dtype dtype = float32, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}) { + return uniform(array(low), array(high), shape, dtype, key, to_stream(s)); +} + +/** Generate uniform random numbers between 0 and 1. */ +MLX_API array uniform( + const Shape& shape, + Dtype dtype, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); +inline array uniform( + const Shape& shape, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}) { + return uniform(shape, float32, key, s); +} + +/** Generate samples from the standard normal distribution. */ +MLX_API array normal( + const Shape& shape, + Dtype dtype, + const std::optional& loc, + const std::optional& scale, + const std::optional& key, + StreamOrDevice s = {}); +inline array normal( + const Shape& shape, + Dtype dtype, + const float loc, + const float scale, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}) { + auto loc_ = loc == 0 ? std::nullopt : std::make_optional(array(loc, dtype)); + auto scale_ = + scale == 1 ? std::nullopt : std::make_optional(array(scale, dtype)); + return normal(shape, dtype, loc_, scale_, key, s); +} +inline array normal( + const Shape& shape, + const float loc, + const float scale, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}) { + return normal(shape, float32, loc, scale, key, s); +} +inline array normal( + const Shape& shape, + const Dtype dtype, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}) { + return normal(shape, dtype, std::nullopt, std::nullopt, key, s); +} +inline array normal( + const Shape& shape, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}) { + return normal(shape, float32, std::nullopt, std::nullopt, key, s); +} + +/** Generate samples from a multivariate normal distribution. **/ +MLX_API array multivariate_normal( + const array& mean, + const array& cov, + const Shape& shape, + Dtype dtype, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); + +/** Generate integer samples uniformly at random */ +MLX_API array randint( + const array& low, + const array& high, + const Shape& shape, + Dtype dtype = int32, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); + +template +array randint( + T low, + U high, + const Shape& shape, + Dtype dtype = int32, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}) { + return randint(array(low), array(high), shape, dtype, key, to_stream(s)); +} + +/** Generate binary variables with probability to be true equal to p */ +MLX_API array bernoulli( + const array& p, + const Shape& shape, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); +MLX_API array bernoulli( + const array& p, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); + +template +array bernoulli( + T p, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}) { + return bernoulli(array(p), key, s); +} + +template +array bernoulli( + T p, + const Shape& shape, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}) { + return bernoulli(array(p), shape, key, s); +} + +MLX_API array bernoulli( + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); + +MLX_API array truncated_normal( + const array& lower, + const array& upper, + const Shape& shape, + Dtype dtype = float32, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); + +MLX_API array truncated_normal( + const array& lower, + const array& upper, + Dtype dtype = float32, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); + +MLX_API array gumbel( + const Shape& shape, + Dtype dtype = float32, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); + +MLX_API array categorical( + const array& logits, + int axis, + const Shape& shape, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); + +MLX_API array categorical( + const array& logits_, + int axis, + int num_samples, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); + +MLX_API array categorical( + const array& logits, + int axis = -1, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); + +/** Generate samples from the laplace distribution. */ +MLX_API array laplace( + const Shape& shape, + Dtype dtype, + const float loc, + const float scale, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); +inline array laplace( + const Shape& shape, + const float loc, + const float scale, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}) { + return laplace(shape, float32, loc, scale, key, s); +} +inline array laplace( + const Shape& shape, + const Dtype dtype, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}) { + return laplace(shape, dtype, 0.0, 1.0, key, s); +} +inline array laplace( + const Shape& shape, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}) { + return laplace(shape, float32, 0.0, 1.0, key, s); +} + +/* Randomly permute the elements of x along the given axis. */ +MLX_API array permutation( + const array& x, + int axis = 0, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); + +/* A random permutation of `arange(x)` */ +MLX_API array permutation( + int x, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); + +} // namespace mlx::core::random diff --git a/Source/Cxxmlx/include/mlx/scheduler.h b/Source/Cxxmlx/include/mlx/scheduler.h new file mode 100644 index 00000000..c94044a7 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/scheduler.h @@ -0,0 +1,192 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include +#include +#include +#include +#include + +#include "mlx/api.h" +#include "mlx/backend/gpu/eval.h" +#include "mlx/device.h" +#include "mlx/stream.h" + +namespace mlx::core::scheduler { + +struct StreamThread { + std::mutex mtx; + std::queue> q; + std::condition_variable cond; + bool stop; + std::thread thread; + + StreamThread() : stop(false), thread(&StreamThread::thread_fn, this) {} + + ~StreamThread() { + { + std::lock_guard lk(mtx); + stop = true; + } + cond.notify_one(); + thread.join(); + } + + void thread_fn() { + while (true) { + std::function task; + { + std::unique_lock lk(mtx); + cond.wait(lk, [this] { return !this->q.empty() || this->stop; }); + if (q.empty() && stop) { + return; + } + task = std::move(q.front()); + q.pop(); + } + + task(); + } + } + + template + void enqueue(F&& f) { + { + std::lock_guard lk(mtx); + if (stop) { + throw std::runtime_error( + "Cannot enqueue work after stream is stopped."); + } + q.emplace(std::forward(f)); + } + cond.notify_one(); + } +}; + +class Scheduler { + public: + Scheduler() : n_active_tasks_(0) { + if (is_available(Device::gpu)) { + default_streams_.insert({Device::gpu, new_stream(Device::gpu)}); + } + default_streams_.insert({Device::cpu, new_stream(Device::cpu)}); + } + + // Not copyable or moveable + Scheduler(const Scheduler&) = delete; + Scheduler(Scheduler&&) = delete; + Scheduler& operator=(const Scheduler&) = delete; + Scheduler& operator=(Scheduler&&) = delete; + + Stream new_stream(const Device& d) { + streams_.emplace_back(streams_.size(), d); + if (d == Device::gpu) { + threads_.push_back(nullptr); + gpu::new_stream(streams_.back()); + } else { + threads_.push_back(new StreamThread{}); + } + return streams_.back(); + } + + template + void enqueue(const Stream& stream, F&& f); + + Stream get_default_stream(const Device& d) const { + return default_streams_.at(d.type); + } + Stream get_stream(int index) const { + return streams_.at(index); + } + std::vector get_streams() const { + return streams_; + } + + void set_default_stream(const Stream& s) { + default_streams_.at(s.device.type) = s; + } + + void notify_new_task(const Stream& stream) { + { + std::lock_guard lk(mtx); + n_active_tasks_++; + } + completion_cv.notify_all(); + } + + void notify_task_completion(const Stream& stream) { + { + std::lock_guard lk(mtx); + n_active_tasks_--; + } + completion_cv.notify_all(); + } + + int n_active_tasks() const { + return n_active_tasks_; + } + + void wait_for_one() { + std::unique_lock lk(mtx); + int n_tasks_old = n_active_tasks(); + if (n_tasks_old > 1) { + completion_cv.wait(lk, [this, n_tasks_old] { + return this->n_active_tasks() < n_tasks_old; + }); + } + } + + ~Scheduler() { + for (auto s : streams_) { + try { + synchronize(s); + } catch (const std::runtime_error&) { + // ignore errors if synch fails + } + } + for (auto t : threads_) { + if (t != nullptr) { + delete t; + } + } + } + + private: + int n_active_tasks_; + std::vector threads_; + std::vector streams_; + std::unordered_map default_streams_; + std::condition_variable completion_cv; + std::mutex mtx; +}; + +template +void Scheduler::enqueue(const Stream& stream, F&& f) { + threads_[stream.index]->enqueue(std::forward(f)); +} + +MLX_API Scheduler& scheduler(); + +template +void enqueue(const Stream& stream, F&& f) { + scheduler().enqueue(stream, std::forward(f)); +} + +inline int n_active_tasks() { + return scheduler().n_active_tasks(); +} + +inline void notify_new_task(const Stream& stream) { + scheduler().notify_new_task(stream); +} + +inline void notify_task_completion(const Stream& stream) { + scheduler().notify_task_completion(stream); +} + +inline void wait_for_one() { + scheduler().wait_for_one(); +} + +} // namespace mlx::core::scheduler diff --git a/Source/Cxxmlx/include/mlx/small_vector.h b/Source/Cxxmlx/include/mlx/small_vector.h new file mode 100644 index 00000000..143101c8 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/small_vector.h @@ -0,0 +1,540 @@ +// Copyright © 2025 Apple Inc. +// Copyright © 2018 the V8 project authors. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following +// disclaimer in the documentation and/or other materials provided +// with the distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#pragma once + +#include +#include +#include +#include + +namespace mlx::core { + +#if defined(__has_builtin) +#define MLX_HAS_BUILTIN(x) __has_builtin(x) +#else +#define MLX_HAS_BUILTIN(x) 0 +#endif + +#if defined(__has_attribute) +#define MLX_HAS_ATTRIBUTE(x) __has_attribute(x) +#else +#define MLX_HAS_ATTRIBUTE(x) 0 +#endif + +#if MLX_HAS_BUILTIN(__builtin_expect) +#define MLX_LIKELY(condition) (__builtin_expect(!!(condition), 1)) +#define MLX_UNLIKELY(condition) (__builtin_expect(!!(condition), 0)) +#else +#define MLX_LIKELY(condition) (condition) +#define MLX_UNLIKELY(condition) (condition) +#endif + +#if MLX_HAS_ATTRIBUTE(noinline) +#define MLX_NOINLINE __attribute__((noinline)) +#else +#define MLX_NOINLINE +#endif + +template +struct is_iterator : std::false_type {}; + +template +struct is_iterator< + T, + std::void_t< + typename std::iterator_traits::difference_type, + typename std::iterator_traits::iterator_category, + typename std::iterator_traits::pointer, + typename std::iterator_traits::reference, + typename std::iterator_traits::value_type>> : std::true_type {}; + +template +constexpr bool is_iterator_v = is_iterator::value; + +// Minimal SmallVector implementation. Uses inline storage first, switches to +// dynamic storage when it overflows. +// +// Notes: +// * The default inline storage size is MAX_NDIM, as it is mainly used for +// shapes and strides, users should choose a better size for other cases. +// * The data() returns real address even for empty vector. +// * The pointer returned by data() will change after moving the vector as it +// points to the inline storage. +// * For trivial elements the storage will not be default constructed, +// i.e. SmallVector(10) will not be filled with 0 by default. +template > +class SmallVector { + public: + using value_type = T; + using reference = T&; + using const_reference = const T&; + using iterator = T*; + using const_iterator = const T*; + using difference_type = std::ptrdiff_t; + using size_type = std::size_t; + + SmallVector() = default; + + explicit SmallVector(const Allocator& allocator) : allocator_(allocator) {} + + explicit SmallVector(size_t size, const Allocator& allocator = Allocator()) + : allocator_(allocator) { + resize(size); + } + + SmallVector( + size_t size, + const T& initial_value, + const Allocator& allocator = Allocator()) + : allocator_(allocator) { + resize(size, initial_value); + } + + SmallVector( + std::initializer_list init, + const Allocator& allocator = Allocator()) + : allocator_(allocator) { + if (init.size() > capacity()) { + grow(init.size()); + } + assert(capacity() >= init.size()); // sanity check + std::uninitialized_move(init.begin(), init.end(), begin_); + end_ = begin_ + init.size(); + } + + template >> + SmallVector(Iter begin, Iter end, const Allocator& allocator = Allocator()) + : allocator_(allocator) { + size_t size = std::distance(begin, end); + if (size > capacity()) { + grow(size); + } + assert(capacity() >= size); // sanity check + std::uninitialized_copy(begin, end, begin_); + end_ = begin_ + size; + } + + SmallVector(const SmallVector& other) : allocator_(other.allocator_) { + *this = other; + } + SmallVector(const SmallVector& other, const Allocator& allocator) + : allocator_(allocator) { + *this = other; + } + SmallVector(SmallVector&& other) : allocator_(std::move(other.allocator_)) { + *this = std::move(other); + } + SmallVector(SmallVector&& other, const Allocator& allocator) + : allocator_(allocator) { + *this = std::move(other); + } + + ~SmallVector() { + free_storage(); + } + + SmallVector& operator=(const SmallVector& other) { + if (this == &other) { + return *this; + } + size_t other_size = other.size(); + if (capacity() < other_size) { + // Create large-enough heap-allocated storage. + free_storage(); + begin_ = allocator_.allocate(other_size); + end_of_storage_ = begin_ + other_size; + std::uninitialized_copy(other.begin_, other.end_, begin_); + } else if constexpr (kHasTrivialElement) { + std::copy(other.begin_, other.end_, begin_); + } else { + ptrdiff_t to_copy = + std::min(static_cast(other_size), end_ - begin_); + std::copy(other.begin_, other.begin_ + to_copy, begin_); + if (other.begin_ + to_copy < other.end_) { + std::uninitialized_copy( + other.begin_ + to_copy, other.end_, begin_ + to_copy); + } else { + std::destroy_n(begin_ + to_copy, size() - to_copy); + } + } + end_ = begin_ + other_size; + return *this; + } + + SmallVector& operator=(SmallVector&& other) { + if (this == &other) { + return *this; + } + if (other.is_big()) { + free_storage(); + begin_ = other.begin_; + end_ = other.end_; + end_of_storage_ = other.end_of_storage_; + } else { + assert(capacity() >= other.size()); // sanity check + size_t other_size = other.size(); + if constexpr (kHasTrivialElement) { + std::move(other.begin_, other.end_, begin_); + } else { + ptrdiff_t to_move = + std::min(static_cast(other_size), end_ - begin_); + std::move(other.begin_, other.begin_ + to_move, begin_); + if (other.begin_ + to_move < other.end_) { + std::uninitialized_move( + other.begin_ + to_move, other.end_, begin_ + to_move); + } else { + std::destroy_n(begin_ + to_move, size() - to_move); + } + } + end_ = begin_ + other_size; + } + other.reset_to_inline_storage(); + return *this; + } + + bool operator==(const SmallVector& other) const { + if (size() != other.size()) { + return false; + } + return std::equal(begin_, end_, other.begin_); + } + + bool operator!=(const SmallVector& other) const { + return !(*this == other); + } + + T* data() { + return begin_; + } + const T* data() const { + return begin_; + } + + iterator begin() { + return begin_; + } + const_iterator begin() const { + return begin_; + } + + iterator end() { + return end_; + } + const_iterator end() const { + return end_; + } + + const_iterator cbegin() const { + return begin_; + } + + const_iterator cend() const { + return end_; + } + + auto rbegin() { + return std::make_reverse_iterator(end_); + } + auto rbegin() const { + return std::make_reverse_iterator(end_); + } + + auto rend() { + return std::make_reverse_iterator(begin_); + } + auto rend() const { + return std::make_reverse_iterator(begin_); + } + + size_t size() const { + return end_ - begin_; + } + bool empty() const { + return end_ == begin_; + } + size_t capacity() const { + return end_of_storage_ - begin_; + } + + T& front() { + assert(size() != 0); + return begin_[0]; + } + const T& front() const { + assert(size() != 0); + return begin_[0]; + } + + T& back() { + assert(size() != 0); + return end_[-1]; + } + const T& back() const { + assert(size() != 0); + return end_[-1]; + } + + T& at(size_t index) { + if (index >= size()) { + throw std::out_of_range("SmallVector out of range."); + } + return begin_[index]; + } + const T& at(size_t index) const { + return const_cast(this)->at(index); + } + + T& operator[](size_t index) { + assert(size() > index); + return begin_[index]; + } + const T& operator[](size_t index) const { + return const_cast(this)->operator[](index); + } + + template + void emplace_back(Args&&... args) { + if (MLX_UNLIKELY(end_ == end_of_storage_)) { + grow(); + } + void* storage = end_; + end_ += 1; + new (storage) T(std::forward(args)...); + } + + void push_back(T x) { + emplace_back(std::move(x)); + } + + void pop_back(size_t count = 1) { + assert(size() >= count); + end_ -= count; + std::destroy_n(end_, count); + } + + iterator insert(iterator pos, T value) { + return insert(pos, static_cast(1), std::move(value)); + } + + iterator insert(iterator pos, size_t count, T value) { + assert(pos <= end_); + size_t offset = pos - begin_; + size_t old_size = size(); + resize(old_size + count); + pos = begin_ + offset; + iterator old_end = begin_ + old_size; + assert(old_end <= end_); + std::move_backward(pos, old_end, end_); + if constexpr (kHasTrivialElement) { + std::fill_n(pos, count, value); + } else { + std::fill_n(pos + 1, count - 1, value); + *pos = std::move(value); + } + return pos; + } + + template >> + iterator insert(iterator pos, Iter begin, Iter end) { + if constexpr (std::is_same_v, iterator>) { + // The implementation can not take overlapping range. + assert(!(begin >= pos && begin < pos + std::distance(begin, end))); + assert(!(end > pos && end <= pos + std::distance(begin, end))); + } + + assert(pos <= end_); + size_t offset = pos - begin_; + size_t count = std::distance(begin, end); + size_t old_size = size(); + resize(old_size + count); + pos = begin_ + offset; + iterator old_end = begin_ + old_size; + assert(old_end <= end_); + std::move_backward(pos, old_end, end_); + std::copy(begin, end, pos); + return pos; + } + + iterator insert(iterator pos, std::initializer_list values) { + return insert(pos, values.begin(), values.end()); + } + + iterator erase(iterator erase_start, iterator erase_end) { + assert(erase_start >= begin_); + assert(erase_start <= erase_end); + assert(erase_end <= end_); + iterator new_end = std::move(erase_end, end_, erase_start); + std::destroy_n(new_end, std::distance(new_end, end_)); + end_ = new_end; + return erase_start; + } + + iterator erase(iterator pos) { + return erase(pos, pos + 1); + } + + void resize(size_t new_size) { + if (new_size > capacity()) { + grow(new_size); + } + T* new_end = begin_ + new_size; + if constexpr (!kHasTrivialElement) { + if (new_end > end_) { + std::uninitialized_default_construct(end_, new_end); + } else { + std::destroy_n(new_end, end_ - new_end); + } + } + end_ = new_end; + } + + void resize(size_t new_size, const T& initial_value) { + if (new_size > capacity()) { + grow(new_size); + } + T* new_end = begin_ + new_size; + if (new_end > end_) { + std::uninitialized_fill(end_, new_end, initial_value); + } else { + std::destroy_n(new_end, end_ - new_end); + } + end_ = new_end; + } + + void reserve(size_t new_capacity) { + if (new_capacity > capacity()) { + grow(new_capacity); + } + } + + // Clear without reverting back to inline storage. + void clear() { + std::destroy_n(begin_, end_ - begin_); + end_ = begin_; + } + + private: + // Grows the backing store by a factor of two, and at least to {min_capacity}. + // TODO: Move to private after removing external code using this method. + MLX_NOINLINE void grow(size_t min_capacity = 0) { + size_t new_capacity = std::max(min_capacity, 2 * capacity()); + // Round up to power of 2. + new_capacity--; + new_capacity |= new_capacity >> 1; + new_capacity |= new_capacity >> 2; + new_capacity |= new_capacity >> 4; + new_capacity |= new_capacity >> 8; + new_capacity |= new_capacity >> 16; + if constexpr (sizeof(size_t) == sizeof(uint64_t)) { + new_capacity |= new_capacity >> 32; + } + new_capacity++; + + T* new_storage = allocator_.allocate(new_capacity); + if (new_storage == nullptr) { + throw std::bad_alloc(); + } + + size_t in_use = end_ - begin_; + std::uninitialized_move(begin_, end_, new_storage); + free_storage(); + begin_ = new_storage; + end_ = new_storage + in_use; + end_of_storage_ = new_storage + new_capacity; + } + + MLX_NOINLINE void free_storage() { + std::destroy_n(begin_, end_ - begin_); + if (is_big()) { + allocator_.deallocate(begin_, end_of_storage_ - begin_); + } + } + + // Clear and go back to inline storage. Dynamic storage is *not* freed. For + // internal use only. + void reset_to_inline_storage() { + if constexpr (!kHasTrivialElement) { + if (!is_big()) + std::destroy_n(begin_, end_ - begin_); + } + begin_ = inline_storage_begin(); + end_ = begin_; + end_of_storage_ = begin_ + kSize; + } + + bool is_big() const { + return begin_ != inline_storage_begin(); + } + + T* inline_storage_begin() { + return reinterpret_cast(inline_storage_); + } + const T* inline_storage_begin() const { + return reinterpret_cast(inline_storage_); + } + + Allocator allocator_; + + // Invariants: + // 1. The elements in the range between `begin_` (included) and `end_` (not + // included) will be initialized at all times. + // 2. All other elements outside the range, both in the inline storage and in + // the dynamic storage (if it exists), will be uninitialized at all times. + + T* begin_ = inline_storage_begin(); + T* end_ = begin_; + T* end_of_storage_ = begin_ + kSize; + + alignas(T) char inline_storage_[sizeof(T) * kSize]; + + static constexpr bool kHasTrivialElement = + std::is_trivially_copyable::value && + std::is_trivially_destructible::value; +}; + +template +struct is_vector : std::false_type {}; + +template +struct is_vector> : std::true_type {}; + +template +struct is_vector> : std::true_type {}; + +template +inline constexpr bool is_vector_v = is_vector::value; + +#undef MLX_HAS_BUILTIN +#undef MLX_HAS_ATTRIBUTE +#undef MLX_LIKELY +#undef MLX_UNLIKELY +#undef MLX_NOINLINE + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/stream.h b/Source/Cxxmlx/include/mlx/stream.h new file mode 100644 index 00000000..efe0ef1a --- /dev/null +++ b/Source/Cxxmlx/include/mlx/stream.h @@ -0,0 +1,47 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include + +#include "mlx/api.h" +#include "mlx/device.h" + +namespace mlx::core { + +struct MLX_API Stream { + int index; + Device device; + explicit Stream(int index, Device device) : index(index), device(device) {} +}; + +/** Get the default stream for the given device. */ +MLX_API Stream default_stream(Device d); + +/** Make the stream the default for its device. */ +MLX_API void set_default_stream(Stream s); + +/** Make a new stream on the given device. */ +MLX_API Stream new_stream(Device d); + +/** Get the stream with the given index. */ +MLX_API Stream get_stream(int index); + +/** Get all available streams. */ +MLX_API std::vector get_streams(); + +inline bool operator==(const Stream& lhs, const Stream& rhs) { + return lhs.index == rhs.index; +} + +inline bool operator!=(const Stream& lhs, const Stream& rhs) { + return !(lhs == rhs); +} + +/* Synchronize with the default stream. */ +MLX_API void synchronize(); + +/* Synchronize with the provided stream. */ +MLX_API void synchronize(Stream); + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/threadpool.h b/Source/Cxxmlx/include/mlx/threadpool.h new file mode 100644 index 00000000..b0e56d0f --- /dev/null +++ b/Source/Cxxmlx/include/mlx/threadpool.h @@ -0,0 +1,133 @@ +// This code was modified from https://github.com/progschj/ThreadPool +// The original License is copied below: +// +// Copyright (c) 2012 Jakob Progsch, Václav Zeman +// This software is provided 'as-is', without any express or implied +// warranty. In no event will the authors be held liable for any damages +// arising from the use of this software. +// +// Permission is granted to anyone to use this software for any purpose, +// including commercial applications, and to alter it and redistribute it +// freely, subject to the following restrictions: +// +// 1. The origin of this software must not be misrepresented; you must not +// claim that you wrote the original software. If you use this software +// in a product, an acknowledgment in the product documentation would be +// appreciated but is not required. +// +// 2. Altered source versions must be plainly marked as such, and must not be +// misrepresented as being the original software. +// +// 3. This notice may not be removed or altered from any source +// distribution. +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +class ThreadPool { + public: + ThreadPool(size_t); + template + auto enqueue(F&& f, Args&&... args) + -> std::future>; + void resize(size_t); + ~ThreadPool(); + + private: + void stop_and_wait(); + void start_threads(size_t); + + std::vector workers; + std::queue> tasks; + std::mutex queue_mutex; + std::condition_variable condition; + bool stop; +}; + +inline ThreadPool::ThreadPool(size_t threads) : stop(false) { + start_threads(threads); +} + +template +auto ThreadPool::enqueue(F&& f, Args&&... args) + -> std::future> { + using return_type = typename std::invoke_result_t; + + auto task = std::make_shared>( + std::bind(std::forward(f), std::forward(args)...)); + + std::future res = task->get_future(); + { + std::unique_lock lock(queue_mutex); + + if (stop) { + throw std::runtime_error( + "[ThreadPool::enqueue] Not allowed on stopped ThreadPool"); + } + + tasks.emplace([task]() { (*task)(); }); + } + condition.notify_one(); + return res; +} + +inline void ThreadPool::resize(size_t threads) { + if (workers.size() == threads) { + return; + } + + if (workers.size() > threads) { + stop_and_wait(); + } + start_threads(threads - workers.size()); +} + +inline ThreadPool::~ThreadPool() { + stop_and_wait(); +} + +inline void ThreadPool::stop_and_wait() { + // Stop the current threads and wait until they finish + { + std::unique_lock lock(queue_mutex); + stop = true; + } + condition.notify_all(); + for (std::thread& worker : workers) { + worker.join(); + } + + // Reset the member variables so that the threadpool is reusable + stop = false; + workers.clear(); +} + +inline void ThreadPool::start_threads(size_t threads) { + for (size_t i = 0; i < threads; ++i) { + workers.emplace_back([this] { + for (;;) { + std::function task; + + { + std::unique_lock lock(this->queue_mutex); + this->condition.wait( + lock, [this] { return this->stop || !this->tasks.empty(); }); + if (this->stop && this->tasks.empty()) + return; + task = std::move(this->tasks.front()); + this->tasks.pop(); + } + + task(); + } + }); + } +} diff --git a/Source/Cxxmlx/include/mlx/transforms.h b/Source/Cxxmlx/include/mlx/transforms.h new file mode 100644 index 00000000..1848be79 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/transforms.h @@ -0,0 +1,231 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include + +#include "mlx/api.h" +#include "mlx/array.h" + +namespace mlx::core { + +MLX_API void async_eval(std::vector outputs); + +template > +void async_eval(Arrays&&... outputs) { + async_eval(std::vector{std::forward(outputs)...}); +} + +MLX_API void eval(std::vector outputs); + +template > +void eval(Arrays&&... outputs) { + eval(std::vector{std::forward(outputs)...}); +} + +/** + * Computes the output and vector-Jacobian product (VJP) of a function. + * + * Computes the vector-Jacobian product of the vector of cotangents with the + * Jacobian of the function evaluated at the primals. Returns a pair of + * vectors of output arrays and VJP arrays. + **/ +MLX_API std::pair, std::vector> vjp( + const std::function(const std::vector&)>& fun, + const std::vector& primals, + const std::vector& cotangents); + +/** + * Computes the output and vector-Jacobian product (VJP) of a unary function. + */ +MLX_API std::pair vjp( + const std::function& fun, + const array& primal, + const array& cotangent); + +/** + * Computes the output and Jacobian-vector product (JVP) of a function. + * + * Computes the Jacobian-vector product of the Jacobian of the function + * evaluated at the primals with the vector of tangents. Returns a pair of + * vectors of output arrays and JVP arrays. + **/ +MLX_API std::pair, std::vector> jvp( + const std::function(const std::vector&)>& fun, + const std::vector& primals, + const std::vector& tangents); + +/** + * Computes the output and Jacobian-vector product (JVP) of a unary function. + */ +MLX_API std::pair jvp( + const std::function& fun, + const array& primal, + const array& tangent); + +// Return type of general value_and_grad: a function which takes an input +// vector of arrays and returns a pair of vectors of arrays one for the +// values and one for the gradients wrt the first value. +using ValueAndGradFn = + std::function, std::vector>( + const std::vector&)>; +using SimpleValueAndGradFn = std::function>( + const std::vector&)>; + +/** + * Returns a function which computes the value and gradient of the input + * function with respect to a vector of input arrays. + **/ +MLX_API ValueAndGradFn value_and_grad( + const std::function(const std::vector&)>& fun, + const std::vector& argnums); + +/** + * Returns a function which computes the value and gradient of the input + * function with respect to a single input array. + **/ +ValueAndGradFn inline value_and_grad( + const std::function(const std::vector&)>& fun, + int argnum = 0) { + return value_and_grad(fun, std::vector{argnum}); +} + +/** + * Returns a function which computes the value and gradient of the unary + * input function. + **/ +std::function(const array&)> inline value_and_grad( + const std::function& fun) { + return [fun](auto inputs) { return vjp(fun, inputs, array(1.0f)); }; +} + +SimpleValueAndGradFn inline value_and_grad( + const std::function&)>& fun, + const std::vector& argnums) { + return [fun, argnums](auto inputs) { + auto result = value_and_grad( + [fun](auto inputs) { return std::vector{fun(inputs)}; }, + argnums)(inputs); + + return std::make_pair(result.first[0], result.second); + }; +} + +SimpleValueAndGradFn inline value_and_grad( + const std::function&)>& fun, + int argnum = 0) { + return value_and_grad(fun, std::vector{argnum}); +} + +/** + * Returns a function which computes the gradient of the input function with + * respect to a vector of input arrays. + * + * The function being differentiated takes a vector of arrays and returns an + * array. The vector of `argnums` specifies which the arguments to compute + * the gradient with respect to. At least one argument must be specified. + **/ +std::function(const std::vector&)> inline grad( + const std::function&)>& fun, + const std::vector& argnums) { + auto fn = value_and_grad(fun, argnums); + return [fn](const std::vector& inputs) { return fn(inputs).second; }; +} + +/** + * Returns a function which computes the gradient of the input function with + * respect to a single input array. + * + * The function being differentiated takes a vector of arrays and returns an + * array. The optional `argnum` index specifies which the argument to compute + * the gradient with respect to and defaults to 0. + **/ +std::function(const std::vector&)> inline grad( + const std::function&)>& fun, + int argnum = 0) { + return grad(fun, std::vector{argnum}); +} + +/** + * Returns a function which computes the gradient of the unary input function. + **/ +std::function inline grad( + const std::function& fun) { + auto fn = value_and_grad(fun); + return [fn](const array& input) { return fn(input).second; }; +} + +/** + * Automatically vectorize a unary function over the requested axes. + */ +MLX_API std::function vmap( + const std::function& fun, + int in_axis = 0, + int out_axis = 0); + +/** + * Automatically vectorize a binary function over the requested axes. + */ +MLX_API std::function vmap( + const std::function& fun, + int in_axis_a = 0, + int in_axis_b = 0, + int out_axis = 0); + +/** + * Automatically vectorize a function over the requested axes. + * + * The input function to `vmap` takes as an argument a vector of arrays and + * returns a vector of arrays. Optionally specify the axes to vectorize over + * with `in_axes` and `out_axes`, otherwise a default of 0 is used. + * Returns a vectorized function with the same signature as the input + * function. + */ +MLX_API std::function(const std::vector&)> vmap( + const std::function(const std::vector&)>& fun, + const std::vector& in_axes = {}, + const std::vector& out_axes = {}); + +/** + * Redefine the transformations of `fun` according to the provided functions. + * + * Namely when calling the vjp of `fun` then `fun_vjp` will be called, + * `fun_jvp` for the jvp and `fun_vmap` for vmap. + * + * If any transformation is not provided, then a default one is created by + * calling `vjp`, `jvp` and `vmap` on the function directly. + */ +MLX_API std::function(const std::vector&)> +custom_function( + std::function(const std::vector&)> fun, + std::optional( + const std::vector&, + const std::vector&, + const std::vector&)>> fun_vjp = std::nullopt, + std::optional( + const std::vector&, + const std::vector&, + const std::vector&)>> fun_jvp = std::nullopt, + std::optional, std::vector>( + const std::vector&, + const std::vector&)>> fun_vmap = std::nullopt); + +/** + * Return a function that behaves exactly like `fun` but if the vjp of the + * results is computed `fun_vjp` will be used instead of `vjp(fun, ...)` . + */ +MLX_API std::function(const std::vector&)> custom_vjp( + std::function(const std::vector&)> fun, + std::function( + const std::vector&, + const std::vector&, + const std::vector&)> fun_vjp); + +/** + * Checkpoint the gradient of a function. Namely, discard all intermediate + * state and recalculate it when we need to compute the gradient. + */ +MLX_API std::function(const std::vector&)> checkpoint( + std::function(const std::vector&)> fun); + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/transforms_impl.h b/Source/Cxxmlx/include/mlx/transforms_impl.h new file mode 100644 index 00000000..eff458c4 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/transforms_impl.h @@ -0,0 +1,88 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include "mlx/api.h" + +namespace mlx::core::detail { + +MLX_API std::pair, std::vector> vmap_trace( + const std::function(const std::vector&)>& fun, + const std::vector& inputs, + const std::vector& in_axes); + +MLX_API std::vector vmap_replace( + const std::vector& inputs, + const std::vector& s_inputs, + const std::vector& s_outputs, + const std::vector& in_axes, + const std::vector& out_axes); + +// Create an InTracing object during tracing operations to signify to the rest +// of the codebase that we are during tracing so evals should not throw away +// the graph. +struct InTracing { + explicit InTracing(bool dynamic = false, bool grad = false) { + grad_counter += grad; + trace_stack().push_back({dynamic, grad}); + } + ~InTracing() { + grad_counter -= trace_stack().back().second; + trace_stack().pop_back(); + } + + static bool in_tracing() { + return !trace_stack().empty(); + } + static bool in_dynamic_tracing() { + // compile is always and only the outer-most transform + return in_tracing() && trace_stack().front().first; + } + + static bool in_grad_tracing() { + return grad_counter > 0; + } + + private: + static int grad_counter; + static std::vector>& trace_stack(); +}; + +struct RetainGraph { + RetainGraph() { + tracing_counter++; + } + ~RetainGraph() { + tracing_counter--; + } + + static bool retain_graph() { + return tracing_counter > 0; + } + + private: + static int tracing_counter; +}; + +/** Return true if we are currently performing a function transformation in + * order to keep the graph when evaluating tracer arrays. */ +inline bool in_tracing() { + return detail::InTracing::in_tracing(); +} + +/** Return true if we are in a dynamic (shapeless) trace used for compiling or + * exporting graphs with dynamic shapes. */ +inline bool in_dynamic_tracing() { + return detail::InTracing::in_dynamic_tracing(); +} + +/** Return true if we are in a gradient trace (vjp, jvp, etc). */ +inline bool in_grad_tracing() { + return detail::InTracing::in_grad_tracing(); +} + +inline bool retain_graph() { + return detail::RetainGraph::retain_graph(); +} + +} // namespace mlx::core::detail diff --git a/Source/Cxxmlx/include/mlx/types/bf16.h b/Source/Cxxmlx/include/mlx/types/bf16.h new file mode 100644 index 00000000..7feaaa66 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/types/bf16.h @@ -0,0 +1,188 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include +#include +#include +#include + +#define __MLX_BFLOAT_NAN__ 0x7FC0 +#define __MLX_BFLOAT_ONE__ 0x3F80 + +namespace mlx::core { + +namespace { +union float_bits_bf16 { + float f; + uint32_t u; +}; +} // namespace + +struct _MLX_BFloat16 { + uint16_t bits_; + + // Default constructor + _MLX_BFloat16() = default; + + // Default copy constructor + _MLX_BFloat16(_MLX_BFloat16 const&) = default; + + // Appease std::vector for being special + _MLX_BFloat16& operator=(std::vector::reference x) { + bits_ = (x) ? __MLX_BFLOAT_ONE__ : 0; + return (*this); + } + + _MLX_BFloat16& operator=(const float& x) { + return (*this = _MLX_BFloat16(x)); + } + + // From float32 + _MLX_BFloat16(const float& x) { + if (std::isnan(x)) { + bits_ = __MLX_BFLOAT_NAN__; + } else { + // Union + float_bits_bf16 in; + + // Take bits + in.f = x; + + // Round to nearest even + in.u += (in.u >> 16 & 1) + uint32_t(0x7FFF); + + // Take upper 16 bits + bits_ = in.u >> 16; + } + } + + // To float32 + operator float() const { + // Union + float_bits_bf16 out; + + // Upper 16 bits are the data and lower 16 bits are 0s + out.u = ((uint32_t)bits_) << 16; + + return out.f; + } +}; + +#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \ + inline otype __operator__(atype lhs, btype rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } + +#define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \ + inline otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } \ + inline otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } + +// Operators +#define bfloat_binop(_op_, _operator_) \ + bfloat_binop_base( \ + _op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, _MLX_BFloat16, float); \ + bfloat_binop_helper(_op_, _operator_, float, float, float); \ + bfloat_binop_helper(_op_, _operator_, double, double, double); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, bool, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float); + +bfloat_binop(+, operator+); +bfloat_binop(-, operator-); +bfloat_binop(*, operator*); +bfloat_binop(/, operator/); + +#undef bfloat_binop + +// Comparison ops +#define bfloat_compop(__op__, __operator__) \ + bfloat_binop_base( \ + __op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, float); \ + bfloat_binop_helper(__op__, __operator__, bool, float, float); \ + bfloat_binop_helper(__op__, __operator__, bool, double, double); \ + bfloat_binop_helper(__op__, __operator__, bool, int32_t, float); \ + bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float); \ + bfloat_binop_helper(__op__, __operator__, bool, int64_t, float); \ + bfloat_binop_helper(__op__, __operator__, bool, uint64_t, float); + +bfloat_compop(>, operator>); +bfloat_compop(<, operator<); +bfloat_compop(>=, operator>=); +bfloat_compop(<=, operator<=); +bfloat_compop(==, operator==); +bfloat_compop(!=, operator!=); + +#undef bfloat_compop + +// Negative +inline _MLX_BFloat16 operator-(_MLX_BFloat16 lhs) { + return -static_cast(lhs); +} + +// Inplace ops +#define bfloat_inplace_op(__op__, __operator__) \ + inline _MLX_BFloat16& __operator__(_MLX_BFloat16& lhs, const float& rhs) { \ + lhs = lhs __op__ rhs; \ + return lhs; \ + } \ + inline float& __operator__(float& lhs, _MLX_BFloat16 rhs) { \ + lhs = lhs __op__ rhs; \ + return lhs; \ + } + +bfloat_inplace_op(+, operator+=); +bfloat_inplace_op(-, operator-=); +bfloat_inplace_op(*, operator*=); +bfloat_inplace_op(/, operator/=); + +#undef bfloat_inplace_op + +// Bitwise ops + +#define bfloat_bitop(__op__, __operator__) \ + inline _MLX_BFloat16 __operator__(_MLX_BFloat16 lhs, _MLX_BFloat16 rhs) { \ + _MLX_BFloat16 out; \ + out.bits_ = lhs.bits_ __op__ rhs.bits_; \ + return out; \ + } \ + inline _MLX_BFloat16 __operator__(_MLX_BFloat16 lhs, uint16_t rhs) { \ + _MLX_BFloat16 out; \ + out.bits_ = lhs.bits_ __op__ rhs; \ + return out; \ + } \ + inline _MLX_BFloat16 __operator__(uint16_t lhs, _MLX_BFloat16 rhs) { \ + _MLX_BFloat16 out; \ + out.bits_ = lhs __op__ rhs.bits_; \ + return out; \ + } + +bfloat_bitop(|, operator|); +bfloat_bitop(&, operator&); +bfloat_bitop(^, operator^); + +#undef bfloat_bitop + +#define bfloat_inplace_bitop(__op__, __operator__) \ + inline _MLX_BFloat16& __operator__(_MLX_BFloat16& lhs, _MLX_BFloat16 rhs) { \ + lhs.bits_ = lhs.bits_ __op__ rhs.bits_; \ + return lhs; \ + } \ + inline _MLX_BFloat16& __operator__(_MLX_BFloat16& lhs, uint16_t rhs) { \ + lhs.bits_ = lhs.bits_ __op__ rhs; \ + return lhs; \ + } + +bfloat_inplace_bitop(|, operator|=); +bfloat_inplace_bitop(&, operator&=); +bfloat_inplace_bitop(^, operator^=); + +#undef bfloat_inplace_bitop + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/types/complex.h b/Source/Cxxmlx/include/mlx/types/complex.h new file mode 100644 index 00000000..51101cc9 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/types/complex.h @@ -0,0 +1,113 @@ +// Copyright © 2023 Apple Inc. + +#pragma once +#include +#include "mlx/types/half_types.h" + +namespace mlx::core { + +struct complex64_t; +struct complex128_t; + +template +inline constexpr bool can_convert_to_complex128 = + !std::is_same_v && std::is_convertible_v; + +struct complex128_t : public std::complex { + complex128_t() : std::complex() {}; + complex128_t(double v, double u) : std::complex(v, u) {}; + complex128_t(std::complex v) : std::complex(v) {}; + + template < + typename T, + typename = typename std::enable_if>::type> + complex128_t(T x) : std::complex(x){}; + + operator float() const { + return real(); + }; +}; + +template +inline constexpr bool can_convert_to_complex64 = + !std::is_same_v && std::is_convertible_v; + +struct complex64_t : public std::complex { + complex64_t() : std::complex() {}; + complex64_t(float v, float u) : std::complex(v, u) {}; + complex64_t(std::complex v) : std::complex(v) {}; + + template < + typename T, + typename = typename std::enable_if>::type> + complex64_t(T x) : std::complex(x){}; + + operator float() const { + return real(); + }; +}; + +inline bool operator>=(const complex64_t& a, const complex64_t& b) { + return (a.real() > b.real()) || + (a.real() == b.real() && a.imag() >= b.imag()); +} + +inline bool operator>(const complex64_t& a, const complex64_t& b) { + return (a.real() > b.real()) || (a.real() == b.real() && a.imag() > b.imag()); +} + +inline complex64_t operator%(complex64_t a, complex64_t b) { + auto real = a.real() - (b.real() * static_cast(a.real() / b.real())); + auto imag = a.imag() - (b.imag() * static_cast(a.imag() / b.imag())); + if (real != 0 && ((real < 0) != (b.real() < 0))) + real += b.real(); + if (imag != 0 && ((imag < 0) != (b.imag() < 0))) + imag += b.imag(); + return {real, imag}; +} + +inline bool operator<=(const complex64_t& a, const complex64_t& b) { + return operator>=(b, a); +} + +inline bool operator<(const complex64_t& a, const complex64_t& b) { + return operator>(b, a); +} + +inline complex64_t operator-(const complex64_t& v) { + return -static_cast>(v); +} + +// clang-format off +#define complex_binop_helper(_op_, _operator_, itype) \ + inline complex64_t _operator_(itype x, const complex64_t& y) { \ + return static_cast(x) _op_ y; \ + } \ + inline complex64_t _operator_(const complex64_t& x, itype y) { \ + return x _op_ static_cast(y); \ + } + +#define complex_binop(_op_, _operator_) \ + inline complex64_t _operator_(const std::complex& x, const complex64_t& y) { \ + return x _op_ static_cast>(y); \ + } \ + inline complex64_t _operator_(const complex64_t& x, const std::complex& y) { \ + return static_cast>(x) _op_ y; \ + } \ + inline complex64_t _operator_(const complex64_t& x, const complex64_t& y) { \ + return static_cast>(x) \ + _op_ static_cast>(y); \ + } \ + complex_binop_helper(_op_, _operator_, bool) \ + complex_binop_helper(_op_, _operator_, uint32_t) \ + complex_binop_helper(_op_, _operator_, uint64_t) \ + complex_binop_helper(_op_, _operator_, int32_t) \ + complex_binop_helper(_op_, _operator_, int64_t) \ + complex_binop_helper(_op_, _operator_, float16_t) \ + complex_binop_helper(_op_, _operator_, bfloat16_t) \ + complex_binop_helper(_op_, _operator_, float) +// clang-format on + +complex_binop(+, operator+) + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/types/fp16.h b/Source/Cxxmlx/include/mlx/types/fp16.h new file mode 100644 index 00000000..31b0a78d --- /dev/null +++ b/Source/Cxxmlx/include/mlx/types/fp16.h @@ -0,0 +1,235 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include +#include +#include +#include + +#define __MLX_HALF_NAN__ 0x7D00 +#define __MLX_HALF_ONE__ 0x3C00 + +namespace mlx::core { + +namespace { +union float_bits_fp16 { + float f; + uint32_t u; +}; +} // namespace + +struct _MLX_Float16 { + uint16_t bits_; + + // Default constructor + _MLX_Float16() = default; + + // Default copy constructor + _MLX_Float16(_MLX_Float16 const&) = default; + + // Appease std::vector for being special + _MLX_Float16& operator=(std::vector::reference x) { + bits_ = (x) ? __MLX_HALF_ONE__ : 0; + return (*this); + } + + _MLX_Float16& operator=(const float& x) { + return (*this = _MLX_Float16(x)); + } + + // From float32 + _MLX_Float16(const float& x) : bits_(0) { + // Conversion following + // https://github.com/Maratyszcza/FP16/blob/master/include/fp16/fp16.h + + // Union + float_bits_fp16 in; + + // Take fp32 bits + in.f = x; + + // Find and take sign bit + uint32_t x_sign_32 = in.u & uint32_t(0x80000000); + uint16_t x_sign_16 = (x_sign_32 >> 16); + + if (std::isnan(x)) { + bits_ = x_sign_16 | uint16_t(__MLX_HALF_NAN__); + } else { + // Union + float_bits_fp16 inf_scale, zero_scale, magic_bits; + + // Find exponent bits and take the max supported by half + uint32_t x_expo_32 = in.u & uint32_t(0x7f800000); + uint32_t max_expo_32 = uint32_t(0x38800000); + x_expo_32 = x_expo_32 < max_expo_32 ? max_expo_32 : x_expo_32; + x_expo_32 += uint32_t(15) << 23; + + // Handle scaling to inf as needed + inf_scale.u = uint32_t(0x77800000); + zero_scale.u = uint32_t(0x08800000); + + // Combine with magic and let addition do rounding + magic_bits.u = x_expo_32; + magic_bits.f += (std::abs(x) * inf_scale.f) * zero_scale.f; + + // Take the lower 5 bits of the exponent + uint32_t x_expo_16 = ((magic_bits.u >> 13) & uint32_t(0x7c00)); + + // Collect the lower 12 bits which have the mantissa + uint32_t x_mant_16 = magic_bits.u & uint32_t(0x0fff); + + // Combine sign, exp and mantissa + bits_ = (x_sign_16 | uint16_t(x_expo_16 + x_mant_16)); + } + } + + // To float32 + operator float() const { + // Conversion following + // https://github.com/Maratyszcza/FP16/blob/master/include/fp16/fp16.h + + // Union + float_bits_fp16 out; + + uint32_t x_sign_32 = (bits_ << 16) & uint32_t(0x80000000); + uint32_t base = (bits_ << 16); + uint32_t two_base = base + base; + + uint32_t denorm_max = 1u << 27; + if (two_base < denorm_max) { + out.u = uint32_t(126) << 23; // magic mask + out.u |= (two_base >> 17); // Bits from fp16 + out.f -= 0.5f; // magic bias + } else { + out.u = uint32_t(0xE0) << 23; // exponent offset + out.u += (two_base >> 4); // Bits from fp16 + float out_unscaled = out.f; // Store value + out.u = uint32_t(0x7800000); // exponent scale + out.f *= out_unscaled; + } + + // Add sign + out.u |= x_sign_32; + + return out.f; + } +}; + +#define half_binop_base(__op__, __operator__, otype, atype, btype, ctype) \ + inline otype __operator__(atype lhs, btype rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } + +#define half_binop_helper(__op__, __operator__, otype, itype, ctype) \ + inline otype __operator__(_MLX_Float16 lhs, itype rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } \ + inline otype __operator__(itype lhs, _MLX_Float16 rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } + +// Operators +#define half_binop(__op__, __operator__) \ + half_binop_base( \ + __op__, __operator__, _MLX_Float16, _MLX_Float16, _MLX_Float16, float); \ + half_binop_helper(__op__, __operator__, float, float, float); \ + half_binop_helper(__op__, __operator__, double, double, double); \ + half_binop_helper(__op__, __operator__, _MLX_Float16, bool, float); \ + half_binop_helper(__op__, __operator__, _MLX_Float16, int32_t, float); \ + half_binop_helper(__op__, __operator__, _MLX_Float16, uint32_t, float); \ + half_binop_helper(__op__, __operator__, _MLX_Float16, int64_t, float); \ + half_binop_helper(__op__, __operator__, _MLX_Float16, uint64_t, float); + +half_binop(+, operator+); +half_binop(-, operator-); +half_binop(*, operator*); +half_binop(/, operator/); + +#undef half_binop + +// Comparison ops +#define half_compop(__op__, __operator__) \ + half_binop_base( \ + __op__, __operator__, bool, _MLX_Float16, _MLX_Float16, float); \ + half_binop_helper(__op__, __operator__, bool, float, float); \ + half_binop_helper(__op__, __operator__, bool, double, double); \ + half_binop_helper(__op__, __operator__, bool, int32_t, float); \ + half_binop_helper(__op__, __operator__, bool, uint32_t, float); \ + half_binop_helper(__op__, __operator__, bool, int64_t, float); \ + half_binop_helper(__op__, __operator__, bool, uint64_t, float); + +half_compop(>, operator>); +half_compop(<, operator<); +half_compop(>=, operator>=); +half_compop(<=, operator<=); +half_compop(==, operator==); +half_compop(!=, operator!=); + +#undef half_compop + +// Negative +inline _MLX_Float16 operator-(_MLX_Float16 lhs) { + return -static_cast(lhs); +} + +// Inplace ops +#define half_inplace_op(__op__, __operator__) \ + inline _MLX_Float16& __operator__(_MLX_Float16& lhs, const float& rhs) { \ + lhs = lhs __op__ rhs; \ + return lhs; \ + } \ + inline float& __operator__(float& lhs, _MLX_Float16 rhs) { \ + lhs = lhs __op__ rhs; \ + return lhs; \ + } + +half_inplace_op(+, operator+=); +half_inplace_op(-, operator-=); +half_inplace_op(*, operator*=); +half_inplace_op(/, operator/=); + +#undef half_inplace_op + +// Bitwise ops + +#define half_bitop(__op__, __operator__) \ + inline _MLX_Float16 __operator__(_MLX_Float16 lhs, _MLX_Float16 rhs) { \ + _MLX_Float16 out; \ + out.bits_ = lhs.bits_ __op__ rhs.bits_; \ + return out; \ + } \ + inline _MLX_Float16 __operator__(_MLX_Float16 lhs, uint16_t rhs) { \ + _MLX_Float16 out; \ + out.bits_ = lhs.bits_ __op__ rhs; \ + return out; \ + } \ + inline _MLX_Float16 __operator__(uint16_t lhs, _MLX_Float16 rhs) { \ + _MLX_Float16 out; \ + out.bits_ = lhs __op__ rhs.bits_; \ + return out; \ + } + +half_bitop(|, operator|); +half_bitop(&, operator&); +half_bitop(^, operator^); + +#undef half_bitop + +#define half_inplace_bitop(__op__, __operator__) \ + inline _MLX_Float16& __operator__(_MLX_Float16& lhs, _MLX_Float16 rhs) { \ + lhs.bits_ = lhs.bits_ __op__ rhs.bits_; \ + return lhs; \ + } \ + inline _MLX_Float16& __operator__(_MLX_Float16& lhs, uint16_t rhs) { \ + lhs.bits_ = lhs.bits_ __op__ rhs; \ + return lhs; \ + } + +half_inplace_bitop(|, operator|=); +half_inplace_bitop(&, operator&=); +half_inplace_bitop(^, operator^=); + +#undef half_inplace_bitop + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/types/half_types.h b/Source/Cxxmlx/include/mlx/types/half_types.h new file mode 100644 index 00000000..d9d6b9bf --- /dev/null +++ b/Source/Cxxmlx/include/mlx/types/half_types.h @@ -0,0 +1,58 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#ifdef __ARM_FEATURE_FP16_SCALAR_ARITHMETIC + +#include +namespace mlx::core { +using ::float16_t; +} // namespace mlx::core + +#else + +#define ADD_HALF_BINOPS +#include "mlx/types/fp16.h" +namespace mlx::core { +typedef struct _MLX_Float16 float16_t; +} // namespace mlx::core + +#endif // __ARM_FEATURE_FP16_SCALAR_ARITHMETIC + +#ifdef __ARM_FEATURE_BF16 + +#include +namespace mlx::core { +using ::bfloat16_t; +} // namespace mlx::core + +#else + +#define ADD_HALF_BINOPS +#include "mlx/types/bf16.h" +namespace mlx::core { +typedef struct _MLX_BFloat16 bfloat16_t; +} // namespace mlx::core + +#endif // __ARM_FEATURE_BF16 + +#ifdef ADD_HALF_BINOPS +namespace mlx::core { + +// clang-format off +#define fp16_bf16_binop_helper(__op__, __operator__) \ + inline float __operator__(float16_t lhs, bfloat16_t rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } \ + inline float __operator__(bfloat16_t lhs, float16_t rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } + +fp16_bf16_binop_helper(+, operator+) +fp16_bf16_binop_helper(-, operator-) +fp16_bf16_binop_helper(*, operator*) +fp16_bf16_binop_helper(/, operator/) +// clang-format on + +} // namespace mlx::core +#endif diff --git a/Source/Cxxmlx/include/mlx/types/limits.h b/Source/Cxxmlx/include/mlx/types/limits.h new file mode 100644 index 00000000..5f2b1e9e --- /dev/null +++ b/Source/Cxxmlx/include/mlx/types/limits.h @@ -0,0 +1,70 @@ +// Copyright © 2024 Apple Inc. +#pragma once + +#include +#include "mlx/types/half_types.h" + +namespace mlx::core { + +template +struct numeric_limits; + +template <> +struct numeric_limits : public std::numeric_limits {}; + +template <> +struct numeric_limits : public std::numeric_limits {}; + +template <> +struct numeric_limits { + private: + union half_or_bits { + uint16_t bits; + float16_t value; + }; + constexpr static float16_t bits_to_half(uint16_t v) { + return half_or_bits{v}.value; + } + + public: + constexpr static float16_t lowest() { + return bits_to_half(0xFBFF); + } + static constexpr float16_t max() { + return bits_to_half(0x7BFF); + } + static constexpr float16_t epsilon() { + return bits_to_half(0x1400); + } + static constexpr float16_t infinity() { + return bits_to_half(0x7C00); + } +}; + +template <> +struct numeric_limits { + private: + union bfloat_or_bits { + uint16_t bits; + bfloat16_t value; + }; + constexpr static bfloat16_t bits_to_bfloat(uint16_t v) { + return bfloat_or_bits{v}.value; + } + + public: + constexpr static bfloat16_t lowest() { + return bits_to_bfloat(0xFF7F); + } + static constexpr bfloat16_t max() { + return bits_to_bfloat(0x7F7F); + } + static constexpr bfloat16_t epsilon() { + return bits_to_bfloat(0x3C00); + } + static constexpr bfloat16_t infinity() { + return bits_to_bfloat(0x7F80); + } +}; + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/utils.h b/Source/Cxxmlx/include/mlx/utils.h new file mode 100644 index 00000000..62aa82b6 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/utils.h @@ -0,0 +1,180 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include +#include + +#include "mlx/api.h" +#include "mlx/array.h" +#include "mlx/device.h" +#include "mlx/dtype.h" +#include "mlx/stream.h" + +namespace mlx::core { + +using StreamOrDevice = std::variant; +MLX_API Stream to_stream(StreamOrDevice s); +MLX_API Stream to_stream(StreamOrDevice s, Device default_); + +struct StreamContext { + public: + StreamContext(StreamOrDevice s) : _stream(default_stream(default_device())) { + if (std::holds_alternative(s)) { + throw std::runtime_error( + "[StreamContext] Invalid argument, please specify a stream or device."); + } + auto _s = to_stream(s); + set_default_device(_s.device); + set_default_stream(_s); + } + + ~StreamContext() { + set_default_device(_stream.device); + set_default_stream(_stream); + } + + private: + Stream _stream; +}; + +struct PrintFormatter { + inline void print(std::ostream& os, bool val); + inline void print(std::ostream& os, int16_t val); + inline void print(std::ostream& os, uint16_t val); + inline void print(std::ostream& os, int32_t val); + inline void print(std::ostream& os, uint32_t val); + inline void print(std::ostream& os, int64_t val); + inline void print(std::ostream& os, uint64_t val); + inline void print(std::ostream& os, float16_t val); + inline void print(std::ostream& os, bfloat16_t val); + inline void print(std::ostream& os, float val); + inline void print(std::ostream& os, double val); + inline void print(std::ostream& os, complex64_t val); + + bool capitalize_bool{false}; +}; + +MLX_API PrintFormatter& get_global_formatter(); + +/** Print the exception and then abort. */ +MLX_API void abort_with_exception(const std::exception& error); + +/** Holds information about floating-point types. */ +struct MLX_API finfo { + explicit finfo(Dtype dtype); + Dtype dtype; + double min; + double max; + double eps; +}; + +/** Holds information about integral types. */ +struct MLX_API iinfo { + explicit iinfo(Dtype dtype); + Dtype dtype; + int64_t min; + uint64_t max; +}; + +/** The type from promoting the arrays' types with one another. */ +inline Dtype result_type(const array& a, const array& b) { + return promote_types(a.dtype(), b.dtype()); +} +inline Dtype result_type(const array& a, const array& b, const array& c) { + return promote_types(result_type(a, b), c.dtype()); +} +MLX_API Dtype result_type(const std::vector& arrays); + +MLX_API Shape broadcast_shapes(const Shape& s1, const Shape& s2); + +/** + * Returns the axis normalized to be in the range [0, ndim). + */ +MLX_API int +normalize_axis_index(int axis, int ndim, const std::string& msg_prefix = ""); + +MLX_API std::ostream& operator<<(std::ostream& os, const Device& d); +MLX_API std::ostream& operator<<(std::ostream& os, const Stream& s); +MLX_API std::ostream& operator<<(std::ostream& os, const Dtype& d); +MLX_API std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k); +MLX_API std::ostream& operator<<(std::ostream& os, array a); +inline std::ostream& operator<<(std::ostream& os, const complex64_t& v) { + return os << v.real() << (v.imag() >= 0 ? "+" : "") << v.imag() << "j"; +} +inline std::ostream& operator<<(std::ostream& os, const float16_t& v) { + return os << static_cast(v); +} +inline std::ostream& operator<<(std::ostream& os, const bfloat16_t& v) { + return os << static_cast(v); +} + +template >> +inline std::ostream& operator<<(std::ostream& os, const Vec& v) { + os << "("; + for (auto it = v.begin(); it != v.end(); ++it) { + os << *it; + if (it != std::prev(v.end())) { + os << ","; + } + } + os << ")"; + return os; +} + +inline bool is_power_of_2(int n) { + return ((n & (n - 1)) == 0) && n != 0; +} + +inline int next_power_of_2(int n) { + if (is_power_of_2(n)) { + return n; + } + return pow(2, std::ceil(std::log2(n))); +} + +namespace env { + +int get_var(const char* name, int default_value); +std::string get_var(const char* name, const char* default_value); + +inline int bfs_max_width() { + static int bfs_max_width_ = get_var("MLX_BFS_MAX_WIDTH", 20); + return bfs_max_width_; +} + +inline int max_ops_per_buffer(int default_value) { + static int max_ops_per_buffer_ = + get_var("MLX_MAX_OPS_PER_BUFFER", default_value); + return max_ops_per_buffer_; +} + +inline int max_mb_per_buffer(int default_value) { + static int max_mb_per_buffer_ = + get_var("MLX_MAX_MB_PER_BUFFER", default_value); + return max_mb_per_buffer_; +} + +inline bool metal_fast_synch() { + static bool metal_fast_synch = get_var("MLX_METAL_FAST_SYNCH", 0); + return metal_fast_synch; +} + +inline bool enable_tf32() { + static bool enable_tf32_ = get_var("MLX_ENABLE_TF32", 1); + return enable_tf32_; +} + +inline int nccl_timeout(int default_value) { + static int nccl_timeout = get_var("MLX_NCCL_TIMEOUT", default_value); + return nccl_timeout; +} + +inline const std::string& metal_gpu_arch() { + static std::string gpu_arch_ = get_var("MLX_METAL_GPU_ARCH", ""); + return gpu_arch_; +} + +} // namespace env + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/mlx/version.h b/Source/Cxxmlx/include/mlx/version.h new file mode 100644 index 00000000..e0cb7bb3 --- /dev/null +++ b/Source/Cxxmlx/include/mlx/version.h @@ -0,0 +1,22 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/api.h" + +#define MLX_VERSION_MAJOR 0 +#define MLX_VERSION_MINOR 31 +#define MLX_VERSION_PATCH 1 +#define MLX_VERSION_NUMERIC \ + (100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH) + +namespace mlx::core { + +/* A string representation of the MLX version in the format + * "major.minor.patch". + * + * For dev builds, the version will include the suffix ".devYYYYMMDD+hash" + */ +MLX_API const char* version(); + +} // namespace mlx::core diff --git a/Source/Cxxmlx/include/module.modulemap b/Source/Cxxmlx/include/module.modulemap new file mode 100644 index 00000000..517083f0 --- /dev/null +++ b/Source/Cxxmlx/include/module.modulemap @@ -0,0 +1,5 @@ +module Cxxmlx { + requires cplusplus + header "Cxxmlx.h" + export * +} diff --git a/Source/Cmlx/json/CMakeLists.txt b/Source/Cxxmlx/json/CMakeLists.txt similarity index 100% rename from Source/Cmlx/json/CMakeLists.txt rename to Source/Cxxmlx/json/CMakeLists.txt diff --git a/Source/Cmlx/json/LICENSE.MIT b/Source/Cxxmlx/json/LICENSE.MIT similarity index 100% rename from Source/Cmlx/json/LICENSE.MIT rename to Source/Cxxmlx/json/LICENSE.MIT diff --git a/Source/Cmlx/json/cmake/config.cmake.in b/Source/Cxxmlx/json/cmake/config.cmake.in similarity index 100% rename from Source/Cmlx/json/cmake/config.cmake.in rename to Source/Cxxmlx/json/cmake/config.cmake.in diff --git a/Source/Cmlx/json/cmake/nlohmann_jsonConfigVersion.cmake.in b/Source/Cxxmlx/json/cmake/nlohmann_jsonConfigVersion.cmake.in similarity index 100% rename from Source/Cmlx/json/cmake/nlohmann_jsonConfigVersion.cmake.in rename to Source/Cxxmlx/json/cmake/nlohmann_jsonConfigVersion.cmake.in diff --git a/Source/Cmlx/json/cmake/pkg-config.pc.in b/Source/Cxxmlx/json/cmake/pkg-config.pc.in similarity index 100% rename from Source/Cmlx/json/cmake/pkg-config.pc.in rename to Source/Cxxmlx/json/cmake/pkg-config.pc.in diff --git a/Source/Cmlx/json/include/nlohmann/adl_serializer.hpp b/Source/Cxxmlx/json/include/nlohmann/adl_serializer.hpp similarity index 100% rename from Source/Cmlx/json/include/nlohmann/adl_serializer.hpp rename to Source/Cxxmlx/json/include/nlohmann/adl_serializer.hpp diff --git a/Source/Cmlx/json/include/nlohmann/byte_container_with_subtype.hpp b/Source/Cxxmlx/json/include/nlohmann/byte_container_with_subtype.hpp similarity index 100% rename from Source/Cmlx/json/include/nlohmann/byte_container_with_subtype.hpp rename to Source/Cxxmlx/json/include/nlohmann/byte_container_with_subtype.hpp diff --git a/Source/Cmlx/json/include/nlohmann/detail/abi_macros.hpp b/Source/Cxxmlx/json/include/nlohmann/detail/abi_macros.hpp similarity index 100% rename from Source/Cmlx/json/include/nlohmann/detail/abi_macros.hpp rename to Source/Cxxmlx/json/include/nlohmann/detail/abi_macros.hpp diff --git a/Source/Cmlx/json/include/nlohmann/detail/conversions/from_json.hpp b/Source/Cxxmlx/json/include/nlohmann/detail/conversions/from_json.hpp similarity index 100% rename from Source/Cmlx/json/include/nlohmann/detail/conversions/from_json.hpp rename to Source/Cxxmlx/json/include/nlohmann/detail/conversions/from_json.hpp diff --git a/Source/Cmlx/json/include/nlohmann/detail/conversions/to_chars.hpp b/Source/Cxxmlx/json/include/nlohmann/detail/conversions/to_chars.hpp similarity index 100% rename from Source/Cmlx/json/include/nlohmann/detail/conversions/to_chars.hpp rename to Source/Cxxmlx/json/include/nlohmann/detail/conversions/to_chars.hpp diff --git a/Source/Cmlx/json/include/nlohmann/detail/conversions/to_json.hpp b/Source/Cxxmlx/json/include/nlohmann/detail/conversions/to_json.hpp similarity index 100% rename from Source/Cmlx/json/include/nlohmann/detail/conversions/to_json.hpp rename to Source/Cxxmlx/json/include/nlohmann/detail/conversions/to_json.hpp diff --git a/Source/Cmlx/json/include/nlohmann/detail/exceptions.hpp b/Source/Cxxmlx/json/include/nlohmann/detail/exceptions.hpp similarity index 100% rename from Source/Cmlx/json/include/nlohmann/detail/exceptions.hpp rename to Source/Cxxmlx/json/include/nlohmann/detail/exceptions.hpp diff --git a/Source/Cmlx/json/include/nlohmann/detail/hash.hpp b/Source/Cxxmlx/json/include/nlohmann/detail/hash.hpp similarity index 100% rename from Source/Cmlx/json/include/nlohmann/detail/hash.hpp rename to Source/Cxxmlx/json/include/nlohmann/detail/hash.hpp diff --git a/Source/Cmlx/json/include/nlohmann/detail/input/binary_reader.hpp b/Source/Cxxmlx/json/include/nlohmann/detail/input/binary_reader.hpp similarity index 100% rename from Source/Cmlx/json/include/nlohmann/detail/input/binary_reader.hpp rename to Source/Cxxmlx/json/include/nlohmann/detail/input/binary_reader.hpp diff --git a/Source/Cmlx/json/include/nlohmann/detail/input/input_adapters.hpp b/Source/Cxxmlx/json/include/nlohmann/detail/input/input_adapters.hpp similarity index 100% rename from Source/Cmlx/json/include/nlohmann/detail/input/input_adapters.hpp rename to Source/Cxxmlx/json/include/nlohmann/detail/input/input_adapters.hpp diff --git a/Source/Cmlx/json/include/nlohmann/detail/input/json_sax.hpp b/Source/Cxxmlx/json/include/nlohmann/detail/input/json_sax.hpp similarity index 100% rename from Source/Cmlx/json/include/nlohmann/detail/input/json_sax.hpp rename to Source/Cxxmlx/json/include/nlohmann/detail/input/json_sax.hpp diff --git a/Source/Cmlx/json/include/nlohmann/detail/input/lexer.hpp b/Source/Cxxmlx/json/include/nlohmann/detail/input/lexer.hpp similarity index 100% rename from Source/Cmlx/json/include/nlohmann/detail/input/lexer.hpp rename to Source/Cxxmlx/json/include/nlohmann/detail/input/lexer.hpp diff --git a/Source/Cmlx/json/include/nlohmann/detail/input/parser.hpp b/Source/Cxxmlx/json/include/nlohmann/detail/input/parser.hpp similarity index 100% rename from Source/Cmlx/json/include/nlohmann/detail/input/parser.hpp rename to Source/Cxxmlx/json/include/nlohmann/detail/input/parser.hpp diff --git a/Source/Cmlx/json/include/nlohmann/detail/input/position_t.hpp b/Source/Cxxmlx/json/include/nlohmann/detail/input/position_t.hpp similarity index 100% rename from Source/Cmlx/json/include/nlohmann/detail/input/position_t.hpp rename to Source/Cxxmlx/json/include/nlohmann/detail/input/position_t.hpp diff --git a/Source/Cmlx/json/include/nlohmann/detail/iterators/internal_iterator.hpp b/Source/Cxxmlx/json/include/nlohmann/detail/iterators/internal_iterator.hpp similarity index 100% rename from Source/Cmlx/json/include/nlohmann/detail/iterators/internal_iterator.hpp rename to Source/Cxxmlx/json/include/nlohmann/detail/iterators/internal_iterator.hpp diff --git a/Source/Cmlx/json/include/nlohmann/detail/iterators/iter_impl.hpp b/Source/Cxxmlx/json/include/nlohmann/detail/iterators/iter_impl.hpp similarity index 100% rename from Source/Cmlx/json/include/nlohmann/detail/iterators/iter_impl.hpp rename to Source/Cxxmlx/json/include/nlohmann/detail/iterators/iter_impl.hpp diff --git a/Source/Cmlx/json/include/nlohmann/detail/iterators/iteration_proxy.hpp b/Source/Cxxmlx/json/include/nlohmann/detail/iterators/iteration_proxy.hpp similarity index 100% rename from Source/Cmlx/json/include/nlohmann/detail/iterators/iteration_proxy.hpp rename to Source/Cxxmlx/json/include/nlohmann/detail/iterators/iteration_proxy.hpp diff --git a/Source/Cmlx/json/include/nlohmann/detail/iterators/iterator_traits.hpp b/Source/Cxxmlx/json/include/nlohmann/detail/iterators/iterator_traits.hpp similarity index 100% rename from Source/Cmlx/json/include/nlohmann/detail/iterators/iterator_traits.hpp rename to Source/Cxxmlx/json/include/nlohmann/detail/iterators/iterator_traits.hpp diff --git a/Source/Cmlx/json/include/nlohmann/detail/iterators/json_reverse_iterator.hpp b/Source/Cxxmlx/json/include/nlohmann/detail/iterators/json_reverse_iterator.hpp similarity index 100% rename from Source/Cmlx/json/include/nlohmann/detail/iterators/json_reverse_iterator.hpp rename to Source/Cxxmlx/json/include/nlohmann/detail/iterators/json_reverse_iterator.hpp diff --git a/Source/Cmlx/json/include/nlohmann/detail/iterators/primitive_iterator.hpp b/Source/Cxxmlx/json/include/nlohmann/detail/iterators/primitive_iterator.hpp similarity index 100% rename from Source/Cmlx/json/include/nlohmann/detail/iterators/primitive_iterator.hpp rename to Source/Cxxmlx/json/include/nlohmann/detail/iterators/primitive_iterator.hpp diff --git a/Source/Cmlx/json/include/nlohmann/detail/json_custom_base_class.hpp b/Source/Cxxmlx/json/include/nlohmann/detail/json_custom_base_class.hpp similarity index 100% rename from Source/Cmlx/json/include/nlohmann/detail/json_custom_base_class.hpp rename to Source/Cxxmlx/json/include/nlohmann/detail/json_custom_base_class.hpp diff --git a/Source/Cmlx/json/include/nlohmann/detail/json_pointer.hpp b/Source/Cxxmlx/json/include/nlohmann/detail/json_pointer.hpp similarity index 100% rename from Source/Cmlx/json/include/nlohmann/detail/json_pointer.hpp rename to Source/Cxxmlx/json/include/nlohmann/detail/json_pointer.hpp diff --git a/Source/Cmlx/json/include/nlohmann/detail/json_ref.hpp b/Source/Cxxmlx/json/include/nlohmann/detail/json_ref.hpp similarity index 100% rename from Source/Cmlx/json/include/nlohmann/detail/json_ref.hpp rename to Source/Cxxmlx/json/include/nlohmann/detail/json_ref.hpp diff --git a/Source/Cmlx/json/include/nlohmann/detail/macro_scope.hpp b/Source/Cxxmlx/json/include/nlohmann/detail/macro_scope.hpp similarity index 100% rename from Source/Cmlx/json/include/nlohmann/detail/macro_scope.hpp rename to Source/Cxxmlx/json/include/nlohmann/detail/macro_scope.hpp diff --git a/Source/Cmlx/json/include/nlohmann/detail/macro_unscope.hpp b/Source/Cxxmlx/json/include/nlohmann/detail/macro_unscope.hpp similarity index 100% rename from Source/Cmlx/json/include/nlohmann/detail/macro_unscope.hpp rename to Source/Cxxmlx/json/include/nlohmann/detail/macro_unscope.hpp diff --git a/Source/Cmlx/json/include/nlohmann/detail/meta/call_std/begin.hpp b/Source/Cxxmlx/json/include/nlohmann/detail/meta/call_std/begin.hpp similarity index 100% rename from Source/Cmlx/json/include/nlohmann/detail/meta/call_std/begin.hpp rename to Source/Cxxmlx/json/include/nlohmann/detail/meta/call_std/begin.hpp diff --git a/Source/Cmlx/json/include/nlohmann/detail/meta/call_std/end.hpp b/Source/Cxxmlx/json/include/nlohmann/detail/meta/call_std/end.hpp similarity index 100% rename from Source/Cmlx/json/include/nlohmann/detail/meta/call_std/end.hpp rename to Source/Cxxmlx/json/include/nlohmann/detail/meta/call_std/end.hpp diff --git a/Source/Cmlx/json/include/nlohmann/detail/meta/cpp_future.hpp b/Source/Cxxmlx/json/include/nlohmann/detail/meta/cpp_future.hpp similarity index 100% rename from Source/Cmlx/json/include/nlohmann/detail/meta/cpp_future.hpp rename to Source/Cxxmlx/json/include/nlohmann/detail/meta/cpp_future.hpp diff --git a/Source/Cmlx/json/include/nlohmann/detail/meta/detected.hpp b/Source/Cxxmlx/json/include/nlohmann/detail/meta/detected.hpp similarity index 100% rename from Source/Cmlx/json/include/nlohmann/detail/meta/detected.hpp rename to Source/Cxxmlx/json/include/nlohmann/detail/meta/detected.hpp diff --git a/Source/Cmlx/json/include/nlohmann/detail/meta/identity_tag.hpp b/Source/Cxxmlx/json/include/nlohmann/detail/meta/identity_tag.hpp similarity index 100% rename from Source/Cmlx/json/include/nlohmann/detail/meta/identity_tag.hpp rename to Source/Cxxmlx/json/include/nlohmann/detail/meta/identity_tag.hpp diff --git a/Source/Cmlx/json/include/nlohmann/detail/meta/is_sax.hpp b/Source/Cxxmlx/json/include/nlohmann/detail/meta/is_sax.hpp similarity index 100% rename from Source/Cmlx/json/include/nlohmann/detail/meta/is_sax.hpp rename to Source/Cxxmlx/json/include/nlohmann/detail/meta/is_sax.hpp diff --git a/Source/Cmlx/json/include/nlohmann/detail/meta/std_fs.hpp b/Source/Cxxmlx/json/include/nlohmann/detail/meta/std_fs.hpp similarity index 100% rename from Source/Cmlx/json/include/nlohmann/detail/meta/std_fs.hpp rename to Source/Cxxmlx/json/include/nlohmann/detail/meta/std_fs.hpp diff --git a/Source/Cmlx/json/include/nlohmann/detail/meta/type_traits.hpp b/Source/Cxxmlx/json/include/nlohmann/detail/meta/type_traits.hpp similarity index 100% rename from Source/Cmlx/json/include/nlohmann/detail/meta/type_traits.hpp rename to Source/Cxxmlx/json/include/nlohmann/detail/meta/type_traits.hpp diff --git a/Source/Cmlx/json/include/nlohmann/detail/meta/void_t.hpp b/Source/Cxxmlx/json/include/nlohmann/detail/meta/void_t.hpp similarity index 100% rename from Source/Cmlx/json/include/nlohmann/detail/meta/void_t.hpp rename to Source/Cxxmlx/json/include/nlohmann/detail/meta/void_t.hpp diff --git a/Source/Cmlx/json/include/nlohmann/detail/output/binary_writer.hpp b/Source/Cxxmlx/json/include/nlohmann/detail/output/binary_writer.hpp similarity index 100% rename from Source/Cmlx/json/include/nlohmann/detail/output/binary_writer.hpp rename to Source/Cxxmlx/json/include/nlohmann/detail/output/binary_writer.hpp diff --git a/Source/Cmlx/json/include/nlohmann/detail/output/output_adapters.hpp b/Source/Cxxmlx/json/include/nlohmann/detail/output/output_adapters.hpp similarity index 100% rename from Source/Cmlx/json/include/nlohmann/detail/output/output_adapters.hpp rename to Source/Cxxmlx/json/include/nlohmann/detail/output/output_adapters.hpp diff --git a/Source/Cmlx/json/include/nlohmann/detail/output/serializer.hpp b/Source/Cxxmlx/json/include/nlohmann/detail/output/serializer.hpp similarity index 100% rename from Source/Cmlx/json/include/nlohmann/detail/output/serializer.hpp rename to Source/Cxxmlx/json/include/nlohmann/detail/output/serializer.hpp diff --git a/Source/Cmlx/json/include/nlohmann/detail/string_concat.hpp b/Source/Cxxmlx/json/include/nlohmann/detail/string_concat.hpp similarity index 100% rename from Source/Cmlx/json/include/nlohmann/detail/string_concat.hpp rename to Source/Cxxmlx/json/include/nlohmann/detail/string_concat.hpp diff --git a/Source/Cmlx/json/include/nlohmann/detail/string_escape.hpp b/Source/Cxxmlx/json/include/nlohmann/detail/string_escape.hpp similarity index 100% rename from Source/Cmlx/json/include/nlohmann/detail/string_escape.hpp rename to Source/Cxxmlx/json/include/nlohmann/detail/string_escape.hpp diff --git a/Source/Cmlx/json/include/nlohmann/detail/value_t.hpp b/Source/Cxxmlx/json/include/nlohmann/detail/value_t.hpp similarity index 100% rename from Source/Cmlx/json/include/nlohmann/detail/value_t.hpp rename to Source/Cxxmlx/json/include/nlohmann/detail/value_t.hpp diff --git a/Source/Cmlx/json/include/nlohmann/json.hpp b/Source/Cxxmlx/json/include/nlohmann/json.hpp similarity index 100% rename from Source/Cmlx/json/include/nlohmann/json.hpp rename to Source/Cxxmlx/json/include/nlohmann/json.hpp diff --git a/Source/Cmlx/json/include/nlohmann/json_fwd.hpp b/Source/Cxxmlx/json/include/nlohmann/json_fwd.hpp similarity index 100% rename from Source/Cmlx/json/include/nlohmann/json_fwd.hpp rename to Source/Cxxmlx/json/include/nlohmann/json_fwd.hpp diff --git a/Source/Cmlx/json/include/nlohmann/ordered_map.hpp b/Source/Cxxmlx/json/include/nlohmann/ordered_map.hpp similarity index 100% rename from Source/Cmlx/json/include/nlohmann/ordered_map.hpp rename to Source/Cxxmlx/json/include/nlohmann/ordered_map.hpp diff --git a/Source/Cmlx/json/include/nlohmann/thirdparty/hedley/hedley.hpp b/Source/Cxxmlx/json/include/nlohmann/thirdparty/hedley/hedley.hpp similarity index 100% rename from Source/Cmlx/json/include/nlohmann/thirdparty/hedley/hedley.hpp rename to Source/Cxxmlx/json/include/nlohmann/thirdparty/hedley/hedley.hpp diff --git a/Source/Cmlx/json/include/nlohmann/thirdparty/hedley/hedley_undef.hpp b/Source/Cxxmlx/json/include/nlohmann/thirdparty/hedley/hedley_undef.hpp similarity index 100% rename from Source/Cmlx/json/include/nlohmann/thirdparty/hedley/hedley_undef.hpp rename to Source/Cxxmlx/json/include/nlohmann/thirdparty/hedley/hedley_undef.hpp diff --git a/Source/Cmlx/json/nlohmann_json.natvis b/Source/Cxxmlx/json/nlohmann_json.natvis similarity index 100% rename from Source/Cmlx/json/nlohmann_json.natvis rename to Source/Cxxmlx/json/nlohmann_json.natvis diff --git a/Source/Cmlx/json/single_include/nlohmann/json.hpp b/Source/Cxxmlx/json/single_include/nlohmann/json.hpp similarity index 100% rename from Source/Cmlx/json/single_include/nlohmann/json.hpp rename to Source/Cxxmlx/json/single_include/nlohmann/json.hpp diff --git a/Source/Cmlx/json/single_include/nlohmann/json_fwd.hpp b/Source/Cxxmlx/json/single_include/nlohmann/json_fwd.hpp similarity index 100% rename from Source/Cmlx/json/single_include/nlohmann/json_fwd.hpp rename to Source/Cxxmlx/json/single_include/nlohmann/json_fwd.hpp diff --git a/Source/Cxxmlx/metal-cpp/Foundation/Foundation.hpp b/Source/Cxxmlx/metal-cpp/Foundation/Foundation.hpp new file mode 100644 index 00000000..31e8fb3c --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Foundation/Foundation.hpp @@ -0,0 +1,47 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Foundation/Foundation.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "NSArray.hpp" +#include "NSAutoreleasePool.hpp" +#include "NSBundle.hpp" +#include "NSData.hpp" +#include "NSDate.hpp" +#include "NSDefines.hpp" +#include "NSDictionary.hpp" +#include "NSEnumerator.hpp" +#include "NSError.hpp" +#include "NSLock.hpp" +#include "NSNotification.hpp" +#include "NSNumber.hpp" +#include "NSObject.hpp" +#include "NSPrivate.hpp" +#include "NSProcessInfo.hpp" +#include "NSRange.hpp" +#include "NSSet.hpp" +#include "NSSharedPtr.hpp" +#include "NSString.hpp" +#include "NSTypes.hpp" +#include "NSURL.hpp" + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/Source/Cxxmlx/metal-cpp/Foundation/NSArray.hpp b/Source/Cxxmlx/metal-cpp/Foundation/NSArray.hpp new file mode 100644 index 00000000..ea04d1ea --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Foundation/NSArray.hpp @@ -0,0 +1,124 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Foundation/NSArray.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "NSObject.hpp" +#include "NSTypes.hpp" +#include "NSEnumerator.hpp" + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace NS +{ +class Array : public Copying +{ +public: + static Array* array(); + static Array* array(const Object* pObject); + static Array* array(const Object* const* pObjects, UInteger count); + + static Array* alloc(); + + Array* init(); + Array* init(const Object* const* pObjects, UInteger count); + Array* init(const class Coder* pCoder); + + template + _Object* object(UInteger index) const; + UInteger count() const; + Enumerator* objectEnumerator() const; +}; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Array* NS::Array::array() +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSArray), _NS_PRIVATE_SEL(array)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Array* NS::Array::array(const Object* pObject) +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSArray), _NS_PRIVATE_SEL(arrayWithObject_), pObject); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Array* NS::Array::array(const Object* const* pObjects, UInteger count) +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSArray), _NS_PRIVATE_SEL(arrayWithObjects_count_), pObjects, count); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Array* NS::Array::alloc() +{ + return NS::Object::alloc(_NS_PRIVATE_CLS(NSArray)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Array* NS::Array::init() +{ + return NS::Object::init(); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Array* NS::Array::init(const Object* const* pObjects, UInteger count) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(initWithObjects_count_), pObjects, count); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Array* NS::Array::init(const class Coder* pCoder) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(initWithCoder_), pCoder); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::UInteger NS::Array::count() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(count)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +template +_NS_INLINE _Object* NS::Array::object(UInteger index) const +{ + return Object::sendMessage<_Object*>(this, _NS_PRIVATE_SEL(objectAtIndex_), index); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Enumerator* NS::Array::objectEnumerator() const +{ + return NS::Object::sendMessage*>(this, _NS_PRIVATE_SEL(objectEnumerator)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/Source/Cxxmlx/metal-cpp/Foundation/NSAutoreleasePool.hpp b/Source/Cxxmlx/metal-cpp/Foundation/NSAutoreleasePool.hpp new file mode 100644 index 00000000..6d01a465 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Foundation/NSAutoreleasePool.hpp @@ -0,0 +1,83 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Foundation/NSAutoreleasePool.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "NSDefines.hpp" +#include "NSObject.hpp" +#include "NSPrivate.hpp" +#include "NSTypes.hpp" + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace NS +{ +class AutoreleasePool : public Object +{ +public: + static AutoreleasePool* alloc(); + AutoreleasePool* init(); + + void drain(); + + void addObject(Object* pObject); + + static void showPools(); +}; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::AutoreleasePool* NS::AutoreleasePool::alloc() +{ + return NS::Object::alloc(_NS_PRIVATE_CLS(NSAutoreleasePool)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::AutoreleasePool* NS::AutoreleasePool::init() +{ + return NS::Object::init(); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE void NS::AutoreleasePool::drain() +{ + Object::sendMessage(this, _NS_PRIVATE_SEL(drain)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE void NS::AutoreleasePool::addObject(Object* pObject) +{ + Object::sendMessage(this, _NS_PRIVATE_SEL(addObject_), pObject); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE void NS::AutoreleasePool::showPools() +{ + Object::sendMessage(_NS_PRIVATE_CLS(NSAutoreleasePool), _NS_PRIVATE_SEL(showPools)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/Source/Cxxmlx/metal-cpp/Foundation/NSBundle.hpp b/Source/Cxxmlx/metal-cpp/Foundation/NSBundle.hpp new file mode 100644 index 00000000..b9637f51 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Foundation/NSBundle.hpp @@ -0,0 +1,374 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Foundation/NSBundle.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "NSDefines.hpp" +#include "NSNotification.hpp" +#include "NSObject.hpp" +#include "NSTypes.hpp" + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace NS +{ +_NS_CONST(NotificationName, BundleDidLoadNotification); +_NS_CONST(NotificationName, BundleResourceRequestLowDiskSpaceNotification); + +class String* LocalizedString(const String* pKey, const String*); +class String* LocalizedStringFromTable(const String* pKey, const String* pTbl, const String*); +class String* LocalizedStringFromTableInBundle(const String* pKey, const String* pTbl, const class Bundle* pBdle, const String*); +class String* LocalizedStringWithDefaultValue(const String* pKey, const String* pTbl, const class Bundle* pBdle, const String* pVal, const String*); + +class Bundle : public Referencing +{ +public: + static Bundle* mainBundle(); + + static Bundle* bundle(const class String* pPath); + static Bundle* bundle(const class URL* pURL); + + static class Array* allBundles(); + static class Array* allFrameworks(); + + static Bundle* alloc(); + + Bundle* init(const class String* pPath); + Bundle* init(const class URL* pURL); + + bool load(); + bool unload(); + + bool isLoaded() const; + + bool preflightAndReturnError(class Error** pError) const; + bool loadAndReturnError(class Error** pError); + + class URL* bundleURL() const; + class URL* resourceURL() const; + class URL* executableURL() const; + class URL* URLForAuxiliaryExecutable(const class String* pExecutableName) const; + + class URL* privateFrameworksURL() const; + class URL* sharedFrameworksURL() const; + class URL* sharedSupportURL() const; + class URL* builtInPlugInsURL() const; + class URL* appStoreReceiptURL() const; + + class String* bundlePath() const; + class String* resourcePath() const; + class String* executablePath() const; + class String* pathForAuxiliaryExecutable(const class String* pExecutableName) const; + + class String* privateFrameworksPath() const; + class String* sharedFrameworksPath() const; + class String* sharedSupportPath() const; + class String* builtInPlugInsPath() const; + + class String* bundleIdentifier() const; + class Dictionary* infoDictionary() const; + class Dictionary* localizedInfoDictionary() const; + class Object* objectForInfoDictionaryKey(const class String* pKey); + + class String* localizedString(const class String* pKey, const class String* pValue = nullptr, const class String* pTableName = nullptr) const; +}; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_PRIVATE_DEF_CONST(NS::NotificationName, BundleDidLoadNotification); +_NS_PRIVATE_DEF_CONST(NS::NotificationName, BundleResourceRequestLowDiskSpaceNotification); + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::LocalizedString(const String* pKey, const String*) +{ + return Bundle::mainBundle()->localizedString(pKey, nullptr, nullptr); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::LocalizedStringFromTable(const String* pKey, const String* pTbl, const String*) +{ + return Bundle::mainBundle()->localizedString(pKey, nullptr, pTbl); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::LocalizedStringFromTableInBundle(const String* pKey, const String* pTbl, const Bundle* pBdl, const String*) +{ + return pBdl->localizedString(pKey, nullptr, pTbl); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::LocalizedStringWithDefaultValue(const String* pKey, const String* pTbl, const Bundle* pBdl, const String* pVal, const String*) +{ + return pBdl->localizedString(pKey, pVal, pTbl); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Bundle* NS::Bundle::mainBundle() +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSBundle), _NS_PRIVATE_SEL(mainBundle)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Bundle* NS::Bundle::bundle(const class String* pPath) +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSBundle), _NS_PRIVATE_SEL(bundleWithPath_), pPath); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Bundle* NS::Bundle::bundle(const class URL* pURL) +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSBundle), _NS_PRIVATE_SEL(bundleWithURL_), pURL); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Array* NS::Bundle::allBundles() +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSBundle), _NS_PRIVATE_SEL(allBundles)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Array* NS::Bundle::allFrameworks() +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSBundle), _NS_PRIVATE_SEL(allFrameworks)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Bundle* NS::Bundle::alloc() +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSBundle), _NS_PRIVATE_SEL(alloc)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Bundle* NS::Bundle::init(const String* pPath) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(initWithPath_), pPath); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Bundle* NS::Bundle::init(const URL* pURL) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(initWithURL_), pURL); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE bool NS::Bundle::load() +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(load)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE bool NS::Bundle::unload() +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(unload)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE bool NS::Bundle::isLoaded() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(isLoaded)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE bool NS::Bundle::preflightAndReturnError(Error** pError) const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(preflightAndReturnError_), pError); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE bool NS::Bundle::loadAndReturnError(Error** pError) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(loadAndReturnError_), pError); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::URL* NS::Bundle::bundleURL() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(bundleURL)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::URL* NS::Bundle::resourceURL() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(resourceURL)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::URL* NS::Bundle::executableURL() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(executableURL)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::URL* NS::Bundle::URLForAuxiliaryExecutable(const String* pExecutableName) const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(URLForAuxiliaryExecutable_), pExecutableName); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::URL* NS::Bundle::privateFrameworksURL() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(privateFrameworksURL)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::URL* NS::Bundle::sharedFrameworksURL() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(sharedFrameworksURL)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::URL* NS::Bundle::sharedSupportURL() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(sharedSupportURL)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::URL* NS::Bundle::builtInPlugInsURL() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(builtInPlugInsURL)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::URL* NS::Bundle::appStoreReceiptURL() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(appStoreReceiptURL)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::Bundle::bundlePath() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(bundlePath)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::Bundle::resourcePath() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(resourcePath)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::Bundle::executablePath() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(executablePath)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::Bundle::pathForAuxiliaryExecutable(const String* pExecutableName) const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(pathForAuxiliaryExecutable_), pExecutableName); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::Bundle::privateFrameworksPath() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(privateFrameworksPath)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::Bundle::sharedFrameworksPath() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(sharedFrameworksPath)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::Bundle::sharedSupportPath() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(sharedSupportPath)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::Bundle::builtInPlugInsPath() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(builtInPlugInsPath)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::Bundle::bundleIdentifier() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(bundleIdentifier)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Dictionary* NS::Bundle::infoDictionary() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(infoDictionary)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Dictionary* NS::Bundle::localizedInfoDictionary() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(localizedInfoDictionary)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Object* NS::Bundle::objectForInfoDictionaryKey(const String* pKey) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(objectForInfoDictionaryKey_), pKey); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::Bundle::localizedString(const String* pKey, const String* pValue /* = nullptr */, const String* pTableName /* = nullptr */) const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(localizedStringForKey_value_table_), pKey, pValue, pTableName); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/Source/Cxxmlx/metal-cpp/Foundation/NSData.hpp b/Source/Cxxmlx/metal-cpp/Foundation/NSData.hpp new file mode 100644 index 00000000..3ad36060 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Foundation/NSData.hpp @@ -0,0 +1,54 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Foundation/NSData.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "NSObject.hpp" +#include "NSTypes.hpp" + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace NS +{ +class Data : public Copying +{ +public: + void* mutableBytes() const; + UInteger length() const; +}; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE void* NS::Data::mutableBytes() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(mutableBytes)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::UInteger NS::Data::length() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(length)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/Source/Cxxmlx/metal-cpp/Foundation/NSDate.hpp b/Source/Cxxmlx/metal-cpp/Foundation/NSDate.hpp new file mode 100644 index 00000000..0a5ec7dd --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Foundation/NSDate.hpp @@ -0,0 +1,53 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Foundation/NSDate.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "NSDefines.hpp" +#include "NSObject.hpp" +#include "NSPrivate.hpp" +#include "NSTypes.hpp" + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace NS +{ + +using TimeInterval = double; + +class Date : public Copying +{ +public: + static Date* dateWithTimeIntervalSinceNow(TimeInterval secs); +}; + +} // NS + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Date* NS::Date::dateWithTimeIntervalSinceNow(NS::TimeInterval secs) +{ + return NS::Object::sendMessage(_NS_PRIVATE_CLS(NSDate), _NS_PRIVATE_SEL(dateWithTimeIntervalSinceNow_), secs); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- \ No newline at end of file diff --git a/Source/Cxxmlx/metal-cpp/Foundation/NSDefines.hpp b/Source/Cxxmlx/metal-cpp/Foundation/NSDefines.hpp new file mode 100644 index 00000000..38bbb56b --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Foundation/NSDefines.hpp @@ -0,0 +1,45 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Foundation/NSDefines.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#define _NS_WEAK_IMPORT __attribute__((weak_import)) +#ifdef METALCPP_SYMBOL_VISIBILITY_HIDDEN +#define _NS_EXPORT __attribute__((visibility("hidden"))) +#else +#define _NS_EXPORT __attribute__((visibility("default"))) +#endif // METALCPP_SYMBOL_VISIBILITY_HIDDEN +#define _NS_EXTERN extern "C" _NS_EXPORT +#define _NS_INLINE inline __attribute__((always_inline)) +#define _NS_PACKED __attribute__((packed)) + +#define _NS_CONST(type, name) _NS_EXTERN type const name +#define _NS_ENUM(type, name) enum name : type +#define _NS_OPTIONS(type, name) \ + using name = type; \ + enum : name + +#define _NS_CAST_TO_UINT(value) static_cast(value) +#define _NS_VALIDATE_SIZE(ns, name) static_assert(sizeof(ns::name) == sizeof(ns##name), "size mismatch " #ns "::" #name) +#define _NS_VALIDATE_ENUM(ns, name) static_assert(_NS_CAST_TO_UINT(ns::name) == _NS_CAST_TO_UINT(ns##name), "value mismatch " #ns "::" #name) + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/Source/Cxxmlx/metal-cpp/Foundation/NSDictionary.hpp b/Source/Cxxmlx/metal-cpp/Foundation/NSDictionary.hpp new file mode 100644 index 00000000..d4a1519d --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Foundation/NSDictionary.hpp @@ -0,0 +1,128 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Foundation/NSDictionary.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "NSEnumerator.hpp" +#include "NSObject.hpp" +#include "NSTypes.hpp" + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace NS +{ +class Dictionary : public NS::Copying +{ +public: + static Dictionary* dictionary(); + static Dictionary* dictionary(const Object* pObject, const Object* pKey); + static Dictionary* dictionary(const Object* const* pObjects, const Object* const* pKeys, UInteger count); + + static Dictionary* alloc(); + + Dictionary* init(); + Dictionary* init(const Object* const* pObjects, const Object* const* pKeys, UInteger count); + Dictionary* init(const class Coder* pCoder); + + template + Enumerator<_KeyType>* keyEnumerator() const; + + template + _Object* object(const Object* pKey) const; + UInteger count() const; +}; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Dictionary* NS::Dictionary::dictionary() +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSDictionary), _NS_PRIVATE_SEL(dictionary)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Dictionary* NS::Dictionary::dictionary(const Object* pObject, const Object* pKey) +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSDictionary), _NS_PRIVATE_SEL(dictionaryWithObject_forKey_), pObject, pKey); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Dictionary* NS::Dictionary::dictionary(const Object* const* pObjects, const Object* const* pKeys, UInteger count) +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSDictionary), _NS_PRIVATE_SEL(dictionaryWithObjects_forKeys_count_), + pObjects, pKeys, count); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Dictionary* NS::Dictionary::alloc() +{ + return NS::Object::alloc(_NS_PRIVATE_CLS(NSDictionary)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Dictionary* NS::Dictionary::init() +{ + return NS::Object::init(); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Dictionary* NS::Dictionary::init(const Object* const* pObjects, const Object* const* pKeys, UInteger count) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(initWithObjects_forKeys_count_), pObjects, pKeys, count); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Dictionary* NS::Dictionary::init(const class Coder* pCoder) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(initWithCoder_), pCoder); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +template +_NS_INLINE NS::Enumerator<_KeyType>* NS::Dictionary::keyEnumerator() const +{ + return Object::sendMessage*>(this, _NS_PRIVATE_SEL(keyEnumerator)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +template +_NS_INLINE _Object* NS::Dictionary::object(const Object* pKey) const +{ + return Object::sendMessage<_Object*>(this, _NS_PRIVATE_SEL(objectForKey_), pKey); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::UInteger NS::Dictionary::count() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(count)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/Source/Cxxmlx/metal-cpp/Foundation/NSEnumerator.hpp b/Source/Cxxmlx/metal-cpp/Foundation/NSEnumerator.hpp new file mode 100644 index 00000000..5a2500c1 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Foundation/NSEnumerator.hpp @@ -0,0 +1,78 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Foundation/NSEnumerator.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "NSObject.hpp" +#include "NSTypes.hpp" + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace NS +{ +struct FastEnumerationState +{ + unsigned long state; + Object** itemsPtr; + unsigned long* mutationsPtr; + unsigned long extra[5]; +} _NS_PACKED; + +class FastEnumeration : public Referencing +{ +public: + NS::UInteger countByEnumerating(FastEnumerationState* pState, Object** pBuffer, NS::UInteger len); +}; + +template +class Enumerator : public Referencing, FastEnumeration> +{ +public: + _ObjectType* nextObject(); + class Array* allObjects(); +}; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::UInteger NS::FastEnumeration::countByEnumerating(FastEnumerationState* pState, Object** pBuffer, NS::UInteger len) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(countByEnumeratingWithState_objects_count_), pState, pBuffer, len); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +template +_NS_INLINE _ObjectType* NS::Enumerator<_ObjectType>::nextObject() +{ + return Object::sendMessage<_ObjectType*>(this, _NS_PRIVATE_SEL(nextObject)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +template +_NS_INLINE NS::Array* NS::Enumerator<_ObjectType>::allObjects() +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(allObjects)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/Source/Cxxmlx/metal-cpp/Foundation/NSError.hpp b/Source/Cxxmlx/metal-cpp/Foundation/NSError.hpp new file mode 100644 index 00000000..ea331d46 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Foundation/NSError.hpp @@ -0,0 +1,173 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Foundation/NSError.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "NSDefines.hpp" +#include "NSObject.hpp" +#include "NSPrivate.hpp" +#include "NSTypes.hpp" + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace NS +{ +using ErrorDomain = class String*; + +_NS_CONST(ErrorDomain, CocoaErrorDomain); +_NS_CONST(ErrorDomain, POSIXErrorDomain); +_NS_CONST(ErrorDomain, OSStatusErrorDomain); +_NS_CONST(ErrorDomain, MachErrorDomain); + +using ErrorUserInfoKey = class String*; + +_NS_CONST(ErrorUserInfoKey, UnderlyingErrorKey); +_NS_CONST(ErrorUserInfoKey, LocalizedDescriptionKey); +_NS_CONST(ErrorUserInfoKey, LocalizedFailureReasonErrorKey); +_NS_CONST(ErrorUserInfoKey, LocalizedRecoverySuggestionErrorKey); +_NS_CONST(ErrorUserInfoKey, LocalizedRecoveryOptionsErrorKey); +_NS_CONST(ErrorUserInfoKey, RecoveryAttempterErrorKey); +_NS_CONST(ErrorUserInfoKey, HelpAnchorErrorKey); +_NS_CONST(ErrorUserInfoKey, DebugDescriptionErrorKey); +_NS_CONST(ErrorUserInfoKey, LocalizedFailureErrorKey); +_NS_CONST(ErrorUserInfoKey, StringEncodingErrorKey); +_NS_CONST(ErrorUserInfoKey, URLErrorKey); +_NS_CONST(ErrorUserInfoKey, FilePathErrorKey); + +class Error : public Copying +{ +public: + static Error* error(ErrorDomain domain, Integer code, class Dictionary* pDictionary); + + static Error* alloc(); + Error* init(); + Error* init(ErrorDomain domain, Integer code, class Dictionary* pDictionary); + + Integer code() const; + ErrorDomain domain() const; + class Dictionary* userInfo() const; + + class String* localizedDescription() const; + class Array* localizedRecoveryOptions() const; + class String* localizedRecoverySuggestion() const; + class String* localizedFailureReason() const; +}; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_PRIVATE_DEF_CONST(NS::ErrorDomain, CocoaErrorDomain); +_NS_PRIVATE_DEF_CONST(NS::ErrorDomain, POSIXErrorDomain); +_NS_PRIVATE_DEF_CONST(NS::ErrorDomain, OSStatusErrorDomain); +_NS_PRIVATE_DEF_CONST(NS::ErrorDomain, MachErrorDomain); + +_NS_PRIVATE_DEF_CONST(NS::ErrorUserInfoKey, UnderlyingErrorKey); +_NS_PRIVATE_DEF_CONST(NS::ErrorUserInfoKey, LocalizedDescriptionKey); +_NS_PRIVATE_DEF_CONST(NS::ErrorUserInfoKey, LocalizedFailureReasonErrorKey); +_NS_PRIVATE_DEF_CONST(NS::ErrorUserInfoKey, LocalizedRecoverySuggestionErrorKey); +_NS_PRIVATE_DEF_CONST(NS::ErrorUserInfoKey, LocalizedRecoveryOptionsErrorKey); +_NS_PRIVATE_DEF_CONST(NS::ErrorUserInfoKey, RecoveryAttempterErrorKey); +_NS_PRIVATE_DEF_CONST(NS::ErrorUserInfoKey, HelpAnchorErrorKey); +_NS_PRIVATE_DEF_CONST(NS::ErrorUserInfoKey, DebugDescriptionErrorKey); +_NS_PRIVATE_DEF_CONST(NS::ErrorUserInfoKey, LocalizedFailureErrorKey); +_NS_PRIVATE_DEF_CONST(NS::ErrorUserInfoKey, StringEncodingErrorKey); +_NS_PRIVATE_DEF_CONST(NS::ErrorUserInfoKey, URLErrorKey); +_NS_PRIVATE_DEF_CONST(NS::ErrorUserInfoKey, FilePathErrorKey); + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Error* NS::Error::error(ErrorDomain domain, Integer code, class Dictionary* pDictionary) +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSError), _NS_PRIVATE_SEL(errorWithDomain_code_userInfo_), domain, code, pDictionary); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Error* NS::Error::alloc() +{ + return Object::alloc(_NS_PRIVATE_CLS(NSError)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Error* NS::Error::init() +{ + return Object::init(); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Error* NS::Error::init(ErrorDomain domain, Integer code, class Dictionary* pDictionary) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(initWithDomain_code_userInfo_), domain, code, pDictionary); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Integer NS::Error::code() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(code)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::ErrorDomain NS::Error::domain() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(domain)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Dictionary* NS::Error::userInfo() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(userInfo)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::Error::localizedDescription() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(localizedDescription)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Array* NS::Error::localizedRecoveryOptions() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(localizedRecoveryOptions)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::Error::localizedRecoverySuggestion() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(localizedRecoverySuggestion)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::Error::localizedFailureReason() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(localizedFailureReason)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/Source/Cxxmlx/metal-cpp/Foundation/NSLock.hpp b/Source/Cxxmlx/metal-cpp/Foundation/NSLock.hpp new file mode 100644 index 00000000..01df2194 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Foundation/NSLock.hpp @@ -0,0 +1,118 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Foundation/NSLock.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "NSDefines.hpp" +#include "NSObject.hpp" +#include "NSPrivate.hpp" +#include "NSTypes.hpp" +#include "NSDate.hpp" + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace NS +{ + +template +class Locking : public _Base +{ +public: + void lock(); + void unlock(); +}; + +class Condition : public Locking +{ +public: + static Condition* alloc(); + + Condition* init(); + + void wait(); + bool waitUntilDate(Date* pLimit); + void signal(); + void broadcast(); +}; + +} // NS + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +template +_NS_INLINE void NS::Locking<_Class, _Base>::lock() +{ + NS::Object::sendMessage(this, _NS_PRIVATE_SEL(lock)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +template +_NS_INLINE void NS::Locking<_Class, _Base>::unlock() +{ + NS::Object::sendMessage(this, _NS_PRIVATE_SEL(unlock)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Condition* NS::Condition::alloc() +{ + return NS::Object::alloc(_NS_PRIVATE_CLS(NSCondition)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Condition* NS::Condition::init() +{ + return NS::Object::init(); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE void NS::Condition::wait() +{ + NS::Object::sendMessage(this, _NS_PRIVATE_SEL(wait)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE bool NS::Condition::waitUntilDate(NS::Date* pLimit) +{ + return NS::Object::sendMessage(this, _NS_PRIVATE_SEL(waitUntilDate_), pLimit); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE void NS::Condition::signal() +{ + NS::Object::sendMessage(this, _NS_PRIVATE_SEL(signal)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE void NS::Condition::broadcast() +{ + NS::Object::sendMessage(this, _NS_PRIVATE_SEL(broadcast)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- \ No newline at end of file diff --git a/Source/Cxxmlx/metal-cpp/Foundation/NSNotification.hpp b/Source/Cxxmlx/metal-cpp/Foundation/NSNotification.hpp new file mode 100644 index 00000000..6b5be121 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Foundation/NSNotification.hpp @@ -0,0 +1,110 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Foundation/NSNotification.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "NSDefines.hpp" +#include "NSDictionary.hpp" +#include "NSObject.hpp" +#include "NSString.hpp" +#include "NSTypes.hpp" +#include + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace NS +{ +using NotificationName = class String*; + +class Notification : public NS::Referencing +{ +public: + NS::String* name() const; + NS::Object* object() const; + NS::Dictionary* userInfo() const; +}; + +using ObserverBlock = void(^)(Notification*); +using ObserverFunction = std::function; + +class NotificationCenter : public NS::Referencing +{ + public: + static class NotificationCenter* defaultCenter(); + Object* addObserver(NotificationName name, Object* pObj, void* pQueue, ObserverBlock block); + Object* addObserver(NotificationName name, Object* pObj, void* pQueue, ObserverFunction &handler); + void removeObserver(Object* pObserver); + +}; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::Notification::name() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(name)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Object* NS::Notification::object() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(object)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Dictionary* NS::Notification::userInfo() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(userInfo)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::NotificationCenter* NS::NotificationCenter::defaultCenter() +{ + return NS::Object::sendMessage(_NS_PRIVATE_CLS(NSNotificationCenter), _NS_PRIVATE_SEL(defaultCenter)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Object* NS::NotificationCenter::addObserver(NS::NotificationName name, Object* pObj, void* pQueue, NS::ObserverBlock block) +{ + return NS::Object::sendMessage(this, _NS_PRIVATE_SEL(addObserverName_object_queue_block_), name, pObj, pQueue, block); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Object* NS::NotificationCenter::addObserver(NS::NotificationName name, Object* pObj, void* pQueue, NS::ObserverFunction &handler) +{ + __block ObserverFunction blockFunction = handler; + + return addObserver(name, pObj, pQueue, ^(NS::Notification* pNotif) {blockFunction(pNotif);}); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE void NS::NotificationCenter::removeObserver(Object* pObserver) +{ + return NS::Object::sendMessage(this, _NS_PRIVATE_SEL(removeObserver_), pObserver); +} + diff --git a/Source/Cxxmlx/metal-cpp/Foundation/NSNumber.hpp b/Source/Cxxmlx/metal-cpp/Foundation/NSNumber.hpp new file mode 100644 index 00000000..eec7ceac --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Foundation/NSNumber.hpp @@ -0,0 +1,501 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Foundation/NSNumber.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "NSObjCRuntime.hpp" +#include "NSObject.hpp" +#include "NSTypes.hpp" + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace NS +{ +class Value : public Copying +{ +public: + static Value* value(const void* pValue, const char* pType); + static Value* value(const void* pPointer); + + static Value* alloc(); + + Value* init(const void* pValue, const char* pType); + Value* init(const class Coder* pCoder); + + void getValue(void* pValue, UInteger size) const; + const char* objCType() const; + + bool isEqualToValue(Value* pValue) const; + void* pointerValue() const; +}; + +class Number : public Copying +{ +public: + static Number* number(char value); + static Number* number(unsigned char value); + static Number* number(short value); + static Number* number(unsigned short value); + static Number* number(int value); + static Number* number(unsigned int value); + static Number* number(long value); + static Number* number(unsigned long value); + static Number* number(long long value); + static Number* number(unsigned long long value); + static Number* number(float value); + static Number* number(double value); + static Number* number(bool value); + + static Number* alloc(); + + Number* init(const class Coder* pCoder); + Number* init(char value); + Number* init(unsigned char value); + Number* init(short value); + Number* init(unsigned short value); + Number* init(int value); + Number* init(unsigned int value); + Number* init(long value); + Number* init(unsigned long value); + Number* init(long long value); + Number* init(unsigned long long value); + Number* init(float value); + Number* init(double value); + Number* init(bool value); + + char charValue() const; + unsigned char unsignedCharValue() const; + short shortValue() const; + unsigned short unsignedShortValue() const; + int intValue() const; + unsigned int unsignedIntValue() const; + long longValue() const; + unsigned long unsignedLongValue() const; + long long longLongValue() const; + unsigned long long unsignedLongLongValue() const; + float floatValue() const; + double doubleValue() const; + bool boolValue() const; + Integer integerValue() const; + UInteger unsignedIntegerValue() const; + class String* stringValue() const; + + ComparisonResult compare(const Number* pOtherNumber) const; + bool isEqualToNumber(const Number* pNumber) const; + + class String* descriptionWithLocale(const Object* pLocale) const; +}; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Value* NS::Value::value(const void* pValue, const char* pType) +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSValue), _NS_PRIVATE_SEL(valueWithBytes_objCType_), pValue, pType); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Value* NS::Value::value(const void* pPointer) +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSValue), _NS_PRIVATE_SEL(valueWithPointer_), pPointer); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Value* NS::Value::alloc() +{ + return NS::Object::alloc(_NS_PRIVATE_CLS(NSValue)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Value* NS::Value::init(const void* pValue, const char* pType) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(initWithBytes_objCType_), pValue, pType); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Value* NS::Value::init(const class Coder* pCoder) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(initWithCoder_), pCoder); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE void NS::Value::getValue(void* pValue, UInteger size) const +{ + Object::sendMessage(this, _NS_PRIVATE_SEL(getValue_size_), pValue, size); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE const char* NS::Value::objCType() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(objCType)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE bool NS::Value::isEqualToValue(Value* pValue) const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(isEqualToValue_), pValue); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE void* NS::Value::pointerValue() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(pointerValue)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Number* NS::Number::number(char value) +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSNumber), _NS_PRIVATE_SEL(numberWithChar_), value); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Number* NS::Number::number(unsigned char value) +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSNumber), _NS_PRIVATE_SEL(numberWithUnsignedChar_), value); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Number* NS::Number::number(short value) +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSNumber), _NS_PRIVATE_SEL(numberWithShort_), value); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Number* NS::Number::number(unsigned short value) +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSNumber), _NS_PRIVATE_SEL(numberWithUnsignedShort_), value); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Number* NS::Number::number(int value) +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSNumber), _NS_PRIVATE_SEL(numberWithInt_), value); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Number* NS::Number::number(unsigned int value) +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSNumber), _NS_PRIVATE_SEL(numberWithUnsignedInt_), value); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Number* NS::Number::number(long value) +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSNumber), _NS_PRIVATE_SEL(numberWithLong_), value); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Number* NS::Number::number(unsigned long value) +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSNumber), _NS_PRIVATE_SEL(numberWithUnsignedLong_), value); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Number* NS::Number::number(long long value) +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSNumber), _NS_PRIVATE_SEL(numberWithLongLong_), value); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Number* NS::Number::number(unsigned long long value) +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSNumber), _NS_PRIVATE_SEL(numberWithUnsignedLongLong_), value); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Number* NS::Number::number(float value) +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSNumber), _NS_PRIVATE_SEL(numberWithFloat_), value); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Number* NS::Number::number(double value) +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSNumber), _NS_PRIVATE_SEL(numberWithDouble_), value); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Number* NS::Number::number(bool value) +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSNumber), _NS_PRIVATE_SEL(numberWithBool_), value); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Number* NS::Number::alloc() +{ + return NS::Object::alloc(_NS_PRIVATE_CLS(NSNumber)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Number* NS::Number::init(const Coder* pCoder) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(initWithCoder_), pCoder); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Number* NS::Number::init(char value) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(initWithChar_), value); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Number* NS::Number::init(unsigned char value) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(initWithUnsignedChar_), value); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Number* NS::Number::init(short value) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(initWithShort_), value); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Number* NS::Number::init(unsigned short value) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(initWithUnsignedShort_), value); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Number* NS::Number::init(int value) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(initWithInt_), value); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Number* NS::Number::init(unsigned int value) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(initWithUnsignedInt_), value); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Number* NS::Number::init(long value) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(initWithLong_), value); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Number* NS::Number::init(unsigned long value) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(initWithUnsignedLong_), value); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Number* NS::Number::init(long long value) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(initWithLongLong_), value); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Number* NS::Number::init(unsigned long long value) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(initWithUnsignedLongLong_), value); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Number* NS::Number::init(float value) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(initWithFloat_), value); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Number* NS::Number::init(double value) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(initWithDouble_), value); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Number* NS::Number::init(bool value) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(initWithBool_), value); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE char NS::Number::charValue() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(charValue)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE unsigned char NS::Number::unsignedCharValue() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(unsignedCharValue)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE short NS::Number::shortValue() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(shortValue)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE unsigned short NS::Number::unsignedShortValue() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(unsignedShortValue)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE int NS::Number::intValue() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(intValue)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE unsigned int NS::Number::unsignedIntValue() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(unsignedIntValue)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE long NS::Number::longValue() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(longValue)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE unsigned long NS::Number::unsignedLongValue() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(unsignedLongValue)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE long long NS::Number::longLongValue() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(longLongValue)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE unsigned long long NS::Number::unsignedLongLongValue() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(unsignedLongLongValue)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE float NS::Number::floatValue() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(floatValue)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE double NS::Number::doubleValue() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(doubleValue)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE bool NS::Number::boolValue() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(boolValue)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Integer NS::Number::integerValue() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(integerValue)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::UInteger NS::Number::unsignedIntegerValue() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(unsignedIntegerValue)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::Number::stringValue() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(stringValue)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::ComparisonResult NS::Number::compare(const Number* pOtherNumber) const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(compare_), pOtherNumber); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE bool NS::Number::isEqualToNumber(const Number* pNumber) const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(isEqualToNumber_), pNumber); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::Number::descriptionWithLocale(const Object* pLocale) const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(descriptionWithLocale_), pLocale); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/Source/Cxxmlx/metal-cpp/Foundation/NSObjCRuntime.hpp b/Source/Cxxmlx/metal-cpp/Foundation/NSObjCRuntime.hpp new file mode 100644 index 00000000..9a5364c2 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Foundation/NSObjCRuntime.hpp @@ -0,0 +1,43 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Foundation/NSObjCRuntime.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "NSDefines.hpp" +#include "NSTypes.hpp" + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace NS +{ + +_NS_ENUM(Integer, ComparisonResult) { + OrderedAscending = -1L, + OrderedSame, + OrderedDescending +}; + +const Integer NotFound = IntegerMax; + +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/Source/Cxxmlx/metal-cpp/Foundation/NSObject.hpp b/Source/Cxxmlx/metal-cpp/Foundation/NSObject.hpp new file mode 100644 index 00000000..aff8e676 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Foundation/NSObject.hpp @@ -0,0 +1,302 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Foundation/NSObject.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "NSDefines.hpp" +#include "NSPrivate.hpp" +#include "NSTypes.hpp" + +#include +#include + +#include + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace NS +{ +template +class _NS_EXPORT Referencing : public _Base +{ +public: + _Class* retain(); + void release(); + + _Class* autorelease(); + + UInteger retainCount() const; +}; + +template +class Copying : public Referencing<_Class, _Base> +{ +public: + _Class* copy() const; +}; + +template +class SecureCoding : public Referencing<_Class, _Base> +{ +}; + +class Object : public Referencing +{ +public: + UInteger hash() const; + bool isEqual(const Object* pObject) const; + + class String* description() const; + class String* debugDescription() const; + +protected: + friend class Referencing; + + template + static _Class* alloc(const char* pClassName); + template + static _Class* alloc(const void* pClass); + template + _Class* init(); + + template + static _Dst bridgingCast(const void* pObj); + static class MethodSignature* methodSignatureForSelector(const void* pObj, SEL selector); + static bool respondsToSelector(const void* pObj, SEL selector); + template + static constexpr bool doesRequireMsgSendStret(); + template + static _Ret sendMessage(const void* pObj, SEL selector, _Args... args); + template + static _Ret sendMessageSafe(const void* pObj, SEL selector, _Args... args); + +private: + Object() = delete; + Object(const Object&) = delete; + ~Object() = delete; + + Object& operator=(const Object&) = delete; +}; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +template +_NS_INLINE _Class* NS::Referencing<_Class, _Base>::retain() +{ + return Object::sendMessage<_Class*>(this, _NS_PRIVATE_SEL(retain)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +template +_NS_INLINE void NS::Referencing<_Class, _Base>::release() +{ + Object::sendMessage(this, _NS_PRIVATE_SEL(release)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +template +_NS_INLINE _Class* NS::Referencing<_Class, _Base>::autorelease() +{ + return Object::sendMessage<_Class*>(this, _NS_PRIVATE_SEL(autorelease)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +template +_NS_INLINE NS::UInteger NS::Referencing<_Class, _Base>::retainCount() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(retainCount)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +template +_NS_INLINE _Class* NS::Copying<_Class, _Base>::copy() const +{ + return Object::sendMessage<_Class*>(this, _NS_PRIVATE_SEL(copy)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +template +_NS_INLINE _Dst NS::Object::bridgingCast(const void* pObj) +{ +#ifdef __OBJC__ + return (__bridge _Dst)pObj; +#else + return (_Dst)pObj; +#endif // __OBJC__ +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +template +_NS_INLINE constexpr bool NS::Object::doesRequireMsgSendStret() +{ +#if (defined(__i386__) || defined(__x86_64__)) + constexpr size_t kStructLimit = (sizeof(std::uintptr_t) << 1); + + return sizeof(_Type) > kStructLimit; +#elif defined(__arm64__) + return false; +#elif defined(__arm__) + constexpr size_t kStructLimit = sizeof(std::uintptr_t); + + return std::is_class_v<_Type> && (sizeof(_Type) > kStructLimit); +#else +#error "Unsupported architecture!" +#endif +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +template <> +_NS_INLINE constexpr bool NS::Object::doesRequireMsgSendStret() +{ + return false; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +template +_NS_INLINE _Ret NS::Object::sendMessage(const void* pObj, SEL selector, _Args... args) +{ +#if (defined(__i386__) || defined(__x86_64__)) + if constexpr (std::is_floating_point<_Ret>()) + { + using SendMessageProcFpret = _Ret (*)(const void*, SEL, _Args...); + + const SendMessageProcFpret pProc = reinterpret_cast(&objc_msgSend_fpret); + + return (*pProc)(pObj, selector, args...); + } + else +#endif // ( defined( __i386__ ) || defined( __x86_64__ ) ) +#if !defined(__arm64__) + if constexpr (doesRequireMsgSendStret<_Ret>()) + { + using SendMessageProcStret = void (*)(_Ret*, const void*, SEL, _Args...); + + const SendMessageProcStret pProc = reinterpret_cast(&objc_msgSend_stret); + _Ret ret; + + (*pProc)(&ret, pObj, selector, args...); + + return ret; + } + else +#endif // !defined( __arm64__ ) + { + using SendMessageProc = _Ret (*)(const void*, SEL, _Args...); + + const SendMessageProc pProc = reinterpret_cast(&objc_msgSend); + + return (*pProc)(pObj, selector, args...); + } +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::MethodSignature* NS::Object::methodSignatureForSelector(const void* pObj, SEL selector) +{ + return sendMessage(pObj, _NS_PRIVATE_SEL(methodSignatureForSelector_), selector); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE bool NS::Object::respondsToSelector(const void* pObj, SEL selector) +{ + return sendMessage(pObj, _NS_PRIVATE_SEL(respondsToSelector_), selector); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +template +_NS_INLINE _Ret NS::Object::sendMessageSafe(const void* pObj, SEL selector, _Args... args) +{ + if ((respondsToSelector(pObj, selector)) || (nullptr != methodSignatureForSelector(pObj, selector))) + { + return sendMessage<_Ret>(pObj, selector, args...); + } + + if constexpr (!std::is_void<_Ret>::value) + { + return _Ret(0); + } +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +template +_NS_INLINE _Class* NS::Object::alloc(const char* pClassName) +{ + return sendMessage<_Class*>(objc_lookUpClass(pClassName), _NS_PRIVATE_SEL(alloc)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +template +_NS_INLINE _Class* NS::Object::alloc(const void* pClass) +{ + return sendMessage<_Class*>(pClass, _NS_PRIVATE_SEL(alloc)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +template +_NS_INLINE _Class* NS::Object::init() +{ + return sendMessage<_Class*>(this, _NS_PRIVATE_SEL(init)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::UInteger NS::Object::hash() const +{ + return sendMessage(this, _NS_PRIVATE_SEL(hash)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE bool NS::Object::isEqual(const Object* pObject) const +{ + return sendMessage(this, _NS_PRIVATE_SEL(isEqual_), pObject); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::Object::description() const +{ + return sendMessage(this, _NS_PRIVATE_SEL(description)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::Object::debugDescription() const +{ + return sendMessageSafe(this, _NS_PRIVATE_SEL(debugDescription)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/Source/Cxxmlx/metal-cpp/Foundation/NSPrivate.hpp b/Source/Cxxmlx/metal-cpp/Foundation/NSPrivate.hpp new file mode 100644 index 00000000..f8d87004 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Foundation/NSPrivate.hpp @@ -0,0 +1,531 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Foundation/NSPrivate.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#define _NS_PRIVATE_CLS(symbol) (Private::Class::s_k##symbol) +#define _NS_PRIVATE_SEL(accessor) (Private::Selector::s_k##accessor) + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#if defined(NS_PRIVATE_IMPLEMENTATION) + +#include + +namespace NS::Private +{ + template + inline _Type const LoadSymbol(const char* pSymbol) + { + const _Type* pAddress = static_cast<_Type*>(dlsym(RTLD_DEFAULT, pSymbol)); + + return pAddress ? *pAddress : _Type(); + } +} // NS::Private + +#ifdef METALCPP_SYMBOL_VISIBILITY_HIDDEN +#define _NS_PRIVATE_VISIBILITY __attribute__((visibility("hidden"))) +#else +#define _NS_PRIVATE_VISIBILITY __attribute__((visibility("default"))) +#endif // METALCPP_SYMBOL_VISIBILITY_HIDDEN + +#define _NS_PRIVATE_IMPORT __attribute__((weak_import)) + +#ifdef __OBJC__ +#define _NS_PRIVATE_OBJC_LOOKUP_CLASS(symbol) ((__bridge void*)objc_lookUpClass(#symbol)) +#define _NS_PRIVATE_OBJC_GET_PROTOCOL(symbol) ((__bridge void*)objc_getProtocol(#symbol)) +#else +#define _NS_PRIVATE_OBJC_LOOKUP_CLASS(symbol) objc_lookUpClass(#symbol) +#define _NS_PRIVATE_OBJC_GET_PROTOCOL(symbol) objc_getProtocol(#symbol) +#endif // __OBJC__ + +#define _NS_PRIVATE_DEF_CLS(symbol) void* s_k##symbol _NS_PRIVATE_VISIBILITY = _NS_PRIVATE_OBJC_LOOKUP_CLASS(symbol) +#define _NS_PRIVATE_DEF_PRO(symbol) void* s_k##symbol _NS_PRIVATE_VISIBILITY = _NS_PRIVATE_OBJC_GET_PROTOCOL(symbol) +#define _NS_PRIVATE_DEF_SEL(accessor, symbol) SEL s_k##accessor _NS_PRIVATE_VISIBILITY = sel_registerName(symbol) + +#if defined(__MAC_26_0) || defined(__IPHONE_26_0) || defined(__TVOS_26_0) +#define _NS_PRIVATE_DEF_CONST(type, symbol) \ + _NS_EXTERN type const NS##symbol _NS_PRIVATE_IMPORT; \ + type const NS::symbol = (nullptr != &NS##symbol) ? NS##symbol : type() +#else +#define _NS_PRIVATE_DEF_CONST(type, symbol) \ + _NS_EXTERN type const MTL##symbol _NS_PRIVATE_IMPORT; \ + type const NS::symbol = Private::LoadSymbol("NS" #symbol) +#endif + +#else + +#define _NS_PRIVATE_DEF_CLS(symbol) extern void* s_k##symbol +#define _NS_PRIVATE_DEF_PRO(symbol) extern void* s_k##symbol +#define _NS_PRIVATE_DEF_SEL(accessor, symbol) extern SEL s_k##accessor +#define _NS_PRIVATE_DEF_CONST(type, symbol) extern type const NS::symbol + +#endif // NS_PRIVATE_IMPLEMENTATION + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace NS +{ +namespace Private +{ + namespace Class + { + + _NS_PRIVATE_DEF_CLS(NSArray); + _NS_PRIVATE_DEF_CLS(NSAutoreleasePool); + _NS_PRIVATE_DEF_CLS(NSBundle); + _NS_PRIVATE_DEF_CLS(NSCondition); + _NS_PRIVATE_DEF_CLS(NSDate); + _NS_PRIVATE_DEF_CLS(NSDictionary); + _NS_PRIVATE_DEF_CLS(NSError); + _NS_PRIVATE_DEF_CLS(NSNotificationCenter); + _NS_PRIVATE_DEF_CLS(NSNumber); + _NS_PRIVATE_DEF_CLS(NSObject); + _NS_PRIVATE_DEF_CLS(NSProcessInfo); + _NS_PRIVATE_DEF_CLS(NSSet); + _NS_PRIVATE_DEF_CLS(NSString); + _NS_PRIVATE_DEF_CLS(NSURL); + _NS_PRIVATE_DEF_CLS(NSValue); + + } // Class +} // Private +} // MTL + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace NS +{ +namespace Private +{ + namespace Protocol + { + + } // Protocol +} // Private +} // NS + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace NS +{ +namespace Private +{ + namespace Selector + { + + _NS_PRIVATE_DEF_SEL(addObject_, + "addObject:"); + _NS_PRIVATE_DEF_SEL(addObserverName_object_queue_block_, + "addObserverForName:object:queue:usingBlock:"); + _NS_PRIVATE_DEF_SEL(activeProcessorCount, + "activeProcessorCount"); + _NS_PRIVATE_DEF_SEL(allBundles, + "allBundles"); + _NS_PRIVATE_DEF_SEL(allFrameworks, + "allFrameworks"); + _NS_PRIVATE_DEF_SEL(allObjects, + "allObjects"); + _NS_PRIVATE_DEF_SEL(alloc, + "alloc"); + _NS_PRIVATE_DEF_SEL(appStoreReceiptURL, + "appStoreReceiptURL"); + _NS_PRIVATE_DEF_SEL(arguments, + "arguments"); + _NS_PRIVATE_DEF_SEL(array, + "array"); + _NS_PRIVATE_DEF_SEL(arrayWithObject_, + "arrayWithObject:"); + _NS_PRIVATE_DEF_SEL(arrayWithObjects_count_, + "arrayWithObjects:count:"); + _NS_PRIVATE_DEF_SEL(automaticTerminationSupportEnabled, + "automaticTerminationSupportEnabled"); + _NS_PRIVATE_DEF_SEL(autorelease, + "autorelease"); + _NS_PRIVATE_DEF_SEL(beginActivityWithOptions_reason_, + "beginActivityWithOptions:reason:"); + _NS_PRIVATE_DEF_SEL(boolValue, + "boolValue"); + _NS_PRIVATE_DEF_SEL(broadcast, + "broadcast"); + _NS_PRIVATE_DEF_SEL(builtInPlugInsPath, + "builtInPlugInsPath"); + _NS_PRIVATE_DEF_SEL(builtInPlugInsURL, + "builtInPlugInsURL"); + _NS_PRIVATE_DEF_SEL(bundleIdentifier, + "bundleIdentifier"); + _NS_PRIVATE_DEF_SEL(bundlePath, + "bundlePath"); + _NS_PRIVATE_DEF_SEL(bundleURL, + "bundleURL"); + _NS_PRIVATE_DEF_SEL(bundleWithPath_, + "bundleWithPath:"); + _NS_PRIVATE_DEF_SEL(bundleWithURL_, + "bundleWithURL:"); + _NS_PRIVATE_DEF_SEL(caseInsensitiveCompare_, + "caseInsensitiveCompare:"); + _NS_PRIVATE_DEF_SEL(characterAtIndex_, + "characterAtIndex:"); + _NS_PRIVATE_DEF_SEL(charValue, + "charValue"); + _NS_PRIVATE_DEF_SEL(countByEnumeratingWithState_objects_count_, + "countByEnumeratingWithState:objects:count:"); + _NS_PRIVATE_DEF_SEL(cStringUsingEncoding_, + "cStringUsingEncoding:"); + _NS_PRIVATE_DEF_SEL(code, + "code"); + _NS_PRIVATE_DEF_SEL(compare_, + "compare:"); + _NS_PRIVATE_DEF_SEL(copy, + "copy"); + _NS_PRIVATE_DEF_SEL(count, + "count"); + _NS_PRIVATE_DEF_SEL(dateWithTimeIntervalSinceNow_, + "dateWithTimeIntervalSinceNow:"); + _NS_PRIVATE_DEF_SEL(defaultCenter, + "defaultCenter"); + _NS_PRIVATE_DEF_SEL(descriptionWithLocale_, + "descriptionWithLocale:"); + _NS_PRIVATE_DEF_SEL(disableAutomaticTermination_, + "disableAutomaticTermination:"); + _NS_PRIVATE_DEF_SEL(disableSuddenTermination, + "disableSuddenTermination"); + _NS_PRIVATE_DEF_SEL(debugDescription, + "debugDescription"); + _NS_PRIVATE_DEF_SEL(description, + "description"); + _NS_PRIVATE_DEF_SEL(dictionary, + "dictionary"); + _NS_PRIVATE_DEF_SEL(dictionaryWithObject_forKey_, + "dictionaryWithObject:forKey:"); + _NS_PRIVATE_DEF_SEL(dictionaryWithObjects_forKeys_count_, + "dictionaryWithObjects:forKeys:count:"); + _NS_PRIVATE_DEF_SEL(domain, + "domain"); + _NS_PRIVATE_DEF_SEL(doubleValue, + "doubleValue"); + _NS_PRIVATE_DEF_SEL(drain, + "drain"); + _NS_PRIVATE_DEF_SEL(enableAutomaticTermination_, + "enableAutomaticTermination:"); + _NS_PRIVATE_DEF_SEL(enableSuddenTermination, + "enableSuddenTermination"); + _NS_PRIVATE_DEF_SEL(endActivity_, + "endActivity:"); + _NS_PRIVATE_DEF_SEL(environment, + "environment"); + _NS_PRIVATE_DEF_SEL(errorWithDomain_code_userInfo_, + "errorWithDomain:code:userInfo:"); + _NS_PRIVATE_DEF_SEL(executablePath, + "executablePath"); + _NS_PRIVATE_DEF_SEL(executableURL, + "executableURL"); + _NS_PRIVATE_DEF_SEL(fileSystemRepresentation, + "fileSystemRepresentation"); + _NS_PRIVATE_DEF_SEL(fileURLWithPath_, + "fileURLWithPath:"); + _NS_PRIVATE_DEF_SEL(floatValue, + "floatValue"); + _NS_PRIVATE_DEF_SEL(fullUserName, + "fullUserName"); + _NS_PRIVATE_DEF_SEL(getValue_size_, + "getValue:size:"); + _NS_PRIVATE_DEF_SEL(globallyUniqueString, + "globallyUniqueString"); + _NS_PRIVATE_DEF_SEL(hash, + "hash"); + _NS_PRIVATE_DEF_SEL(hasPerformanceProfile_, + "hasPerformanceProfile:"); + _NS_PRIVATE_DEF_SEL(hostName, + "hostName"); + _NS_PRIVATE_DEF_SEL(infoDictionary, + "infoDictionary"); + _NS_PRIVATE_DEF_SEL(init, + "init"); + _NS_PRIVATE_DEF_SEL(initFileURLWithPath_, + "initFileURLWithPath:"); + _NS_PRIVATE_DEF_SEL(initWithBool_, + "initWithBool:"); + _NS_PRIVATE_DEF_SEL(initWithBytes_objCType_, + "initWithBytes:objCType:"); + _NS_PRIVATE_DEF_SEL(initWithBytesNoCopy_length_encoding_freeWhenDone_, + "initWithBytesNoCopy:length:encoding:freeWhenDone:"); + _NS_PRIVATE_DEF_SEL(initWithChar_, + "initWithChar:"); + _NS_PRIVATE_DEF_SEL(initWithCoder_, + "initWithCoder:"); + _NS_PRIVATE_DEF_SEL(initWithCString_encoding_, + "initWithCString:encoding:"); + _NS_PRIVATE_DEF_SEL(initWithDomain_code_userInfo_, + "initWithDomain:code:userInfo:"); + _NS_PRIVATE_DEF_SEL(initWithDouble_, + "initWithDouble:"); + _NS_PRIVATE_DEF_SEL(initWithFloat_, + "initWithFloat:"); + _NS_PRIVATE_DEF_SEL(initWithInt_, + "initWithInt:"); + _NS_PRIVATE_DEF_SEL(initWithLong_, + "initWithLong:"); + _NS_PRIVATE_DEF_SEL(initWithLongLong_, + "initWithLongLong:"); + _NS_PRIVATE_DEF_SEL(initWithObjects_count_, + "initWithObjects:count:"); + _NS_PRIVATE_DEF_SEL(initWithObjects_forKeys_count_, + "initWithObjects:forKeys:count:"); + _NS_PRIVATE_DEF_SEL(initWithPath_, + "initWithPath:"); + _NS_PRIVATE_DEF_SEL(initWithShort_, + "initWithShort:"); + _NS_PRIVATE_DEF_SEL(initWithString_, + "initWithString:"); + _NS_PRIVATE_DEF_SEL(initWithUnsignedChar_, + "initWithUnsignedChar:"); + _NS_PRIVATE_DEF_SEL(initWithUnsignedInt_, + "initWithUnsignedInt:"); + _NS_PRIVATE_DEF_SEL(initWithUnsignedLong_, + "initWithUnsignedLong:"); + _NS_PRIVATE_DEF_SEL(initWithUnsignedLongLong_, + "initWithUnsignedLongLong:"); + _NS_PRIVATE_DEF_SEL(initWithUnsignedShort_, + "initWithUnsignedShort:"); + _NS_PRIVATE_DEF_SEL(initWithURL_, + "initWithURL:"); + _NS_PRIVATE_DEF_SEL(integerValue, + "integerValue"); + _NS_PRIVATE_DEF_SEL(intValue, + "intValue"); + _NS_PRIVATE_DEF_SEL(isDeviceCertified_, + "isDeviceCertifiedFor:"); + _NS_PRIVATE_DEF_SEL(isEqual_, + "isEqual:"); + _NS_PRIVATE_DEF_SEL(isEqualToNumber_, + "isEqualToNumber:"); + _NS_PRIVATE_DEF_SEL(isEqualToString_, + "isEqualToString:"); + _NS_PRIVATE_DEF_SEL(isEqualToValue_, + "isEqualToValue:"); + _NS_PRIVATE_DEF_SEL(isiOSAppOnMac, + "isiOSAppOnMac"); + _NS_PRIVATE_DEF_SEL(isLoaded, + "isLoaded"); + _NS_PRIVATE_DEF_SEL(isLowPowerModeEnabled, + "isLowPowerModeEnabled"); + _NS_PRIVATE_DEF_SEL(isMacCatalystApp, + "isMacCatalystApp"); + _NS_PRIVATE_DEF_SEL(isOperatingSystemAtLeastVersion_, + "isOperatingSystemAtLeastVersion:"); + _NS_PRIVATE_DEF_SEL(keyEnumerator, + "keyEnumerator"); + _NS_PRIVATE_DEF_SEL(length, + "length"); + _NS_PRIVATE_DEF_SEL(lengthOfBytesUsingEncoding_, + "lengthOfBytesUsingEncoding:"); + _NS_PRIVATE_DEF_SEL(load, + "load"); + _NS_PRIVATE_DEF_SEL(loadAndReturnError_, + "loadAndReturnError:"); + _NS_PRIVATE_DEF_SEL(localizedDescription, + "localizedDescription"); + _NS_PRIVATE_DEF_SEL(localizedFailureReason, + "localizedFailureReason"); + _NS_PRIVATE_DEF_SEL(localizedInfoDictionary, + "localizedInfoDictionary"); + _NS_PRIVATE_DEF_SEL(localizedRecoveryOptions, + "localizedRecoveryOptions"); + _NS_PRIVATE_DEF_SEL(localizedRecoverySuggestion, + "localizedRecoverySuggestion"); + _NS_PRIVATE_DEF_SEL(localizedStringForKey_value_table_, + "localizedStringForKey:value:table:"); + _NS_PRIVATE_DEF_SEL(lock, + "lock"); + _NS_PRIVATE_DEF_SEL(longValue, + "longValue"); + _NS_PRIVATE_DEF_SEL(longLongValue, + "longLongValue"); + _NS_PRIVATE_DEF_SEL(mainBundle, + "mainBundle"); + _NS_PRIVATE_DEF_SEL(maximumLengthOfBytesUsingEncoding_, + "maximumLengthOfBytesUsingEncoding:"); + _NS_PRIVATE_DEF_SEL(methodSignatureForSelector_, + "methodSignatureForSelector:"); + _NS_PRIVATE_DEF_SEL(mutableBytes, + "mutableBytes"); + _NS_PRIVATE_DEF_SEL(name, + "name"); + _NS_PRIVATE_DEF_SEL(nextObject, + "nextObject"); + _NS_PRIVATE_DEF_SEL(numberWithBool_, + "numberWithBool:"); + _NS_PRIVATE_DEF_SEL(numberWithChar_, + "numberWithChar:"); + _NS_PRIVATE_DEF_SEL(numberWithDouble_, + "numberWithDouble:"); + _NS_PRIVATE_DEF_SEL(numberWithFloat_, + "numberWithFloat:"); + _NS_PRIVATE_DEF_SEL(numberWithInt_, + "numberWithInt:"); + _NS_PRIVATE_DEF_SEL(numberWithLong_, + "numberWithLong:"); + _NS_PRIVATE_DEF_SEL(numberWithLongLong_, + "numberWithLongLong:"); + _NS_PRIVATE_DEF_SEL(numberWithShort_, + "numberWithShort:"); + _NS_PRIVATE_DEF_SEL(numberWithUnsignedChar_, + "numberWithUnsignedChar:"); + _NS_PRIVATE_DEF_SEL(numberWithUnsignedInt_, + "numberWithUnsignedInt:"); + _NS_PRIVATE_DEF_SEL(numberWithUnsignedLong_, + "numberWithUnsignedLong:"); + _NS_PRIVATE_DEF_SEL(numberWithUnsignedLongLong_, + "numberWithUnsignedLongLong:"); + _NS_PRIVATE_DEF_SEL(numberWithUnsignedShort_, + "numberWithUnsignedShort:"); + _NS_PRIVATE_DEF_SEL(objCType, + "objCType"); + _NS_PRIVATE_DEF_SEL(object, + "object"); + _NS_PRIVATE_DEF_SEL(objectAtIndex_, + "objectAtIndex:"); + _NS_PRIVATE_DEF_SEL(objectEnumerator, + "objectEnumerator"); + _NS_PRIVATE_DEF_SEL(objectForInfoDictionaryKey_, + "objectForInfoDictionaryKey:"); + _NS_PRIVATE_DEF_SEL(objectForKey_, + "objectForKey:"); + _NS_PRIVATE_DEF_SEL(operatingSystem, + "operatingSystem"); + _NS_PRIVATE_DEF_SEL(operatingSystemVersion, + "operatingSystemVersion"); + _NS_PRIVATE_DEF_SEL(operatingSystemVersionString, + "operatingSystemVersionString"); + _NS_PRIVATE_DEF_SEL(pathForAuxiliaryExecutable_, + "pathForAuxiliaryExecutable:"); + _NS_PRIVATE_DEF_SEL(performActivityWithOptions_reason_usingBlock_, + "performActivityWithOptions:reason:usingBlock:"); + _NS_PRIVATE_DEF_SEL(performExpiringActivityWithReason_usingBlock_, + "performExpiringActivityWithReason:usingBlock:"); + _NS_PRIVATE_DEF_SEL(physicalMemory, + "physicalMemory"); + _NS_PRIVATE_DEF_SEL(pointerValue, + "pointerValue"); + _NS_PRIVATE_DEF_SEL(preflightAndReturnError_, + "preflightAndReturnError:"); + _NS_PRIVATE_DEF_SEL(privateFrameworksPath, + "privateFrameworksPath"); + _NS_PRIVATE_DEF_SEL(privateFrameworksURL, + "privateFrameworksURL"); + _NS_PRIVATE_DEF_SEL(processIdentifier, + "processIdentifier"); + _NS_PRIVATE_DEF_SEL(processInfo, + "processInfo"); + _NS_PRIVATE_DEF_SEL(processName, + "processName"); + _NS_PRIVATE_DEF_SEL(processorCount, + "processorCount"); + _NS_PRIVATE_DEF_SEL(rangeOfString_options_, + "rangeOfString:options:"); + _NS_PRIVATE_DEF_SEL(release, + "release"); + _NS_PRIVATE_DEF_SEL(removeObserver_, + "removeObserver:"); + _NS_PRIVATE_DEF_SEL(resourcePath, + "resourcePath"); + _NS_PRIVATE_DEF_SEL(resourceURL, + "resourceURL"); + _NS_PRIVATE_DEF_SEL(respondsToSelector_, + "respondsToSelector:"); + _NS_PRIVATE_DEF_SEL(retain, + "retain"); + _NS_PRIVATE_DEF_SEL(retainCount, + "retainCount"); + _NS_PRIVATE_DEF_SEL(setAutomaticTerminationSupportEnabled_, + "setAutomaticTerminationSupportEnabled:"); + _NS_PRIVATE_DEF_SEL(setProcessName_, + "setProcessName:"); + _NS_PRIVATE_DEF_SEL(sharedFrameworksPath, + "sharedFrameworksPath"); + _NS_PRIVATE_DEF_SEL(sharedFrameworksURL, + "sharedFrameworksURL"); + _NS_PRIVATE_DEF_SEL(sharedSupportPath, + "sharedSupportPath"); + _NS_PRIVATE_DEF_SEL(sharedSupportURL, + "sharedSupportURL"); + _NS_PRIVATE_DEF_SEL(shortValue, + "shortValue"); + _NS_PRIVATE_DEF_SEL(showPools, + "showPools"); + _NS_PRIVATE_DEF_SEL(signal, + "signal"); + _NS_PRIVATE_DEF_SEL(string, + "string"); + _NS_PRIVATE_DEF_SEL(stringValue, + "stringValue"); + _NS_PRIVATE_DEF_SEL(stringWithString_, + "stringWithString:"); + _NS_PRIVATE_DEF_SEL(stringWithCString_encoding_, + "stringWithCString:encoding:"); + _NS_PRIVATE_DEF_SEL(stringByAppendingString_, + "stringByAppendingString:"); + _NS_PRIVATE_DEF_SEL(systemUptime, + "systemUptime"); + _NS_PRIVATE_DEF_SEL(thermalState, + "thermalState"); + _NS_PRIVATE_DEF_SEL(unload, + "unload"); + _NS_PRIVATE_DEF_SEL(unlock, + "unlock"); + _NS_PRIVATE_DEF_SEL(unsignedCharValue, + "unsignedCharValue"); + _NS_PRIVATE_DEF_SEL(unsignedIntegerValue, + "unsignedIntegerValue"); + _NS_PRIVATE_DEF_SEL(unsignedIntValue, + "unsignedIntValue"); + _NS_PRIVATE_DEF_SEL(unsignedLongValue, + "unsignedLongValue"); + _NS_PRIVATE_DEF_SEL(unsignedLongLongValue, + "unsignedLongLongValue"); + _NS_PRIVATE_DEF_SEL(unsignedShortValue, + "unsignedShortValue"); + _NS_PRIVATE_DEF_SEL(URLForAuxiliaryExecutable_, + "URLForAuxiliaryExecutable:"); + _NS_PRIVATE_DEF_SEL(userInfo, + "userInfo"); + _NS_PRIVATE_DEF_SEL(userName, + "userName"); + _NS_PRIVATE_DEF_SEL(UTF8String, + "UTF8String"); + _NS_PRIVATE_DEF_SEL(valueWithBytes_objCType_, + "valueWithBytes:objCType:"); + _NS_PRIVATE_DEF_SEL(valueWithPointer_, + "valueWithPointer:"); + _NS_PRIVATE_DEF_SEL(wait, + "wait"); + _NS_PRIVATE_DEF_SEL(waitUntilDate_, + "waitUntilDate:"); + } // Class +} // Private +} // MTL + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/Source/Cxxmlx/metal-cpp/Foundation/NSProcessInfo.hpp b/Source/Cxxmlx/metal-cpp/Foundation/NSProcessInfo.hpp new file mode 100644 index 00000000..09c212d5 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Foundation/NSProcessInfo.hpp @@ -0,0 +1,386 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Foundation/NSProcessInfo.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "NSDefines.hpp" +#include "NSNotification.hpp" +#include "NSObject.hpp" +#include "NSPrivate.hpp" +#include "NSTypes.hpp" + +#include + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace NS +{ +_NS_CONST(NotificationName, ProcessInfoThermalStateDidChangeNotification); +_NS_CONST(NotificationName, ProcessInfoPowerStateDidChangeNotification); +_NS_CONST(NotificationName, ProcessInfoPerformanceProfileDidChangeNotification); + +_NS_ENUM(NS::Integer, ProcessInfoThermalState) { + ProcessInfoThermalStateNominal = 0, + ProcessInfoThermalStateFair = 1, + ProcessInfoThermalStateSerious = 2, + ProcessInfoThermalStateCritical = 3 +}; + +_NS_OPTIONS(std::uint64_t, ActivityOptions) { + ActivityIdleDisplaySleepDisabled = (1ULL << 40), + ActivityIdleSystemSleepDisabled = (1ULL << 20), + ActivitySuddenTerminationDisabled = (1ULL << 14), + ActivityAutomaticTerminationDisabled = (1ULL << 15), + ActivityUserInitiated = (0x00FFFFFFULL | ActivityIdleSystemSleepDisabled), + ActivityUserInitiatedAllowingIdleSystemSleep = (ActivityUserInitiated & ~ActivityIdleSystemSleepDisabled), + ActivityBackground = 0x000000FFULL, + ActivityLatencyCritical = 0xFF00000000ULL, +}; + +typedef NS::Integer DeviceCertification; +_NS_CONST(DeviceCertification, DeviceCertificationiPhonePerformanceGaming); + +typedef NS::Integer ProcessPerformanceProfile; +_NS_CONST(ProcessPerformanceProfile, ProcessPerformanceProfileDefault); +_NS_CONST(ProcessPerformanceProfile, ProcessPerformanceProfileSustained); + +class ProcessInfo : public Referencing +{ +public: + static ProcessInfo* processInfo(); + + class Array* arguments() const; + class Dictionary* environment() const; + class String* hostName() const; + class String* processName() const; + void setProcessName(const String* pString); + int processIdentifier() const; + class String* globallyUniqueString() const; + + class String* userName() const; + class String* fullUserName() const; + + UInteger operatingSystem() const; + OperatingSystemVersion operatingSystemVersion() const; + class String* operatingSystemVersionString() const; + bool isOperatingSystemAtLeastVersion(OperatingSystemVersion version) const; + + UInteger processorCount() const; + UInteger activeProcessorCount() const; + unsigned long long physicalMemory() const; + TimeInterval systemUptime() const; + + void disableSuddenTermination(); + void enableSuddenTermination(); + + void disableAutomaticTermination(const class String* pReason); + void enableAutomaticTermination(const class String* pReason); + bool automaticTerminationSupportEnabled() const; + void setAutomaticTerminationSupportEnabled(bool enabled); + + class Object* beginActivity(ActivityOptions options, const class String* pReason); + void endActivity(class Object* pActivity); + void performActivity(ActivityOptions options, const class String* pReason, void (^block)(void)); + void performActivity(ActivityOptions options, const class String* pReason, const std::function& func); + void performExpiringActivity(const class String* pReason, void (^block)(bool expired)); + void performExpiringActivity(const class String* pReason, const std::function& func); + + ProcessInfoThermalState thermalState() const; + bool isLowPowerModeEnabled() const; + + bool isiOSAppOnMac() const; + bool isMacCatalystApp() const; + + bool isDeviceCertified(DeviceCertification performanceTier) const; + bool hasPerformanceProfile(ProcessPerformanceProfile performanceProfile) const; + +}; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_PRIVATE_DEF_CONST(NS::NotificationName, ProcessInfoThermalStateDidChangeNotification); +_NS_PRIVATE_DEF_CONST(NS::NotificationName, ProcessInfoPowerStateDidChangeNotification); + +// The linker searches for these symbols in the Metal framework, be sure to link it in as well: +_NS_PRIVATE_DEF_CONST(NS::NotificationName, ProcessInfoPerformanceProfileDidChangeNotification); +_NS_PRIVATE_DEF_CONST(NS::DeviceCertification, DeviceCertificationiPhonePerformanceGaming); +_NS_PRIVATE_DEF_CONST(NS::ProcessPerformanceProfile, ProcessPerformanceProfileDefault); +_NS_PRIVATE_DEF_CONST(NS::ProcessPerformanceProfile, ProcessPerformanceProfileSustained); + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::ProcessInfo* NS::ProcessInfo::processInfo() +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSProcessInfo), _NS_PRIVATE_SEL(processInfo)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Array* NS::ProcessInfo::arguments() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(arguments)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Dictionary* NS::ProcessInfo::environment() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(environment)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::ProcessInfo::hostName() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(hostName)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::ProcessInfo::processName() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(processName)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE void NS::ProcessInfo::setProcessName(const String* pString) +{ + Object::sendMessage(this, _NS_PRIVATE_SEL(setProcessName_), pString); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE int NS::ProcessInfo::processIdentifier() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(processIdentifier)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::ProcessInfo::globallyUniqueString() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(globallyUniqueString)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::ProcessInfo::userName() const +{ + return Object::sendMessageSafe(this, _NS_PRIVATE_SEL(userName)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::ProcessInfo::fullUserName() const +{ + return Object::sendMessageSafe(this, _NS_PRIVATE_SEL(fullUserName)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::UInteger NS::ProcessInfo::operatingSystem() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(operatingSystem)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::OperatingSystemVersion NS::ProcessInfo::operatingSystemVersion() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(operatingSystemVersion)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::ProcessInfo::operatingSystemVersionString() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(operatingSystemVersionString)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE bool NS::ProcessInfo::isOperatingSystemAtLeastVersion(OperatingSystemVersion version) const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(isOperatingSystemAtLeastVersion_), version); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::UInteger NS::ProcessInfo::processorCount() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(processorCount)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::UInteger NS::ProcessInfo::activeProcessorCount() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(activeProcessorCount)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE unsigned long long NS::ProcessInfo::physicalMemory() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(physicalMemory)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::TimeInterval NS::ProcessInfo::systemUptime() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(systemUptime)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE void NS::ProcessInfo::disableSuddenTermination() +{ + Object::sendMessageSafe(this, _NS_PRIVATE_SEL(disableSuddenTermination)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE void NS::ProcessInfo::enableSuddenTermination() +{ + Object::sendMessageSafe(this, _NS_PRIVATE_SEL(enableSuddenTermination)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE void NS::ProcessInfo::disableAutomaticTermination(const String* pReason) +{ + Object::sendMessageSafe(this, _NS_PRIVATE_SEL(disableAutomaticTermination_), pReason); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE void NS::ProcessInfo::enableAutomaticTermination(const String* pReason) +{ + Object::sendMessageSafe(this, _NS_PRIVATE_SEL(enableAutomaticTermination_), pReason); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE bool NS::ProcessInfo::automaticTerminationSupportEnabled() const +{ + return Object::sendMessageSafe(this, _NS_PRIVATE_SEL(automaticTerminationSupportEnabled)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE void NS::ProcessInfo::setAutomaticTerminationSupportEnabled(bool enabled) +{ + Object::sendMessageSafe(this, _NS_PRIVATE_SEL(setAutomaticTerminationSupportEnabled_), enabled); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Object* NS::ProcessInfo::beginActivity(ActivityOptions options, const String* pReason) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(beginActivityWithOptions_reason_), options, pReason); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE void NS::ProcessInfo::endActivity(Object* pActivity) +{ + Object::sendMessage(this, _NS_PRIVATE_SEL(endActivity_), pActivity); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE void NS::ProcessInfo::performActivity(ActivityOptions options, const String* pReason, void (^block)(void)) +{ + Object::sendMessage(this, _NS_PRIVATE_SEL(performActivityWithOptions_reason_usingBlock_), options, pReason, block); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE void NS::ProcessInfo::performActivity(ActivityOptions options, const String* pReason, const std::function& function) +{ + __block std::function blockFunction = function; + + performActivity(options, pReason, ^() { blockFunction(); }); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE void NS::ProcessInfo::performExpiringActivity(const String* pReason, void (^block)(bool expired)) +{ + Object::sendMessageSafe(this, _NS_PRIVATE_SEL(performExpiringActivityWithReason_usingBlock_), pReason, block); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE void NS::ProcessInfo::performExpiringActivity(const String* pReason, const std::function& function) +{ + __block std::function blockFunction = function; + + performExpiringActivity(pReason, ^(bool expired) { blockFunction(expired); }); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::ProcessInfoThermalState NS::ProcessInfo::thermalState() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(thermalState)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE bool NS::ProcessInfo::isLowPowerModeEnabled() const +{ + return Object::sendMessageSafe(this, _NS_PRIVATE_SEL(isLowPowerModeEnabled)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE bool NS::ProcessInfo::isiOSAppOnMac() const +{ + return Object::sendMessageSafe(this, _NS_PRIVATE_SEL(isiOSAppOnMac)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE bool NS::ProcessInfo::isMacCatalystApp() const +{ + return Object::sendMessageSafe(this, _NS_PRIVATE_SEL(isMacCatalystApp)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE bool NS::ProcessInfo::isDeviceCertified(DeviceCertification performanceTier) const +{ + return Object::sendMessageSafe(this, _NS_PRIVATE_SEL(isDeviceCertified_), performanceTier); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE bool NS::ProcessInfo::hasPerformanceProfile(ProcessPerformanceProfile performanceProfile) const +{ + return Object::sendMessageSafe(this, _NS_PRIVATE_SEL(hasPerformanceProfile_), performanceProfile); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/Source/Cxxmlx/metal-cpp/Foundation/NSRange.hpp b/Source/Cxxmlx/metal-cpp/Foundation/NSRange.hpp new file mode 100644 index 00000000..8500271d --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Foundation/NSRange.hpp @@ -0,0 +1,83 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Foundation/NSRange.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "NSDefines.hpp" +#include "NSTypes.hpp" + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace NS +{ +struct Range +{ + static Range Make(UInteger loc, UInteger len); + + Range(UInteger loc, UInteger len); + + bool Equal(const Range& range) const; + bool LocationInRange(UInteger loc) const; + UInteger Max() const; + + UInteger location; + UInteger length; +} _NS_PACKED; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Range::Range(UInteger loc, UInteger len) + : location(loc) + , length(len) +{ +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Range NS::Range::Make(UInteger loc, UInteger len) +{ + return Range(loc, len); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE bool NS::Range::Equal(const Range& range) const +{ + return (location == range.location) && (length == range.length); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE bool NS::Range::LocationInRange(UInteger loc) const +{ + return (!(loc < location)) && ((loc - location) < length); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::UInteger NS::Range::Max() const +{ + return location + length; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/Source/Cxxmlx/metal-cpp/Foundation/NSSet.hpp b/Source/Cxxmlx/metal-cpp/Foundation/NSSet.hpp new file mode 100644 index 00000000..382b6714 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Foundation/NSSet.hpp @@ -0,0 +1,87 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Foundation/NSSet.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "NSObject.hpp" +#include "NSEnumerator.hpp" + +/*****Immutable Set*******/ + +namespace NS +{ + class Set : public NS::Copying + { + public: + UInteger count() const; + Enumerator* objectEnumerator() const; + + static Set* alloc(); + + Set* init(); + Set* init(const Object* const* pObjects, UInteger count); + Set* init(const class Coder* pCoder); + + }; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::UInteger NS::Set::count() const +{ + return NS::Object::sendMessage(this, _NS_PRIVATE_SEL(count)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Enumerator* NS::Set::objectEnumerator() const +{ + return NS::Object::sendMessage*>(this, _NS_PRIVATE_SEL(objectEnumerator)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Set* NS::Set::alloc() +{ + return NS::Object::alloc(_NS_PRIVATE_CLS(NSSet)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Set* NS::Set::init() +{ + return NS::Object::init(); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Set* NS::Set::init(const Object* const* pObjects, NS::UInteger count) +{ + return NS::Object::sendMessage(this, _NS_PRIVATE_SEL(initWithObjects_count_), pObjects, count); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Set* NS::Set::init(const class Coder* pCoder) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(initWithCoder_), pCoder); +} diff --git a/Source/Cxxmlx/metal-cpp/Foundation/NSSharedPtr.hpp b/Source/Cxxmlx/metal-cpp/Foundation/NSSharedPtr.hpp new file mode 100644 index 00000000..f1cf68e4 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Foundation/NSSharedPtr.hpp @@ -0,0 +1,310 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Foundation/NSSharedPtr.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include +#include "NSDefines.hpp" + +namespace NS +{ +template +class SharedPtr +{ +public: + /** + * Create a new null pointer. + */ + SharedPtr(); + + /** + * Destroy this SharedPtr, decreasing the reference count. + */ + ~SharedPtr(); + + /** + * Create a new null pointer. + */ + SharedPtr(std::nullptr_t) noexcept; + + /** + * SharedPtr copy constructor. + */ + SharedPtr(const SharedPtr<_Class>& other) noexcept; + + /** + * Construction from another pointee type. + */ + template + SharedPtr(const SharedPtr<_OtherClass>& other, typename std::enable_if_t> * = nullptr) noexcept; + + /** + * SharedPtr move constructor. + */ + SharedPtr(SharedPtr<_Class>&& other) noexcept; + + /** + * Move from another pointee type. + */ + template + SharedPtr(SharedPtr<_OtherClass>&& other, typename std::enable_if_t> * = nullptr) noexcept; + + /** + * Copy assignment operator. + * Copying increases reference count. Only releases previous pointee if objects are different. + */ + SharedPtr& operator=(const SharedPtr<_Class>& other); + + /** + * Copy-assignment from different pointee. + * Copying increases reference count. Only releases previous pointee if objects are different. + */ + template + typename std::enable_if_t, SharedPtr &> + operator=(const SharedPtr<_OtherClass>& other); + + /** + * Move assignment operator. + * Move without affecting reference counts, unless pointees are equal. Moved-from object is reset to nullptr. + */ + SharedPtr& operator=(SharedPtr<_Class>&& other); + + /** + * Move-asignment from different pointee. + * Move without affecting reference counts, unless pointees are equal. Moved-from object is reset to nullptr. + */ + template + typename std::enable_if_t, SharedPtr &> + operator=(SharedPtr<_OtherClass>&& other); + + /** + * Access raw pointee. + * @warning Avoid wrapping the returned value again, as it may lead double frees unless this object becomes detached. + */ + _Class* get() const; + + /** + * Call operations directly on the pointee. + */ + _Class* operator->() const; + + /** + * Implicit cast to bool. + */ + explicit operator bool() const; + + /** + * Reset this SharedPtr to null, decreasing the reference count. + */ + void reset(); + + /** + * Detach the SharedPtr from the pointee, without decreasing the reference count. + */ + void detach(); + + template + friend SharedPtr<_OtherClass> RetainPtr(_OtherClass* ptr); + + template + friend SharedPtr<_OtherClass> TransferPtr(_OtherClass* ptr); + +private: + _Class* m_pObject; +}; + +/** + * Create a SharedPtr by retaining an existing raw pointer. + * Increases the reference count of the passed-in object. + * If the passed-in object was in an AutoreleasePool, it will be removed from it. + */ +template +_NS_INLINE NS::SharedPtr<_Class> RetainPtr(_Class* pObject) +{ + NS::SharedPtr<_Class> ret; + ret.m_pObject = pObject->retain(); + return ret; +} + +/* + * Create a SharedPtr by transfering the ownership of an existing raw pointer to SharedPtr. + * Does not increase the reference count of the passed-in pointer, it is assumed to be >= 1. + * This method does not remove objects from an AutoreleasePool. +*/ +template +_NS_INLINE NS::SharedPtr<_Class> TransferPtr(_Class* pObject) +{ + NS::SharedPtr<_Class> ret; + ret.m_pObject = pObject; + return ret; +} + +} + +template +_NS_INLINE NS::SharedPtr<_Class>::SharedPtr() + : m_pObject(nullptr) +{ +} + +template +_NS_INLINE NS::SharedPtr<_Class>::~SharedPtr<_Class>() __attribute__((no_sanitize("undefined"))) +{ + m_pObject->release(); +} + +template +_NS_INLINE NS::SharedPtr<_Class>::SharedPtr(std::nullptr_t) noexcept + : m_pObject(nullptr) +{ +} + +template +_NS_INLINE NS::SharedPtr<_Class>::SharedPtr(const SharedPtr<_Class>& other) noexcept + : m_pObject(other.m_pObject->retain()) +{ +} + +template +template +_NS_INLINE NS::SharedPtr<_Class>::SharedPtr(const SharedPtr<_OtherClass>& other, typename std::enable_if_t> *) noexcept + : m_pObject(reinterpret_cast<_Class*>(other.get()->retain())) +{ +} + +template +_NS_INLINE NS::SharedPtr<_Class>::SharedPtr(SharedPtr<_Class>&& other) noexcept + : m_pObject(other.m_pObject) +{ + other.m_pObject = nullptr; +} + +template +template +_NS_INLINE NS::SharedPtr<_Class>::SharedPtr(SharedPtr<_OtherClass>&& other, typename std::enable_if_t> *) noexcept + : m_pObject(reinterpret_cast<_Class*>(other.get())) +{ + other.detach(); +} + +template +_NS_INLINE _Class* NS::SharedPtr<_Class>::get() const +{ + return m_pObject; +} + +template +_NS_INLINE _Class* NS::SharedPtr<_Class>::operator->() const +{ + return m_pObject; +} + +template +_NS_INLINE NS::SharedPtr<_Class>::operator bool() const +{ + return nullptr != m_pObject; +} + +template +_NS_INLINE void NS::SharedPtr<_Class>::reset() __attribute__((no_sanitize("undefined"))) +{ + m_pObject->release(); + m_pObject = nullptr; +} + +template +_NS_INLINE void NS::SharedPtr<_Class>::detach() +{ + m_pObject = nullptr; +} + +template +_NS_INLINE NS::SharedPtr<_Class>& NS::SharedPtr<_Class>::operator=(const SharedPtr<_Class>& other) __attribute__((no_sanitize("undefined"))) +{ + _Class* pOldObject = m_pObject; + + m_pObject = other.m_pObject->retain(); + + pOldObject->release(); + + return *this; +} + +template +template +typename std::enable_if_t, NS::SharedPtr<_Class> &> +_NS_INLINE NS::SharedPtr<_Class>::operator=(const SharedPtr<_OtherClass>& other) __attribute__((no_sanitize("undefined"))) +{ + _Class* pOldObject = m_pObject; + + m_pObject = reinterpret_cast<_Class*>(other.get()->retain()); + + pOldObject->release(); + + return *this; +} + +template +_NS_INLINE NS::SharedPtr<_Class>& NS::SharedPtr<_Class>::operator=(SharedPtr<_Class>&& other) __attribute__((no_sanitize("undefined"))) +{ + if (m_pObject != other.m_pObject) + { + m_pObject->release(); + m_pObject = other.m_pObject; + } + else + { + m_pObject = other.m_pObject; + other.m_pObject->release(); + } + other.m_pObject = nullptr; + return *this; +} + +template +template +typename std::enable_if_t, NS::SharedPtr<_Class> &> +_NS_INLINE NS::SharedPtr<_Class>::operator=(SharedPtr<_OtherClass>&& other) __attribute__((no_sanitize("undefined"))) +{ + if (m_pObject != other.get()) + { + m_pObject->release(); + m_pObject = reinterpret_cast<_Class*>(other.get()); + other.detach(); + } + else + { + m_pObject = other.get(); + other.reset(); + } + return *this; +} + +template +_NS_INLINE bool operator==(const NS::SharedPtr<_ClassLhs>& lhs, const NS::SharedPtr<_ClassRhs>& rhs) +{ + return lhs.get() == rhs.get(); +} + +template +_NS_INLINE bool operator!=(const NS::SharedPtr<_ClassLhs>& lhs, const NS::SharedPtr<_ClassRhs>& rhs) +{ + return lhs.get() != rhs.get(); +} diff --git a/Source/Cxxmlx/metal-cpp/Foundation/NSString.hpp b/Source/Cxxmlx/metal-cpp/Foundation/NSString.hpp new file mode 100644 index 00000000..c48e0689 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Foundation/NSString.hpp @@ -0,0 +1,255 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Foundation/NSString.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "NSDefines.hpp" +#include "NSObjCRuntime.hpp" +#include "NSObject.hpp" +#include "NSPrivate.hpp" +#include "NSRange.hpp" +#include "NSTypes.hpp" + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace NS +{ +_NS_ENUM(NS::UInteger, StringEncoding) { + ASCIIStringEncoding = 1, + NEXTSTEPStringEncoding = 2, + JapaneseEUCStringEncoding = 3, + UTF8StringEncoding = 4, + ISOLatin1StringEncoding = 5, + SymbolStringEncoding = 6, + NonLossyASCIIStringEncoding = 7, + ShiftJISStringEncoding = 8, + ISOLatin2StringEncoding = 9, + UnicodeStringEncoding = 10, + WindowsCP1251StringEncoding = 11, + WindowsCP1252StringEncoding = 12, + WindowsCP1253StringEncoding = 13, + WindowsCP1254StringEncoding = 14, + WindowsCP1250StringEncoding = 15, + ISO2022JPStringEncoding = 21, + MacOSRomanStringEncoding = 30, + + UTF16StringEncoding = UnicodeStringEncoding, + + UTF16BigEndianStringEncoding = 0x90000100, + UTF16LittleEndianStringEncoding = 0x94000100, + + UTF32StringEncoding = 0x8c000100, + UTF32BigEndianStringEncoding = 0x98000100, + UTF32LittleEndianStringEncoding = 0x9c000100 +}; + +_NS_OPTIONS(NS::UInteger, StringCompareOptions) { + CaseInsensitiveSearch = 1, + LiteralSearch = 2, + BackwardsSearch = 4, + AnchoredSearch = 8, + NumericSearch = 64, + DiacriticInsensitiveSearch = 128, + WidthInsensitiveSearch = 256, + ForcedOrderingSearch = 512, + RegularExpressionSearch = 1024 +}; + +using unichar = unsigned short; + +class String : public Copying +{ +public: + static String* string(); + static String* string(const String* pString); + static String* string(const char* pString, StringEncoding encoding); + + static String* alloc(); + String* init(); + String* init(const String* pString); + String* init(const char* pString, StringEncoding encoding); + String* init(void* pBytes, UInteger len, StringEncoding encoding, bool freeBuffer); + + unichar character(UInteger index) const; + UInteger length() const; + + const char* cString(StringEncoding encoding) const; + const char* utf8String() const; + UInteger maximumLengthOfBytes(StringEncoding encoding) const; + UInteger lengthOfBytes(StringEncoding encoding) const; + + bool isEqualToString(const String* pString) const; + Range rangeOfString(const String* pString, StringCompareOptions options) const; + + const char* fileSystemRepresentation() const; + + String* stringByAppendingString(const String* pString) const; + ComparisonResult caseInsensitiveCompare(const String* pString) const; +}; + +/// Create an NS::String* from a string literal. +#define MTLSTR(literal) (NS::String*)__builtin___CFStringMakeConstantString("" literal "") + +template +[[deprecated("please use MTLSTR(str)")]] constexpr const String* MakeConstantString(const char (&str)[_StringLen]) +{ + return reinterpret_cast(__CFStringMakeConstantString(str)); +} + +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::String::string() +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSString), _NS_PRIVATE_SEL(string)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::String::string(const String* pString) +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSString), _NS_PRIVATE_SEL(stringWithString_), pString); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::String::string(const char* pString, StringEncoding encoding) +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSString), _NS_PRIVATE_SEL(stringWithCString_encoding_), pString, encoding); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::String::alloc() +{ + return Object::alloc(_NS_PRIVATE_CLS(NSString)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::String::init() +{ + return Object::init(); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::String::init(const String* pString) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(initWithString_), pString); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::String::init(const char* pString, StringEncoding encoding) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(initWithCString_encoding_), pString, encoding); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::String::init(void* pBytes, UInteger len, StringEncoding encoding, bool freeBuffer) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(initWithBytesNoCopy_length_encoding_freeWhenDone_), pBytes, len, encoding, freeBuffer); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::unichar NS::String::character(UInteger index) const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(characterAtIndex_), index); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::UInteger NS::String::length() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(length)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE const char* NS::String::cString(StringEncoding encoding) const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(cStringUsingEncoding_), encoding); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE const char* NS::String::utf8String() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(UTF8String)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::UInteger NS::String::maximumLengthOfBytes(StringEncoding encoding) const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(maximumLengthOfBytesUsingEncoding_), encoding); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::UInteger NS::String::lengthOfBytes(StringEncoding encoding) const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(lengthOfBytesUsingEncoding_), encoding); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE bool NS::String::isEqualToString(const NS::String* pString) const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(isEqualToString_), pString); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Range NS::String::rangeOfString(const NS::String* pString, NS::StringCompareOptions options) const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(rangeOfString_options_), pString, options); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE const char* NS::String::fileSystemRepresentation() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(fileSystemRepresentation)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::String::stringByAppendingString(const String* pString) const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(stringByAppendingString_), pString); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::ComparisonResult NS::String::caseInsensitiveCompare(const String* pString) const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(caseInsensitiveCompare_), pString); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/Source/Cxxmlx/metal-cpp/Foundation/NSTypes.hpp b/Source/Cxxmlx/metal-cpp/Foundation/NSTypes.hpp new file mode 100644 index 00000000..e6b723e5 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Foundation/NSTypes.hpp @@ -0,0 +1,51 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Foundation/NSTypes.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "NSDefines.hpp" + +#include +#include + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace NS +{ +using TimeInterval = double; + +using Integer = std::intptr_t; +using UInteger = std::uintptr_t; + +const Integer IntegerMax = INTPTR_MAX; +const Integer IntegerMin = INTPTR_MIN; +const UInteger UIntegerMax = UINTPTR_MAX; + +struct OperatingSystemVersion +{ + Integer majorVersion; + Integer minorVersion; + Integer patchVersion; +} _NS_PACKED; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/Source/Cxxmlx/metal-cpp/Foundation/NSURL.hpp b/Source/Cxxmlx/metal-cpp/Foundation/NSURL.hpp new file mode 100644 index 00000000..d90e5d70 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Foundation/NSURL.hpp @@ -0,0 +1,90 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Foundation/NSURL.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "NSDefines.hpp" +#include "NSObject.hpp" +#include "NSPrivate.hpp" +#include "NSTypes.hpp" + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace NS +{ +class URL : public Copying +{ +public: + static URL* fileURLWithPath(const class String* pPath); + + static URL* alloc(); + URL* init(); + URL* init(const class String* pString); + URL* initFileURLWithPath(const class String* pPath); + + const char* fileSystemRepresentation() const; +}; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::URL* NS::URL::fileURLWithPath(const String* pPath) +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSURL), _NS_PRIVATE_SEL(fileURLWithPath_), pPath); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::URL* NS::URL::alloc() +{ + return Object::alloc(_NS_PRIVATE_CLS(NSURL)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::URL* NS::URL::init() +{ + return Object::init(); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::URL* NS::URL::init(const String* pString) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(initWithString_), pString); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::URL* NS::URL::initFileURLWithPath(const String* pPath) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(initFileURLWithPath_), pPath); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE const char* NS::URL::fileSystemRepresentation() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(fileSystemRepresentation)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/Source/Cmlx/metal-cpp/LICENSE.txt b/Source/Cxxmlx/metal-cpp/LICENSE.txt similarity index 100% rename from Source/Cmlx/metal-cpp/LICENSE.txt rename to Source/Cxxmlx/metal-cpp/LICENSE.txt diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTL4AccelerationStructure.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTL4AccelerationStructure.hpp new file mode 100644 index 00000000..11540150 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTL4AccelerationStructure.hpp @@ -0,0 +1,1395 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTL4AccelerationStructure.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLAccelerationStructure.hpp" +#include "MTLAccelerationStructureTypes.hpp" +#include "MTLArgument.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLStageInputOutputDescriptor.hpp" + +namespace MTL4 +{ +class AccelerationStructureBoundingBoxGeometryDescriptor; +class AccelerationStructureCurveGeometryDescriptor; +class AccelerationStructureDescriptor; +class AccelerationStructureGeometryDescriptor; +class AccelerationStructureMotionBoundingBoxGeometryDescriptor; +class AccelerationStructureMotionCurveGeometryDescriptor; +class AccelerationStructureMotionTriangleGeometryDescriptor; +class AccelerationStructureTriangleGeometryDescriptor; +class IndirectInstanceAccelerationStructureDescriptor; +class InstanceAccelerationStructureDescriptor; +class PrimitiveAccelerationStructureDescriptor; + +class AccelerationStructureDescriptor : public NS::Copying +{ +public: + static AccelerationStructureDescriptor* alloc(); + + AccelerationStructureDescriptor* init(); +}; +class AccelerationStructureGeometryDescriptor : public NS::Copying +{ +public: + static AccelerationStructureGeometryDescriptor* alloc(); + + bool allowDuplicateIntersectionFunctionInvocation() const; + + AccelerationStructureGeometryDescriptor* init(); + + NS::UInteger intersectionFunctionTableOffset() const; + + NS::String* label() const; + + bool opaque() const; + + BufferRange primitiveDataBuffer() const; + + NS::UInteger primitiveDataElementSize() const; + + NS::UInteger primitiveDataStride() const; + + void setAllowDuplicateIntersectionFunctionInvocation(bool allowDuplicateIntersectionFunctionInvocation); + + void setIntersectionFunctionTableOffset(NS::UInteger intersectionFunctionTableOffset); + + void setLabel(const NS::String* label); + + void setOpaque(bool opaque); + + void setPrimitiveDataBuffer(const MTL4::BufferRange primitiveDataBuffer); + + void setPrimitiveDataElementSize(NS::UInteger primitiveDataElementSize); + + void setPrimitiveDataStride(NS::UInteger primitiveDataStride); +}; +class PrimitiveAccelerationStructureDescriptor : public NS::Copying +{ +public: + static PrimitiveAccelerationStructureDescriptor* alloc(); + + NS::Array* geometryDescriptors() const; + + PrimitiveAccelerationStructureDescriptor* init(); + + MTL::MotionBorderMode motionEndBorderMode() const; + + float motionEndTime() const; + + NS::UInteger motionKeyframeCount() const; + + MTL::MotionBorderMode motionStartBorderMode() const; + + float motionStartTime() const; + + void setGeometryDescriptors(const NS::Array* geometryDescriptors); + + void setMotionEndBorderMode(MTL::MotionBorderMode motionEndBorderMode); + + void setMotionEndTime(float motionEndTime); + + void setMotionKeyframeCount(NS::UInteger motionKeyframeCount); + + void setMotionStartBorderMode(MTL::MotionBorderMode motionStartBorderMode); + + void setMotionStartTime(float motionStartTime); +}; +class AccelerationStructureTriangleGeometryDescriptor : public NS::Copying +{ +public: + static AccelerationStructureTriangleGeometryDescriptor* alloc(); + + BufferRange indexBuffer() const; + + MTL::IndexType indexType() const; + + AccelerationStructureTriangleGeometryDescriptor* init(); + + void setIndexBuffer(const MTL4::BufferRange indexBuffer); + + void setIndexType(MTL::IndexType indexType); + + void setTransformationMatrixBuffer(const MTL4::BufferRange transformationMatrixBuffer); + + void setTransformationMatrixLayout(MTL::MatrixLayout transformationMatrixLayout); + + void setTriangleCount(NS::UInteger triangleCount); + + void setVertexBuffer(const MTL4::BufferRange vertexBuffer); + + void setVertexFormat(MTL::AttributeFormat vertexFormat); + + void setVertexStride(NS::UInteger vertexStride); + + BufferRange transformationMatrixBuffer() const; + + MTL::MatrixLayout transformationMatrixLayout() const; + + NS::UInteger triangleCount() const; + + BufferRange vertexBuffer() const; + + MTL::AttributeFormat vertexFormat() const; + + NS::UInteger vertexStride() const; +}; +class AccelerationStructureBoundingBoxGeometryDescriptor : public NS::Copying +{ +public: + static AccelerationStructureBoundingBoxGeometryDescriptor* alloc(); + + BufferRange boundingBoxBuffer() const; + + NS::UInteger boundingBoxCount() const; + + NS::UInteger boundingBoxStride() const; + + AccelerationStructureBoundingBoxGeometryDescriptor* init(); + + void setBoundingBoxBuffer(const MTL4::BufferRange boundingBoxBuffer); + + void setBoundingBoxCount(NS::UInteger boundingBoxCount); + + void setBoundingBoxStride(NS::UInteger boundingBoxStride); +}; +class AccelerationStructureMotionTriangleGeometryDescriptor : public NS::Copying +{ +public: + static AccelerationStructureMotionTriangleGeometryDescriptor* alloc(); + + BufferRange indexBuffer() const; + + MTL::IndexType indexType() const; + + AccelerationStructureMotionTriangleGeometryDescriptor* init(); + + void setIndexBuffer(const MTL4::BufferRange indexBuffer); + + void setIndexType(MTL::IndexType indexType); + + void setTransformationMatrixBuffer(const MTL4::BufferRange transformationMatrixBuffer); + + void setTransformationMatrixLayout(MTL::MatrixLayout transformationMatrixLayout); + + void setTriangleCount(NS::UInteger triangleCount); + + void setVertexBuffers(const MTL4::BufferRange vertexBuffers); + + void setVertexFormat(MTL::AttributeFormat vertexFormat); + + void setVertexStride(NS::UInteger vertexStride); + + BufferRange transformationMatrixBuffer() const; + + MTL::MatrixLayout transformationMatrixLayout() const; + + NS::UInteger triangleCount() const; + + BufferRange vertexBuffers() const; + + MTL::AttributeFormat vertexFormat() const; + + NS::UInteger vertexStride() const; +}; +class AccelerationStructureMotionBoundingBoxGeometryDescriptor : public NS::Copying +{ +public: + static AccelerationStructureMotionBoundingBoxGeometryDescriptor* alloc(); + + BufferRange boundingBoxBuffers() const; + + NS::UInteger boundingBoxCount() const; + + NS::UInteger boundingBoxStride() const; + + AccelerationStructureMotionBoundingBoxGeometryDescriptor* init(); + + void setBoundingBoxBuffers(const MTL4::BufferRange boundingBoxBuffers); + + void setBoundingBoxCount(NS::UInteger boundingBoxCount); + + void setBoundingBoxStride(NS::UInteger boundingBoxStride); +}; +class AccelerationStructureCurveGeometryDescriptor : public NS::Copying +{ +public: + static AccelerationStructureCurveGeometryDescriptor* alloc(); + + BufferRange controlPointBuffer() const; + + NS::UInteger controlPointCount() const; + + MTL::AttributeFormat controlPointFormat() const; + + NS::UInteger controlPointStride() const; + + MTL::CurveBasis curveBasis() const; + + MTL::CurveEndCaps curveEndCaps() const; + + MTL::CurveType curveType() const; + + BufferRange indexBuffer() const; + + MTL::IndexType indexType() const; + + AccelerationStructureCurveGeometryDescriptor* init(); + + BufferRange radiusBuffer() const; + + MTL::AttributeFormat radiusFormat() const; + + NS::UInteger radiusStride() const; + + NS::UInteger segmentControlPointCount() const; + + NS::UInteger segmentCount() const; + + void setControlPointBuffer(const MTL4::BufferRange controlPointBuffer); + + void setControlPointCount(NS::UInteger controlPointCount); + + void setControlPointFormat(MTL::AttributeFormat controlPointFormat); + + void setControlPointStride(NS::UInteger controlPointStride); + + void setCurveBasis(MTL::CurveBasis curveBasis); + + void setCurveEndCaps(MTL::CurveEndCaps curveEndCaps); + + void setCurveType(MTL::CurveType curveType); + + void setIndexBuffer(const MTL4::BufferRange indexBuffer); + + void setIndexType(MTL::IndexType indexType); + + void setRadiusBuffer(const MTL4::BufferRange radiusBuffer); + + void setRadiusFormat(MTL::AttributeFormat radiusFormat); + + void setRadiusStride(NS::UInteger radiusStride); + + void setSegmentControlPointCount(NS::UInteger segmentControlPointCount); + + void setSegmentCount(NS::UInteger segmentCount); +}; +class AccelerationStructureMotionCurveGeometryDescriptor : public NS::Copying +{ +public: + static AccelerationStructureMotionCurveGeometryDescriptor* alloc(); + + BufferRange controlPointBuffers() const; + + NS::UInteger controlPointCount() const; + + MTL::AttributeFormat controlPointFormat() const; + + NS::UInteger controlPointStride() const; + + MTL::CurveBasis curveBasis() const; + + MTL::CurveEndCaps curveEndCaps() const; + + MTL::CurveType curveType() const; + + BufferRange indexBuffer() const; + + MTL::IndexType indexType() const; + + AccelerationStructureMotionCurveGeometryDescriptor* init(); + + BufferRange radiusBuffers() const; + + MTL::AttributeFormat radiusFormat() const; + + NS::UInteger radiusStride() const; + + NS::UInteger segmentControlPointCount() const; + + NS::UInteger segmentCount() const; + + void setControlPointBuffers(const MTL4::BufferRange controlPointBuffers); + + void setControlPointCount(NS::UInteger controlPointCount); + + void setControlPointFormat(MTL::AttributeFormat controlPointFormat); + + void setControlPointStride(NS::UInteger controlPointStride); + + void setCurveBasis(MTL::CurveBasis curveBasis); + + void setCurveEndCaps(MTL::CurveEndCaps curveEndCaps); + + void setCurveType(MTL::CurveType curveType); + + void setIndexBuffer(const MTL4::BufferRange indexBuffer); + + void setIndexType(MTL::IndexType indexType); + + void setRadiusBuffers(const MTL4::BufferRange radiusBuffers); + + void setRadiusFormat(MTL::AttributeFormat radiusFormat); + + void setRadiusStride(NS::UInteger radiusStride); + + void setSegmentControlPointCount(NS::UInteger segmentControlPointCount); + + void setSegmentCount(NS::UInteger segmentCount); +}; +class InstanceAccelerationStructureDescriptor : public NS::Copying +{ +public: + static InstanceAccelerationStructureDescriptor* alloc(); + + InstanceAccelerationStructureDescriptor* init(); + + NS::UInteger instanceCount() const; + + BufferRange instanceDescriptorBuffer() const; + + NS::UInteger instanceDescriptorStride() const; + + MTL::AccelerationStructureInstanceDescriptorType instanceDescriptorType() const; + + MTL::MatrixLayout instanceTransformationMatrixLayout() const; + + BufferRange motionTransformBuffer() const; + + NS::UInteger motionTransformCount() const; + + NS::UInteger motionTransformStride() const; + + MTL::TransformType motionTransformType() const; + + void setInstanceCount(NS::UInteger instanceCount); + + void setInstanceDescriptorBuffer(const MTL4::BufferRange instanceDescriptorBuffer); + + void setInstanceDescriptorStride(NS::UInteger instanceDescriptorStride); + + void setInstanceDescriptorType(MTL::AccelerationStructureInstanceDescriptorType instanceDescriptorType); + + void setInstanceTransformationMatrixLayout(MTL::MatrixLayout instanceTransformationMatrixLayout); + + void setMotionTransformBuffer(const MTL4::BufferRange motionTransformBuffer); + + void setMotionTransformCount(NS::UInteger motionTransformCount); + + void setMotionTransformStride(NS::UInteger motionTransformStride); + + void setMotionTransformType(MTL::TransformType motionTransformType); +}; +class IndirectInstanceAccelerationStructureDescriptor : public NS::Copying +{ +public: + static IndirectInstanceAccelerationStructureDescriptor* alloc(); + + IndirectInstanceAccelerationStructureDescriptor* init(); + + BufferRange instanceCountBuffer() const; + + BufferRange instanceDescriptorBuffer() const; + + NS::UInteger instanceDescriptorStride() const; + + MTL::AccelerationStructureInstanceDescriptorType instanceDescriptorType() const; + + MTL::MatrixLayout instanceTransformationMatrixLayout() const; + + NS::UInteger maxInstanceCount() const; + + NS::UInteger maxMotionTransformCount() const; + + BufferRange motionTransformBuffer() const; + + BufferRange motionTransformCountBuffer() const; + + NS::UInteger motionTransformStride() const; + + MTL::TransformType motionTransformType() const; + + void setInstanceCountBuffer(const MTL4::BufferRange instanceCountBuffer); + + void setInstanceDescriptorBuffer(const MTL4::BufferRange instanceDescriptorBuffer); + + void setInstanceDescriptorStride(NS::UInteger instanceDescriptorStride); + + void setInstanceDescriptorType(MTL::AccelerationStructureInstanceDescriptorType instanceDescriptorType); + + void setInstanceTransformationMatrixLayout(MTL::MatrixLayout instanceTransformationMatrixLayout); + + void setMaxInstanceCount(NS::UInteger maxInstanceCount); + + void setMaxMotionTransformCount(NS::UInteger maxMotionTransformCount); + + void setMotionTransformBuffer(const MTL4::BufferRange motionTransformBuffer); + + void setMotionTransformCountBuffer(const MTL4::BufferRange motionTransformCountBuffer); + + void setMotionTransformStride(NS::UInteger motionTransformStride); + + void setMotionTransformType(MTL::TransformType motionTransformType); +}; + +} +_MTL_INLINE MTL4::AccelerationStructureDescriptor* MTL4::AccelerationStructureDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4AccelerationStructureDescriptor)); +} + +_MTL_INLINE MTL4::AccelerationStructureDescriptor* MTL4::AccelerationStructureDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL4::AccelerationStructureGeometryDescriptor* MTL4::AccelerationStructureGeometryDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4AccelerationStructureGeometryDescriptor)); +} + +_MTL_INLINE bool MTL4::AccelerationStructureGeometryDescriptor::allowDuplicateIntersectionFunctionInvocation() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(allowDuplicateIntersectionFunctionInvocation)); +} + +_MTL_INLINE MTL4::AccelerationStructureGeometryDescriptor* MTL4::AccelerationStructureGeometryDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::UInteger MTL4::AccelerationStructureGeometryDescriptor::intersectionFunctionTableOffset() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(intersectionFunctionTableOffset)); +} + +_MTL_INLINE NS::String* MTL4::AccelerationStructureGeometryDescriptor::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE bool MTL4::AccelerationStructureGeometryDescriptor::opaque() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(opaque)); +} + +_MTL_INLINE MTL4::BufferRange MTL4::AccelerationStructureGeometryDescriptor::primitiveDataBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(primitiveDataBuffer)); +} + +_MTL_INLINE NS::UInteger MTL4::AccelerationStructureGeometryDescriptor::primitiveDataElementSize() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(primitiveDataElementSize)); +} + +_MTL_INLINE NS::UInteger MTL4::AccelerationStructureGeometryDescriptor::primitiveDataStride() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(primitiveDataStride)); +} + +_MTL_INLINE void MTL4::AccelerationStructureGeometryDescriptor::setAllowDuplicateIntersectionFunctionInvocation(bool allowDuplicateIntersectionFunctionInvocation) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setAllowDuplicateIntersectionFunctionInvocation_), allowDuplicateIntersectionFunctionInvocation); +} + +_MTL_INLINE void MTL4::AccelerationStructureGeometryDescriptor::setIntersectionFunctionTableOffset(NS::UInteger intersectionFunctionTableOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setIntersectionFunctionTableOffset_), intersectionFunctionTableOffset); +} + +_MTL_INLINE void MTL4::AccelerationStructureGeometryDescriptor::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +_MTL_INLINE void MTL4::AccelerationStructureGeometryDescriptor::setOpaque(bool opaque) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setOpaque_), opaque); +} + +_MTL_INLINE void MTL4::AccelerationStructureGeometryDescriptor::setPrimitiveDataBuffer(const MTL4::BufferRange primitiveDataBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setPrimitiveDataBuffer_), primitiveDataBuffer); +} + +_MTL_INLINE void MTL4::AccelerationStructureGeometryDescriptor::setPrimitiveDataElementSize(NS::UInteger primitiveDataElementSize) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setPrimitiveDataElementSize_), primitiveDataElementSize); +} + +_MTL_INLINE void MTL4::AccelerationStructureGeometryDescriptor::setPrimitiveDataStride(NS::UInteger primitiveDataStride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setPrimitiveDataStride_), primitiveDataStride); +} + +_MTL_INLINE MTL4::PrimitiveAccelerationStructureDescriptor* MTL4::PrimitiveAccelerationStructureDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4PrimitiveAccelerationStructureDescriptor)); +} + +_MTL_INLINE NS::Array* MTL4::PrimitiveAccelerationStructureDescriptor::geometryDescriptors() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(geometryDescriptors)); +} + +_MTL_INLINE MTL4::PrimitiveAccelerationStructureDescriptor* MTL4::PrimitiveAccelerationStructureDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::MotionBorderMode MTL4::PrimitiveAccelerationStructureDescriptor::motionEndBorderMode() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(motionEndBorderMode)); +} + +_MTL_INLINE float MTL4::PrimitiveAccelerationStructureDescriptor::motionEndTime() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(motionEndTime)); +} + +_MTL_INLINE NS::UInteger MTL4::PrimitiveAccelerationStructureDescriptor::motionKeyframeCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(motionKeyframeCount)); +} + +_MTL_INLINE MTL::MotionBorderMode MTL4::PrimitiveAccelerationStructureDescriptor::motionStartBorderMode() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(motionStartBorderMode)); +} + +_MTL_INLINE float MTL4::PrimitiveAccelerationStructureDescriptor::motionStartTime() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(motionStartTime)); +} + +_MTL_INLINE void MTL4::PrimitiveAccelerationStructureDescriptor::setGeometryDescriptors(const NS::Array* geometryDescriptors) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setGeometryDescriptors_), geometryDescriptors); +} + +_MTL_INLINE void MTL4::PrimitiveAccelerationStructureDescriptor::setMotionEndBorderMode(MTL::MotionBorderMode motionEndBorderMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionEndBorderMode_), motionEndBorderMode); +} + +_MTL_INLINE void MTL4::PrimitiveAccelerationStructureDescriptor::setMotionEndTime(float motionEndTime) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionEndTime_), motionEndTime); +} + +_MTL_INLINE void MTL4::PrimitiveAccelerationStructureDescriptor::setMotionKeyframeCount(NS::UInteger motionKeyframeCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionKeyframeCount_), motionKeyframeCount); +} + +_MTL_INLINE void MTL4::PrimitiveAccelerationStructureDescriptor::setMotionStartBorderMode(MTL::MotionBorderMode motionStartBorderMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionStartBorderMode_), motionStartBorderMode); +} + +_MTL_INLINE void MTL4::PrimitiveAccelerationStructureDescriptor::setMotionStartTime(float motionStartTime) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionStartTime_), motionStartTime); +} + +_MTL_INLINE MTL4::AccelerationStructureTriangleGeometryDescriptor* MTL4::AccelerationStructureTriangleGeometryDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4AccelerationStructureTriangleGeometryDescriptor)); +} + +_MTL_INLINE MTL4::BufferRange MTL4::AccelerationStructureTriangleGeometryDescriptor::indexBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(indexBuffer)); +} + +_MTL_INLINE MTL::IndexType MTL4::AccelerationStructureTriangleGeometryDescriptor::indexType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(indexType)); +} + +_MTL_INLINE MTL4::AccelerationStructureTriangleGeometryDescriptor* MTL4::AccelerationStructureTriangleGeometryDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE void MTL4::AccelerationStructureTriangleGeometryDescriptor::setIndexBuffer(const MTL4::BufferRange indexBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setIndexBuffer_), indexBuffer); +} + +_MTL_INLINE void MTL4::AccelerationStructureTriangleGeometryDescriptor::setIndexType(MTL::IndexType indexType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setIndexType_), indexType); +} + +_MTL_INLINE void MTL4::AccelerationStructureTriangleGeometryDescriptor::setTransformationMatrixBuffer(const MTL4::BufferRange transformationMatrixBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTransformationMatrixBuffer_), transformationMatrixBuffer); +} + +_MTL_INLINE void MTL4::AccelerationStructureTriangleGeometryDescriptor::setTransformationMatrixLayout(MTL::MatrixLayout transformationMatrixLayout) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTransformationMatrixLayout_), transformationMatrixLayout); +} + +_MTL_INLINE void MTL4::AccelerationStructureTriangleGeometryDescriptor::setTriangleCount(NS::UInteger triangleCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTriangleCount_), triangleCount); +} + +_MTL_INLINE void MTL4::AccelerationStructureTriangleGeometryDescriptor::setVertexBuffer(const MTL4::BufferRange vertexBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexBuffer_), vertexBuffer); +} + +_MTL_INLINE void MTL4::AccelerationStructureTriangleGeometryDescriptor::setVertexFormat(MTL::AttributeFormat vertexFormat) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexFormat_), vertexFormat); +} + +_MTL_INLINE void MTL4::AccelerationStructureTriangleGeometryDescriptor::setVertexStride(NS::UInteger vertexStride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexStride_), vertexStride); +} + +_MTL_INLINE MTL4::BufferRange MTL4::AccelerationStructureTriangleGeometryDescriptor::transformationMatrixBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(transformationMatrixBuffer)); +} + +_MTL_INLINE MTL::MatrixLayout MTL4::AccelerationStructureTriangleGeometryDescriptor::transformationMatrixLayout() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(transformationMatrixLayout)); +} + +_MTL_INLINE NS::UInteger MTL4::AccelerationStructureTriangleGeometryDescriptor::triangleCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(triangleCount)); +} + +_MTL_INLINE MTL4::BufferRange MTL4::AccelerationStructureTriangleGeometryDescriptor::vertexBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(vertexBuffer)); +} + +_MTL_INLINE MTL::AttributeFormat MTL4::AccelerationStructureTriangleGeometryDescriptor::vertexFormat() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(vertexFormat)); +} + +_MTL_INLINE NS::UInteger MTL4::AccelerationStructureTriangleGeometryDescriptor::vertexStride() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(vertexStride)); +} + +_MTL_INLINE MTL4::AccelerationStructureBoundingBoxGeometryDescriptor* MTL4::AccelerationStructureBoundingBoxGeometryDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4AccelerationStructureBoundingBoxGeometryDescriptor)); +} + +_MTL_INLINE MTL4::BufferRange MTL4::AccelerationStructureBoundingBoxGeometryDescriptor::boundingBoxBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(boundingBoxBuffer)); +} + +_MTL_INLINE NS::UInteger MTL4::AccelerationStructureBoundingBoxGeometryDescriptor::boundingBoxCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(boundingBoxCount)); +} + +_MTL_INLINE NS::UInteger MTL4::AccelerationStructureBoundingBoxGeometryDescriptor::boundingBoxStride() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(boundingBoxStride)); +} + +_MTL_INLINE MTL4::AccelerationStructureBoundingBoxGeometryDescriptor* MTL4::AccelerationStructureBoundingBoxGeometryDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE void MTL4::AccelerationStructureBoundingBoxGeometryDescriptor::setBoundingBoxBuffer(const MTL4::BufferRange boundingBoxBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBoundingBoxBuffer_), boundingBoxBuffer); +} + +_MTL_INLINE void MTL4::AccelerationStructureBoundingBoxGeometryDescriptor::setBoundingBoxCount(NS::UInteger boundingBoxCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBoundingBoxCount_), boundingBoxCount); +} + +_MTL_INLINE void MTL4::AccelerationStructureBoundingBoxGeometryDescriptor::setBoundingBoxStride(NS::UInteger boundingBoxStride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBoundingBoxStride_), boundingBoxStride); +} + +_MTL_INLINE MTL4::AccelerationStructureMotionTriangleGeometryDescriptor* MTL4::AccelerationStructureMotionTriangleGeometryDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4AccelerationStructureMotionTriangleGeometryDescriptor)); +} + +_MTL_INLINE MTL4::BufferRange MTL4::AccelerationStructureMotionTriangleGeometryDescriptor::indexBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(indexBuffer)); +} + +_MTL_INLINE MTL::IndexType MTL4::AccelerationStructureMotionTriangleGeometryDescriptor::indexType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(indexType)); +} + +_MTL_INLINE MTL4::AccelerationStructureMotionTriangleGeometryDescriptor* MTL4::AccelerationStructureMotionTriangleGeometryDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE void MTL4::AccelerationStructureMotionTriangleGeometryDescriptor::setIndexBuffer(const MTL4::BufferRange indexBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setIndexBuffer_), indexBuffer); +} + +_MTL_INLINE void MTL4::AccelerationStructureMotionTriangleGeometryDescriptor::setIndexType(MTL::IndexType indexType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setIndexType_), indexType); +} + +_MTL_INLINE void MTL4::AccelerationStructureMotionTriangleGeometryDescriptor::setTransformationMatrixBuffer(const MTL4::BufferRange transformationMatrixBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTransformationMatrixBuffer_), transformationMatrixBuffer); +} + +_MTL_INLINE void MTL4::AccelerationStructureMotionTriangleGeometryDescriptor::setTransformationMatrixLayout(MTL::MatrixLayout transformationMatrixLayout) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTransformationMatrixLayout_), transformationMatrixLayout); +} + +_MTL_INLINE void MTL4::AccelerationStructureMotionTriangleGeometryDescriptor::setTriangleCount(NS::UInteger triangleCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTriangleCount_), triangleCount); +} + +_MTL_INLINE void MTL4::AccelerationStructureMotionTriangleGeometryDescriptor::setVertexBuffers(const MTL4::BufferRange vertexBuffers) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexBuffers_), vertexBuffers); +} + +_MTL_INLINE void MTL4::AccelerationStructureMotionTriangleGeometryDescriptor::setVertexFormat(MTL::AttributeFormat vertexFormat) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexFormat_), vertexFormat); +} + +_MTL_INLINE void MTL4::AccelerationStructureMotionTriangleGeometryDescriptor::setVertexStride(NS::UInteger vertexStride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexStride_), vertexStride); +} + +_MTL_INLINE MTL4::BufferRange MTL4::AccelerationStructureMotionTriangleGeometryDescriptor::transformationMatrixBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(transformationMatrixBuffer)); +} + +_MTL_INLINE MTL::MatrixLayout MTL4::AccelerationStructureMotionTriangleGeometryDescriptor::transformationMatrixLayout() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(transformationMatrixLayout)); +} + +_MTL_INLINE NS::UInteger MTL4::AccelerationStructureMotionTriangleGeometryDescriptor::triangleCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(triangleCount)); +} + +_MTL_INLINE MTL4::BufferRange MTL4::AccelerationStructureMotionTriangleGeometryDescriptor::vertexBuffers() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(vertexBuffers)); +} + +_MTL_INLINE MTL::AttributeFormat MTL4::AccelerationStructureMotionTriangleGeometryDescriptor::vertexFormat() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(vertexFormat)); +} + +_MTL_INLINE NS::UInteger MTL4::AccelerationStructureMotionTriangleGeometryDescriptor::vertexStride() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(vertexStride)); +} + +_MTL_INLINE MTL4::AccelerationStructureMotionBoundingBoxGeometryDescriptor* MTL4::AccelerationStructureMotionBoundingBoxGeometryDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4AccelerationStructureMotionBoundingBoxGeometryDescriptor)); +} + +_MTL_INLINE MTL4::BufferRange MTL4::AccelerationStructureMotionBoundingBoxGeometryDescriptor::boundingBoxBuffers() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(boundingBoxBuffers)); +} + +_MTL_INLINE NS::UInteger MTL4::AccelerationStructureMotionBoundingBoxGeometryDescriptor::boundingBoxCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(boundingBoxCount)); +} + +_MTL_INLINE NS::UInteger MTL4::AccelerationStructureMotionBoundingBoxGeometryDescriptor::boundingBoxStride() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(boundingBoxStride)); +} + +_MTL_INLINE MTL4::AccelerationStructureMotionBoundingBoxGeometryDescriptor* MTL4::AccelerationStructureMotionBoundingBoxGeometryDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE void MTL4::AccelerationStructureMotionBoundingBoxGeometryDescriptor::setBoundingBoxBuffers(const MTL4::BufferRange boundingBoxBuffers) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBoundingBoxBuffers_), boundingBoxBuffers); +} + +_MTL_INLINE void MTL4::AccelerationStructureMotionBoundingBoxGeometryDescriptor::setBoundingBoxCount(NS::UInteger boundingBoxCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBoundingBoxCount_), boundingBoxCount); +} + +_MTL_INLINE void MTL4::AccelerationStructureMotionBoundingBoxGeometryDescriptor::setBoundingBoxStride(NS::UInteger boundingBoxStride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBoundingBoxStride_), boundingBoxStride); +} + +_MTL_INLINE MTL4::AccelerationStructureCurveGeometryDescriptor* MTL4::AccelerationStructureCurveGeometryDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4AccelerationStructureCurveGeometryDescriptor)); +} + +_MTL_INLINE MTL4::BufferRange MTL4::AccelerationStructureCurveGeometryDescriptor::controlPointBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(controlPointBuffer)); +} + +_MTL_INLINE NS::UInteger MTL4::AccelerationStructureCurveGeometryDescriptor::controlPointCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(controlPointCount)); +} + +_MTL_INLINE MTL::AttributeFormat MTL4::AccelerationStructureCurveGeometryDescriptor::controlPointFormat() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(controlPointFormat)); +} + +_MTL_INLINE NS::UInteger MTL4::AccelerationStructureCurveGeometryDescriptor::controlPointStride() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(controlPointStride)); +} + +_MTL_INLINE MTL::CurveBasis MTL4::AccelerationStructureCurveGeometryDescriptor::curveBasis() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(curveBasis)); +} + +_MTL_INLINE MTL::CurveEndCaps MTL4::AccelerationStructureCurveGeometryDescriptor::curveEndCaps() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(curveEndCaps)); +} + +_MTL_INLINE MTL::CurveType MTL4::AccelerationStructureCurveGeometryDescriptor::curveType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(curveType)); +} + +_MTL_INLINE MTL4::BufferRange MTL4::AccelerationStructureCurveGeometryDescriptor::indexBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(indexBuffer)); +} + +_MTL_INLINE MTL::IndexType MTL4::AccelerationStructureCurveGeometryDescriptor::indexType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(indexType)); +} + +_MTL_INLINE MTL4::AccelerationStructureCurveGeometryDescriptor* MTL4::AccelerationStructureCurveGeometryDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL4::BufferRange MTL4::AccelerationStructureCurveGeometryDescriptor::radiusBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(radiusBuffer)); +} + +_MTL_INLINE MTL::AttributeFormat MTL4::AccelerationStructureCurveGeometryDescriptor::radiusFormat() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(radiusFormat)); +} + +_MTL_INLINE NS::UInteger MTL4::AccelerationStructureCurveGeometryDescriptor::radiusStride() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(radiusStride)); +} + +_MTL_INLINE NS::UInteger MTL4::AccelerationStructureCurveGeometryDescriptor::segmentControlPointCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(segmentControlPointCount)); +} + +_MTL_INLINE NS::UInteger MTL4::AccelerationStructureCurveGeometryDescriptor::segmentCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(segmentCount)); +} + +_MTL_INLINE void MTL4::AccelerationStructureCurveGeometryDescriptor::setControlPointBuffer(const MTL4::BufferRange controlPointBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setControlPointBuffer_), controlPointBuffer); +} + +_MTL_INLINE void MTL4::AccelerationStructureCurveGeometryDescriptor::setControlPointCount(NS::UInteger controlPointCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setControlPointCount_), controlPointCount); +} + +_MTL_INLINE void MTL4::AccelerationStructureCurveGeometryDescriptor::setControlPointFormat(MTL::AttributeFormat controlPointFormat) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setControlPointFormat_), controlPointFormat); +} + +_MTL_INLINE void MTL4::AccelerationStructureCurveGeometryDescriptor::setControlPointStride(NS::UInteger controlPointStride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setControlPointStride_), controlPointStride); +} + +_MTL_INLINE void MTL4::AccelerationStructureCurveGeometryDescriptor::setCurveBasis(MTL::CurveBasis curveBasis) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setCurveBasis_), curveBasis); +} + +_MTL_INLINE void MTL4::AccelerationStructureCurveGeometryDescriptor::setCurveEndCaps(MTL::CurveEndCaps curveEndCaps) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setCurveEndCaps_), curveEndCaps); +} + +_MTL_INLINE void MTL4::AccelerationStructureCurveGeometryDescriptor::setCurveType(MTL::CurveType curveType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setCurveType_), curveType); +} + +_MTL_INLINE void MTL4::AccelerationStructureCurveGeometryDescriptor::setIndexBuffer(const MTL4::BufferRange indexBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setIndexBuffer_), indexBuffer); +} + +_MTL_INLINE void MTL4::AccelerationStructureCurveGeometryDescriptor::setIndexType(MTL::IndexType indexType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setIndexType_), indexType); +} + +_MTL_INLINE void MTL4::AccelerationStructureCurveGeometryDescriptor::setRadiusBuffer(const MTL4::BufferRange radiusBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRadiusBuffer_), radiusBuffer); +} + +_MTL_INLINE void MTL4::AccelerationStructureCurveGeometryDescriptor::setRadiusFormat(MTL::AttributeFormat radiusFormat) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRadiusFormat_), radiusFormat); +} + +_MTL_INLINE void MTL4::AccelerationStructureCurveGeometryDescriptor::setRadiusStride(NS::UInteger radiusStride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRadiusStride_), radiusStride); +} + +_MTL_INLINE void MTL4::AccelerationStructureCurveGeometryDescriptor::setSegmentControlPointCount(NS::UInteger segmentControlPointCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSegmentControlPointCount_), segmentControlPointCount); +} + +_MTL_INLINE void MTL4::AccelerationStructureCurveGeometryDescriptor::setSegmentCount(NS::UInteger segmentCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSegmentCount_), segmentCount); +} + +_MTL_INLINE MTL4::AccelerationStructureMotionCurveGeometryDescriptor* MTL4::AccelerationStructureMotionCurveGeometryDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4AccelerationStructureMotionCurveGeometryDescriptor)); +} + +_MTL_INLINE MTL4::BufferRange MTL4::AccelerationStructureMotionCurveGeometryDescriptor::controlPointBuffers() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(controlPointBuffers)); +} + +_MTL_INLINE NS::UInteger MTL4::AccelerationStructureMotionCurveGeometryDescriptor::controlPointCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(controlPointCount)); +} + +_MTL_INLINE MTL::AttributeFormat MTL4::AccelerationStructureMotionCurveGeometryDescriptor::controlPointFormat() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(controlPointFormat)); +} + +_MTL_INLINE NS::UInteger MTL4::AccelerationStructureMotionCurveGeometryDescriptor::controlPointStride() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(controlPointStride)); +} + +_MTL_INLINE MTL::CurveBasis MTL4::AccelerationStructureMotionCurveGeometryDescriptor::curveBasis() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(curveBasis)); +} + +_MTL_INLINE MTL::CurveEndCaps MTL4::AccelerationStructureMotionCurveGeometryDescriptor::curveEndCaps() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(curveEndCaps)); +} + +_MTL_INLINE MTL::CurveType MTL4::AccelerationStructureMotionCurveGeometryDescriptor::curveType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(curveType)); +} + +_MTL_INLINE MTL4::BufferRange MTL4::AccelerationStructureMotionCurveGeometryDescriptor::indexBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(indexBuffer)); +} + +_MTL_INLINE MTL::IndexType MTL4::AccelerationStructureMotionCurveGeometryDescriptor::indexType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(indexType)); +} + +_MTL_INLINE MTL4::AccelerationStructureMotionCurveGeometryDescriptor* MTL4::AccelerationStructureMotionCurveGeometryDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL4::BufferRange MTL4::AccelerationStructureMotionCurveGeometryDescriptor::radiusBuffers() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(radiusBuffers)); +} + +_MTL_INLINE MTL::AttributeFormat MTL4::AccelerationStructureMotionCurveGeometryDescriptor::radiusFormat() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(radiusFormat)); +} + +_MTL_INLINE NS::UInteger MTL4::AccelerationStructureMotionCurveGeometryDescriptor::radiusStride() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(radiusStride)); +} + +_MTL_INLINE NS::UInteger MTL4::AccelerationStructureMotionCurveGeometryDescriptor::segmentControlPointCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(segmentControlPointCount)); +} + +_MTL_INLINE NS::UInteger MTL4::AccelerationStructureMotionCurveGeometryDescriptor::segmentCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(segmentCount)); +} + +_MTL_INLINE void MTL4::AccelerationStructureMotionCurveGeometryDescriptor::setControlPointBuffers(const MTL4::BufferRange controlPointBuffers) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setControlPointBuffers_), controlPointBuffers); +} + +_MTL_INLINE void MTL4::AccelerationStructureMotionCurveGeometryDescriptor::setControlPointCount(NS::UInteger controlPointCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setControlPointCount_), controlPointCount); +} + +_MTL_INLINE void MTL4::AccelerationStructureMotionCurveGeometryDescriptor::setControlPointFormat(MTL::AttributeFormat controlPointFormat) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setControlPointFormat_), controlPointFormat); +} + +_MTL_INLINE void MTL4::AccelerationStructureMotionCurveGeometryDescriptor::setControlPointStride(NS::UInteger controlPointStride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setControlPointStride_), controlPointStride); +} + +_MTL_INLINE void MTL4::AccelerationStructureMotionCurveGeometryDescriptor::setCurveBasis(MTL::CurveBasis curveBasis) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setCurveBasis_), curveBasis); +} + +_MTL_INLINE void MTL4::AccelerationStructureMotionCurveGeometryDescriptor::setCurveEndCaps(MTL::CurveEndCaps curveEndCaps) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setCurveEndCaps_), curveEndCaps); +} + +_MTL_INLINE void MTL4::AccelerationStructureMotionCurveGeometryDescriptor::setCurveType(MTL::CurveType curveType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setCurveType_), curveType); +} + +_MTL_INLINE void MTL4::AccelerationStructureMotionCurveGeometryDescriptor::setIndexBuffer(const MTL4::BufferRange indexBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setIndexBuffer_), indexBuffer); +} + +_MTL_INLINE void MTL4::AccelerationStructureMotionCurveGeometryDescriptor::setIndexType(MTL::IndexType indexType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setIndexType_), indexType); +} + +_MTL_INLINE void MTL4::AccelerationStructureMotionCurveGeometryDescriptor::setRadiusBuffers(const MTL4::BufferRange radiusBuffers) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRadiusBuffers_), radiusBuffers); +} + +_MTL_INLINE void MTL4::AccelerationStructureMotionCurveGeometryDescriptor::setRadiusFormat(MTL::AttributeFormat radiusFormat) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRadiusFormat_), radiusFormat); +} + +_MTL_INLINE void MTL4::AccelerationStructureMotionCurveGeometryDescriptor::setRadiusStride(NS::UInteger radiusStride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRadiusStride_), radiusStride); +} + +_MTL_INLINE void MTL4::AccelerationStructureMotionCurveGeometryDescriptor::setSegmentControlPointCount(NS::UInteger segmentControlPointCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSegmentControlPointCount_), segmentControlPointCount); +} + +_MTL_INLINE void MTL4::AccelerationStructureMotionCurveGeometryDescriptor::setSegmentCount(NS::UInteger segmentCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSegmentCount_), segmentCount); +} + +_MTL_INLINE MTL4::InstanceAccelerationStructureDescriptor* MTL4::InstanceAccelerationStructureDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4InstanceAccelerationStructureDescriptor)); +} + +_MTL_INLINE MTL4::InstanceAccelerationStructureDescriptor* MTL4::InstanceAccelerationStructureDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::UInteger MTL4::InstanceAccelerationStructureDescriptor::instanceCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(instanceCount)); +} + +_MTL_INLINE MTL4::BufferRange MTL4::InstanceAccelerationStructureDescriptor::instanceDescriptorBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(instanceDescriptorBuffer)); +} + +_MTL_INLINE NS::UInteger MTL4::InstanceAccelerationStructureDescriptor::instanceDescriptorStride() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(instanceDescriptorStride)); +} + +_MTL_INLINE MTL::AccelerationStructureInstanceDescriptorType MTL4::InstanceAccelerationStructureDescriptor::instanceDescriptorType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(instanceDescriptorType)); +} + +_MTL_INLINE MTL::MatrixLayout MTL4::InstanceAccelerationStructureDescriptor::instanceTransformationMatrixLayout() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(instanceTransformationMatrixLayout)); +} + +_MTL_INLINE MTL4::BufferRange MTL4::InstanceAccelerationStructureDescriptor::motionTransformBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(motionTransformBuffer)); +} + +_MTL_INLINE NS::UInteger MTL4::InstanceAccelerationStructureDescriptor::motionTransformCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(motionTransformCount)); +} + +_MTL_INLINE NS::UInteger MTL4::InstanceAccelerationStructureDescriptor::motionTransformStride() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(motionTransformStride)); +} + +_MTL_INLINE MTL::TransformType MTL4::InstanceAccelerationStructureDescriptor::motionTransformType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(motionTransformType)); +} + +_MTL_INLINE void MTL4::InstanceAccelerationStructureDescriptor::setInstanceCount(NS::UInteger instanceCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInstanceCount_), instanceCount); +} + +_MTL_INLINE void MTL4::InstanceAccelerationStructureDescriptor::setInstanceDescriptorBuffer(const MTL4::BufferRange instanceDescriptorBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInstanceDescriptorBuffer_), instanceDescriptorBuffer); +} + +_MTL_INLINE void MTL4::InstanceAccelerationStructureDescriptor::setInstanceDescriptorStride(NS::UInteger instanceDescriptorStride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInstanceDescriptorStride_), instanceDescriptorStride); +} + +_MTL_INLINE void MTL4::InstanceAccelerationStructureDescriptor::setInstanceDescriptorType(MTL::AccelerationStructureInstanceDescriptorType instanceDescriptorType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInstanceDescriptorType_), instanceDescriptorType); +} + +_MTL_INLINE void MTL4::InstanceAccelerationStructureDescriptor::setInstanceTransformationMatrixLayout(MTL::MatrixLayout instanceTransformationMatrixLayout) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInstanceTransformationMatrixLayout_), instanceTransformationMatrixLayout); +} + +_MTL_INLINE void MTL4::InstanceAccelerationStructureDescriptor::setMotionTransformBuffer(const MTL4::BufferRange motionTransformBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionTransformBuffer_), motionTransformBuffer); +} + +_MTL_INLINE void MTL4::InstanceAccelerationStructureDescriptor::setMotionTransformCount(NS::UInteger motionTransformCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionTransformCount_), motionTransformCount); +} + +_MTL_INLINE void MTL4::InstanceAccelerationStructureDescriptor::setMotionTransformStride(NS::UInteger motionTransformStride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionTransformStride_), motionTransformStride); +} + +_MTL_INLINE void MTL4::InstanceAccelerationStructureDescriptor::setMotionTransformType(MTL::TransformType motionTransformType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionTransformType_), motionTransformType); +} + +_MTL_INLINE MTL4::IndirectInstanceAccelerationStructureDescriptor* MTL4::IndirectInstanceAccelerationStructureDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4IndirectInstanceAccelerationStructureDescriptor)); +} + +_MTL_INLINE MTL4::IndirectInstanceAccelerationStructureDescriptor* MTL4::IndirectInstanceAccelerationStructureDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL4::BufferRange MTL4::IndirectInstanceAccelerationStructureDescriptor::instanceCountBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(instanceCountBuffer)); +} + +_MTL_INLINE MTL4::BufferRange MTL4::IndirectInstanceAccelerationStructureDescriptor::instanceDescriptorBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(instanceDescriptorBuffer)); +} + +_MTL_INLINE NS::UInteger MTL4::IndirectInstanceAccelerationStructureDescriptor::instanceDescriptorStride() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(instanceDescriptorStride)); +} + +_MTL_INLINE MTL::AccelerationStructureInstanceDescriptorType MTL4::IndirectInstanceAccelerationStructureDescriptor::instanceDescriptorType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(instanceDescriptorType)); +} + +_MTL_INLINE MTL::MatrixLayout MTL4::IndirectInstanceAccelerationStructureDescriptor::instanceTransformationMatrixLayout() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(instanceTransformationMatrixLayout)); +} + +_MTL_INLINE NS::UInteger MTL4::IndirectInstanceAccelerationStructureDescriptor::maxInstanceCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxInstanceCount)); +} + +_MTL_INLINE NS::UInteger MTL4::IndirectInstanceAccelerationStructureDescriptor::maxMotionTransformCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxMotionTransformCount)); +} + +_MTL_INLINE MTL4::BufferRange MTL4::IndirectInstanceAccelerationStructureDescriptor::motionTransformBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(motionTransformBuffer)); +} + +_MTL_INLINE MTL4::BufferRange MTL4::IndirectInstanceAccelerationStructureDescriptor::motionTransformCountBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(motionTransformCountBuffer)); +} + +_MTL_INLINE NS::UInteger MTL4::IndirectInstanceAccelerationStructureDescriptor::motionTransformStride() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(motionTransformStride)); +} + +_MTL_INLINE MTL::TransformType MTL4::IndirectInstanceAccelerationStructureDescriptor::motionTransformType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(motionTransformType)); +} + +_MTL_INLINE void MTL4::IndirectInstanceAccelerationStructureDescriptor::setInstanceCountBuffer(const MTL4::BufferRange instanceCountBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInstanceCountBuffer_), instanceCountBuffer); +} + +_MTL_INLINE void MTL4::IndirectInstanceAccelerationStructureDescriptor::setInstanceDescriptorBuffer(const MTL4::BufferRange instanceDescriptorBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInstanceDescriptorBuffer_), instanceDescriptorBuffer); +} + +_MTL_INLINE void MTL4::IndirectInstanceAccelerationStructureDescriptor::setInstanceDescriptorStride(NS::UInteger instanceDescriptorStride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInstanceDescriptorStride_), instanceDescriptorStride); +} + +_MTL_INLINE void MTL4::IndirectInstanceAccelerationStructureDescriptor::setInstanceDescriptorType(MTL::AccelerationStructureInstanceDescriptorType instanceDescriptorType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInstanceDescriptorType_), instanceDescriptorType); +} + +_MTL_INLINE void MTL4::IndirectInstanceAccelerationStructureDescriptor::setInstanceTransformationMatrixLayout(MTL::MatrixLayout instanceTransformationMatrixLayout) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInstanceTransformationMatrixLayout_), instanceTransformationMatrixLayout); +} + +_MTL_INLINE void MTL4::IndirectInstanceAccelerationStructureDescriptor::setMaxInstanceCount(NS::UInteger maxInstanceCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxInstanceCount_), maxInstanceCount); +} + +_MTL_INLINE void MTL4::IndirectInstanceAccelerationStructureDescriptor::setMaxMotionTransformCount(NS::UInteger maxMotionTransformCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxMotionTransformCount_), maxMotionTransformCount); +} + +_MTL_INLINE void MTL4::IndirectInstanceAccelerationStructureDescriptor::setMotionTransformBuffer(const MTL4::BufferRange motionTransformBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionTransformBuffer_), motionTransformBuffer); +} + +_MTL_INLINE void MTL4::IndirectInstanceAccelerationStructureDescriptor::setMotionTransformCountBuffer(const MTL4::BufferRange motionTransformCountBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionTransformCountBuffer_), motionTransformCountBuffer); +} + +_MTL_INLINE void MTL4::IndirectInstanceAccelerationStructureDescriptor::setMotionTransformStride(NS::UInteger motionTransformStride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionTransformStride_), motionTransformStride); +} + +_MTL_INLINE void MTL4::IndirectInstanceAccelerationStructureDescriptor::setMotionTransformType(MTL::TransformType motionTransformType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionTransformType_), motionTransformType); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTL4Archive.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTL4Archive.hpp new file mode 100644 index 00000000..c83ef638 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTL4Archive.hpp @@ -0,0 +1,93 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTL4Archive.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL +{ +class ComputePipelineState; +class RenderPipelineState; +} + +namespace MTL4 +{ +class BinaryFunction; +class BinaryFunctionDescriptor; +class ComputePipelineDescriptor; +class PipelineDescriptor; +class PipelineStageDynamicLinkingDescriptor; +class RenderPipelineDynamicLinkingDescriptor; + +class Archive : public NS::Referencing +{ +public: + NS::String* label() const; + + BinaryFunction* newBinaryFunction(const MTL4::BinaryFunctionDescriptor* descriptor, NS::Error** error); + + MTL::ComputePipelineState* newComputePipelineState(const MTL4::ComputePipelineDescriptor* descriptor, NS::Error** error); + MTL::ComputePipelineState* newComputePipelineState(const MTL4::ComputePipelineDescriptor* descriptor, const MTL4::PipelineStageDynamicLinkingDescriptor* dynamicLinkingDescriptor, NS::Error** error); + + MTL::RenderPipelineState* newRenderPipelineState(const MTL4::PipelineDescriptor* descriptor, NS::Error** error); + MTL::RenderPipelineState* newRenderPipelineState(const MTL4::PipelineDescriptor* descriptor, const MTL4::RenderPipelineDynamicLinkingDescriptor* dynamicLinkingDescriptor, NS::Error** error); + + void setLabel(const NS::String* label); +}; + +} +_MTL_INLINE NS::String* MTL4::Archive::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE MTL4::BinaryFunction* MTL4::Archive::newBinaryFunction(const MTL4::BinaryFunctionDescriptor* descriptor, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newBinaryFunctionWithDescriptor_error_), descriptor, error); +} + +_MTL_INLINE MTL::ComputePipelineState* MTL4::Archive::newComputePipelineState(const MTL4::ComputePipelineDescriptor* descriptor, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newComputePipelineStateWithDescriptor_error_), descriptor, error); +} + +_MTL_INLINE MTL::ComputePipelineState* MTL4::Archive::newComputePipelineState(const MTL4::ComputePipelineDescriptor* descriptor, const MTL4::PipelineStageDynamicLinkingDescriptor* dynamicLinkingDescriptor, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newComputePipelineStateWithDescriptor_dynamicLinkingDescriptor_error_), descriptor, dynamicLinkingDescriptor, error); +} + +_MTL_INLINE MTL::RenderPipelineState* MTL4::Archive::newRenderPipelineState(const MTL4::PipelineDescriptor* descriptor, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newRenderPipelineStateWithDescriptor_error_), descriptor, error); +} + +_MTL_INLINE MTL::RenderPipelineState* MTL4::Archive::newRenderPipelineState(const MTL4::PipelineDescriptor* descriptor, const MTL4::RenderPipelineDynamicLinkingDescriptor* dynamicLinkingDescriptor, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newRenderPipelineStateWithDescriptor_dynamicLinkingDescriptor_error_), descriptor, dynamicLinkingDescriptor, error); +} + +_MTL_INLINE void MTL4::Archive::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTL4ArgumentTable.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTL4ArgumentTable.hpp new file mode 100644 index 00000000..7788ed94 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTL4ArgumentTable.hpp @@ -0,0 +1,187 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTL4ArgumentTable.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLGPUAddress.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLTypes.hpp" + +namespace MTL +{ +class Device; +} + +namespace MTL4 +{ +class ArgumentTableDescriptor : public NS::Copying +{ +public: + static ArgumentTableDescriptor* alloc(); + + ArgumentTableDescriptor* init(); + bool initializeBindings() const; + + NS::String* label() const; + + NS::UInteger maxBufferBindCount() const; + + NS::UInteger maxSamplerStateBindCount() const; + + NS::UInteger maxTextureBindCount() const; + + void setInitializeBindings(bool initializeBindings); + + void setLabel(const NS::String* label); + + void setMaxBufferBindCount(NS::UInteger maxBufferBindCount); + + void setMaxSamplerStateBindCount(NS::UInteger maxSamplerStateBindCount); + + void setMaxTextureBindCount(NS::UInteger maxTextureBindCount); + + void setSupportAttributeStrides(bool supportAttributeStrides); + bool supportAttributeStrides() const; +}; +class ArgumentTable : public NS::Referencing +{ +public: + MTL::Device* device() const; + + NS::String* label() const; + + void setAddress(MTL::GPUAddress gpuAddress, NS::UInteger bindingIndex); + void setAddress(MTL::GPUAddress gpuAddress, NS::UInteger stride, NS::UInteger bindingIndex); + + void setResource(MTL::ResourceID resourceID, NS::UInteger bindingIndex); + + void setSamplerState(MTL::ResourceID resourceID, NS::UInteger bindingIndex); + + void setTexture(MTL::ResourceID resourceID, NS::UInteger bindingIndex); +}; + +} +_MTL_INLINE MTL4::ArgumentTableDescriptor* MTL4::ArgumentTableDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4ArgumentTableDescriptor)); +} + +_MTL_INLINE MTL4::ArgumentTableDescriptor* MTL4::ArgumentTableDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE bool MTL4::ArgumentTableDescriptor::initializeBindings() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(initializeBindings)); +} + +_MTL_INLINE NS::String* MTL4::ArgumentTableDescriptor::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE NS::UInteger MTL4::ArgumentTableDescriptor::maxBufferBindCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxBufferBindCount)); +} + +_MTL_INLINE NS::UInteger MTL4::ArgumentTableDescriptor::maxSamplerStateBindCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxSamplerStateBindCount)); +} + +_MTL_INLINE NS::UInteger MTL4::ArgumentTableDescriptor::maxTextureBindCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxTextureBindCount)); +} + +_MTL_INLINE void MTL4::ArgumentTableDescriptor::setInitializeBindings(bool initializeBindings) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInitializeBindings_), initializeBindings); +} + +_MTL_INLINE void MTL4::ArgumentTableDescriptor::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +_MTL_INLINE void MTL4::ArgumentTableDescriptor::setMaxBufferBindCount(NS::UInteger maxBufferBindCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxBufferBindCount_), maxBufferBindCount); +} + +_MTL_INLINE void MTL4::ArgumentTableDescriptor::setMaxSamplerStateBindCount(NS::UInteger maxSamplerStateBindCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxSamplerStateBindCount_), maxSamplerStateBindCount); +} + +_MTL_INLINE void MTL4::ArgumentTableDescriptor::setMaxTextureBindCount(NS::UInteger maxTextureBindCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxTextureBindCount_), maxTextureBindCount); +} + +_MTL_INLINE void MTL4::ArgumentTableDescriptor::setSupportAttributeStrides(bool supportAttributeStrides) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSupportAttributeStrides_), supportAttributeStrides); +} + +_MTL_INLINE bool MTL4::ArgumentTableDescriptor::supportAttributeStrides() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportAttributeStrides)); +} + +_MTL_INLINE MTL::Device* MTL4::ArgumentTable::device() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(device)); +} + +_MTL_INLINE NS::String* MTL4::ArgumentTable::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE void MTL4::ArgumentTable::setAddress(MTL::GPUAddress gpuAddress, NS::UInteger bindingIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setAddress_atIndex_), gpuAddress, bindingIndex); +} + +_MTL_INLINE void MTL4::ArgumentTable::setAddress(MTL::GPUAddress gpuAddress, NS::UInteger stride, NS::UInteger bindingIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setAddress_attributeStride_atIndex_), gpuAddress, stride, bindingIndex); +} + +_MTL_INLINE void MTL4::ArgumentTable::setResource(MTL::ResourceID resourceID, NS::UInteger bindingIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setResource_atBufferIndex_), resourceID, bindingIndex); +} + +_MTL_INLINE void MTL4::ArgumentTable::setSamplerState(MTL::ResourceID resourceID, NS::UInteger bindingIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSamplerState_atIndex_), resourceID, bindingIndex); +} + +_MTL_INLINE void MTL4::ArgumentTable::setTexture(MTL::ResourceID resourceID, NS::UInteger bindingIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTexture_atIndex_), resourceID, bindingIndex); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTL4BinaryFunction.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTL4BinaryFunction.hpp new file mode 100644 index 00000000..30d90a6b --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTL4BinaryFunction.hpp @@ -0,0 +1,50 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTL4BinaryFunction.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLLibrary.hpp" +#include "MTLPrivate.hpp" + +namespace MTL4 +{ + +class BinaryFunction : public NS::Referencing +{ +public: + MTL::FunctionType functionType() const; + + NS::String* name() const; +}; + +} + +_MTL_INLINE MTL::FunctionType MTL4::BinaryFunction::functionType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(functionType)); +} + +_MTL_INLINE NS::String* MTL4::BinaryFunction::name() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(name)); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTL4BinaryFunctionDescriptor.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTL4BinaryFunctionDescriptor.hpp new file mode 100644 index 00000000..ce173ce0 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTL4BinaryFunctionDescriptor.hpp @@ -0,0 +1,97 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTL4BinaryFunctionDescriptor.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL4 +{ +class BinaryFunctionDescriptor; +class FunctionDescriptor; + +_MTL_OPTIONS(NS::UInteger, BinaryFunctionOptions) { + BinaryFunctionOptionNone = 0, + BinaryFunctionOptionPipelineIndependent = 1 << 1, +}; + +class BinaryFunctionDescriptor : public NS::Copying +{ +public: + static BinaryFunctionDescriptor* alloc(); + + FunctionDescriptor* functionDescriptor() const; + + BinaryFunctionDescriptor* init(); + + NS::String* name() const; + + BinaryFunctionOptions options() const; + + void setFunctionDescriptor(const MTL4::FunctionDescriptor* functionDescriptor); + + void setName(const NS::String* name); + + void setOptions(MTL4::BinaryFunctionOptions options); +}; + +} +_MTL_INLINE MTL4::BinaryFunctionDescriptor* MTL4::BinaryFunctionDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4BinaryFunctionDescriptor)); +} + +_MTL_INLINE MTL4::FunctionDescriptor* MTL4::BinaryFunctionDescriptor::functionDescriptor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(functionDescriptor)); +} + +_MTL_INLINE MTL4::BinaryFunctionDescriptor* MTL4::BinaryFunctionDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::String* MTL4::BinaryFunctionDescriptor::name() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(name)); +} + +_MTL_INLINE MTL4::BinaryFunctionOptions MTL4::BinaryFunctionDescriptor::options() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(options)); +} + +_MTL_INLINE void MTL4::BinaryFunctionDescriptor::setFunctionDescriptor(const MTL4::FunctionDescriptor* functionDescriptor) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFunctionDescriptor_), functionDescriptor); +} + +_MTL_INLINE void MTL4::BinaryFunctionDescriptor::setName(const NS::String* name) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setName_), name); +} + +_MTL_INLINE void MTL4::BinaryFunctionDescriptor::setOptions(MTL4::BinaryFunctionOptions options) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setOptions_), options); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTL4CommandAllocator.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTL4CommandAllocator.hpp new file mode 100644 index 00000000..a36b0508 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTL4CommandAllocator.hpp @@ -0,0 +1,100 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTL4CommandAllocator.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include + +namespace MTL +{ +class Device; +} + +namespace MTL4 +{ + +class CommandAllocatorDescriptor : public NS::Copying +{ +public: + static CommandAllocatorDescriptor* alloc(); + + CommandAllocatorDescriptor* init(); + + NS::String* label() const; + void setLabel(const NS::String* label); +}; + +class CommandAllocator : public NS::Referencing +{ +public: + uint64_t allocatedSize(); + + MTL::Device* device() const; + + NS::String* label() const; + + void reset(); +}; + +} + +_MTL_INLINE MTL4::CommandAllocatorDescriptor* MTL4::CommandAllocatorDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4CommandAllocatorDescriptor)); +} + +_MTL_INLINE MTL4::CommandAllocatorDescriptor* MTL4::CommandAllocatorDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::String* MTL4::CommandAllocatorDescriptor::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE void MTL4::CommandAllocatorDescriptor::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +_MTL_INLINE uint64_t MTL4::CommandAllocator::allocatedSize() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(allocatedSize)); +} + +_MTL_INLINE MTL::Device* MTL4::CommandAllocator::device() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(device)); +} + +_MTL_INLINE NS::String* MTL4::CommandAllocator::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE void MTL4::CommandAllocator::reset() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(reset)); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTL4CommandBuffer.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTL4CommandBuffer.hpp new file mode 100644 index 00000000..a69cc941 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTL4CommandBuffer.hpp @@ -0,0 +1,193 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTL4CommandBuffer.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTL4RenderCommandEncoder.hpp" +#include "MTLAccelerationStructureTypes.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL4 +{ +class CommandAllocator; +class CommandBufferOptions; +class ComputeCommandEncoder; +class CounterHeap; +class MachineLearningCommandEncoder; +class RenderCommandEncoder; +class RenderPassDescriptor; +} + +namespace MTL +{ +class Device; +class Fence; +class LogState; +class ResidencySet; +} + +namespace MTL4 +{ +class CommandBufferOptions : public NS::Copying +{ +public: + static CommandBufferOptions* alloc(); + + CommandBufferOptions* init(); + + MTL::LogState* logState() const; + void setLogState(const MTL::LogState* logState); +}; +class CommandBuffer : public NS::Referencing +{ +public: + void beginCommandBuffer(const MTL4::CommandAllocator* allocator); + void beginCommandBuffer(const MTL4::CommandAllocator* allocator, const MTL4::CommandBufferOptions* options); + + ComputeCommandEncoder* computeCommandEncoder(); + + MTL::Device* device() const; + + void endCommandBuffer(); + + NS::String* label() const; + + MachineLearningCommandEncoder* machineLearningCommandEncoder(); + + void popDebugGroup(); + + void pushDebugGroup(const NS::String* string); + + RenderCommandEncoder* renderCommandEncoder(const MTL4::RenderPassDescriptor* descriptor); + RenderCommandEncoder* renderCommandEncoder(const MTL4::RenderPassDescriptor* descriptor, MTL4::RenderEncoderOptions options); + + void resolveCounterHeap(const MTL4::CounterHeap* counterHeap, NS::Range range, const MTL4::BufferRange bufferRange, const MTL::Fence* fenceToWait, const MTL::Fence* fenceToUpdate); + + void setLabel(const NS::String* label); + + void useResidencySet(const MTL::ResidencySet* residencySet); + void useResidencySets(const MTL::ResidencySet* const residencySets[], NS::UInteger count); + + void writeTimestampIntoHeap(const MTL4::CounterHeap* counterHeap, NS::UInteger index); +}; + +} +_MTL_INLINE MTL4::CommandBufferOptions* MTL4::CommandBufferOptions::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4CommandBufferOptions)); +} + +_MTL_INLINE MTL4::CommandBufferOptions* MTL4::CommandBufferOptions::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::LogState* MTL4::CommandBufferOptions::logState() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(logState)); +} + +_MTL_INLINE void MTL4::CommandBufferOptions::setLogState(const MTL::LogState* logState) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLogState_), logState); +} + +_MTL_INLINE void MTL4::CommandBuffer::beginCommandBuffer(const MTL4::CommandAllocator* allocator) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(beginCommandBufferWithAllocator_), allocator); +} + +_MTL_INLINE void MTL4::CommandBuffer::beginCommandBuffer(const MTL4::CommandAllocator* allocator, const MTL4::CommandBufferOptions* options) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(beginCommandBufferWithAllocator_options_), allocator, options); +} + +_MTL_INLINE MTL4::ComputeCommandEncoder* MTL4::CommandBuffer::computeCommandEncoder() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(computeCommandEncoder)); +} + +_MTL_INLINE MTL::Device* MTL4::CommandBuffer::device() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(device)); +} + +_MTL_INLINE void MTL4::CommandBuffer::endCommandBuffer() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(endCommandBuffer)); +} + +_MTL_INLINE NS::String* MTL4::CommandBuffer::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE MTL4::MachineLearningCommandEncoder* MTL4::CommandBuffer::machineLearningCommandEncoder() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(machineLearningCommandEncoder)); +} + +_MTL_INLINE void MTL4::CommandBuffer::popDebugGroup() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(popDebugGroup)); +} + +_MTL_INLINE void MTL4::CommandBuffer::pushDebugGroup(const NS::String* string) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(pushDebugGroup_), string); +} + +_MTL_INLINE MTL4::RenderCommandEncoder* MTL4::CommandBuffer::renderCommandEncoder(const MTL4::RenderPassDescriptor* descriptor) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(renderCommandEncoderWithDescriptor_), descriptor); +} + +_MTL_INLINE MTL4::RenderCommandEncoder* MTL4::CommandBuffer::renderCommandEncoder(const MTL4::RenderPassDescriptor* descriptor, MTL4::RenderEncoderOptions options) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(renderCommandEncoderWithDescriptor_options_), descriptor, options); +} + +_MTL_INLINE void MTL4::CommandBuffer::resolveCounterHeap(const MTL4::CounterHeap* counterHeap, NS::Range range, const MTL4::BufferRange bufferRange, const MTL::Fence* fenceToWait, const MTL::Fence* fenceToUpdate) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(resolveCounterHeap_withRange_intoBuffer_waitFence_updateFence_), counterHeap, range, bufferRange, fenceToWait, fenceToUpdate); +} + +_MTL_INLINE void MTL4::CommandBuffer::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +_MTL_INLINE void MTL4::CommandBuffer::useResidencySet(const MTL::ResidencySet* residencySet) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(useResidencySet_), residencySet); +} + +_MTL_INLINE void MTL4::CommandBuffer::useResidencySets(const MTL::ResidencySet* const residencySets[], NS::UInteger count) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(useResidencySets_count_), residencySets, count); +} + +_MTL_INLINE void MTL4::CommandBuffer::writeTimestampIntoHeap(const MTL4::CounterHeap* counterHeap, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(writeTimestampIntoHeap_atIndex_), counterHeap, index); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTL4CommandEncoder.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTL4CommandEncoder.hpp new file mode 100644 index 00000000..2336021e --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTL4CommandEncoder.hpp @@ -0,0 +1,134 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTL4CommandEncoder.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLCommandEncoder.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL4 +{ +class CommandBuffer; +} + +namespace MTL +{ +class Fence; +} + +namespace MTL4 +{ +_MTL_OPTIONS(NS::UInteger, VisibilityOptions) { + VisibilityOptionNone = 0, + VisibilityOptionDevice = 1, + VisibilityOptionResourceAlias = 1 << 1, +}; + +class CommandEncoder : public NS::Referencing +{ +public: + void barrierAfterEncoderStages(MTL::Stages afterEncoderStages, MTL::Stages beforeEncoderStages, MTL4::VisibilityOptions visibilityOptions); + + void barrierAfterQueueStages(MTL::Stages afterQueueStages, MTL::Stages beforeStages, MTL4::VisibilityOptions visibilityOptions); + + void barrierAfterStages(MTL::Stages afterStages, MTL::Stages beforeQueueStages, MTL4::VisibilityOptions visibilityOptions); + + CommandBuffer* commandBuffer() const; + + void endEncoding(); + + void insertDebugSignpost(const NS::String* string); + + NS::String* label() const; + + void popDebugGroup(); + + void pushDebugGroup(const NS::String* string); + + void setLabel(const NS::String* label); + + void updateFence(const MTL::Fence* fence, MTL::Stages afterEncoderStages); + + void waitForFence(const MTL::Fence* fence, MTL::Stages beforeEncoderStages); +}; + +} +_MTL_INLINE void MTL4::CommandEncoder::barrierAfterEncoderStages(MTL::Stages afterEncoderStages, MTL::Stages beforeEncoderStages, MTL4::VisibilityOptions visibilityOptions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(barrierAfterEncoderStages_beforeEncoderStages_visibilityOptions_), afterEncoderStages, beforeEncoderStages, visibilityOptions); +} + +_MTL_INLINE void MTL4::CommandEncoder::barrierAfterQueueStages(MTL::Stages afterQueueStages, MTL::Stages beforeStages, MTL4::VisibilityOptions visibilityOptions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(barrierAfterQueueStages_beforeStages_visibilityOptions_), afterQueueStages, beforeStages, visibilityOptions); +} + +_MTL_INLINE void MTL4::CommandEncoder::barrierAfterStages(MTL::Stages afterStages, MTL::Stages beforeQueueStages, MTL4::VisibilityOptions visibilityOptions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(barrierAfterStages_beforeQueueStages_visibilityOptions_), afterStages, beforeQueueStages, visibilityOptions); +} + +_MTL_INLINE MTL4::CommandBuffer* MTL4::CommandEncoder::commandBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(commandBuffer)); +} + +_MTL_INLINE void MTL4::CommandEncoder::endEncoding() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(endEncoding)); +} + +_MTL_INLINE void MTL4::CommandEncoder::insertDebugSignpost(const NS::String* string) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(insertDebugSignpost_), string); +} + +_MTL_INLINE NS::String* MTL4::CommandEncoder::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE void MTL4::CommandEncoder::popDebugGroup() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(popDebugGroup)); +} + +_MTL_INLINE void MTL4::CommandEncoder::pushDebugGroup(const NS::String* string) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(pushDebugGroup_), string); +} + +_MTL_INLINE void MTL4::CommandEncoder::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +_MTL_INLINE void MTL4::CommandEncoder::updateFence(const MTL::Fence* fence, MTL::Stages afterEncoderStages) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(updateFence_afterEncoderStages_), fence, afterEncoderStages); +} + +_MTL_INLINE void MTL4::CommandEncoder::waitForFence(const MTL::Fence* fence, MTL::Stages beforeEncoderStages) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(waitForFence_beforeEncoderStages_), fence, beforeEncoderStages); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTL4CommandQueue.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTL4CommandQueue.hpp new file mode 100644 index 00000000..cbd21c7a --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTL4CommandQueue.hpp @@ -0,0 +1,283 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTL4CommandQueue.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTL4CommitFeedback.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLResourceStateCommandEncoder.hpp" +#include "MTLTypes.hpp" +#include +#include + +namespace MTL +{ +class Buffer; +class Device; +class Drawable; +class Event; +class Heap; +class ResidencySet; +class Texture; +} + +namespace MTL4 +{ +class CommandBuffer; +class CommandQueueDescriptor; +class CommitOptions; +struct CopySparseBufferMappingOperation; +struct CopySparseTextureMappingOperation; +struct UpdateSparseBufferMappingOperation; +struct UpdateSparseTextureMappingOperation; +_MTL_ENUM(NS::Integer, CommandQueueError) { + CommandQueueErrorNone = 0, + CommandQueueErrorTimeout = 1, + CommandQueueErrorNotPermitted = 2, + CommandQueueErrorOutOfMemory = 3, + CommandQueueErrorDeviceRemoved = 4, + CommandQueueErrorAccessRevoked = 5, + CommandQueueErrorInternal = 6, +}; + +struct UpdateSparseTextureMappingOperation +{ + MTL::SparseTextureMappingMode mode; + MTL::Region textureRegion; + NS::UInteger textureLevel; + NS::UInteger textureSlice; + NS::UInteger heapOffset; +} _MTL_PACKED; + +struct CopySparseTextureMappingOperation +{ + MTL::Region sourceRegion; + NS::UInteger sourceLevel; + NS::UInteger sourceSlice; + MTL::Origin destinationOrigin; + NS::UInteger destinationLevel; + NS::UInteger destinationSlice; +} _MTL_PACKED; + +struct UpdateSparseBufferMappingOperation +{ + MTL::SparseTextureMappingMode mode; + NS::Range bufferRange; + NS::UInteger heapOffset; +} _MTL_PACKED; + +struct CopySparseBufferMappingOperation +{ + NS::Range sourceRange; + NS::UInteger destinationOffset; +} _MTL_PACKED; + +class CommitOptions : public NS::Referencing +{ +public: + void addFeedbackHandler(const MTL4::CommitFeedbackHandler block); + void addFeedbackHandler(const MTL4::CommitFeedbackHandlerFunction& function); + + static CommitOptions* alloc(); + + CommitOptions* init(); +}; +class CommandQueueDescriptor : public NS::Copying +{ +public: + static CommandQueueDescriptor* alloc(); + + dispatch_queue_t feedbackQueue() const; + + CommandQueueDescriptor* init(); + + NS::String* label() const; + + void setFeedbackQueue(const dispatch_queue_t feedbackQueue); + + void setLabel(const NS::String* label); +}; +class CommandQueue : public NS::Referencing +{ +public: + void addResidencySet(const MTL::ResidencySet* residencySet); + void addResidencySets(const MTL::ResidencySet* const residencySets[], NS::UInteger count); + + void commit(const MTL4::CommandBuffer* const commandBuffers[], NS::UInteger count); + void commit(const MTL4::CommandBuffer* const commandBuffers[], NS::UInteger count, const MTL4::CommitOptions* options); + + void copyBufferMappingsFromBuffer(const MTL::Buffer* sourceBuffer, const MTL::Buffer* destinationBuffer, const MTL4::CopySparseBufferMappingOperation* operations, NS::UInteger count); + + void copyTextureMappingsFromTexture(const MTL::Texture* sourceTexture, const MTL::Texture* destinationTexture, const MTL4::CopySparseTextureMappingOperation* operations, NS::UInteger count); + + MTL::Device* device() const; + + NS::String* label() const; + + void removeResidencySet(const MTL::ResidencySet* residencySet); + void removeResidencySets(const MTL::ResidencySet* const residencySets[], NS::UInteger count); + + void signalDrawable(const MTL::Drawable* drawable); + + void signalEvent(const MTL::Event* event, uint64_t value); + + void updateBufferMappings(const MTL::Buffer* buffer, const MTL::Heap* heap, const MTL4::UpdateSparseBufferMappingOperation* operations, NS::UInteger count); + + void updateTextureMappings(const MTL::Texture* texture, const MTL::Heap* heap, const MTL4::UpdateSparseTextureMappingOperation* operations, NS::UInteger count); + + void wait(const MTL::Event* event, uint64_t value); + void wait(const MTL::Drawable* drawable); +}; + +} + +_MTL_INLINE void MTL4::CommitOptions::addFeedbackHandler(const MTL4::CommitFeedbackHandler block) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(addFeedbackHandler_), block); +} + +_MTL_INLINE void MTL4::CommitOptions::addFeedbackHandler(const MTL4::CommitFeedbackHandlerFunction& function) +{ + __block MTL4::CommitFeedbackHandlerFunction blockFunction = function; + addFeedbackHandler(^(MTL4::CommitFeedback* pFeedback) { blockFunction(pFeedback); }); +} + +_MTL_INLINE MTL4::CommitOptions* MTL4::CommitOptions::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4CommitOptions)); +} + +_MTL_INLINE MTL4::CommitOptions* MTL4::CommitOptions::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL4::CommandQueueDescriptor* MTL4::CommandQueueDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4CommandQueueDescriptor)); +} + +_MTL_INLINE dispatch_queue_t MTL4::CommandQueueDescriptor::feedbackQueue() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(feedbackQueue)); +} + +_MTL_INLINE MTL4::CommandQueueDescriptor* MTL4::CommandQueueDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::String* MTL4::CommandQueueDescriptor::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE void MTL4::CommandQueueDescriptor::setFeedbackQueue(const dispatch_queue_t feedbackQueue) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFeedbackQueue_), feedbackQueue); +} + +_MTL_INLINE void MTL4::CommandQueueDescriptor::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +_MTL_INLINE void MTL4::CommandQueue::addResidencySet(const MTL::ResidencySet* residencySet) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(addResidencySet_), residencySet); +} + +_MTL_INLINE void MTL4::CommandQueue::addResidencySets(const MTL::ResidencySet* const residencySets[], NS::UInteger count) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(addResidencySets_count_), residencySets, count); +} + +_MTL_INLINE void MTL4::CommandQueue::commit(const MTL4::CommandBuffer* const commandBuffers[], NS::UInteger count) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(commit_count_), commandBuffers, count); +} + +_MTL_INLINE void MTL4::CommandQueue::commit(const MTL4::CommandBuffer* const commandBuffers[], NS::UInteger count, const MTL4::CommitOptions* options) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(commit_count_options_), commandBuffers, count, options); +} + +_MTL_INLINE void MTL4::CommandQueue::copyBufferMappingsFromBuffer(const MTL::Buffer* sourceBuffer, const MTL::Buffer* destinationBuffer, const MTL4::CopySparseBufferMappingOperation* operations, NS::UInteger count) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(copyBufferMappingsFromBuffer_toBuffer_operations_count_), sourceBuffer, destinationBuffer, operations, count); +} + +_MTL_INLINE void MTL4::CommandQueue::copyTextureMappingsFromTexture(const MTL::Texture* sourceTexture, const MTL::Texture* destinationTexture, const MTL4::CopySparseTextureMappingOperation* operations, NS::UInteger count) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(copyTextureMappingsFromTexture_toTexture_operations_count_), sourceTexture, destinationTexture, operations, count); +} + +_MTL_INLINE MTL::Device* MTL4::CommandQueue::device() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(device)); +} + +_MTL_INLINE NS::String* MTL4::CommandQueue::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE void MTL4::CommandQueue::removeResidencySet(const MTL::ResidencySet* residencySet) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(removeResidencySet_), residencySet); +} + +_MTL_INLINE void MTL4::CommandQueue::removeResidencySets(const MTL::ResidencySet* const residencySets[], NS::UInteger count) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(removeResidencySets_count_), residencySets, count); +} + +_MTL_INLINE void MTL4::CommandQueue::signalDrawable(const MTL::Drawable* drawable) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(signalDrawable_), drawable); +} + +_MTL_INLINE void MTL4::CommandQueue::signalEvent(const MTL::Event* event, uint64_t value) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(signalEvent_value_), event, value); +} + +_MTL_INLINE void MTL4::CommandQueue::updateBufferMappings(const MTL::Buffer* buffer, const MTL::Heap* heap, const MTL4::UpdateSparseBufferMappingOperation* operations, NS::UInteger count) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(updateBufferMappings_heap_operations_count_), buffer, heap, operations, count); +} + +_MTL_INLINE void MTL4::CommandQueue::updateTextureMappings(const MTL::Texture* texture, const MTL::Heap* heap, const MTL4::UpdateSparseTextureMappingOperation* operations, NS::UInteger count) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(updateTextureMappings_heap_operations_count_), texture, heap, operations, count); +} + +_MTL_INLINE void MTL4::CommandQueue::wait(const MTL::Event* event, uint64_t value) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(waitForEvent_value_), event, value); +} + +_MTL_INLINE void MTL4::CommandQueue::wait(const MTL::Drawable* drawable) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(waitForDrawable_), drawable); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTL4CommitFeedback.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTL4CommitFeedback.hpp new file mode 100644 index 00000000..6b8181f7 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTL4CommitFeedback.hpp @@ -0,0 +1,62 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTL4CommitFeedback.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include + +#include + +namespace MTL4 +{ +class CommitFeedback; + +using CommitFeedbackHandler = void (^)(MTL4::CommitFeedback*); +using CommitFeedbackHandlerFunction = std::function; + +class CommitFeedback : public NS::Referencing +{ +public: + CFTimeInterval GPUEndTime() const; + + CFTimeInterval GPUStartTime() const; + + NS::Error* error() const; +}; + +} +_MTL_INLINE CFTimeInterval MTL4::CommitFeedback::GPUEndTime() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(GPUEndTime)); +} + +_MTL_INLINE CFTimeInterval MTL4::CommitFeedback::GPUStartTime() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(GPUStartTime)); +} + +_MTL_INLINE NS::Error* MTL4::CommitFeedback::error() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(error)); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTL4Compiler.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTL4Compiler.hpp new file mode 100644 index 00000000..94249b29 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTL4Compiler.hpp @@ -0,0 +1,345 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTL4Compiler.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLDevice.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +#include + +namespace MTL4 +{ +class BinaryFunction; +class BinaryFunctionDescriptor; +class CompilerDescriptor; +class CompilerTask; +class CompilerTaskOptions; +class ComputePipelineDescriptor; +class LibraryDescriptor; +class MachineLearningPipelineDescriptor; +class MachineLearningPipelineState; +class PipelineDataSetSerializer; +class PipelineDescriptor; +class PipelineStageDynamicLinkingDescriptor; +class RenderPipelineDynamicLinkingDescriptor; +} + +namespace MTL +{ +class ComputePipelineState; +class Device; +class DynamicLibrary; +class Library; +class RenderPipelineState; + +using NewDynamicLibraryCompletionHandler = void (^)(MTL::DynamicLibrary*, NS::Error*); +using NewDynamicLibraryCompletionHandlerFunction = std::function; +} + +namespace MTL4 +{ +using NewComputePipelineStateCompletionHandler = void (^)(MTL::ComputePipelineState*, NS::Error*); +using NewComputePipelineStateCompletionHandlerFunction = std::function; +using NewRenderPipelineStateCompletionHandler = void (^)(MTL::RenderPipelineState*, NS::Error*); +using NewRenderPipelineStateCompletionHandlerFunction = std::function; +using NewBinaryFunctionCompletionHandler = void (^)(MTL4::BinaryFunction*, NS::Error*); +using NewBinaryFunctionCompletionHandlerFunction = std::function; +using NewMachineLearningPipelineStateCompletionHandler = void (^)(MTL4::MachineLearningPipelineState*, NS::Error*); +using NewMachineLearningPipelineStateCompletionHandlerFunction = std::function; + +class CompilerDescriptor : public NS::Copying +{ +public: + static CompilerDescriptor* alloc(); + + CompilerDescriptor* init(); + + NS::String* label() const; + + PipelineDataSetSerializer* pipelineDataSetSerializer() const; + + void setLabel(const NS::String* label); + + void setPipelineDataSetSerializer(const MTL4::PipelineDataSetSerializer* pipelineDataSetSerializer); +}; +class CompilerTaskOptions : public NS::Copying +{ +public: + static CompilerTaskOptions* alloc(); + + CompilerTaskOptions* init(); + + NS::Array* lookupArchives() const; + void setLookupArchives(const NS::Array* lookupArchives); +}; +class Compiler : public NS::Referencing +{ +public: + MTL::Device* device() const; + + NS::String* label() const; + + BinaryFunction* newBinaryFunction(const MTL4::BinaryFunctionDescriptor* descriptor, const MTL4::CompilerTaskOptions* compilerTaskOptions, NS::Error** error); + CompilerTask* newBinaryFunction(const MTL4::BinaryFunctionDescriptor* descriptor, const MTL4::CompilerTaskOptions* compilerTaskOptions, const MTL4::NewBinaryFunctionCompletionHandler completionHandler); + + MTL::ComputePipelineState* newComputePipelineState(const MTL4::ComputePipelineDescriptor* descriptor, const MTL4::CompilerTaskOptions* compilerTaskOptions, NS::Error** error); + MTL::ComputePipelineState* newComputePipelineState(const MTL4::ComputePipelineDescriptor* descriptor, const MTL4::PipelineStageDynamicLinkingDescriptor* dynamicLinkingDescriptor, const MTL4::CompilerTaskOptions* compilerTaskOptions, NS::Error** error); + CompilerTask* newComputePipelineState(const MTL4::ComputePipelineDescriptor* descriptor, const MTL4::CompilerTaskOptions* compilerTaskOptions, const MTL::NewComputePipelineStateCompletionHandler completionHandler); + CompilerTask* newComputePipelineState(const MTL4::ComputePipelineDescriptor* descriptor, const MTL4::PipelineStageDynamicLinkingDescriptor* dynamicLinkingDescriptor, const MTL4::CompilerTaskOptions* compilerTaskOptions, const MTL::NewComputePipelineStateCompletionHandler completionHandler); + CompilerTask* newComputePipelineState(const MTL4::ComputePipelineDescriptor* pDescriptor, const MTL4::CompilerTaskOptions* options, const MTL4::NewComputePipelineStateCompletionHandlerFunction& function); + + MTL::DynamicLibrary* newDynamicLibrary(const MTL::Library* library, NS::Error** error); + MTL::DynamicLibrary* newDynamicLibrary(const NS::URL* url, NS::Error** error); + CompilerTask* newDynamicLibrary(const MTL::Library* library, const MTL::NewDynamicLibraryCompletionHandler completionHandler); + CompilerTask* newDynamicLibrary(const NS::URL* url, const MTL::NewDynamicLibraryCompletionHandler completionHandler); + CompilerTask* newDynamicLibrary(const MTL::Library* pLibrary, const MTL::NewDynamicLibraryCompletionHandlerFunction& function); + CompilerTask* newDynamicLibrary(const NS::URL* pURL, const MTL::NewDynamicLibraryCompletionHandlerFunction& function); + + MTL::Library* newLibrary(const MTL4::LibraryDescriptor* descriptor, NS::Error** error); + CompilerTask* newLibrary(const MTL4::LibraryDescriptor* descriptor, const MTL::NewLibraryCompletionHandler completionHandler); + CompilerTask* newLibrary(const MTL4::LibraryDescriptor* pDescriptor, const MTL::NewLibraryCompletionHandlerFunction& function); + + MachineLearningPipelineState* newMachineLearningPipelineState(const MTL4::MachineLearningPipelineDescriptor* descriptor, NS::Error** error); + CompilerTask* newMachineLearningPipelineState(const MTL4::MachineLearningPipelineDescriptor* descriptor, const MTL4::NewMachineLearningPipelineStateCompletionHandler completionHandler); + CompilerTask* newMachineLearningPipelineState(const MTL4::MachineLearningPipelineDescriptor* pDescriptor, const MTL4::NewMachineLearningPipelineStateCompletionHandlerFunction& function); + + MTL::RenderPipelineState* newRenderPipelineState(const MTL4::PipelineDescriptor* descriptor, const MTL4::CompilerTaskOptions* compilerTaskOptions, NS::Error** error); + MTL::RenderPipelineState* newRenderPipelineState(const MTL4::PipelineDescriptor* descriptor, const MTL4::RenderPipelineDynamicLinkingDescriptor* dynamicLinkingDescriptor, const MTL4::CompilerTaskOptions* compilerTaskOptions, NS::Error** error); + CompilerTask* newRenderPipelineState(const MTL4::PipelineDescriptor* descriptor, const MTL4::CompilerTaskOptions* compilerTaskOptions, const MTL::NewRenderPipelineStateCompletionHandler completionHandler); + CompilerTask* newRenderPipelineState(const MTL4::PipelineDescriptor* descriptor, const MTL4::RenderPipelineDynamicLinkingDescriptor* dynamicLinkingDescriptor, const MTL4::CompilerTaskOptions* compilerTaskOptions, const MTL::NewRenderPipelineStateCompletionHandler completionHandler); + CompilerTask* newRenderPipelineState(const MTL4::PipelineDescriptor* pDescriptor, const MTL4::CompilerTaskOptions* options, const MTL4::NewRenderPipelineStateCompletionHandlerFunction& function); + MTL::RenderPipelineState* newRenderPipelineStateBySpecialization(const MTL4::PipelineDescriptor* descriptor, const MTL::RenderPipelineState* pipeline, NS::Error** error); + CompilerTask* newRenderPipelineStateBySpecialization(const MTL4::PipelineDescriptor* descriptor, const MTL::RenderPipelineState* pipeline, const MTL::NewRenderPipelineStateCompletionHandler completionHandler); + CompilerTask* newRenderPipelineStateBySpecialization(const MTL4::PipelineDescriptor* pDescriptor, const MTL::RenderPipelineState* pPipeline, const MTL4::NewRenderPipelineStateCompletionHandlerFunction& function); + + PipelineDataSetSerializer* pipelineDataSetSerializer() const; +}; + +} +_MTL_INLINE MTL4::CompilerDescriptor* MTL4::CompilerDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4CompilerDescriptor)); +} + +_MTL_INLINE MTL4::CompilerDescriptor* MTL4::CompilerDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::String* MTL4::CompilerDescriptor::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE MTL4::PipelineDataSetSerializer* MTL4::CompilerDescriptor::pipelineDataSetSerializer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(pipelineDataSetSerializer)); +} + +_MTL_INLINE void MTL4::CompilerDescriptor::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +_MTL_INLINE void MTL4::CompilerDescriptor::setPipelineDataSetSerializer(const MTL4::PipelineDataSetSerializer* pipelineDataSetSerializer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setPipelineDataSetSerializer_), pipelineDataSetSerializer); +} + +_MTL_INLINE MTL4::CompilerTaskOptions* MTL4::CompilerTaskOptions::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4CompilerTaskOptions)); +} + +_MTL_INLINE MTL4::CompilerTaskOptions* MTL4::CompilerTaskOptions::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::Array* MTL4::CompilerTaskOptions::lookupArchives() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(lookupArchives)); +} + +_MTL_INLINE void MTL4::CompilerTaskOptions::setLookupArchives(const NS::Array* lookupArchives) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLookupArchives_), lookupArchives); +} + +_MTL_INLINE MTL::Device* MTL4::Compiler::device() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(device)); +} + +_MTL_INLINE NS::String* MTL4::Compiler::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE MTL4::BinaryFunction* MTL4::Compiler::newBinaryFunction(const MTL4::BinaryFunctionDescriptor* descriptor, const MTL4::CompilerTaskOptions* compilerTaskOptions, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newBinaryFunctionWithDescriptor_compilerTaskOptions_error_), descriptor, compilerTaskOptions, error); +} + +_MTL_INLINE MTL4::CompilerTask* MTL4::Compiler::newBinaryFunction(const MTL4::BinaryFunctionDescriptor* descriptor, const MTL4::CompilerTaskOptions* compilerTaskOptions, const MTL4::NewBinaryFunctionCompletionHandler completionHandler) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newBinaryFunctionWithDescriptor_compilerTaskOptions_completionHandler_), descriptor, compilerTaskOptions, completionHandler); +} + +_MTL_INLINE MTL::ComputePipelineState* MTL4::Compiler::newComputePipelineState(const MTL4::ComputePipelineDescriptor* descriptor, const MTL4::CompilerTaskOptions* compilerTaskOptions, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newComputePipelineStateWithDescriptor_compilerTaskOptions_error_), descriptor, compilerTaskOptions, error); +} + +_MTL_INLINE MTL::ComputePipelineState* MTL4::Compiler::newComputePipelineState(const MTL4::ComputePipelineDescriptor* descriptor, const MTL4::PipelineStageDynamicLinkingDescriptor* dynamicLinkingDescriptor, const MTL4::CompilerTaskOptions* compilerTaskOptions, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newComputePipelineStateWithDescriptor_dynamicLinkingDescriptor_compilerTaskOptions_error_), descriptor, dynamicLinkingDescriptor, compilerTaskOptions, error); +} + +_MTL_INLINE MTL4::CompilerTask* MTL4::Compiler::newComputePipelineState(const MTL4::ComputePipelineDescriptor* descriptor, const MTL4::CompilerTaskOptions* compilerTaskOptions, const MTL::NewComputePipelineStateCompletionHandler completionHandler) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newComputePipelineStateWithDescriptor_compilerTaskOptions_completionHandler_), descriptor, compilerTaskOptions, completionHandler); +} + +_MTL_INLINE MTL4::CompilerTask* MTL4::Compiler::newComputePipelineState(const MTL4::ComputePipelineDescriptor* descriptor, const MTL4::PipelineStageDynamicLinkingDescriptor* dynamicLinkingDescriptor, const MTL4::CompilerTaskOptions* compilerTaskOptions, const MTL::NewComputePipelineStateCompletionHandler completionHandler) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newComputePipelineStateWithDescriptor_dynamicLinkingDescriptor_compilerTaskOptions_completionHandler_), descriptor, dynamicLinkingDescriptor, compilerTaskOptions, completionHandler); +} + +_MTL_INLINE MTL4::CompilerTask* MTL4::Compiler::newComputePipelineState(const MTL4::ComputePipelineDescriptor* pDescriptor, const MTL4::CompilerTaskOptions* options, const MTL4::NewComputePipelineStateCompletionHandlerFunction& function) +{ + __block MTL4::NewComputePipelineStateCompletionHandlerFunction blockFunction = function; + return newComputePipelineState(pDescriptor, options, ^(MTL::ComputePipelineState* pPipeline, NS::Error* pError) { blockFunction(pPipeline, pError); }); +} + +_MTL_INLINE MTL::DynamicLibrary* MTL4::Compiler::newDynamicLibrary(const MTL::Library* library, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newDynamicLibrary_error_), library, error); +} + +_MTL_INLINE MTL::DynamicLibrary* MTL4::Compiler::newDynamicLibrary(const NS::URL* url, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newDynamicLibraryWithURL_error_), url, error); +} + +_MTL_INLINE MTL4::CompilerTask* MTL4::Compiler::newDynamicLibrary(const MTL::Library* library, const MTL::NewDynamicLibraryCompletionHandler completionHandler) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newDynamicLibrary_completionHandler_), library, completionHandler); +} + +_MTL_INLINE MTL4::CompilerTask* MTL4::Compiler::newDynamicLibrary(const NS::URL* url, const MTL::NewDynamicLibraryCompletionHandler completionHandler) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newDynamicLibraryWithURL_completionHandler_), url, completionHandler); +} + +_MTL_INLINE MTL4::CompilerTask* MTL4::Compiler::newDynamicLibrary(const MTL::Library* pLibrary, const MTL::NewDynamicLibraryCompletionHandlerFunction& function) +{ + __block MTL::NewDynamicLibraryCompletionHandlerFunction blockFunction = function; + return newDynamicLibrary(pLibrary, ^(MTL::DynamicLibrary* pLibraryRef, NS::Error* pError) { blockFunction(pLibraryRef, pError); }); +} + +_MTL_INLINE MTL4::CompilerTask* MTL4::Compiler::newDynamicLibrary(const NS::URL* pURL, const MTL::NewDynamicLibraryCompletionHandlerFunction& function) +{ + __block MTL::NewDynamicLibraryCompletionHandlerFunction blockFunction = function; + return newDynamicLibrary(pURL, ^(MTL::DynamicLibrary* pLibrary, NS::Error* pError) { blockFunction(pLibrary, pError); }); +} + +_MTL_INLINE MTL::Library* MTL4::Compiler::newLibrary(const MTL4::LibraryDescriptor* descriptor, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newLibraryWithDescriptor_error_), descriptor, error); +} + +_MTL_INLINE MTL4::CompilerTask* MTL4::Compiler::newLibrary(const MTL4::LibraryDescriptor* descriptor, const MTL::NewLibraryCompletionHandler completionHandler) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newLibraryWithDescriptor_completionHandler_), descriptor, completionHandler); +} + +_MTL_INLINE MTL4::CompilerTask* MTL4::Compiler::newLibrary(const MTL4::LibraryDescriptor* pDescriptor, const MTL::NewLibraryCompletionHandlerFunction& function) +{ + __block MTL::NewLibraryCompletionHandlerFunction blockFunction = function; + return newLibrary(pDescriptor, ^(MTL::Library* pLibrary, NS::Error* pError) { blockFunction(pLibrary, pError); }); +} + +_MTL_INLINE MTL4::MachineLearningPipelineState* MTL4::Compiler::newMachineLearningPipelineState(const MTL4::MachineLearningPipelineDescriptor* descriptor, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newMachineLearningPipelineStateWithDescriptor_error_), descriptor, error); +} + +_MTL_INLINE MTL4::CompilerTask* MTL4::Compiler::newMachineLearningPipelineState(const MTL4::MachineLearningPipelineDescriptor* descriptor, const MTL4::NewMachineLearningPipelineStateCompletionHandler completionHandler) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newMachineLearningPipelineStateWithDescriptor_completionHandler_), descriptor, completionHandler); +} + +_MTL_INLINE MTL4::CompilerTask* MTL4::Compiler::newMachineLearningPipelineState(const MTL4::MachineLearningPipelineDescriptor* pDescriptor, const MTL4::NewMachineLearningPipelineStateCompletionHandlerFunction& function) +{ + __block MTL4::NewMachineLearningPipelineStateCompletionHandlerFunction blockFunction = function; + return newMachineLearningPipelineState(pDescriptor, ^(MTL4::MachineLearningPipelineState* pPipeline, NS::Error* pError) { blockFunction(pPipeline, pError); }); +} + +_MTL_INLINE MTL::RenderPipelineState* MTL4::Compiler::newRenderPipelineState(const MTL4::PipelineDescriptor* descriptor, const MTL4::CompilerTaskOptions* compilerTaskOptions, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newRenderPipelineStateWithDescriptor_compilerTaskOptions_error_), descriptor, compilerTaskOptions, error); +} + +_MTL_INLINE MTL::RenderPipelineState* MTL4::Compiler::newRenderPipelineState(const MTL4::PipelineDescriptor* descriptor, const MTL4::RenderPipelineDynamicLinkingDescriptor* dynamicLinkingDescriptor, const MTL4::CompilerTaskOptions* compilerTaskOptions, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newRenderPipelineStateWithDescriptor_dynamicLinkingDescriptor_compilerTaskOptions_error_), descriptor, dynamicLinkingDescriptor, compilerTaskOptions, error); +} + +_MTL_INLINE MTL4::CompilerTask* MTL4::Compiler::newRenderPipelineState(const MTL4::PipelineDescriptor* descriptor, const MTL4::CompilerTaskOptions* compilerTaskOptions, const MTL::NewRenderPipelineStateCompletionHandler completionHandler) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newRenderPipelineStateWithDescriptor_compilerTaskOptions_completionHandler_), descriptor, compilerTaskOptions, completionHandler); +} + +_MTL_INLINE MTL4::CompilerTask* MTL4::Compiler::newRenderPipelineState(const MTL4::PipelineDescriptor* descriptor, const MTL4::RenderPipelineDynamicLinkingDescriptor* dynamicLinkingDescriptor, const MTL4::CompilerTaskOptions* compilerTaskOptions, const MTL::NewRenderPipelineStateCompletionHandler completionHandler) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newRenderPipelineStateWithDescriptor_dynamicLinkingDescriptor_compilerTaskOptions_completionHandler_), descriptor, dynamicLinkingDescriptor, compilerTaskOptions, completionHandler); +} + +_MTL_INLINE MTL4::CompilerTask* MTL4::Compiler::newRenderPipelineState(const MTL4::PipelineDescriptor* pDescriptor, const MTL4::CompilerTaskOptions* options, const MTL4::NewRenderPipelineStateCompletionHandlerFunction& function) +{ + __block MTL4::NewRenderPipelineStateCompletionHandlerFunction blockFunction = function; + return newRenderPipelineState(pDescriptor, options, ^(MTL::RenderPipelineState* pPipeline, NS::Error* pError) { blockFunction(pPipeline, pError); }); +} + +_MTL_INLINE MTL::RenderPipelineState* MTL4::Compiler::newRenderPipelineStateBySpecialization(const MTL4::PipelineDescriptor* descriptor, const MTL::RenderPipelineState* pipeline, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newRenderPipelineStateBySpecializationWithDescriptor_pipeline_error_), descriptor, pipeline, error); +} + +_MTL_INLINE MTL4::CompilerTask* MTL4::Compiler::newRenderPipelineStateBySpecialization(const MTL4::PipelineDescriptor* descriptor, const MTL::RenderPipelineState* pipeline, const MTL::NewRenderPipelineStateCompletionHandler completionHandler) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newRenderPipelineStateBySpecializationWithDescriptor_pipeline_completionHandler_), descriptor, pipeline, completionHandler); +} + +_MTL_INLINE MTL4::CompilerTask* MTL4::Compiler::newRenderPipelineStateBySpecialization(const MTL4::PipelineDescriptor* pDescriptor, const MTL::RenderPipelineState* pPipeline, const MTL4::NewRenderPipelineStateCompletionHandlerFunction& function) +{ + __block MTL4::NewRenderPipelineStateCompletionHandlerFunction blockFunction = function; + return newRenderPipelineStateBySpecialization(pDescriptor, pPipeline, ^(MTL::RenderPipelineState* pPipelineRef, NS::Error* pError) { blockFunction(pPipelineRef, pError); }); +} + +_MTL_INLINE MTL4::PipelineDataSetSerializer* MTL4::Compiler::pipelineDataSetSerializer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(pipelineDataSetSerializer)); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTL4CompilerTask.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTL4CompilerTask.hpp new file mode 100644 index 00000000..a1ee9cdf --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTL4CompilerTask.hpp @@ -0,0 +1,63 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTL4CompilerTask.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL4 +{ +class Compiler; +_MTL_ENUM(NS::Integer, CompilerTaskStatus) { + CompilerTaskStatusNone = 0, + CompilerTaskStatusScheduled = 1, + CompilerTaskStatusCompiling = 2, + CompilerTaskStatusFinished = 3, +}; + +class CompilerTask : public NS::Referencing +{ +public: + Compiler* compiler() const; + + CompilerTaskStatus status() const; + + void waitUntilCompleted(); +}; + +} + +_MTL_INLINE MTL4::Compiler* MTL4::CompilerTask::compiler() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(compiler)); +} + +_MTL_INLINE MTL4::CompilerTaskStatus MTL4::CompilerTask::status() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(status)); +} + +_MTL_INLINE void MTL4::CompilerTask::waitUntilCompleted() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(waitUntilCompleted)); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTL4ComputeCommandEncoder.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTL4ComputeCommandEncoder.hpp new file mode 100644 index 00000000..7ef19da2 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTL4ComputeCommandEncoder.hpp @@ -0,0 +1,300 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTL4ComputeCommandEncoder.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTL4CommandEncoder.hpp" +#include "MTL4Counters.hpp" +#include "MTLAccelerationStructure.hpp" +#include "MTLAccelerationStructureTypes.hpp" +#include "MTLBlitCommandEncoder.hpp" +#include "MTLCommandEncoder.hpp" +#include "MTLDefines.hpp" +#include "MTLGPUAddress.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLTypes.hpp" +#include + +namespace MTL4 +{ +class AccelerationStructureDescriptor; +class ArgumentTable; +class CounterHeap; +} + +namespace MTL +{ +class AccelerationStructure; +class Buffer; +class ComputePipelineState; +class IndirectCommandBuffer; +class Tensor; +class TensorExtents; +class Texture; +} + +namespace MTL4 +{ +class ComputeCommandEncoder : public NS::Referencing +{ +public: + void buildAccelerationStructure(const MTL::AccelerationStructure* accelerationStructure, const MTL4::AccelerationStructureDescriptor* descriptor, const MTL4::BufferRange scratchBuffer); + + void copyAccelerationStructure(const MTL::AccelerationStructure* sourceAccelerationStructure, const MTL::AccelerationStructure* destinationAccelerationStructure); + + void copyAndCompactAccelerationStructure(const MTL::AccelerationStructure* sourceAccelerationStructure, const MTL::AccelerationStructure* destinationAccelerationStructure); + + void copyFromBuffer(const MTL::Buffer* sourceBuffer, NS::UInteger sourceOffset, const MTL::Buffer* destinationBuffer, NS::UInteger destinationOffset, NS::UInteger size); + void copyFromBuffer(const MTL::Buffer* sourceBuffer, NS::UInteger sourceOffset, NS::UInteger sourceBytesPerRow, NS::UInteger sourceBytesPerImage, MTL::Size sourceSize, const MTL::Texture* destinationTexture, NS::UInteger destinationSlice, NS::UInteger destinationLevel, MTL::Origin destinationOrigin); + void copyFromBuffer(const MTL::Buffer* sourceBuffer, NS::UInteger sourceOffset, NS::UInteger sourceBytesPerRow, NS::UInteger sourceBytesPerImage, MTL::Size sourceSize, const MTL::Texture* destinationTexture, NS::UInteger destinationSlice, NS::UInteger destinationLevel, MTL::Origin destinationOrigin, MTL::BlitOption options); + + void copyFromTensor(const MTL::Tensor* sourceTensor, const MTL::TensorExtents* sourceOrigin, const MTL::TensorExtents* sourceDimensions, const MTL::Tensor* destinationTensor, const MTL::TensorExtents* destinationOrigin, const MTL::TensorExtents* destinationDimensions); + + void copyFromTexture(const MTL::Texture* sourceTexture, const MTL::Texture* destinationTexture); + void copyFromTexture(const MTL::Texture* sourceTexture, NS::UInteger sourceSlice, NS::UInteger sourceLevel, const MTL::Texture* destinationTexture, NS::UInteger destinationSlice, NS::UInteger destinationLevel, NS::UInteger sliceCount, NS::UInteger levelCount); + void copyFromTexture(const MTL::Texture* sourceTexture, NS::UInteger sourceSlice, NS::UInteger sourceLevel, MTL::Origin sourceOrigin, MTL::Size sourceSize, const MTL::Texture* destinationTexture, NS::UInteger destinationSlice, NS::UInteger destinationLevel, MTL::Origin destinationOrigin); + void copyFromTexture(const MTL::Texture* sourceTexture, NS::UInteger sourceSlice, NS::UInteger sourceLevel, MTL::Origin sourceOrigin, MTL::Size sourceSize, const MTL::Buffer* destinationBuffer, NS::UInteger destinationOffset, NS::UInteger destinationBytesPerRow, NS::UInteger destinationBytesPerImage); + void copyFromTexture(const MTL::Texture* sourceTexture, NS::UInteger sourceSlice, NS::UInteger sourceLevel, MTL::Origin sourceOrigin, MTL::Size sourceSize, const MTL::Buffer* destinationBuffer, NS::UInteger destinationOffset, NS::UInteger destinationBytesPerRow, NS::UInteger destinationBytesPerImage, MTL::BlitOption options); + + void copyIndirectCommandBuffer(const MTL::IndirectCommandBuffer* source, NS::Range sourceRange, const MTL::IndirectCommandBuffer* destination, NS::UInteger destinationIndex); + + void dispatchThreadgroups(MTL::Size threadgroupsPerGrid, MTL::Size threadsPerThreadgroup); + void dispatchThreadgroups(MTL::GPUAddress indirectBuffer, MTL::Size threadsPerThreadgroup); + + void dispatchThreads(MTL::Size threadsPerGrid, MTL::Size threadsPerThreadgroup); + void dispatchThreads(MTL::GPUAddress indirectBuffer); + + void executeCommandsInBuffer(const MTL::IndirectCommandBuffer* indirectCommandBuffer, NS::Range executionRange); + void executeCommandsInBuffer(const MTL::IndirectCommandBuffer* indirectCommandbuffer, MTL::GPUAddress indirectRangeBuffer); + + void fillBuffer(const MTL::Buffer* buffer, NS::Range range, uint8_t value); + + void generateMipmaps(const MTL::Texture* texture); + + void optimizeContentsForCPUAccess(const MTL::Texture* texture); + void optimizeContentsForCPUAccess(const MTL::Texture* texture, NS::UInteger slice, NS::UInteger level); + + void optimizeContentsForGPUAccess(const MTL::Texture* texture); + void optimizeContentsForGPUAccess(const MTL::Texture* texture, NS::UInteger slice, NS::UInteger level); + + void optimizeIndirectCommandBuffer(const MTL::IndirectCommandBuffer* indirectCommandBuffer, NS::Range range); + + void refitAccelerationStructure(const MTL::AccelerationStructure* sourceAccelerationStructure, const MTL4::AccelerationStructureDescriptor* descriptor, const MTL::AccelerationStructure* destinationAccelerationStructure, const MTL4::BufferRange scratchBuffer); + void refitAccelerationStructure(const MTL::AccelerationStructure* sourceAccelerationStructure, const MTL4::AccelerationStructureDescriptor* descriptor, const MTL::AccelerationStructure* destinationAccelerationStructure, const MTL4::BufferRange scratchBuffer, MTL::AccelerationStructureRefitOptions options); + + void resetCommandsInBuffer(const MTL::IndirectCommandBuffer* buffer, NS::Range range); + + void setArgumentTable(const MTL4::ArgumentTable* argumentTable); + + void setComputePipelineState(const MTL::ComputePipelineState* state); + + void setImageblockWidth(NS::UInteger width, NS::UInteger height); + + void setThreadgroupMemoryLength(NS::UInteger length, NS::UInteger index); + + MTL::Stages stages(); + + void writeCompactedAccelerationStructureSize(const MTL::AccelerationStructure* accelerationStructure, const MTL4::BufferRange buffer); + + void writeTimestamp(MTL4::TimestampGranularity granularity, const MTL4::CounterHeap* counterHeap, NS::UInteger index); +}; + +} +_MTL_INLINE void MTL4::ComputeCommandEncoder::buildAccelerationStructure(const MTL::AccelerationStructure* accelerationStructure, const MTL4::AccelerationStructureDescriptor* descriptor, const MTL4::BufferRange scratchBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(buildAccelerationStructure_descriptor_scratchBuffer_), accelerationStructure, descriptor, scratchBuffer); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::copyAccelerationStructure(const MTL::AccelerationStructure* sourceAccelerationStructure, const MTL::AccelerationStructure* destinationAccelerationStructure) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(copyAccelerationStructure_toAccelerationStructure_), sourceAccelerationStructure, destinationAccelerationStructure); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::copyAndCompactAccelerationStructure(const MTL::AccelerationStructure* sourceAccelerationStructure, const MTL::AccelerationStructure* destinationAccelerationStructure) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(copyAndCompactAccelerationStructure_toAccelerationStructure_), sourceAccelerationStructure, destinationAccelerationStructure); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::copyFromBuffer(const MTL::Buffer* sourceBuffer, NS::UInteger sourceOffset, const MTL::Buffer* destinationBuffer, NS::UInteger destinationOffset, NS::UInteger size) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(copyFromBuffer_sourceOffset_toBuffer_destinationOffset_size_), sourceBuffer, sourceOffset, destinationBuffer, destinationOffset, size); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::copyFromBuffer(const MTL::Buffer* sourceBuffer, NS::UInteger sourceOffset, NS::UInteger sourceBytesPerRow, NS::UInteger sourceBytesPerImage, MTL::Size sourceSize, const MTL::Texture* destinationTexture, NS::UInteger destinationSlice, NS::UInteger destinationLevel, MTL::Origin destinationOrigin) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(copyFromBuffer_sourceOffset_sourceBytesPerRow_sourceBytesPerImage_sourceSize_toTexture_destinationSlice_destinationLevel_destinationOrigin_), sourceBuffer, sourceOffset, sourceBytesPerRow, sourceBytesPerImage, sourceSize, destinationTexture, destinationSlice, destinationLevel, destinationOrigin); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::copyFromBuffer(const MTL::Buffer* sourceBuffer, NS::UInteger sourceOffset, NS::UInteger sourceBytesPerRow, NS::UInteger sourceBytesPerImage, MTL::Size sourceSize, const MTL::Texture* destinationTexture, NS::UInteger destinationSlice, NS::UInteger destinationLevel, MTL::Origin destinationOrigin, MTL::BlitOption options) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(copyFromBuffer_sourceOffset_sourceBytesPerRow_sourceBytesPerImage_sourceSize_toTexture_destinationSlice_destinationLevel_destinationOrigin_options_), sourceBuffer, sourceOffset, sourceBytesPerRow, sourceBytesPerImage, sourceSize, destinationTexture, destinationSlice, destinationLevel, destinationOrigin, options); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::copyFromTensor(const MTL::Tensor* sourceTensor, const MTL::TensorExtents* sourceOrigin, const MTL::TensorExtents* sourceDimensions, const MTL::Tensor* destinationTensor, const MTL::TensorExtents* destinationOrigin, const MTL::TensorExtents* destinationDimensions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(copyFromTensor_sourceOrigin_sourceDimensions_toTensor_destinationOrigin_destinationDimensions_), sourceTensor, sourceOrigin, sourceDimensions, destinationTensor, destinationOrigin, destinationDimensions); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::copyFromTexture(const MTL::Texture* sourceTexture, const MTL::Texture* destinationTexture) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(copyFromTexture_toTexture_), sourceTexture, destinationTexture); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::copyFromTexture(const MTL::Texture* sourceTexture, NS::UInteger sourceSlice, NS::UInteger sourceLevel, const MTL::Texture* destinationTexture, NS::UInteger destinationSlice, NS::UInteger destinationLevel, NS::UInteger sliceCount, NS::UInteger levelCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(copyFromTexture_sourceSlice_sourceLevel_toTexture_destinationSlice_destinationLevel_sliceCount_levelCount_), sourceTexture, sourceSlice, sourceLevel, destinationTexture, destinationSlice, destinationLevel, sliceCount, levelCount); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::copyFromTexture(const MTL::Texture* sourceTexture, NS::UInteger sourceSlice, NS::UInteger sourceLevel, MTL::Origin sourceOrigin, MTL::Size sourceSize, const MTL::Texture* destinationTexture, NS::UInteger destinationSlice, NS::UInteger destinationLevel, MTL::Origin destinationOrigin) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(copyFromTexture_sourceSlice_sourceLevel_sourceOrigin_sourceSize_toTexture_destinationSlice_destinationLevel_destinationOrigin_), sourceTexture, sourceSlice, sourceLevel, sourceOrigin, sourceSize, destinationTexture, destinationSlice, destinationLevel, destinationOrigin); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::copyFromTexture(const MTL::Texture* sourceTexture, NS::UInteger sourceSlice, NS::UInteger sourceLevel, MTL::Origin sourceOrigin, MTL::Size sourceSize, const MTL::Buffer* destinationBuffer, NS::UInteger destinationOffset, NS::UInteger destinationBytesPerRow, NS::UInteger destinationBytesPerImage) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(copyFromTexture_sourceSlice_sourceLevel_sourceOrigin_sourceSize_toBuffer_destinationOffset_destinationBytesPerRow_destinationBytesPerImage_), sourceTexture, sourceSlice, sourceLevel, sourceOrigin, sourceSize, destinationBuffer, destinationOffset, destinationBytesPerRow, destinationBytesPerImage); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::copyFromTexture(const MTL::Texture* sourceTexture, NS::UInteger sourceSlice, NS::UInteger sourceLevel, MTL::Origin sourceOrigin, MTL::Size sourceSize, const MTL::Buffer* destinationBuffer, NS::UInteger destinationOffset, NS::UInteger destinationBytesPerRow, NS::UInteger destinationBytesPerImage, MTL::BlitOption options) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(copyFromTexture_sourceSlice_sourceLevel_sourceOrigin_sourceSize_toBuffer_destinationOffset_destinationBytesPerRow_destinationBytesPerImage_options_), sourceTexture, sourceSlice, sourceLevel, sourceOrigin, sourceSize, destinationBuffer, destinationOffset, destinationBytesPerRow, destinationBytesPerImage, options); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::copyIndirectCommandBuffer(const MTL::IndirectCommandBuffer* source, NS::Range sourceRange, const MTL::IndirectCommandBuffer* destination, NS::UInteger destinationIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(copyIndirectCommandBuffer_sourceRange_destination_destinationIndex_), source, sourceRange, destination, destinationIndex); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::dispatchThreadgroups(MTL::Size threadgroupsPerGrid, MTL::Size threadsPerThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(dispatchThreadgroups_threadsPerThreadgroup_), threadgroupsPerGrid, threadsPerThreadgroup); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::dispatchThreadgroups(MTL::GPUAddress indirectBuffer, MTL::Size threadsPerThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(dispatchThreadgroupsWithIndirectBuffer_threadsPerThreadgroup_), indirectBuffer, threadsPerThreadgroup); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::dispatchThreads(MTL::Size threadsPerGrid, MTL::Size threadsPerThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(dispatchThreads_threadsPerThreadgroup_), threadsPerGrid, threadsPerThreadgroup); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::dispatchThreads(MTL::GPUAddress indirectBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(dispatchThreadsWithIndirectBuffer_), indirectBuffer); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::executeCommandsInBuffer(const MTL::IndirectCommandBuffer* indirectCommandBuffer, NS::Range executionRange) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(executeCommandsInBuffer_withRange_), indirectCommandBuffer, executionRange); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::executeCommandsInBuffer(const MTL::IndirectCommandBuffer* indirectCommandbuffer, MTL::GPUAddress indirectRangeBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(executeCommandsInBuffer_indirectBuffer_), indirectCommandbuffer, indirectRangeBuffer); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::fillBuffer(const MTL::Buffer* buffer, NS::Range range, uint8_t value) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(fillBuffer_range_value_), buffer, range, value); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::generateMipmaps(const MTL::Texture* texture) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(generateMipmapsForTexture_), texture); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::optimizeContentsForCPUAccess(const MTL::Texture* texture) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(optimizeContentsForCPUAccess_), texture); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::optimizeContentsForCPUAccess(const MTL::Texture* texture, NS::UInteger slice, NS::UInteger level) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(optimizeContentsForCPUAccess_slice_level_), texture, slice, level); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::optimizeContentsForGPUAccess(const MTL::Texture* texture) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(optimizeContentsForGPUAccess_), texture); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::optimizeContentsForGPUAccess(const MTL::Texture* texture, NS::UInteger slice, NS::UInteger level) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(optimizeContentsForGPUAccess_slice_level_), texture, slice, level); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::optimizeIndirectCommandBuffer(const MTL::IndirectCommandBuffer* indirectCommandBuffer, NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(optimizeIndirectCommandBuffer_withRange_), indirectCommandBuffer, range); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::refitAccelerationStructure(const MTL::AccelerationStructure* sourceAccelerationStructure, const MTL4::AccelerationStructureDescriptor* descriptor, const MTL::AccelerationStructure* destinationAccelerationStructure, const MTL4::BufferRange scratchBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(refitAccelerationStructure_descriptor_destination_scratchBuffer_), sourceAccelerationStructure, descriptor, destinationAccelerationStructure, scratchBuffer); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::refitAccelerationStructure(const MTL::AccelerationStructure* sourceAccelerationStructure, const MTL4::AccelerationStructureDescriptor* descriptor, const MTL::AccelerationStructure* destinationAccelerationStructure, const MTL4::BufferRange scratchBuffer, MTL::AccelerationStructureRefitOptions options) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(refitAccelerationStructure_descriptor_destination_scratchBuffer_options_), sourceAccelerationStructure, descriptor, destinationAccelerationStructure, scratchBuffer, options); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::resetCommandsInBuffer(const MTL::IndirectCommandBuffer* buffer, NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(resetCommandsInBuffer_withRange_), buffer, range); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::setArgumentTable(const MTL4::ArgumentTable* argumentTable) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setArgumentTable_), argumentTable); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::setComputePipelineState(const MTL::ComputePipelineState* state) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setComputePipelineState_), state); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::setImageblockWidth(NS::UInteger width, NS::UInteger height) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setImageblockWidth_height_), width, height); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::setThreadgroupMemoryLength(NS::UInteger length, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setThreadgroupMemoryLength_atIndex_), length, index); +} + +_MTL_INLINE MTL::Stages MTL4::ComputeCommandEncoder::stages() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(stages)); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::writeCompactedAccelerationStructureSize(const MTL::AccelerationStructure* accelerationStructure, const MTL4::BufferRange buffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(writeCompactedAccelerationStructureSize_toBuffer_), accelerationStructure, buffer); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::writeTimestamp(MTL4::TimestampGranularity granularity, const MTL4::CounterHeap* counterHeap, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(writeTimestampWithGranularity_intoHeap_atIndex_), granularity, counterHeap, index); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTL4ComputePipeline.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTL4ComputePipeline.hpp new file mode 100644 index 00000000..a808431a --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTL4ComputePipeline.hpp @@ -0,0 +1,158 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTL4ComputePipeline.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTL4PipelineState.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLTypes.hpp" + +namespace MTL4 +{ +class ComputePipelineDescriptor; +class FunctionDescriptor; +class StaticLinkingDescriptor; + +class ComputePipelineDescriptor : public NS::Copying +{ +public: + static ComputePipelineDescriptor* alloc(); + + FunctionDescriptor* computeFunctionDescriptor() const; + + ComputePipelineDescriptor* init(); + + NS::UInteger maxTotalThreadsPerThreadgroup() const; + + MTL::Size requiredThreadsPerThreadgroup() const; + + void reset(); + + void setComputeFunctionDescriptor(const MTL4::FunctionDescriptor* computeFunctionDescriptor); + + void setMaxTotalThreadsPerThreadgroup(NS::UInteger maxTotalThreadsPerThreadgroup); + + void setRequiredThreadsPerThreadgroup(MTL::Size requiredThreadsPerThreadgroup); + + void setStaticLinkingDescriptor(const MTL4::StaticLinkingDescriptor* staticLinkingDescriptor); + + void setSupportBinaryLinking(bool supportBinaryLinking); + + void setSupportIndirectCommandBuffers(MTL4::IndirectCommandBufferSupportState supportIndirectCommandBuffers); + + void setThreadGroupSizeIsMultipleOfThreadExecutionWidth(bool threadGroupSizeIsMultipleOfThreadExecutionWidth); + + StaticLinkingDescriptor* staticLinkingDescriptor() const; + + bool supportBinaryLinking() const; + + IndirectCommandBufferSupportState supportIndirectCommandBuffers() const; + + bool threadGroupSizeIsMultipleOfThreadExecutionWidth() const; +}; + +} +_MTL_INLINE MTL4::ComputePipelineDescriptor* MTL4::ComputePipelineDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4ComputePipelineDescriptor)); +} + +_MTL_INLINE MTL4::FunctionDescriptor* MTL4::ComputePipelineDescriptor::computeFunctionDescriptor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(computeFunctionDescriptor)); +} + +_MTL_INLINE MTL4::ComputePipelineDescriptor* MTL4::ComputePipelineDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::UInteger MTL4::ComputePipelineDescriptor::maxTotalThreadsPerThreadgroup() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxTotalThreadsPerThreadgroup)); +} + +_MTL_INLINE MTL::Size MTL4::ComputePipelineDescriptor::requiredThreadsPerThreadgroup() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(requiredThreadsPerThreadgroup)); +} + +_MTL_INLINE void MTL4::ComputePipelineDescriptor::reset() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(reset)); +} + +_MTL_INLINE void MTL4::ComputePipelineDescriptor::setComputeFunctionDescriptor(const MTL4::FunctionDescriptor* computeFunctionDescriptor) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setComputeFunctionDescriptor_), computeFunctionDescriptor); +} + +_MTL_INLINE void MTL4::ComputePipelineDescriptor::setMaxTotalThreadsPerThreadgroup(NS::UInteger maxTotalThreadsPerThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxTotalThreadsPerThreadgroup_), maxTotalThreadsPerThreadgroup); +} + +_MTL_INLINE void MTL4::ComputePipelineDescriptor::setRequiredThreadsPerThreadgroup(MTL::Size requiredThreadsPerThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRequiredThreadsPerThreadgroup_), requiredThreadsPerThreadgroup); +} + +_MTL_INLINE void MTL4::ComputePipelineDescriptor::setStaticLinkingDescriptor(const MTL4::StaticLinkingDescriptor* staticLinkingDescriptor) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStaticLinkingDescriptor_), staticLinkingDescriptor); +} + +_MTL_INLINE void MTL4::ComputePipelineDescriptor::setSupportBinaryLinking(bool supportBinaryLinking) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSupportBinaryLinking_), supportBinaryLinking); +} + +_MTL_INLINE void MTL4::ComputePipelineDescriptor::setSupportIndirectCommandBuffers(MTL4::IndirectCommandBufferSupportState supportIndirectCommandBuffers) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSupportIndirectCommandBuffers_), supportIndirectCommandBuffers); +} + +_MTL_INLINE void MTL4::ComputePipelineDescriptor::setThreadGroupSizeIsMultipleOfThreadExecutionWidth(bool threadGroupSizeIsMultipleOfThreadExecutionWidth) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setThreadGroupSizeIsMultipleOfThreadExecutionWidth_), threadGroupSizeIsMultipleOfThreadExecutionWidth); +} + +_MTL_INLINE MTL4::StaticLinkingDescriptor* MTL4::ComputePipelineDescriptor::staticLinkingDescriptor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(staticLinkingDescriptor)); +} + +_MTL_INLINE bool MTL4::ComputePipelineDescriptor::supportBinaryLinking() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportBinaryLinking)); +} + +_MTL_INLINE MTL4::IndirectCommandBufferSupportState MTL4::ComputePipelineDescriptor::supportIndirectCommandBuffers() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(supportIndirectCommandBuffers)); +} + +_MTL_INLINE bool MTL4::ComputePipelineDescriptor::threadGroupSizeIsMultipleOfThreadExecutionWidth() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(threadGroupSizeIsMultipleOfThreadExecutionWidth)); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTL4Counters.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTL4Counters.hpp new file mode 100644 index 00000000..b507b766 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTL4Counters.hpp @@ -0,0 +1,138 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTL4Counters.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include + +#include + +namespace MTL4 +{ +class CounterHeapDescriptor; +_MTL_ENUM(NS::Integer, CounterHeapType) { + CounterHeapTypeInvalid, + CounterHeapTypeTimestamp, +}; + +_MTL_ENUM(NS::Integer, TimestampGranularity) { + TimestampGranularityRelaxed = 0, + TimestampGranularityPrecise = 1, +}; + +struct TimestampHeapEntry +{ + uint64_t timestamp; +} _MTL_PACKED; + +class CounterHeapDescriptor : public NS::Copying +{ +public: + static CounterHeapDescriptor* alloc(); + + NS::UInteger count() const; + + CounterHeapDescriptor* init(); + + void setCount(NS::UInteger count); + + void setType(MTL4::CounterHeapType type); + CounterHeapType type() const; +}; +class CounterHeap : public NS::Referencing +{ +public: + NS::UInteger count() const; + void invalidateCounterRange(NS::Range range); + + NS::String* label() const; + + NS::Data* resolveCounterRange(NS::Range range); + + void setLabel(const NS::String* label); + + CounterHeapType type() const; +}; + +} + +_MTL_INLINE MTL4::CounterHeapDescriptor* MTL4::CounterHeapDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4CounterHeapDescriptor)); +} + +_MTL_INLINE NS::UInteger MTL4::CounterHeapDescriptor::count() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(count)); +} + +_MTL_INLINE MTL4::CounterHeapDescriptor* MTL4::CounterHeapDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE void MTL4::CounterHeapDescriptor::setCount(NS::UInteger count) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setCount_), count); +} + +_MTL_INLINE void MTL4::CounterHeapDescriptor::setType(MTL4::CounterHeapType type) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setType_), type); +} + +_MTL_INLINE MTL4::CounterHeapType MTL4::CounterHeapDescriptor::type() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(type)); +} + +_MTL_INLINE NS::UInteger MTL4::CounterHeap::count() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(count)); +} + +_MTL_INLINE void MTL4::CounterHeap::invalidateCounterRange(NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(invalidateCounterRange_), range); +} + +_MTL_INLINE NS::String* MTL4::CounterHeap::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE NS::Data* MTL4::CounterHeap::resolveCounterRange(NS::Range range) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(resolveCounterRange_), range); +} + +_MTL_INLINE void MTL4::CounterHeap::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +_MTL_INLINE MTL4::CounterHeapType MTL4::CounterHeap::type() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(type)); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTL4FunctionDescriptor.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTL4FunctionDescriptor.hpp new file mode 100644 index 00000000..9049677e --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTL4FunctionDescriptor.hpp @@ -0,0 +1,49 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal//MTL4FunctionDescriptor.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL4 +{ +class FunctionDescriptor; + +class FunctionDescriptor : public NS::Copying +{ +public: + static FunctionDescriptor* alloc(); + + FunctionDescriptor* init(); +}; + +} +_MTL_INLINE MTL4::FunctionDescriptor* MTL4::FunctionDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4FunctionDescriptor)); +} + +_MTL_INLINE MTL4::FunctionDescriptor* MTL4::FunctionDescriptor::init() +{ + return NS::Object::init(); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTL4LibraryDescriptor.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTL4LibraryDescriptor.hpp new file mode 100644 index 00000000..bc491b69 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTL4LibraryDescriptor.hpp @@ -0,0 +1,98 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTL4LibraryDescriptor.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL4 +{ +class LibraryDescriptor; +} + +namespace MTL +{ +class CompileOptions; +} + +namespace MTL4 +{ +class LibraryDescriptor : public NS::Copying +{ +public: + static LibraryDescriptor* alloc(); + + LibraryDescriptor* init(); + + NS::String* name() const; + + MTL::CompileOptions* options() const; + + void setName(const NS::String* name); + + void setOptions(const MTL::CompileOptions* options); + + void setSource(const NS::String* source); + NS::String* source() const; +}; + +} +_MTL_INLINE MTL4::LibraryDescriptor* MTL4::LibraryDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4LibraryDescriptor)); +} + +_MTL_INLINE MTL4::LibraryDescriptor* MTL4::LibraryDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::String* MTL4::LibraryDescriptor::name() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(name)); +} + +_MTL_INLINE MTL::CompileOptions* MTL4::LibraryDescriptor::options() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(options)); +} + +_MTL_INLINE void MTL4::LibraryDescriptor::setName(const NS::String* name) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setName_), name); +} + +_MTL_INLINE void MTL4::LibraryDescriptor::setOptions(const MTL::CompileOptions* options) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setOptions_), options); +} + +_MTL_INLINE void MTL4::LibraryDescriptor::setSource(const NS::String* source) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSource_), source); +} + +_MTL_INLINE NS::String* MTL4::LibraryDescriptor::source() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(source)); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTL4LibraryFunctionDescriptor.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTL4LibraryFunctionDescriptor.hpp new file mode 100644 index 00000000..1dec4bf2 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTL4LibraryFunctionDescriptor.hpp @@ -0,0 +1,86 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTL4LibraryFunctionDescriptor.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTL4FunctionDescriptor.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL4 +{ +class LibraryFunctionDescriptor; +} + +namespace MTL +{ +class Library; +} + +namespace MTL4 +{ +class LibraryFunctionDescriptor : public NS::Copying +{ +public: + static LibraryFunctionDescriptor* alloc(); + + LibraryFunctionDescriptor* init(); + + MTL::Library* library() const; + + NS::String* name() const; + + void setLibrary(const MTL::Library* library); + + void setName(const NS::String* name); +}; + +} +_MTL_INLINE MTL4::LibraryFunctionDescriptor* MTL4::LibraryFunctionDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4LibraryFunctionDescriptor)); +} + +_MTL_INLINE MTL4::LibraryFunctionDescriptor* MTL4::LibraryFunctionDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::Library* MTL4::LibraryFunctionDescriptor::library() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(library)); +} + +_MTL_INLINE NS::String* MTL4::LibraryFunctionDescriptor::name() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(name)); +} + +_MTL_INLINE void MTL4::LibraryFunctionDescriptor::setLibrary(const MTL::Library* library) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLibrary_), library); +} + +_MTL_INLINE void MTL4::LibraryFunctionDescriptor::setName(const NS::String* name) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setName_), name); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTL4LinkingDescriptor.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTL4LinkingDescriptor.hpp new file mode 100644 index 00000000..ef5900b1 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTL4LinkingDescriptor.hpp @@ -0,0 +1,204 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTL4LinkingDescriptor.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL4 +{ +class PipelineStageDynamicLinkingDescriptor; +class RenderPipelineDynamicLinkingDescriptor; +class StaticLinkingDescriptor; + +class StaticLinkingDescriptor : public NS::Copying +{ +public: + static StaticLinkingDescriptor* alloc(); + + NS::Array* functionDescriptors() const; + + NS::Dictionary* groups() const; + + StaticLinkingDescriptor* init(); + + NS::Array* privateFunctionDescriptors() const; + + void setFunctionDescriptors(const NS::Array* functionDescriptors); + + void setGroups(const NS::Dictionary* groups); + + void setPrivateFunctionDescriptors(const NS::Array* privateFunctionDescriptors); +}; +class PipelineStageDynamicLinkingDescriptor : public NS::Copying +{ +public: + static PipelineStageDynamicLinkingDescriptor* alloc(); + + NS::Array* binaryLinkedFunctions() const; + + PipelineStageDynamicLinkingDescriptor* init(); + + NS::UInteger maxCallStackDepth() const; + + NS::Array* preloadedLibraries() const; + + void setBinaryLinkedFunctions(const NS::Array* binaryLinkedFunctions); + + void setMaxCallStackDepth(NS::UInteger maxCallStackDepth); + + void setPreloadedLibraries(const NS::Array* preloadedLibraries); +}; +class RenderPipelineDynamicLinkingDescriptor : public NS::Copying +{ +public: + static RenderPipelineDynamicLinkingDescriptor* alloc(); + + PipelineStageDynamicLinkingDescriptor* fragmentLinkingDescriptor() const; + + RenderPipelineDynamicLinkingDescriptor* init(); + + PipelineStageDynamicLinkingDescriptor* meshLinkingDescriptor() const; + + PipelineStageDynamicLinkingDescriptor* objectLinkingDescriptor() const; + + PipelineStageDynamicLinkingDescriptor* tileLinkingDescriptor() const; + + PipelineStageDynamicLinkingDescriptor* vertexLinkingDescriptor() const; +}; + +} +_MTL_INLINE MTL4::StaticLinkingDescriptor* MTL4::StaticLinkingDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4StaticLinkingDescriptor)); +} + +_MTL_INLINE NS::Array* MTL4::StaticLinkingDescriptor::functionDescriptors() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(functionDescriptors)); +} + +_MTL_INLINE NS::Dictionary* MTL4::StaticLinkingDescriptor::groups() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(groups)); +} + +_MTL_INLINE MTL4::StaticLinkingDescriptor* MTL4::StaticLinkingDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::Array* MTL4::StaticLinkingDescriptor::privateFunctionDescriptors() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(privateFunctionDescriptors)); +} + +_MTL_INLINE void MTL4::StaticLinkingDescriptor::setFunctionDescriptors(const NS::Array* functionDescriptors) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFunctionDescriptors_), functionDescriptors); +} + +_MTL_INLINE void MTL4::StaticLinkingDescriptor::setGroups(const NS::Dictionary* groups) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setGroups_), groups); +} + +_MTL_INLINE void MTL4::StaticLinkingDescriptor::setPrivateFunctionDescriptors(const NS::Array* privateFunctionDescriptors) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setPrivateFunctionDescriptors_), privateFunctionDescriptors); +} + +_MTL_INLINE MTL4::PipelineStageDynamicLinkingDescriptor* MTL4::PipelineStageDynamicLinkingDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4PipelineStageDynamicLinkingDescriptor)); +} + +_MTL_INLINE NS::Array* MTL4::PipelineStageDynamicLinkingDescriptor::binaryLinkedFunctions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(binaryLinkedFunctions)); +} + +_MTL_INLINE MTL4::PipelineStageDynamicLinkingDescriptor* MTL4::PipelineStageDynamicLinkingDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::UInteger MTL4::PipelineStageDynamicLinkingDescriptor::maxCallStackDepth() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxCallStackDepth)); +} + +_MTL_INLINE NS::Array* MTL4::PipelineStageDynamicLinkingDescriptor::preloadedLibraries() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(preloadedLibraries)); +} + +_MTL_INLINE void MTL4::PipelineStageDynamicLinkingDescriptor::setBinaryLinkedFunctions(const NS::Array* binaryLinkedFunctions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBinaryLinkedFunctions_), binaryLinkedFunctions); +} + +_MTL_INLINE void MTL4::PipelineStageDynamicLinkingDescriptor::setMaxCallStackDepth(NS::UInteger maxCallStackDepth) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxCallStackDepth_), maxCallStackDepth); +} + +_MTL_INLINE void MTL4::PipelineStageDynamicLinkingDescriptor::setPreloadedLibraries(const NS::Array* preloadedLibraries) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setPreloadedLibraries_), preloadedLibraries); +} + +_MTL_INLINE MTL4::RenderPipelineDynamicLinkingDescriptor* MTL4::RenderPipelineDynamicLinkingDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4RenderPipelineDynamicLinkingDescriptor)); +} + +_MTL_INLINE MTL4::PipelineStageDynamicLinkingDescriptor* MTL4::RenderPipelineDynamicLinkingDescriptor::fragmentLinkingDescriptor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(fragmentLinkingDescriptor)); +} + +_MTL_INLINE MTL4::RenderPipelineDynamicLinkingDescriptor* MTL4::RenderPipelineDynamicLinkingDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL4::PipelineStageDynamicLinkingDescriptor* MTL4::RenderPipelineDynamicLinkingDescriptor::meshLinkingDescriptor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(meshLinkingDescriptor)); +} + +_MTL_INLINE MTL4::PipelineStageDynamicLinkingDescriptor* MTL4::RenderPipelineDynamicLinkingDescriptor::objectLinkingDescriptor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(objectLinkingDescriptor)); +} + +_MTL_INLINE MTL4::PipelineStageDynamicLinkingDescriptor* MTL4::RenderPipelineDynamicLinkingDescriptor::tileLinkingDescriptor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(tileLinkingDescriptor)); +} + +_MTL_INLINE MTL4::PipelineStageDynamicLinkingDescriptor* MTL4::RenderPipelineDynamicLinkingDescriptor::vertexLinkingDescriptor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(vertexLinkingDescriptor)); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTL4MachineLearningCommandEncoder.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTL4MachineLearningCommandEncoder.hpp new file mode 100644 index 00000000..4d3cff66 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTL4MachineLearningCommandEncoder.hpp @@ -0,0 +1,66 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTL4MachineLearningCommandEncoder.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTL4CommandEncoder.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL4 +{ +class ArgumentTable; +class MachineLearningPipelineState; +} + +namespace MTL +{ +class Heap; +} + +namespace MTL4 +{ +class MachineLearningCommandEncoder : public NS::Referencing +{ +public: + void dispatchNetwork(const MTL::Heap* heap); + + void setArgumentTable(const MTL4::ArgumentTable* argumentTable); + + void setPipelineState(const MTL4::MachineLearningPipelineState* pipelineState); +}; + +} +_MTL_INLINE void MTL4::MachineLearningCommandEncoder::dispatchNetwork(const MTL::Heap* heap) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(dispatchNetworkWithIntermediatesHeap_), heap); +} + +_MTL_INLINE void MTL4::MachineLearningCommandEncoder::setArgumentTable(const MTL4::ArgumentTable* argumentTable) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setArgumentTable_), argumentTable); +} + +_MTL_INLINE void MTL4::MachineLearningCommandEncoder::setPipelineState(const MTL4::MachineLearningPipelineState* pipelineState) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setPipelineState_), pipelineState); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTL4MachineLearningPipeline.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTL4MachineLearningPipeline.hpp new file mode 100644 index 00000000..713569f9 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTL4MachineLearningPipeline.hpp @@ -0,0 +1,172 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTL4MachineLearningPipeline.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTL4PipelineState.hpp" +#include "MTLAllocation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL4 +{ +class FunctionDescriptor; +class MachineLearningPipelineDescriptor; +class MachineLearningPipelineReflection; +} + +namespace MTL +{ +class Device; +class TensorExtents; +} + +namespace MTL4 +{ +class MachineLearningPipelineDescriptor : public NS::Copying +{ +public: + static MachineLearningPipelineDescriptor* alloc(); + + MachineLearningPipelineDescriptor* init(); + + MTL::TensorExtents* inputDimensionsAtBufferIndex(NS::Integer bufferIndex); + + NS::String* label() const; + + FunctionDescriptor* machineLearningFunctionDescriptor() const; + + void reset(); + + void setInputDimensions(const MTL::TensorExtents* dimensions, NS::Integer bufferIndex); + void setInputDimensions(const NS::Array* dimensions, NS::Range range); + + void setLabel(const NS::String* label); + + void setMachineLearningFunctionDescriptor(const MTL4::FunctionDescriptor* machineLearningFunctionDescriptor); +}; +class MachineLearningPipelineReflection : public NS::Referencing +{ +public: + static MachineLearningPipelineReflection* alloc(); + + NS::Array* bindings() const; + + MachineLearningPipelineReflection* init(); +}; +class MachineLearningPipelineState : public NS::Referencing +{ +public: + MTL::Device* device() const; + + NS::UInteger intermediatesHeapSize() const; + + NS::String* label() const; + + MachineLearningPipelineReflection* reflection() const; +}; + +} +_MTL_INLINE MTL4::MachineLearningPipelineDescriptor* MTL4::MachineLearningPipelineDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4MachineLearningPipelineDescriptor)); +} + +_MTL_INLINE MTL4::MachineLearningPipelineDescriptor* MTL4::MachineLearningPipelineDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::TensorExtents* MTL4::MachineLearningPipelineDescriptor::inputDimensionsAtBufferIndex(NS::Integer bufferIndex) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(inputDimensionsAtBufferIndex_), bufferIndex); +} + +_MTL_INLINE NS::String* MTL4::MachineLearningPipelineDescriptor::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE MTL4::FunctionDescriptor* MTL4::MachineLearningPipelineDescriptor::machineLearningFunctionDescriptor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(machineLearningFunctionDescriptor)); +} + +_MTL_INLINE void MTL4::MachineLearningPipelineDescriptor::reset() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(reset)); +} + +_MTL_INLINE void MTL4::MachineLearningPipelineDescriptor::setInputDimensions(const MTL::TensorExtents* dimensions, NS::Integer bufferIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInputDimensions_atBufferIndex_), dimensions, bufferIndex); +} + +_MTL_INLINE void MTL4::MachineLearningPipelineDescriptor::setInputDimensions(const NS::Array* dimensions, NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInputDimensions_withRange_), dimensions, range); +} + +_MTL_INLINE void MTL4::MachineLearningPipelineDescriptor::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +_MTL_INLINE void MTL4::MachineLearningPipelineDescriptor::setMachineLearningFunctionDescriptor(const MTL4::FunctionDescriptor* machineLearningFunctionDescriptor) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMachineLearningFunctionDescriptor_), machineLearningFunctionDescriptor); +} + +_MTL_INLINE MTL4::MachineLearningPipelineReflection* MTL4::MachineLearningPipelineReflection::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4MachineLearningPipelineReflection)); +} + +_MTL_INLINE NS::Array* MTL4::MachineLearningPipelineReflection::bindings() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(bindings)); +} + +_MTL_INLINE MTL4::MachineLearningPipelineReflection* MTL4::MachineLearningPipelineReflection::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::Device* MTL4::MachineLearningPipelineState::device() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(device)); +} + +_MTL_INLINE NS::UInteger MTL4::MachineLearningPipelineState::intermediatesHeapSize() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(intermediatesHeapSize)); +} + +_MTL_INLINE NS::String* MTL4::MachineLearningPipelineState::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE MTL4::MachineLearningPipelineReflection* MTL4::MachineLearningPipelineState::reflection() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(reflection)); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTL4MeshRenderPipeline.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTL4MeshRenderPipeline.hpp new file mode 100644 index 00000000..f66dffe2 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTL4MeshRenderPipeline.hpp @@ -0,0 +1,413 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTL4MeshRenderPipeline.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTL4PipelineState.hpp" +#include "MTL4RenderPipeline.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLTypes.hpp" + +namespace MTL4 +{ +class FunctionDescriptor; +class MeshRenderPipelineDescriptor; +class RenderPipelineColorAttachmentDescriptorArray; +class StaticLinkingDescriptor; + +class MeshRenderPipelineDescriptor : public NS::Copying +{ +public: + static MeshRenderPipelineDescriptor* alloc(); + + AlphaToCoverageState alphaToCoverageState() const; + + AlphaToOneState alphaToOneState() const; + + LogicalToPhysicalColorAttachmentMappingState colorAttachmentMappingState() const; + + RenderPipelineColorAttachmentDescriptorArray* colorAttachments() const; + + FunctionDescriptor* fragmentFunctionDescriptor() const; + + StaticLinkingDescriptor* fragmentStaticLinkingDescriptor() const; + + MeshRenderPipelineDescriptor* init(); + + bool isRasterizationEnabled() const; + + NS::UInteger maxTotalThreadgroupsPerMeshGrid() const; + + NS::UInteger maxTotalThreadsPerMeshThreadgroup() const; + + NS::UInteger maxTotalThreadsPerObjectThreadgroup() const; + + NS::UInteger maxVertexAmplificationCount() const; + + FunctionDescriptor* meshFunctionDescriptor() const; + + StaticLinkingDescriptor* meshStaticLinkingDescriptor() const; + + bool meshThreadgroupSizeIsMultipleOfThreadExecutionWidth() const; + + FunctionDescriptor* objectFunctionDescriptor() const; + + StaticLinkingDescriptor* objectStaticLinkingDescriptor() const; + + bool objectThreadgroupSizeIsMultipleOfThreadExecutionWidth() const; + + NS::UInteger payloadMemoryLength() const; + + NS::UInteger rasterSampleCount() const; + + [[deprecated("please use isRasterizationEnabled instead")]] + bool rasterizationEnabled() const; + + MTL::Size requiredThreadsPerMeshThreadgroup() const; + + MTL::Size requiredThreadsPerObjectThreadgroup() const; + + void reset(); + + void setAlphaToCoverageState(MTL4::AlphaToCoverageState alphaToCoverageState); + + void setAlphaToOneState(MTL4::AlphaToOneState alphaToOneState); + + void setColorAttachmentMappingState(MTL4::LogicalToPhysicalColorAttachmentMappingState colorAttachmentMappingState); + + void setFragmentFunctionDescriptor(const MTL4::FunctionDescriptor* fragmentFunctionDescriptor); + + void setFragmentStaticLinkingDescriptor(const MTL4::StaticLinkingDescriptor* fragmentStaticLinkingDescriptor); + + void setMaxTotalThreadgroupsPerMeshGrid(NS::UInteger maxTotalThreadgroupsPerMeshGrid); + + void setMaxTotalThreadsPerMeshThreadgroup(NS::UInteger maxTotalThreadsPerMeshThreadgroup); + + void setMaxTotalThreadsPerObjectThreadgroup(NS::UInteger maxTotalThreadsPerObjectThreadgroup); + + void setMaxVertexAmplificationCount(NS::UInteger maxVertexAmplificationCount); + + void setMeshFunctionDescriptor(const MTL4::FunctionDescriptor* meshFunctionDescriptor); + + void setMeshStaticLinkingDescriptor(const MTL4::StaticLinkingDescriptor* meshStaticLinkingDescriptor); + + void setMeshThreadgroupSizeIsMultipleOfThreadExecutionWidth(bool meshThreadgroupSizeIsMultipleOfThreadExecutionWidth); + + void setObjectFunctionDescriptor(const MTL4::FunctionDescriptor* objectFunctionDescriptor); + + void setObjectStaticLinkingDescriptor(const MTL4::StaticLinkingDescriptor* objectStaticLinkingDescriptor); + + void setObjectThreadgroupSizeIsMultipleOfThreadExecutionWidth(bool objectThreadgroupSizeIsMultipleOfThreadExecutionWidth); + + void setPayloadMemoryLength(NS::UInteger payloadMemoryLength); + + void setRasterSampleCount(NS::UInteger rasterSampleCount); + + void setRasterizationEnabled(bool rasterizationEnabled); + + void setRequiredThreadsPerMeshThreadgroup(MTL::Size requiredThreadsPerMeshThreadgroup); + + void setRequiredThreadsPerObjectThreadgroup(MTL::Size requiredThreadsPerObjectThreadgroup); + + void setSupportFragmentBinaryLinking(bool supportFragmentBinaryLinking); + + void setSupportIndirectCommandBuffers(MTL4::IndirectCommandBufferSupportState supportIndirectCommandBuffers); + + void setSupportMeshBinaryLinking(bool supportMeshBinaryLinking); + + void setSupportObjectBinaryLinking(bool supportObjectBinaryLinking); + + bool supportFragmentBinaryLinking() const; + + IndirectCommandBufferSupportState supportIndirectCommandBuffers() const; + + bool supportMeshBinaryLinking() const; + + bool supportObjectBinaryLinking() const; +}; + +} +_MTL_INLINE MTL4::MeshRenderPipelineDescriptor* MTL4::MeshRenderPipelineDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4MeshRenderPipelineDescriptor)); +} + +_MTL_INLINE MTL4::AlphaToCoverageState MTL4::MeshRenderPipelineDescriptor::alphaToCoverageState() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(alphaToCoverageState)); +} + +_MTL_INLINE MTL4::AlphaToOneState MTL4::MeshRenderPipelineDescriptor::alphaToOneState() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(alphaToOneState)); +} + +_MTL_INLINE MTL4::LogicalToPhysicalColorAttachmentMappingState MTL4::MeshRenderPipelineDescriptor::colorAttachmentMappingState() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(colorAttachmentMappingState)); +} + +_MTL_INLINE MTL4::RenderPipelineColorAttachmentDescriptorArray* MTL4::MeshRenderPipelineDescriptor::colorAttachments() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(colorAttachments)); +} + +_MTL_INLINE MTL4::FunctionDescriptor* MTL4::MeshRenderPipelineDescriptor::fragmentFunctionDescriptor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(fragmentFunctionDescriptor)); +} + +_MTL_INLINE MTL4::StaticLinkingDescriptor* MTL4::MeshRenderPipelineDescriptor::fragmentStaticLinkingDescriptor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(fragmentStaticLinkingDescriptor)); +} + +_MTL_INLINE MTL4::MeshRenderPipelineDescriptor* MTL4::MeshRenderPipelineDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE bool MTL4::MeshRenderPipelineDescriptor::isRasterizationEnabled() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isRasterizationEnabled)); +} + +_MTL_INLINE NS::UInteger MTL4::MeshRenderPipelineDescriptor::maxTotalThreadgroupsPerMeshGrid() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxTotalThreadgroupsPerMeshGrid)); +} + +_MTL_INLINE NS::UInteger MTL4::MeshRenderPipelineDescriptor::maxTotalThreadsPerMeshThreadgroup() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxTotalThreadsPerMeshThreadgroup)); +} + +_MTL_INLINE NS::UInteger MTL4::MeshRenderPipelineDescriptor::maxTotalThreadsPerObjectThreadgroup() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxTotalThreadsPerObjectThreadgroup)); +} + +_MTL_INLINE NS::UInteger MTL4::MeshRenderPipelineDescriptor::maxVertexAmplificationCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxVertexAmplificationCount)); +} + +_MTL_INLINE MTL4::FunctionDescriptor* MTL4::MeshRenderPipelineDescriptor::meshFunctionDescriptor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(meshFunctionDescriptor)); +} + +_MTL_INLINE MTL4::StaticLinkingDescriptor* MTL4::MeshRenderPipelineDescriptor::meshStaticLinkingDescriptor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(meshStaticLinkingDescriptor)); +} + +_MTL_INLINE bool MTL4::MeshRenderPipelineDescriptor::meshThreadgroupSizeIsMultipleOfThreadExecutionWidth() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(meshThreadgroupSizeIsMultipleOfThreadExecutionWidth)); +} + +_MTL_INLINE MTL4::FunctionDescriptor* MTL4::MeshRenderPipelineDescriptor::objectFunctionDescriptor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(objectFunctionDescriptor)); +} + +_MTL_INLINE MTL4::StaticLinkingDescriptor* MTL4::MeshRenderPipelineDescriptor::objectStaticLinkingDescriptor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(objectStaticLinkingDescriptor)); +} + +_MTL_INLINE bool MTL4::MeshRenderPipelineDescriptor::objectThreadgroupSizeIsMultipleOfThreadExecutionWidth() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(objectThreadgroupSizeIsMultipleOfThreadExecutionWidth)); +} + +_MTL_INLINE NS::UInteger MTL4::MeshRenderPipelineDescriptor::payloadMemoryLength() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(payloadMemoryLength)); +} + +_MTL_INLINE NS::UInteger MTL4::MeshRenderPipelineDescriptor::rasterSampleCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(rasterSampleCount)); +} + +_MTL_INLINE bool MTL4::MeshRenderPipelineDescriptor::rasterizationEnabled() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isRasterizationEnabled)); +} + +_MTL_INLINE MTL::Size MTL4::MeshRenderPipelineDescriptor::requiredThreadsPerMeshThreadgroup() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(requiredThreadsPerMeshThreadgroup)); +} + +_MTL_INLINE MTL::Size MTL4::MeshRenderPipelineDescriptor::requiredThreadsPerObjectThreadgroup() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(requiredThreadsPerObjectThreadgroup)); +} + +_MTL_INLINE void MTL4::MeshRenderPipelineDescriptor::reset() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(reset)); +} + +_MTL_INLINE void MTL4::MeshRenderPipelineDescriptor::setAlphaToCoverageState(MTL4::AlphaToCoverageState alphaToCoverageState) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setAlphaToCoverageState_), alphaToCoverageState); +} + +_MTL_INLINE void MTL4::MeshRenderPipelineDescriptor::setAlphaToOneState(MTL4::AlphaToOneState alphaToOneState) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setAlphaToOneState_), alphaToOneState); +} + +_MTL_INLINE void MTL4::MeshRenderPipelineDescriptor::setColorAttachmentMappingState(MTL4::LogicalToPhysicalColorAttachmentMappingState colorAttachmentMappingState) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setColorAttachmentMappingState_), colorAttachmentMappingState); +} + +_MTL_INLINE void MTL4::MeshRenderPipelineDescriptor::setFragmentFunctionDescriptor(const MTL4::FunctionDescriptor* fragmentFunctionDescriptor) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFragmentFunctionDescriptor_), fragmentFunctionDescriptor); +} + +_MTL_INLINE void MTL4::MeshRenderPipelineDescriptor::setFragmentStaticLinkingDescriptor(const MTL4::StaticLinkingDescriptor* fragmentStaticLinkingDescriptor) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFragmentStaticLinkingDescriptor_), fragmentStaticLinkingDescriptor); +} + +_MTL_INLINE void MTL4::MeshRenderPipelineDescriptor::setMaxTotalThreadgroupsPerMeshGrid(NS::UInteger maxTotalThreadgroupsPerMeshGrid) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxTotalThreadgroupsPerMeshGrid_), maxTotalThreadgroupsPerMeshGrid); +} + +_MTL_INLINE void MTL4::MeshRenderPipelineDescriptor::setMaxTotalThreadsPerMeshThreadgroup(NS::UInteger maxTotalThreadsPerMeshThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxTotalThreadsPerMeshThreadgroup_), maxTotalThreadsPerMeshThreadgroup); +} + +_MTL_INLINE void MTL4::MeshRenderPipelineDescriptor::setMaxTotalThreadsPerObjectThreadgroup(NS::UInteger maxTotalThreadsPerObjectThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxTotalThreadsPerObjectThreadgroup_), maxTotalThreadsPerObjectThreadgroup); +} + +_MTL_INLINE void MTL4::MeshRenderPipelineDescriptor::setMaxVertexAmplificationCount(NS::UInteger maxVertexAmplificationCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxVertexAmplificationCount_), maxVertexAmplificationCount); +} + +_MTL_INLINE void MTL4::MeshRenderPipelineDescriptor::setMeshFunctionDescriptor(const MTL4::FunctionDescriptor* meshFunctionDescriptor) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMeshFunctionDescriptor_), meshFunctionDescriptor); +} + +_MTL_INLINE void MTL4::MeshRenderPipelineDescriptor::setMeshStaticLinkingDescriptor(const MTL4::StaticLinkingDescriptor* meshStaticLinkingDescriptor) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMeshStaticLinkingDescriptor_), meshStaticLinkingDescriptor); +} + +_MTL_INLINE void MTL4::MeshRenderPipelineDescriptor::setMeshThreadgroupSizeIsMultipleOfThreadExecutionWidth(bool meshThreadgroupSizeIsMultipleOfThreadExecutionWidth) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMeshThreadgroupSizeIsMultipleOfThreadExecutionWidth_), meshThreadgroupSizeIsMultipleOfThreadExecutionWidth); +} + +_MTL_INLINE void MTL4::MeshRenderPipelineDescriptor::setObjectFunctionDescriptor(const MTL4::FunctionDescriptor* objectFunctionDescriptor) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObjectFunctionDescriptor_), objectFunctionDescriptor); +} + +_MTL_INLINE void MTL4::MeshRenderPipelineDescriptor::setObjectStaticLinkingDescriptor(const MTL4::StaticLinkingDescriptor* objectStaticLinkingDescriptor) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObjectStaticLinkingDescriptor_), objectStaticLinkingDescriptor); +} + +_MTL_INLINE void MTL4::MeshRenderPipelineDescriptor::setObjectThreadgroupSizeIsMultipleOfThreadExecutionWidth(bool objectThreadgroupSizeIsMultipleOfThreadExecutionWidth) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObjectThreadgroupSizeIsMultipleOfThreadExecutionWidth_), objectThreadgroupSizeIsMultipleOfThreadExecutionWidth); +} + +_MTL_INLINE void MTL4::MeshRenderPipelineDescriptor::setPayloadMemoryLength(NS::UInteger payloadMemoryLength) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setPayloadMemoryLength_), payloadMemoryLength); +} + +_MTL_INLINE void MTL4::MeshRenderPipelineDescriptor::setRasterSampleCount(NS::UInteger rasterSampleCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRasterSampleCount_), rasterSampleCount); +} + +_MTL_INLINE void MTL4::MeshRenderPipelineDescriptor::setRasterizationEnabled(bool rasterizationEnabled) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRasterizationEnabled_), rasterizationEnabled); +} + +_MTL_INLINE void MTL4::MeshRenderPipelineDescriptor::setRequiredThreadsPerMeshThreadgroup(MTL::Size requiredThreadsPerMeshThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRequiredThreadsPerMeshThreadgroup_), requiredThreadsPerMeshThreadgroup); +} + +_MTL_INLINE void MTL4::MeshRenderPipelineDescriptor::setRequiredThreadsPerObjectThreadgroup(MTL::Size requiredThreadsPerObjectThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRequiredThreadsPerObjectThreadgroup_), requiredThreadsPerObjectThreadgroup); +} + +_MTL_INLINE void MTL4::MeshRenderPipelineDescriptor::setSupportFragmentBinaryLinking(bool supportFragmentBinaryLinking) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSupportFragmentBinaryLinking_), supportFragmentBinaryLinking); +} + +_MTL_INLINE void MTL4::MeshRenderPipelineDescriptor::setSupportIndirectCommandBuffers(MTL4::IndirectCommandBufferSupportState supportIndirectCommandBuffers) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSupportIndirectCommandBuffers_), supportIndirectCommandBuffers); +} + +_MTL_INLINE void MTL4::MeshRenderPipelineDescriptor::setSupportMeshBinaryLinking(bool supportMeshBinaryLinking) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSupportMeshBinaryLinking_), supportMeshBinaryLinking); +} + +_MTL_INLINE void MTL4::MeshRenderPipelineDescriptor::setSupportObjectBinaryLinking(bool supportObjectBinaryLinking) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSupportObjectBinaryLinking_), supportObjectBinaryLinking); +} + +_MTL_INLINE bool MTL4::MeshRenderPipelineDescriptor::supportFragmentBinaryLinking() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportFragmentBinaryLinking)); +} + +_MTL_INLINE MTL4::IndirectCommandBufferSupportState MTL4::MeshRenderPipelineDescriptor::supportIndirectCommandBuffers() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(supportIndirectCommandBuffers)); +} + +_MTL_INLINE bool MTL4::MeshRenderPipelineDescriptor::supportMeshBinaryLinking() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportMeshBinaryLinking)); +} + +_MTL_INLINE bool MTL4::MeshRenderPipelineDescriptor::supportObjectBinaryLinking() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportObjectBinaryLinking)); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTL4PipelineDataSetSerializer.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTL4PipelineDataSetSerializer.hpp new file mode 100644 index 00000000..9dbd6103 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTL4PipelineDataSetSerializer.hpp @@ -0,0 +1,85 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTL4PipelineDataSetSerializer.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL4 +{ +class PipelineDataSetSerializerDescriptor; + +_MTL_OPTIONS(NS::UInteger, PipelineDataSetSerializerConfiguration) { + PipelineDataSetSerializerConfigurationCaptureDescriptors = 1, + PipelineDataSetSerializerConfigurationCaptureBinaries = 1 << 1, +}; + +class PipelineDataSetSerializerDescriptor : public NS::Copying +{ +public: + static PipelineDataSetSerializerDescriptor* alloc(); + + PipelineDataSetSerializerConfiguration configuration() const; + + PipelineDataSetSerializerDescriptor* init(); + + void setConfiguration(MTL4::PipelineDataSetSerializerConfiguration configuration); +}; +class PipelineDataSetSerializer : public NS::Referencing +{ +public: + bool serializeAsArchiveAndFlushToURL(const NS::URL* url, NS::Error** error); + + NS::Data* serializeAsPipelinesScript(NS::Error** error); +}; + +} +_MTL_INLINE MTL4::PipelineDataSetSerializerDescriptor* MTL4::PipelineDataSetSerializerDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4PipelineDataSetSerializerDescriptor)); +} + +_MTL_INLINE MTL4::PipelineDataSetSerializerConfiguration MTL4::PipelineDataSetSerializerDescriptor::configuration() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(configuration)); +} + +_MTL_INLINE MTL4::PipelineDataSetSerializerDescriptor* MTL4::PipelineDataSetSerializerDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE void MTL4::PipelineDataSetSerializerDescriptor::setConfiguration(MTL4::PipelineDataSetSerializerConfiguration configuration) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setConfiguration_), configuration); +} + +_MTL_INLINE bool MTL4::PipelineDataSetSerializer::serializeAsArchiveAndFlushToURL(const NS::URL* url, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(serializeAsArchiveAndFlushToURL_error_), url, error); +} + +_MTL_INLINE NS::Data* MTL4::PipelineDataSetSerializer::serializeAsPipelinesScript(NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(serializeAsPipelinesScriptWithError_), error); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTL4PipelineState.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTL4PipelineState.hpp new file mode 100644 index 00000000..cecefa8a --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTL4PipelineState.hpp @@ -0,0 +1,150 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTL4PipelineState.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPipeline.hpp" +#include "MTLPrivate.hpp" + +namespace MTL4 +{ +class PipelineDescriptor; +class PipelineOptions; +_MTL_ENUM(NS::Integer, AlphaToOneState) { + AlphaToOneStateDisabled = 0, + AlphaToOneStateEnabled = 1, +}; + +_MTL_ENUM(NS::Integer, AlphaToCoverageState) { + AlphaToCoverageStateDisabled = 0, + AlphaToCoverageStateEnabled = 1, +}; + +_MTL_ENUM(NS::Integer, BlendState) { + BlendStateDisabled = 0, + BlendStateEnabled = 1, + BlendStateUnspecialized = 2, +}; + +_MTL_ENUM(NS::Integer, IndirectCommandBufferSupportState) { + IndirectCommandBufferSupportStateDisabled = 0, + IndirectCommandBufferSupportStateEnabled = 1, +}; + +_MTL_OPTIONS(NS::UInteger, ShaderReflection) { + ShaderReflectionNone = 0, + ShaderReflectionBindingInfo = 1, + ShaderReflectionBufferTypeInfo = 1 << 1, +}; + +class PipelineOptions : public NS::Copying +{ +public: + static PipelineOptions* alloc(); + + PipelineOptions* init(); + + void setShaderReflection(MTL4::ShaderReflection shaderReflection); + + void setShaderValidation(MTL::ShaderValidation shaderValidation); + + ShaderReflection shaderReflection() const; + + MTL::ShaderValidation shaderValidation() const; +}; +class PipelineDescriptor : public NS::Copying +{ +public: + static PipelineDescriptor* alloc(); + + PipelineDescriptor* init(); + + NS::String* label() const; + + PipelineOptions* options() const; + + void setLabel(const NS::String* label); + + void setOptions(const MTL4::PipelineOptions* options); +}; + +} +_MTL_INLINE MTL4::PipelineOptions* MTL4::PipelineOptions::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4PipelineOptions)); +} + +_MTL_INLINE MTL4::PipelineOptions* MTL4::PipelineOptions::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE void MTL4::PipelineOptions::setShaderReflection(MTL4::ShaderReflection shaderReflection) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setShaderReflection_), shaderReflection); +} + +_MTL_INLINE void MTL4::PipelineOptions::setShaderValidation(MTL::ShaderValidation shaderValidation) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setShaderValidation_), shaderValidation); +} + +_MTL_INLINE MTL4::ShaderReflection MTL4::PipelineOptions::shaderReflection() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(shaderReflection)); +} + +_MTL_INLINE MTL::ShaderValidation MTL4::PipelineOptions::shaderValidation() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(shaderValidation)); +} + +_MTL_INLINE MTL4::PipelineDescriptor* MTL4::PipelineDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4PipelineDescriptor)); +} + +_MTL_INLINE MTL4::PipelineDescriptor* MTL4::PipelineDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::String* MTL4::PipelineDescriptor::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE MTL4::PipelineOptions* MTL4::PipelineDescriptor::options() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(options)); +} + +_MTL_INLINE void MTL4::PipelineDescriptor::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +_MTL_INLINE void MTL4::PipelineDescriptor::setOptions(const MTL4::PipelineOptions* options) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setOptions_), options); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTL4RenderCommandEncoder.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTL4RenderCommandEncoder.hpp new file mode 100644 index 00000000..0dd01f4d --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTL4RenderCommandEncoder.hpp @@ -0,0 +1,340 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTL4RenderCommandEncoder.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTL4CommandEncoder.hpp" +#include "MTL4Counters.hpp" +#include "MTLArgument.hpp" +#include "MTLDefines.hpp" +#include "MTLGPUAddress.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLRenderCommandEncoder.hpp" +#include "MTLRenderPass.hpp" +#include "MTLTypes.hpp" +#include + +namespace MTL4 +{ +class ArgumentTable; +class CounterHeap; +} + +namespace MTL +{ +class DepthStencilState; +class IndirectCommandBuffer; +class LogicalToPhysicalColorAttachmentMap; +class RenderPipelineState; +struct ScissorRect; +struct VertexAmplificationViewMapping; +struct Viewport; + +} +namespace MTL4 +{ +_MTL_OPTIONS(NS::UInteger, RenderEncoderOptions) { + RenderEncoderOptionNone = 0, + RenderEncoderOptionSuspending = 1, + RenderEncoderOptionResuming = 1 << 1, +}; + +class RenderCommandEncoder : public NS::Referencing +{ +public: + void dispatchThreadsPerTile(MTL::Size threadsPerTile); + + void drawIndexedPrimitives(MTL::PrimitiveType primitiveType, NS::UInteger indexCount, MTL::IndexType indexType, MTL::GPUAddress indexBuffer, NS::UInteger indexBufferLength); + void drawIndexedPrimitives(MTL::PrimitiveType primitiveType, NS::UInteger indexCount, MTL::IndexType indexType, MTL::GPUAddress indexBuffer, NS::UInteger indexBufferLength, NS::UInteger instanceCount); + void drawIndexedPrimitives(MTL::PrimitiveType primitiveType, NS::UInteger indexCount, MTL::IndexType indexType, MTL::GPUAddress indexBuffer, NS::UInteger indexBufferLength, NS::UInteger instanceCount, NS::Integer baseVertex, NS::UInteger baseInstance); + void drawIndexedPrimitives(MTL::PrimitiveType primitiveType, MTL::IndexType indexType, MTL::GPUAddress indexBuffer, NS::UInteger indexBufferLength, MTL::GPUAddress indirectBuffer); + + void drawMeshThreadgroups(MTL::Size threadgroupsPerGrid, MTL::Size threadsPerObjectThreadgroup, MTL::Size threadsPerMeshThreadgroup); + void drawMeshThreadgroups(MTL::GPUAddress indirectBuffer, MTL::Size threadsPerObjectThreadgroup, MTL::Size threadsPerMeshThreadgroup); + + void drawMeshThreads(MTL::Size threadsPerGrid, MTL::Size threadsPerObjectThreadgroup, MTL::Size threadsPerMeshThreadgroup); + + void drawPrimitives(MTL::PrimitiveType primitiveType, NS::UInteger vertexStart, NS::UInteger vertexCount); + void drawPrimitives(MTL::PrimitiveType primitiveType, NS::UInteger vertexStart, NS::UInteger vertexCount, NS::UInteger instanceCount); + void drawPrimitives(MTL::PrimitiveType primitiveType, NS::UInteger vertexStart, NS::UInteger vertexCount, NS::UInteger instanceCount, NS::UInteger baseInstance); + void drawPrimitives(MTL::PrimitiveType primitiveType, MTL::GPUAddress indirectBuffer); + + void executeCommandsInBuffer(const MTL::IndirectCommandBuffer* indirectCommandBuffer, NS::Range executionRange); + void executeCommandsInBuffer(const MTL::IndirectCommandBuffer* indirectCommandBuffer, MTL::GPUAddress indirectRangeBuffer); + + void setArgumentTable(const MTL4::ArgumentTable* argumentTable, MTL::RenderStages stages); + + void setBlendColor(float red, float green, float blue, float alpha); + + void setColorAttachmentMap(const MTL::LogicalToPhysicalColorAttachmentMap* mapping); + + void setColorStoreAction(MTL::StoreAction storeAction, NS::UInteger colorAttachmentIndex); + + void setCullMode(MTL::CullMode cullMode); + + void setDepthBias(float depthBias, float slopeScale, float clamp); + + void setDepthClipMode(MTL::DepthClipMode depthClipMode); + + void setDepthStencilState(const MTL::DepthStencilState* depthStencilState); + + void setDepthStoreAction(MTL::StoreAction storeAction); + + void setDepthTestBounds(float minBound, float maxBound); + + void setFrontFacingWinding(MTL::Winding frontFacingWinding); + + void setObjectThreadgroupMemoryLength(NS::UInteger length, NS::UInteger index); + + void setRenderPipelineState(const MTL::RenderPipelineState* pipelineState); + + void setScissorRect(MTL::ScissorRect rect); + void setScissorRects(const MTL::ScissorRect* scissorRects, NS::UInteger count); + + void setStencilReferenceValue(uint32_t referenceValue); + void setStencilReferenceValues(uint32_t frontReferenceValue, uint32_t backReferenceValue); + + void setStencilStoreAction(MTL::StoreAction storeAction); + + void setThreadgroupMemoryLength(NS::UInteger length, NS::UInteger offset, NS::UInteger index); + + void setTriangleFillMode(MTL::TriangleFillMode fillMode); + + void setVertexAmplificationCount(NS::UInteger count, const MTL::VertexAmplificationViewMapping* viewMappings); + + void setViewport(MTL::Viewport viewport); + void setViewports(const MTL::Viewport* viewports, NS::UInteger count); + + void setVisibilityResultMode(MTL::VisibilityResultMode mode, NS::UInteger offset); + + NS::UInteger tileHeight() const; + + NS::UInteger tileWidth() const; + + void writeTimestamp(MTL4::TimestampGranularity granularity, MTL::RenderStages stage, const MTL4::CounterHeap* counterHeap, NS::UInteger index); +}; + +} +_MTL_INLINE void MTL4::RenderCommandEncoder::dispatchThreadsPerTile(MTL::Size threadsPerTile) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(dispatchThreadsPerTile_), threadsPerTile); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::drawIndexedPrimitives(MTL::PrimitiveType primitiveType, NS::UInteger indexCount, MTL::IndexType indexType, MTL::GPUAddress indexBuffer, NS::UInteger indexBufferLength) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawIndexedPrimitives_indexCount_indexType_indexBuffer_indexBufferLength_), primitiveType, indexCount, indexType, indexBuffer, indexBufferLength); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::drawIndexedPrimitives(MTL::PrimitiveType primitiveType, NS::UInteger indexCount, MTL::IndexType indexType, MTL::GPUAddress indexBuffer, NS::UInteger indexBufferLength, NS::UInteger instanceCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawIndexedPrimitives_indexCount_indexType_indexBuffer_indexBufferLength_instanceCount_), primitiveType, indexCount, indexType, indexBuffer, indexBufferLength, instanceCount); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::drawIndexedPrimitives(MTL::PrimitiveType primitiveType, NS::UInteger indexCount, MTL::IndexType indexType, MTL::GPUAddress indexBuffer, NS::UInteger indexBufferLength, NS::UInteger instanceCount, NS::Integer baseVertex, NS::UInteger baseInstance) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawIndexedPrimitives_indexCount_indexType_indexBuffer_indexBufferLength_instanceCount_baseVertex_baseInstance_), primitiveType, indexCount, indexType, indexBuffer, indexBufferLength, instanceCount, baseVertex, baseInstance); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::drawIndexedPrimitives(MTL::PrimitiveType primitiveType, MTL::IndexType indexType, MTL::GPUAddress indexBuffer, NS::UInteger indexBufferLength, MTL::GPUAddress indirectBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawIndexedPrimitives_indexType_indexBuffer_indexBufferLength_indirectBuffer_), primitiveType, indexType, indexBuffer, indexBufferLength, indirectBuffer); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::drawMeshThreadgroups(MTL::Size threadgroupsPerGrid, MTL::Size threadsPerObjectThreadgroup, MTL::Size threadsPerMeshThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawMeshThreadgroups_threadsPerObjectThreadgroup_threadsPerMeshThreadgroup_), threadgroupsPerGrid, threadsPerObjectThreadgroup, threadsPerMeshThreadgroup); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::drawMeshThreadgroups(MTL::GPUAddress indirectBuffer, MTL::Size threadsPerObjectThreadgroup, MTL::Size threadsPerMeshThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawMeshThreadgroupsWithIndirectBuffer_threadsPerObjectThreadgroup_threadsPerMeshThreadgroup_), indirectBuffer, threadsPerObjectThreadgroup, threadsPerMeshThreadgroup); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::drawMeshThreads(MTL::Size threadsPerGrid, MTL::Size threadsPerObjectThreadgroup, MTL::Size threadsPerMeshThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawMeshThreads_threadsPerObjectThreadgroup_threadsPerMeshThreadgroup_), threadsPerGrid, threadsPerObjectThreadgroup, threadsPerMeshThreadgroup); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::drawPrimitives(MTL::PrimitiveType primitiveType, NS::UInteger vertexStart, NS::UInteger vertexCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawPrimitives_vertexStart_vertexCount_), primitiveType, vertexStart, vertexCount); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::drawPrimitives(MTL::PrimitiveType primitiveType, NS::UInteger vertexStart, NS::UInteger vertexCount, NS::UInteger instanceCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawPrimitives_vertexStart_vertexCount_instanceCount_), primitiveType, vertexStart, vertexCount, instanceCount); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::drawPrimitives(MTL::PrimitiveType primitiveType, NS::UInteger vertexStart, NS::UInteger vertexCount, NS::UInteger instanceCount, NS::UInteger baseInstance) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawPrimitives_vertexStart_vertexCount_instanceCount_baseInstance_), primitiveType, vertexStart, vertexCount, instanceCount, baseInstance); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::drawPrimitives(MTL::PrimitiveType primitiveType, MTL::GPUAddress indirectBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawPrimitives_indirectBuffer_), primitiveType, indirectBuffer); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::executeCommandsInBuffer(const MTL::IndirectCommandBuffer* indirectCommandBuffer, NS::Range executionRange) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(executeCommandsInBuffer_withRange_), indirectCommandBuffer, executionRange); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::executeCommandsInBuffer(const MTL::IndirectCommandBuffer* indirectCommandBuffer, MTL::GPUAddress indirectRangeBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(executeCommandsInBuffer_indirectBuffer_), indirectCommandBuffer, indirectRangeBuffer); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::setArgumentTable(const MTL4::ArgumentTable* argumentTable, MTL::RenderStages stages) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setArgumentTable_atStages_), argumentTable, stages); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::setBlendColor(float red, float green, float blue, float alpha) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBlendColorRed_green_blue_alpha_), red, green, blue, alpha); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::setColorAttachmentMap(const MTL::LogicalToPhysicalColorAttachmentMap* mapping) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setColorAttachmentMap_), mapping); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::setColorStoreAction(MTL::StoreAction storeAction, NS::UInteger colorAttachmentIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setColorStoreAction_atIndex_), storeAction, colorAttachmentIndex); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::setCullMode(MTL::CullMode cullMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setCullMode_), cullMode); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::setDepthBias(float depthBias, float slopeScale, float clamp) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDepthBias_slopeScale_clamp_), depthBias, slopeScale, clamp); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::setDepthClipMode(MTL::DepthClipMode depthClipMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDepthClipMode_), depthClipMode); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::setDepthStencilState(const MTL::DepthStencilState* depthStencilState) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDepthStencilState_), depthStencilState); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::setDepthStoreAction(MTL::StoreAction storeAction) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDepthStoreAction_), storeAction); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::setDepthTestBounds(float minBound, float maxBound) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDepthTestMinBound_maxBound_), minBound, maxBound); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::setFrontFacingWinding(MTL::Winding frontFacingWinding) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFrontFacingWinding_), frontFacingWinding); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::setObjectThreadgroupMemoryLength(NS::UInteger length, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObjectThreadgroupMemoryLength_atIndex_), length, index); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::setRenderPipelineState(const MTL::RenderPipelineState* pipelineState) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRenderPipelineState_), pipelineState); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::setScissorRect(MTL::ScissorRect rect) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setScissorRect_), rect); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::setScissorRects(const MTL::ScissorRect* scissorRects, NS::UInteger count) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setScissorRects_count_), scissorRects, count); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::setStencilReferenceValue(uint32_t referenceValue) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStencilReferenceValue_), referenceValue); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::setStencilReferenceValues(uint32_t frontReferenceValue, uint32_t backReferenceValue) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStencilFrontReferenceValue_backReferenceValue_), frontReferenceValue, backReferenceValue); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::setStencilStoreAction(MTL::StoreAction storeAction) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStencilStoreAction_), storeAction); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::setThreadgroupMemoryLength(NS::UInteger length, NS::UInteger offset, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setThreadgroupMemoryLength_offset_atIndex_), length, offset, index); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::setTriangleFillMode(MTL::TriangleFillMode fillMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTriangleFillMode_), fillMode); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::setVertexAmplificationCount(NS::UInteger count, const MTL::VertexAmplificationViewMapping* viewMappings) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexAmplificationCount_viewMappings_), count, viewMappings); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::setViewport(MTL::Viewport viewport) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setViewport_), viewport); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::setViewports(const MTL::Viewport* viewports, NS::UInteger count) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setViewports_count_), viewports, count); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::setVisibilityResultMode(MTL::VisibilityResultMode mode, NS::UInteger offset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVisibilityResultMode_offset_), mode, offset); +} + +_MTL_INLINE NS::UInteger MTL4::RenderCommandEncoder::tileHeight() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(tileHeight)); +} + +_MTL_INLINE NS::UInteger MTL4::RenderCommandEncoder::tileWidth() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(tileWidth)); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::writeTimestamp(MTL4::TimestampGranularity granularity, MTL::RenderStages stage, const MTL4::CounterHeap* counterHeap, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(writeTimestampWithGranularity_afterStage_intoHeap_atIndex_), granularity, stage, counterHeap, index); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTL4RenderPass.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTL4RenderPass.hpp new file mode 100644 index 00000000..c5aa9ed6 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTL4RenderPass.hpp @@ -0,0 +1,280 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTL4RenderPass.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLRenderPass.hpp" + +namespace MTL4 +{ +class RenderPassDescriptor; +} + +namespace MTL +{ +class Buffer; +class RasterizationRateMap; +class RenderPassColorAttachmentDescriptorArray; +class RenderPassDepthAttachmentDescriptor; +class RenderPassStencilAttachmentDescriptor; +struct SamplePosition; +} + +namespace MTL4 +{ +class RenderPassDescriptor : public NS::Copying +{ +public: + static RenderPassDescriptor* alloc(); + + MTL::RenderPassColorAttachmentDescriptorArray* colorAttachments() const; + + NS::UInteger defaultRasterSampleCount() const; + + MTL::RenderPassDepthAttachmentDescriptor* depthAttachment() const; + + NS::UInteger getSamplePositions(MTL::SamplePosition* positions, NS::UInteger count); + + NS::UInteger imageblockSampleLength() const; + + RenderPassDescriptor* init(); + + MTL::RasterizationRateMap* rasterizationRateMap() const; + + NS::UInteger renderTargetArrayLength() const; + + NS::UInteger renderTargetHeight() const; + + NS::UInteger renderTargetWidth() const; + + void setDefaultRasterSampleCount(NS::UInteger defaultRasterSampleCount); + + void setDepthAttachment(const MTL::RenderPassDepthAttachmentDescriptor* depthAttachment); + + void setImageblockSampleLength(NS::UInteger imageblockSampleLength); + + void setRasterizationRateMap(const MTL::RasterizationRateMap* rasterizationRateMap); + + void setRenderTargetArrayLength(NS::UInteger renderTargetArrayLength); + + void setRenderTargetHeight(NS::UInteger renderTargetHeight); + + void setRenderTargetWidth(NS::UInteger renderTargetWidth); + + void setSamplePositions(const MTL::SamplePosition* positions, NS::UInteger count); + + void setStencilAttachment(const MTL::RenderPassStencilAttachmentDescriptor* stencilAttachment); + + void setSupportColorAttachmentMapping(bool supportColorAttachmentMapping); + + void setThreadgroupMemoryLength(NS::UInteger threadgroupMemoryLength); + + void setTileHeight(NS::UInteger tileHeight); + + void setTileWidth(NS::UInteger tileWidth); + + void setVisibilityResultBuffer(const MTL::Buffer* visibilityResultBuffer); + + void setVisibilityResultType(MTL::VisibilityResultType visibilityResultType); + + MTL::RenderPassStencilAttachmentDescriptor* stencilAttachment() const; + + bool supportColorAttachmentMapping() const; + + NS::UInteger threadgroupMemoryLength() const; + + NS::UInteger tileHeight() const; + + NS::UInteger tileWidth() const; + + MTL::Buffer* visibilityResultBuffer() const; + + MTL::VisibilityResultType visibilityResultType() const; +}; + +} +_MTL_INLINE MTL4::RenderPassDescriptor* MTL4::RenderPassDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4RenderPassDescriptor)); +} + +_MTL_INLINE MTL::RenderPassColorAttachmentDescriptorArray* MTL4::RenderPassDescriptor::colorAttachments() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(colorAttachments)); +} + +_MTL_INLINE NS::UInteger MTL4::RenderPassDescriptor::defaultRasterSampleCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(defaultRasterSampleCount)); +} + +_MTL_INLINE MTL::RenderPassDepthAttachmentDescriptor* MTL4::RenderPassDescriptor::depthAttachment() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(depthAttachment)); +} + +_MTL_INLINE NS::UInteger MTL4::RenderPassDescriptor::getSamplePositions(MTL::SamplePosition* positions, NS::UInteger count) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(getSamplePositions_count_), positions, count); +} + +_MTL_INLINE NS::UInteger MTL4::RenderPassDescriptor::imageblockSampleLength() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(imageblockSampleLength)); +} + +_MTL_INLINE MTL4::RenderPassDescriptor* MTL4::RenderPassDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::RasterizationRateMap* MTL4::RenderPassDescriptor::rasterizationRateMap() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(rasterizationRateMap)); +} + +_MTL_INLINE NS::UInteger MTL4::RenderPassDescriptor::renderTargetArrayLength() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(renderTargetArrayLength)); +} + +_MTL_INLINE NS::UInteger MTL4::RenderPassDescriptor::renderTargetHeight() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(renderTargetHeight)); +} + +_MTL_INLINE NS::UInteger MTL4::RenderPassDescriptor::renderTargetWidth() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(renderTargetWidth)); +} + +_MTL_INLINE void MTL4::RenderPassDescriptor::setDefaultRasterSampleCount(NS::UInteger defaultRasterSampleCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDefaultRasterSampleCount_), defaultRasterSampleCount); +} + +_MTL_INLINE void MTL4::RenderPassDescriptor::setDepthAttachment(const MTL::RenderPassDepthAttachmentDescriptor* depthAttachment) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDepthAttachment_), depthAttachment); +} + +_MTL_INLINE void MTL4::RenderPassDescriptor::setImageblockSampleLength(NS::UInteger imageblockSampleLength) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setImageblockSampleLength_), imageblockSampleLength); +} + +_MTL_INLINE void MTL4::RenderPassDescriptor::setRasterizationRateMap(const MTL::RasterizationRateMap* rasterizationRateMap) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRasterizationRateMap_), rasterizationRateMap); +} + +_MTL_INLINE void MTL4::RenderPassDescriptor::setRenderTargetArrayLength(NS::UInteger renderTargetArrayLength) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRenderTargetArrayLength_), renderTargetArrayLength); +} + +_MTL_INLINE void MTL4::RenderPassDescriptor::setRenderTargetHeight(NS::UInteger renderTargetHeight) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRenderTargetHeight_), renderTargetHeight); +} + +_MTL_INLINE void MTL4::RenderPassDescriptor::setRenderTargetWidth(NS::UInteger renderTargetWidth) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRenderTargetWidth_), renderTargetWidth); +} + +_MTL_INLINE void MTL4::RenderPassDescriptor::setSamplePositions(const MTL::SamplePosition* positions, NS::UInteger count) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSamplePositions_count_), positions, count); +} + +_MTL_INLINE void MTL4::RenderPassDescriptor::setStencilAttachment(const MTL::RenderPassStencilAttachmentDescriptor* stencilAttachment) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStencilAttachment_), stencilAttachment); +} + +_MTL_INLINE void MTL4::RenderPassDescriptor::setSupportColorAttachmentMapping(bool supportColorAttachmentMapping) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSupportColorAttachmentMapping_), supportColorAttachmentMapping); +} + +_MTL_INLINE void MTL4::RenderPassDescriptor::setThreadgroupMemoryLength(NS::UInteger threadgroupMemoryLength) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setThreadgroupMemoryLength_), threadgroupMemoryLength); +} + +_MTL_INLINE void MTL4::RenderPassDescriptor::setTileHeight(NS::UInteger tileHeight) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTileHeight_), tileHeight); +} + +_MTL_INLINE void MTL4::RenderPassDescriptor::setTileWidth(NS::UInteger tileWidth) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTileWidth_), tileWidth); +} + +_MTL_INLINE void MTL4::RenderPassDescriptor::setVisibilityResultBuffer(const MTL::Buffer* visibilityResultBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVisibilityResultBuffer_), visibilityResultBuffer); +} + +_MTL_INLINE void MTL4::RenderPassDescriptor::setVisibilityResultType(MTL::VisibilityResultType visibilityResultType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVisibilityResultType_), visibilityResultType); +} + +_MTL_INLINE MTL::RenderPassStencilAttachmentDescriptor* MTL4::RenderPassDescriptor::stencilAttachment() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(stencilAttachment)); +} + +_MTL_INLINE bool MTL4::RenderPassDescriptor::supportColorAttachmentMapping() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportColorAttachmentMapping)); +} + +_MTL_INLINE NS::UInteger MTL4::RenderPassDescriptor::threadgroupMemoryLength() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(threadgroupMemoryLength)); +} + +_MTL_INLINE NS::UInteger MTL4::RenderPassDescriptor::tileHeight() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(tileHeight)); +} + +_MTL_INLINE NS::UInteger MTL4::RenderPassDescriptor::tileWidth() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(tileWidth)); +} + +_MTL_INLINE MTL::Buffer* MTL4::RenderPassDescriptor::visibilityResultBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(visibilityResultBuffer)); +} + +_MTL_INLINE MTL::VisibilityResultType MTL4::RenderPassDescriptor::visibilityResultType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(visibilityResultType)); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTL4RenderPipeline.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTL4RenderPipeline.hpp new file mode 100644 index 00000000..fc2e5e6f --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTL4RenderPipeline.hpp @@ -0,0 +1,587 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTL4RenderPipeline.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTL4PipelineState.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPixelFormat.hpp" +#include "MTLPrivate.hpp" +#include "MTLRenderPipeline.hpp" + +namespace MTL4 +{ +class FunctionDescriptor; +class RenderPipelineBinaryFunctionsDescriptor; +class RenderPipelineColorAttachmentDescriptor; +class RenderPipelineColorAttachmentDescriptorArray; +class RenderPipelineDescriptor; +class StaticLinkingDescriptor; +} + +namespace MTL +{ +class VertexDescriptor; +} + +namespace MTL4 +{ +_MTL_ENUM(NS::Integer, LogicalToPhysicalColorAttachmentMappingState) { + LogicalToPhysicalColorAttachmentMappingStateIdentity = 0, + LogicalToPhysicalColorAttachmentMappingStateInherited = 1, +}; + +class RenderPipelineColorAttachmentDescriptor : public NS::Copying +{ +public: + static RenderPipelineColorAttachmentDescriptor* alloc(); + + MTL::BlendOperation alphaBlendOperation() const; + + BlendState blendingState() const; + + MTL::BlendFactor destinationAlphaBlendFactor() const; + + MTL::BlendFactor destinationRGBBlendFactor() const; + + RenderPipelineColorAttachmentDescriptor* init(); + + MTL::PixelFormat pixelFormat() const; + + void reset(); + + MTL::BlendOperation rgbBlendOperation() const; + + void setAlphaBlendOperation(MTL::BlendOperation alphaBlendOperation); + + void setBlendingState(MTL4::BlendState blendingState); + + void setDestinationAlphaBlendFactor(MTL::BlendFactor destinationAlphaBlendFactor); + + void setDestinationRGBBlendFactor(MTL::BlendFactor destinationRGBBlendFactor); + + void setPixelFormat(MTL::PixelFormat pixelFormat); + + void setRgbBlendOperation(MTL::BlendOperation rgbBlendOperation); + + void setSourceAlphaBlendFactor(MTL::BlendFactor sourceAlphaBlendFactor); + + void setSourceRGBBlendFactor(MTL::BlendFactor sourceRGBBlendFactor); + + void setWriteMask(MTL::ColorWriteMask writeMask); + + MTL::BlendFactor sourceAlphaBlendFactor() const; + + MTL::BlendFactor sourceRGBBlendFactor() const; + + MTL::ColorWriteMask writeMask() const; +}; + +class RenderPipelineColorAttachmentDescriptorArray : public NS::Copying +{ +public: + static RenderPipelineColorAttachmentDescriptorArray* alloc(); + + RenderPipelineColorAttachmentDescriptorArray* init(); + + RenderPipelineColorAttachmentDescriptor* object(NS::UInteger attachmentIndex); + + void reset(); + + void setObject(const MTL4::RenderPipelineColorAttachmentDescriptor* attachment, NS::UInteger attachmentIndex); +}; + +class RenderPipelineBinaryFunctionsDescriptor : public NS::Copying +{ +public: + static RenderPipelineBinaryFunctionsDescriptor* alloc(); + + NS::Array* fragmentAdditionalBinaryFunctions() const; + + RenderPipelineBinaryFunctionsDescriptor* init(); + + NS::Array* meshAdditionalBinaryFunctions() const; + + NS::Array* objectAdditionalBinaryFunctions() const; + + void reset(); + + void setFragmentAdditionalBinaryFunctions(const NS::Array* fragmentAdditionalBinaryFunctions); + + void setMeshAdditionalBinaryFunctions(const NS::Array* meshAdditionalBinaryFunctions); + + void setObjectAdditionalBinaryFunctions(const NS::Array* objectAdditionalBinaryFunctions); + + void setTileAdditionalBinaryFunctions(const NS::Array* tileAdditionalBinaryFunctions); + + void setVertexAdditionalBinaryFunctions(const NS::Array* vertexAdditionalBinaryFunctions); + + NS::Array* tileAdditionalBinaryFunctions() const; + + NS::Array* vertexAdditionalBinaryFunctions() const; +}; + +class RenderPipelineDescriptor : public NS::Copying +{ +public: + static RenderPipelineDescriptor* alloc(); + + AlphaToCoverageState alphaToCoverageState() const; + + AlphaToOneState alphaToOneState() const; + + LogicalToPhysicalColorAttachmentMappingState colorAttachmentMappingState() const; + + RenderPipelineColorAttachmentDescriptorArray* colorAttachments() const; + + FunctionDescriptor* fragmentFunctionDescriptor() const; + + StaticLinkingDescriptor* fragmentStaticLinkingDescriptor() const; + + RenderPipelineDescriptor* init(); + + MTL::PrimitiveTopologyClass inputPrimitiveTopology() const; + + bool isRasterizationEnabled() const; + + NS::UInteger maxVertexAmplificationCount() const; + + NS::UInteger rasterSampleCount() const; + + [[deprecated("please use isRasterizationEnabled instead")]] + bool rasterizationEnabled() const; + + void reset(); + + void setAlphaToCoverageState(MTL4::AlphaToCoverageState alphaToCoverageState); + + void setAlphaToOneState(MTL4::AlphaToOneState alphaToOneState); + + void setColorAttachmentMappingState(MTL4::LogicalToPhysicalColorAttachmentMappingState colorAttachmentMappingState); + + void setFragmentFunctionDescriptor(const MTL4::FunctionDescriptor* fragmentFunctionDescriptor); + + void setFragmentStaticLinkingDescriptor(const MTL4::StaticLinkingDescriptor* fragmentStaticLinkingDescriptor); + + void setInputPrimitiveTopology(MTL::PrimitiveTopologyClass inputPrimitiveTopology); + + void setMaxVertexAmplificationCount(NS::UInteger maxVertexAmplificationCount); + + void setRasterSampleCount(NS::UInteger rasterSampleCount); + + void setRasterizationEnabled(bool rasterizationEnabled); + + void setSupportFragmentBinaryLinking(bool supportFragmentBinaryLinking); + + void setSupportIndirectCommandBuffers(MTL4::IndirectCommandBufferSupportState supportIndirectCommandBuffers); + + void setSupportVertexBinaryLinking(bool supportVertexBinaryLinking); + + void setVertexDescriptor(const MTL::VertexDescriptor* vertexDescriptor); + + void setVertexFunctionDescriptor(const MTL4::FunctionDescriptor* vertexFunctionDescriptor); + + void setVertexStaticLinkingDescriptor(const MTL4::StaticLinkingDescriptor* vertexStaticLinkingDescriptor); + + bool supportFragmentBinaryLinking() const; + + IndirectCommandBufferSupportState supportIndirectCommandBuffers() const; + + bool supportVertexBinaryLinking() const; + + MTL::VertexDescriptor* vertexDescriptor() const; + + FunctionDescriptor* vertexFunctionDescriptor() const; + + StaticLinkingDescriptor* vertexStaticLinkingDescriptor() const; +}; + +} +_MTL_INLINE MTL4::RenderPipelineColorAttachmentDescriptor* MTL4::RenderPipelineColorAttachmentDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4RenderPipelineColorAttachmentDescriptor)); +} + +_MTL_INLINE MTL::BlendOperation MTL4::RenderPipelineColorAttachmentDescriptor::alphaBlendOperation() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(alphaBlendOperation)); +} + +_MTL_INLINE MTL4::BlendState MTL4::RenderPipelineColorAttachmentDescriptor::blendingState() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(blendingState)); +} + +_MTL_INLINE MTL::BlendFactor MTL4::RenderPipelineColorAttachmentDescriptor::destinationAlphaBlendFactor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(destinationAlphaBlendFactor)); +} + +_MTL_INLINE MTL::BlendFactor MTL4::RenderPipelineColorAttachmentDescriptor::destinationRGBBlendFactor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(destinationRGBBlendFactor)); +} + +_MTL_INLINE MTL4::RenderPipelineColorAttachmentDescriptor* MTL4::RenderPipelineColorAttachmentDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::PixelFormat MTL4::RenderPipelineColorAttachmentDescriptor::pixelFormat() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(pixelFormat)); +} + +_MTL_INLINE void MTL4::RenderPipelineColorAttachmentDescriptor::reset() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(reset)); +} + +_MTL_INLINE MTL::BlendOperation MTL4::RenderPipelineColorAttachmentDescriptor::rgbBlendOperation() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(rgbBlendOperation)); +} + +_MTL_INLINE void MTL4::RenderPipelineColorAttachmentDescriptor::setAlphaBlendOperation(MTL::BlendOperation alphaBlendOperation) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setAlphaBlendOperation_), alphaBlendOperation); +} + +_MTL_INLINE void MTL4::RenderPipelineColorAttachmentDescriptor::setBlendingState(MTL4::BlendState blendingState) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBlendingState_), blendingState); +} + +_MTL_INLINE void MTL4::RenderPipelineColorAttachmentDescriptor::setDestinationAlphaBlendFactor(MTL::BlendFactor destinationAlphaBlendFactor) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDestinationAlphaBlendFactor_), destinationAlphaBlendFactor); +} + +_MTL_INLINE void MTL4::RenderPipelineColorAttachmentDescriptor::setDestinationRGBBlendFactor(MTL::BlendFactor destinationRGBBlendFactor) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDestinationRGBBlendFactor_), destinationRGBBlendFactor); +} + +_MTL_INLINE void MTL4::RenderPipelineColorAttachmentDescriptor::setPixelFormat(MTL::PixelFormat pixelFormat) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setPixelFormat_), pixelFormat); +} + +_MTL_INLINE void MTL4::RenderPipelineColorAttachmentDescriptor::setRgbBlendOperation(MTL::BlendOperation rgbBlendOperation) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRgbBlendOperation_), rgbBlendOperation); +} + +_MTL_INLINE void MTL4::RenderPipelineColorAttachmentDescriptor::setSourceAlphaBlendFactor(MTL::BlendFactor sourceAlphaBlendFactor) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSourceAlphaBlendFactor_), sourceAlphaBlendFactor); +} + +_MTL_INLINE void MTL4::RenderPipelineColorAttachmentDescriptor::setSourceRGBBlendFactor(MTL::BlendFactor sourceRGBBlendFactor) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSourceRGBBlendFactor_), sourceRGBBlendFactor); +} + +_MTL_INLINE void MTL4::RenderPipelineColorAttachmentDescriptor::setWriteMask(MTL::ColorWriteMask writeMask) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setWriteMask_), writeMask); +} + +_MTL_INLINE MTL::BlendFactor MTL4::RenderPipelineColorAttachmentDescriptor::sourceAlphaBlendFactor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(sourceAlphaBlendFactor)); +} + +_MTL_INLINE MTL::BlendFactor MTL4::RenderPipelineColorAttachmentDescriptor::sourceRGBBlendFactor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(sourceRGBBlendFactor)); +} + +_MTL_INLINE MTL::ColorWriteMask MTL4::RenderPipelineColorAttachmentDescriptor::writeMask() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(writeMask)); +} + +_MTL_INLINE MTL4::RenderPipelineColorAttachmentDescriptorArray* MTL4::RenderPipelineColorAttachmentDescriptorArray::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4RenderPipelineColorAttachmentDescriptorArray)); +} + +_MTL_INLINE MTL4::RenderPipelineColorAttachmentDescriptorArray* MTL4::RenderPipelineColorAttachmentDescriptorArray::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL4::RenderPipelineColorAttachmentDescriptor* MTL4::RenderPipelineColorAttachmentDescriptorArray::object(NS::UInteger attachmentIndex) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(objectAtIndexedSubscript_), attachmentIndex); +} + +_MTL_INLINE void MTL4::RenderPipelineColorAttachmentDescriptorArray::reset() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(reset)); +} + +_MTL_INLINE void MTL4::RenderPipelineColorAttachmentDescriptorArray::setObject(const MTL4::RenderPipelineColorAttachmentDescriptor* attachment, NS::UInteger attachmentIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObject_atIndexedSubscript_), attachment, attachmentIndex); +} + +_MTL_INLINE MTL4::RenderPipelineBinaryFunctionsDescriptor* MTL4::RenderPipelineBinaryFunctionsDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4RenderPipelineBinaryFunctionsDescriptor)); +} + +_MTL_INLINE NS::Array* MTL4::RenderPipelineBinaryFunctionsDescriptor::fragmentAdditionalBinaryFunctions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(fragmentAdditionalBinaryFunctions)); +} + +_MTL_INLINE MTL4::RenderPipelineBinaryFunctionsDescriptor* MTL4::RenderPipelineBinaryFunctionsDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::Array* MTL4::RenderPipelineBinaryFunctionsDescriptor::meshAdditionalBinaryFunctions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(meshAdditionalBinaryFunctions)); +} + +_MTL_INLINE NS::Array* MTL4::RenderPipelineBinaryFunctionsDescriptor::objectAdditionalBinaryFunctions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(objectAdditionalBinaryFunctions)); +} + +_MTL_INLINE void MTL4::RenderPipelineBinaryFunctionsDescriptor::reset() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(reset)); +} + +_MTL_INLINE void MTL4::RenderPipelineBinaryFunctionsDescriptor::setFragmentAdditionalBinaryFunctions(const NS::Array* fragmentAdditionalBinaryFunctions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFragmentAdditionalBinaryFunctions_), fragmentAdditionalBinaryFunctions); +} + +_MTL_INLINE void MTL4::RenderPipelineBinaryFunctionsDescriptor::setMeshAdditionalBinaryFunctions(const NS::Array* meshAdditionalBinaryFunctions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMeshAdditionalBinaryFunctions_), meshAdditionalBinaryFunctions); +} + +_MTL_INLINE void MTL4::RenderPipelineBinaryFunctionsDescriptor::setObjectAdditionalBinaryFunctions(const NS::Array* objectAdditionalBinaryFunctions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObjectAdditionalBinaryFunctions_), objectAdditionalBinaryFunctions); +} + +_MTL_INLINE void MTL4::RenderPipelineBinaryFunctionsDescriptor::setTileAdditionalBinaryFunctions(const NS::Array* tileAdditionalBinaryFunctions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTileAdditionalBinaryFunctions_), tileAdditionalBinaryFunctions); +} + +_MTL_INLINE void MTL4::RenderPipelineBinaryFunctionsDescriptor::setVertexAdditionalBinaryFunctions(const NS::Array* vertexAdditionalBinaryFunctions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexAdditionalBinaryFunctions_), vertexAdditionalBinaryFunctions); +} + +_MTL_INLINE NS::Array* MTL4::RenderPipelineBinaryFunctionsDescriptor::tileAdditionalBinaryFunctions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(tileAdditionalBinaryFunctions)); +} + +_MTL_INLINE NS::Array* MTL4::RenderPipelineBinaryFunctionsDescriptor::vertexAdditionalBinaryFunctions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(vertexAdditionalBinaryFunctions)); +} + +_MTL_INLINE MTL4::RenderPipelineDescriptor* MTL4::RenderPipelineDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4RenderPipelineDescriptor)); +} + +_MTL_INLINE MTL4::AlphaToCoverageState MTL4::RenderPipelineDescriptor::alphaToCoverageState() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(alphaToCoverageState)); +} + +_MTL_INLINE MTL4::AlphaToOneState MTL4::RenderPipelineDescriptor::alphaToOneState() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(alphaToOneState)); +} + +_MTL_INLINE MTL4::LogicalToPhysicalColorAttachmentMappingState MTL4::RenderPipelineDescriptor::colorAttachmentMappingState() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(colorAttachmentMappingState)); +} + +_MTL_INLINE MTL4::RenderPipelineColorAttachmentDescriptorArray* MTL4::RenderPipelineDescriptor::colorAttachments() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(colorAttachments)); +} + +_MTL_INLINE MTL4::FunctionDescriptor* MTL4::RenderPipelineDescriptor::fragmentFunctionDescriptor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(fragmentFunctionDescriptor)); +} + +_MTL_INLINE MTL4::StaticLinkingDescriptor* MTL4::RenderPipelineDescriptor::fragmentStaticLinkingDescriptor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(fragmentStaticLinkingDescriptor)); +} + +_MTL_INLINE MTL4::RenderPipelineDescriptor* MTL4::RenderPipelineDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::PrimitiveTopologyClass MTL4::RenderPipelineDescriptor::inputPrimitiveTopology() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(inputPrimitiveTopology)); +} + +_MTL_INLINE bool MTL4::RenderPipelineDescriptor::isRasterizationEnabled() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isRasterizationEnabled)); +} + +_MTL_INLINE NS::UInteger MTL4::RenderPipelineDescriptor::maxVertexAmplificationCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxVertexAmplificationCount)); +} + +_MTL_INLINE NS::UInteger MTL4::RenderPipelineDescriptor::rasterSampleCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(rasterSampleCount)); +} + +_MTL_INLINE bool MTL4::RenderPipelineDescriptor::rasterizationEnabled() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isRasterizationEnabled)); +} + +_MTL_INLINE void MTL4::RenderPipelineDescriptor::reset() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(reset)); +} + +_MTL_INLINE void MTL4::RenderPipelineDescriptor::setAlphaToCoverageState(MTL4::AlphaToCoverageState alphaToCoverageState) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setAlphaToCoverageState_), alphaToCoverageState); +} + +_MTL_INLINE void MTL4::RenderPipelineDescriptor::setAlphaToOneState(MTL4::AlphaToOneState alphaToOneState) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setAlphaToOneState_), alphaToOneState); +} + +_MTL_INLINE void MTL4::RenderPipelineDescriptor::setColorAttachmentMappingState(MTL4::LogicalToPhysicalColorAttachmentMappingState colorAttachmentMappingState) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setColorAttachmentMappingState_), colorAttachmentMappingState); +} + +_MTL_INLINE void MTL4::RenderPipelineDescriptor::setFragmentFunctionDescriptor(const MTL4::FunctionDescriptor* fragmentFunctionDescriptor) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFragmentFunctionDescriptor_), fragmentFunctionDescriptor); +} + +_MTL_INLINE void MTL4::RenderPipelineDescriptor::setFragmentStaticLinkingDescriptor(const MTL4::StaticLinkingDescriptor* fragmentStaticLinkingDescriptor) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFragmentStaticLinkingDescriptor_), fragmentStaticLinkingDescriptor); +} + +_MTL_INLINE void MTL4::RenderPipelineDescriptor::setInputPrimitiveTopology(MTL::PrimitiveTopologyClass inputPrimitiveTopology) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInputPrimitiveTopology_), inputPrimitiveTopology); +} + +_MTL_INLINE void MTL4::RenderPipelineDescriptor::setMaxVertexAmplificationCount(NS::UInteger maxVertexAmplificationCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxVertexAmplificationCount_), maxVertexAmplificationCount); +} + +_MTL_INLINE void MTL4::RenderPipelineDescriptor::setRasterSampleCount(NS::UInteger rasterSampleCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRasterSampleCount_), rasterSampleCount); +} + +_MTL_INLINE void MTL4::RenderPipelineDescriptor::setRasterizationEnabled(bool rasterizationEnabled) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRasterizationEnabled_), rasterizationEnabled); +} + +_MTL_INLINE void MTL4::RenderPipelineDescriptor::setSupportFragmentBinaryLinking(bool supportFragmentBinaryLinking) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSupportFragmentBinaryLinking_), supportFragmentBinaryLinking); +} + +_MTL_INLINE void MTL4::RenderPipelineDescriptor::setSupportIndirectCommandBuffers(MTL4::IndirectCommandBufferSupportState supportIndirectCommandBuffers) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSupportIndirectCommandBuffers_), supportIndirectCommandBuffers); +} + +_MTL_INLINE void MTL4::RenderPipelineDescriptor::setSupportVertexBinaryLinking(bool supportVertexBinaryLinking) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSupportVertexBinaryLinking_), supportVertexBinaryLinking); +} + +_MTL_INLINE void MTL4::RenderPipelineDescriptor::setVertexDescriptor(const MTL::VertexDescriptor* vertexDescriptor) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexDescriptor_), vertexDescriptor); +} + +_MTL_INLINE void MTL4::RenderPipelineDescriptor::setVertexFunctionDescriptor(const MTL4::FunctionDescriptor* vertexFunctionDescriptor) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexFunctionDescriptor_), vertexFunctionDescriptor); +} + +_MTL_INLINE void MTL4::RenderPipelineDescriptor::setVertexStaticLinkingDescriptor(const MTL4::StaticLinkingDescriptor* vertexStaticLinkingDescriptor) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexStaticLinkingDescriptor_), vertexStaticLinkingDescriptor); +} + +_MTL_INLINE bool MTL4::RenderPipelineDescriptor::supportFragmentBinaryLinking() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportFragmentBinaryLinking)); +} + +_MTL_INLINE MTL4::IndirectCommandBufferSupportState MTL4::RenderPipelineDescriptor::supportIndirectCommandBuffers() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(supportIndirectCommandBuffers)); +} + +_MTL_INLINE bool MTL4::RenderPipelineDescriptor::supportVertexBinaryLinking() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportVertexBinaryLinking)); +} + +_MTL_INLINE MTL::VertexDescriptor* MTL4::RenderPipelineDescriptor::vertexDescriptor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(vertexDescriptor)); +} + +_MTL_INLINE MTL4::FunctionDescriptor* MTL4::RenderPipelineDescriptor::vertexFunctionDescriptor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(vertexFunctionDescriptor)); +} + +_MTL_INLINE MTL4::StaticLinkingDescriptor* MTL4::RenderPipelineDescriptor::vertexStaticLinkingDescriptor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(vertexStaticLinkingDescriptor)); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTL4SpecializedFunctionDescriptor.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTL4SpecializedFunctionDescriptor.hpp new file mode 100644 index 00000000..57c0094d --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTL4SpecializedFunctionDescriptor.hpp @@ -0,0 +1,100 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTL4SpecializedFunctionDescriptor.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTL4FunctionDescriptor.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL4 +{ +class FunctionDescriptor; +class SpecializedFunctionDescriptor; +} + +namespace MTL +{ +class FunctionConstantValues; +} + +namespace MTL4 +{ +class SpecializedFunctionDescriptor : public NS::Copying +{ +public: + static SpecializedFunctionDescriptor* alloc(); + + MTL::FunctionConstantValues* constantValues() const; + + FunctionDescriptor* functionDescriptor() const; + + SpecializedFunctionDescriptor* init(); + + void setConstantValues(const MTL::FunctionConstantValues* constantValues); + + void setFunctionDescriptor(const MTL4::FunctionDescriptor* functionDescriptor); + + void setSpecializedName(const NS::String* specializedName); + NS::String* specializedName() const; +}; + +} +_MTL_INLINE MTL4::SpecializedFunctionDescriptor* MTL4::SpecializedFunctionDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4SpecializedFunctionDescriptor)); +} + +_MTL_INLINE MTL::FunctionConstantValues* MTL4::SpecializedFunctionDescriptor::constantValues() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(constantValues)); +} + +_MTL_INLINE MTL4::FunctionDescriptor* MTL4::SpecializedFunctionDescriptor::functionDescriptor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(functionDescriptor)); +} + +_MTL_INLINE MTL4::SpecializedFunctionDescriptor* MTL4::SpecializedFunctionDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE void MTL4::SpecializedFunctionDescriptor::setConstantValues(const MTL::FunctionConstantValues* constantValues) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setConstantValues_), constantValues); +} + +_MTL_INLINE void MTL4::SpecializedFunctionDescriptor::setFunctionDescriptor(const MTL4::FunctionDescriptor* functionDescriptor) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFunctionDescriptor_), functionDescriptor); +} + +_MTL_INLINE void MTL4::SpecializedFunctionDescriptor::setSpecializedName(const NS::String* specializedName) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSpecializedName_), specializedName); +} + +_MTL_INLINE NS::String* MTL4::SpecializedFunctionDescriptor::specializedName() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(specializedName)); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTL4StitchedFunctionDescriptor.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTL4StitchedFunctionDescriptor.hpp new file mode 100644 index 00000000..ca8ea5cf --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTL4StitchedFunctionDescriptor.hpp @@ -0,0 +1,86 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTL4StitchedFunctionDescriptor.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTL4FunctionDescriptor.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL4 +{ +class StitchedFunctionDescriptor; +} + +namespace MTL +{ +class FunctionStitchingGraph; +} + +namespace MTL4 +{ +class StitchedFunctionDescriptor : public NS::Copying +{ +public: + static StitchedFunctionDescriptor* alloc(); + + NS::Array* functionDescriptors() const; + + MTL::FunctionStitchingGraph* functionGraph() const; + + StitchedFunctionDescriptor* init(); + + void setFunctionDescriptors(const NS::Array* functionDescriptors); + + void setFunctionGraph(const MTL::FunctionStitchingGraph* functionGraph); +}; + +} +_MTL_INLINE MTL4::StitchedFunctionDescriptor* MTL4::StitchedFunctionDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4StitchedFunctionDescriptor)); +} + +_MTL_INLINE NS::Array* MTL4::StitchedFunctionDescriptor::functionDescriptors() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(functionDescriptors)); +} + +_MTL_INLINE MTL::FunctionStitchingGraph* MTL4::StitchedFunctionDescriptor::functionGraph() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(functionGraph)); +} + +_MTL_INLINE MTL4::StitchedFunctionDescriptor* MTL4::StitchedFunctionDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE void MTL4::StitchedFunctionDescriptor::setFunctionDescriptors(const NS::Array* functionDescriptors) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFunctionDescriptors_), functionDescriptors); +} + +_MTL_INLINE void MTL4::StitchedFunctionDescriptor::setFunctionGraph(const MTL::FunctionStitchingGraph* functionGraph) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFunctionGraph_), functionGraph); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTL4TileRenderPipeline.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTL4TileRenderPipeline.hpp new file mode 100644 index 00000000..dc74f484 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTL4TileRenderPipeline.hpp @@ -0,0 +1,173 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTL4TileRenderPipeline.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTL4PipelineState.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLTypes.hpp" + +namespace MTL4 +{ +class FunctionDescriptor; +class StaticLinkingDescriptor; +class TileRenderPipelineDescriptor; +} + +namespace MTL +{ +class TileRenderPipelineColorAttachmentDescriptorArray; +} + +namespace MTL4 +{ +class TileRenderPipelineDescriptor : public NS::Copying +{ +public: + static TileRenderPipelineDescriptor* alloc(); + + MTL::TileRenderPipelineColorAttachmentDescriptorArray* colorAttachments() const; + + TileRenderPipelineDescriptor* init(); + + NS::UInteger maxTotalThreadsPerThreadgroup() const; + + NS::UInteger rasterSampleCount() const; + + MTL::Size requiredThreadsPerThreadgroup() const; + + void reset(); + + void setMaxTotalThreadsPerThreadgroup(NS::UInteger maxTotalThreadsPerThreadgroup); + + void setRasterSampleCount(NS::UInteger rasterSampleCount); + + void setRequiredThreadsPerThreadgroup(MTL::Size requiredThreadsPerThreadgroup); + + void setStaticLinkingDescriptor(const MTL4::StaticLinkingDescriptor* staticLinkingDescriptor); + + void setSupportBinaryLinking(bool supportBinaryLinking); + + void setThreadgroupSizeMatchesTileSize(bool threadgroupSizeMatchesTileSize); + + void setTileFunctionDescriptor(const MTL4::FunctionDescriptor* tileFunctionDescriptor); + + StaticLinkingDescriptor* staticLinkingDescriptor() const; + + bool supportBinaryLinking() const; + + bool threadgroupSizeMatchesTileSize() const; + + FunctionDescriptor* tileFunctionDescriptor() const; +}; + +} +_MTL_INLINE MTL4::TileRenderPipelineDescriptor* MTL4::TileRenderPipelineDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4TileRenderPipelineDescriptor)); +} + +_MTL_INLINE MTL::TileRenderPipelineColorAttachmentDescriptorArray* MTL4::TileRenderPipelineDescriptor::colorAttachments() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(colorAttachments)); +} + +_MTL_INLINE MTL4::TileRenderPipelineDescriptor* MTL4::TileRenderPipelineDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::UInteger MTL4::TileRenderPipelineDescriptor::maxTotalThreadsPerThreadgroup() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxTotalThreadsPerThreadgroup)); +} + +_MTL_INLINE NS::UInteger MTL4::TileRenderPipelineDescriptor::rasterSampleCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(rasterSampleCount)); +} + +_MTL_INLINE MTL::Size MTL4::TileRenderPipelineDescriptor::requiredThreadsPerThreadgroup() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(requiredThreadsPerThreadgroup)); +} + +_MTL_INLINE void MTL4::TileRenderPipelineDescriptor::reset() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(reset)); +} + +_MTL_INLINE void MTL4::TileRenderPipelineDescriptor::setMaxTotalThreadsPerThreadgroup(NS::UInteger maxTotalThreadsPerThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxTotalThreadsPerThreadgroup_), maxTotalThreadsPerThreadgroup); +} + +_MTL_INLINE void MTL4::TileRenderPipelineDescriptor::setRasterSampleCount(NS::UInteger rasterSampleCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRasterSampleCount_), rasterSampleCount); +} + +_MTL_INLINE void MTL4::TileRenderPipelineDescriptor::setRequiredThreadsPerThreadgroup(MTL::Size requiredThreadsPerThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRequiredThreadsPerThreadgroup_), requiredThreadsPerThreadgroup); +} + +_MTL_INLINE void MTL4::TileRenderPipelineDescriptor::setStaticLinkingDescriptor(const MTL4::StaticLinkingDescriptor* staticLinkingDescriptor) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStaticLinkingDescriptor_), staticLinkingDescriptor); +} + +_MTL_INLINE void MTL4::TileRenderPipelineDescriptor::setSupportBinaryLinking(bool supportBinaryLinking) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSupportBinaryLinking_), supportBinaryLinking); +} + +_MTL_INLINE void MTL4::TileRenderPipelineDescriptor::setThreadgroupSizeMatchesTileSize(bool threadgroupSizeMatchesTileSize) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setThreadgroupSizeMatchesTileSize_), threadgroupSizeMatchesTileSize); +} + +_MTL_INLINE void MTL4::TileRenderPipelineDescriptor::setTileFunctionDescriptor(const MTL4::FunctionDescriptor* tileFunctionDescriptor) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTileFunctionDescriptor_), tileFunctionDescriptor); +} + +_MTL_INLINE MTL4::StaticLinkingDescriptor* MTL4::TileRenderPipelineDescriptor::staticLinkingDescriptor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(staticLinkingDescriptor)); +} + +_MTL_INLINE bool MTL4::TileRenderPipelineDescriptor::supportBinaryLinking() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportBinaryLinking)); +} + +_MTL_INLINE bool MTL4::TileRenderPipelineDescriptor::threadgroupSizeMatchesTileSize() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(threadgroupSizeMatchesTileSize)); +} + +_MTL_INLINE MTL4::FunctionDescriptor* MTL4::TileRenderPipelineDescriptor::tileFunctionDescriptor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(tileFunctionDescriptor)); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLAccelerationStructure.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLAccelerationStructure.hpp new file mode 100644 index 00000000..d3457c39 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLAccelerationStructure.hpp @@ -0,0 +1,1887 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLAccelerationStructure.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLAccelerationStructureTypes.hpp" +#include "MTLArgument.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLResource.hpp" +#include "MTLStageInputOutputDescriptor.hpp" +#include "MTLTypes.hpp" +#include + +namespace MTL +{ +class AccelerationStructureBoundingBoxGeometryDescriptor; +class AccelerationStructureCurveGeometryDescriptor; +class AccelerationStructureDescriptor; +class AccelerationStructureGeometryDescriptor; +class AccelerationStructureMotionBoundingBoxGeometryDescriptor; +class AccelerationStructureMotionCurveGeometryDescriptor; +class AccelerationStructureMotionTriangleGeometryDescriptor; +class AccelerationStructureTriangleGeometryDescriptor; +class Buffer; +class IndirectInstanceAccelerationStructureDescriptor; +class InstanceAccelerationStructureDescriptor; +class MotionKeyframeData; +class PrimitiveAccelerationStructureDescriptor; +} + +namespace MTL +{ +_MTL_ENUM(NS::Integer, MatrixLayout) { + MatrixLayoutColumnMajor = 0, + MatrixLayoutRowMajor = 1, +}; + +_MTL_ENUM(uint32_t, MotionBorderMode) { + MotionBorderModeClamp = 0, + MotionBorderModeVanish = 1, +}; + +_MTL_ENUM(NS::Integer, CurveType) { + CurveTypeRound = 0, + CurveTypeFlat = 1, +}; + +_MTL_ENUM(NS::Integer, CurveBasis) { + CurveBasisBSpline = 0, + CurveBasisCatmullRom = 1, + CurveBasisLinear = 2, + CurveBasisBezier = 3, +}; + +_MTL_ENUM(NS::Integer, CurveEndCaps) { + CurveEndCapsNone = 0, + CurveEndCapsDisk = 1, + CurveEndCapsSphere = 2, +}; + +_MTL_ENUM(NS::UInteger, AccelerationStructureInstanceDescriptorType) { + AccelerationStructureInstanceDescriptorTypeDefault = 0, + AccelerationStructureInstanceDescriptorTypeUserID = 1, + AccelerationStructureInstanceDescriptorTypeMotion = 2, + AccelerationStructureInstanceDescriptorTypeIndirect = 3, + AccelerationStructureInstanceDescriptorTypeIndirectMotion = 4, +}; + +_MTL_ENUM(NS::Integer, TransformType) { + TransformTypePackedFloat4x3 = 0, + TransformTypeComponent = 1, +}; + +_MTL_OPTIONS(NS::UInteger, AccelerationStructureRefitOptions) { + AccelerationStructureRefitOptionVertexData = 1, + AccelerationStructureRefitOptionPerPrimitiveData = 1 << 1, +}; + +_MTL_OPTIONS(NS::UInteger, AccelerationStructureUsage) { + AccelerationStructureUsageNone = 0, + AccelerationStructureUsageRefit = 1, + AccelerationStructureUsagePreferFastBuild = 1 << 1, + AccelerationStructureUsageExtendedLimits = 1 << 2, + AccelerationStructureUsagePreferFastIntersection = 1 << 4, + AccelerationStructureUsageMinimizeMemory = 1 << 5, +}; + +_MTL_OPTIONS(uint32_t, AccelerationStructureInstanceOptions) { + AccelerationStructureInstanceOptionNone = 0, + AccelerationStructureInstanceOptionDisableTriangleCulling = 1, + AccelerationStructureInstanceOptionTriangleFrontFacingWindingCounterClockwise = 1 << 1, + AccelerationStructureInstanceOptionOpaque = 1 << 2, + AccelerationStructureInstanceOptionNonOpaque = 1 << 3, +}; + +struct AccelerationStructureInstanceDescriptor +{ + MTL::PackedFloat4x3 transformationMatrix; + MTL::AccelerationStructureInstanceOptions options; + uint32_t mask; + uint32_t intersectionFunctionTableOffset; + uint32_t accelerationStructureIndex; +} _MTL_PACKED; + +struct AccelerationStructureUserIDInstanceDescriptor +{ + MTL::PackedFloat4x3 transformationMatrix; + MTL::AccelerationStructureInstanceOptions options; + uint32_t mask; + uint32_t intersectionFunctionTableOffset; + uint32_t accelerationStructureIndex; + uint32_t userID; +} _MTL_PACKED; + +struct AccelerationStructureMotionInstanceDescriptor +{ + MTL::AccelerationStructureInstanceOptions options; + uint32_t mask; + uint32_t intersectionFunctionTableOffset; + uint32_t accelerationStructureIndex; + uint32_t userID; + uint32_t motionTransformsStartIndex; + uint32_t motionTransformsCount; + MTL::MotionBorderMode motionStartBorderMode; + MTL::MotionBorderMode motionEndBorderMode; + float motionStartTime; + float motionEndTime; +} _MTL_PACKED; + +struct IndirectAccelerationStructureInstanceDescriptor +{ + MTL::PackedFloat4x3 transformationMatrix; + MTL::AccelerationStructureInstanceOptions options; + uint32_t mask; + uint32_t intersectionFunctionTableOffset; + uint32_t userID; + MTL::ResourceID accelerationStructureID; +} _MTL_PACKED; + +struct IndirectAccelerationStructureMotionInstanceDescriptor +{ + MTL::AccelerationStructureInstanceOptions options; + uint32_t mask; + uint32_t intersectionFunctionTableOffset; + uint32_t userID; + MTL::ResourceID accelerationStructureID; + uint32_t motionTransformsStartIndex; + uint32_t motionTransformsCount; + MTL::MotionBorderMode motionStartBorderMode; + MTL::MotionBorderMode motionEndBorderMode; + float motionStartTime; + float motionEndTime; +} _MTL_PACKED; + +class AccelerationStructureDescriptor : public NS::Copying +{ +public: + static AccelerationStructureDescriptor* alloc(); + + AccelerationStructureDescriptor* init(); + + void setUsage(MTL::AccelerationStructureUsage usage); + AccelerationStructureUsage usage() const; +}; +class AccelerationStructureGeometryDescriptor : public NS::Copying +{ +public: + static AccelerationStructureGeometryDescriptor* alloc(); + + bool allowDuplicateIntersectionFunctionInvocation() const; + + AccelerationStructureGeometryDescriptor* init(); + + NS::UInteger intersectionFunctionTableOffset() const; + + NS::String* label() const; + + bool opaque() const; + + Buffer* primitiveDataBuffer() const; + NS::UInteger primitiveDataBufferOffset() const; + + NS::UInteger primitiveDataElementSize() const; + + NS::UInteger primitiveDataStride() const; + + void setAllowDuplicateIntersectionFunctionInvocation(bool allowDuplicateIntersectionFunctionInvocation); + + void setIntersectionFunctionTableOffset(NS::UInteger intersectionFunctionTableOffset); + + void setLabel(const NS::String* label); + + void setOpaque(bool opaque); + + void setPrimitiveDataBuffer(const MTL::Buffer* primitiveDataBuffer); + void setPrimitiveDataBufferOffset(NS::UInteger primitiveDataBufferOffset); + + void setPrimitiveDataElementSize(NS::UInteger primitiveDataElementSize); + + void setPrimitiveDataStride(NS::UInteger primitiveDataStride); +}; +class PrimitiveAccelerationStructureDescriptor : public NS::Copying +{ +public: + static PrimitiveAccelerationStructureDescriptor* alloc(); + + static PrimitiveAccelerationStructureDescriptor* descriptor(); + NS::Array* geometryDescriptors() const; + + PrimitiveAccelerationStructureDescriptor* init(); + + MotionBorderMode motionEndBorderMode() const; + + float motionEndTime() const; + + NS::UInteger motionKeyframeCount() const; + + MotionBorderMode motionStartBorderMode() const; + + float motionStartTime() const; + + void setGeometryDescriptors(const NS::Array* geometryDescriptors); + + void setMotionEndBorderMode(MTL::MotionBorderMode motionEndBorderMode); + + void setMotionEndTime(float motionEndTime); + + void setMotionKeyframeCount(NS::UInteger motionKeyframeCount); + + void setMotionStartBorderMode(MTL::MotionBorderMode motionStartBorderMode); + + void setMotionStartTime(float motionStartTime); +}; +class AccelerationStructureTriangleGeometryDescriptor : public NS::Copying +{ +public: + static AccelerationStructureTriangleGeometryDescriptor* alloc(); + + static AccelerationStructureTriangleGeometryDescriptor* descriptor(); + + Buffer* indexBuffer() const; + NS::UInteger indexBufferOffset() const; + + IndexType indexType() const; + + AccelerationStructureTriangleGeometryDescriptor* init(); + + void setIndexBuffer(const MTL::Buffer* indexBuffer); + void setIndexBufferOffset(NS::UInteger indexBufferOffset); + + void setIndexType(MTL::IndexType indexType); + + void setTransformationMatrixBuffer(const MTL::Buffer* transformationMatrixBuffer); + void setTransformationMatrixBufferOffset(NS::UInteger transformationMatrixBufferOffset); + + void setTransformationMatrixLayout(MTL::MatrixLayout transformationMatrixLayout); + + void setTriangleCount(NS::UInteger triangleCount); + + void setVertexBuffer(const MTL::Buffer* vertexBuffer); + void setVertexBufferOffset(NS::UInteger vertexBufferOffset); + + void setVertexFormat(MTL::AttributeFormat vertexFormat); + + void setVertexStride(NS::UInteger vertexStride); + + Buffer* transformationMatrixBuffer() const; + NS::UInteger transformationMatrixBufferOffset() const; + + MatrixLayout transformationMatrixLayout() const; + + NS::UInteger triangleCount() const; + + Buffer* vertexBuffer() const; + NS::UInteger vertexBufferOffset() const; + + AttributeFormat vertexFormat() const; + + NS::UInteger vertexStride() const; +}; +class AccelerationStructureBoundingBoxGeometryDescriptor : public NS::Copying +{ +public: + static AccelerationStructureBoundingBoxGeometryDescriptor* alloc(); + + Buffer* boundingBoxBuffer() const; + NS::UInteger boundingBoxBufferOffset() const; + + NS::UInteger boundingBoxCount() const; + + NS::UInteger boundingBoxStride() const; + + static AccelerationStructureBoundingBoxGeometryDescriptor* descriptor(); + + AccelerationStructureBoundingBoxGeometryDescriptor* init(); + + void setBoundingBoxBuffer(const MTL::Buffer* boundingBoxBuffer); + void setBoundingBoxBufferOffset(NS::UInteger boundingBoxBufferOffset); + + void setBoundingBoxCount(NS::UInteger boundingBoxCount); + + void setBoundingBoxStride(NS::UInteger boundingBoxStride); +}; +class MotionKeyframeData : public NS::Referencing +{ +public: + static MotionKeyframeData* alloc(); + + Buffer* buffer() const; + + static MotionKeyframeData* data(); + + MotionKeyframeData* init(); + + NS::UInteger offset() const; + + void setBuffer(const MTL::Buffer* buffer); + + void setOffset(NS::UInteger offset); +}; +class AccelerationStructureMotionTriangleGeometryDescriptor : public NS::Copying +{ +public: + static AccelerationStructureMotionTriangleGeometryDescriptor* alloc(); + + static AccelerationStructureMotionTriangleGeometryDescriptor* descriptor(); + + Buffer* indexBuffer() const; + NS::UInteger indexBufferOffset() const; + + IndexType indexType() const; + + AccelerationStructureMotionTriangleGeometryDescriptor* init(); + + void setIndexBuffer(const MTL::Buffer* indexBuffer); + void setIndexBufferOffset(NS::UInteger indexBufferOffset); + + void setIndexType(MTL::IndexType indexType); + + void setTransformationMatrixBuffer(const MTL::Buffer* transformationMatrixBuffer); + void setTransformationMatrixBufferOffset(NS::UInteger transformationMatrixBufferOffset); + + void setTransformationMatrixLayout(MTL::MatrixLayout transformationMatrixLayout); + + void setTriangleCount(NS::UInteger triangleCount); + + void setVertexBuffers(const NS::Array* vertexBuffers); + + void setVertexFormat(MTL::AttributeFormat vertexFormat); + + void setVertexStride(NS::UInteger vertexStride); + + Buffer* transformationMatrixBuffer() const; + NS::UInteger transformationMatrixBufferOffset() const; + + MatrixLayout transformationMatrixLayout() const; + + NS::UInteger triangleCount() const; + + NS::Array* vertexBuffers() const; + + AttributeFormat vertexFormat() const; + + NS::UInteger vertexStride() const; +}; +class AccelerationStructureMotionBoundingBoxGeometryDescriptor : public NS::Copying +{ +public: + static AccelerationStructureMotionBoundingBoxGeometryDescriptor* alloc(); + + NS::Array* boundingBoxBuffers() const; + + NS::UInteger boundingBoxCount() const; + + NS::UInteger boundingBoxStride() const; + + static AccelerationStructureMotionBoundingBoxGeometryDescriptor* descriptor(); + + AccelerationStructureMotionBoundingBoxGeometryDescriptor* init(); + + void setBoundingBoxBuffers(const NS::Array* boundingBoxBuffers); + + void setBoundingBoxCount(NS::UInteger boundingBoxCount); + + void setBoundingBoxStride(NS::UInteger boundingBoxStride); +}; +class AccelerationStructureCurveGeometryDescriptor : public NS::Copying +{ +public: + static AccelerationStructureCurveGeometryDescriptor* alloc(); + + Buffer* controlPointBuffer() const; + NS::UInteger controlPointBufferOffset() const; + + NS::UInteger controlPointCount() const; + + AttributeFormat controlPointFormat() const; + + NS::UInteger controlPointStride() const; + + CurveBasis curveBasis() const; + + CurveEndCaps curveEndCaps() const; + + CurveType curveType() const; + + static AccelerationStructureCurveGeometryDescriptor* descriptor(); + + Buffer* indexBuffer() const; + NS::UInteger indexBufferOffset() const; + + IndexType indexType() const; + + AccelerationStructureCurveGeometryDescriptor* init(); + + Buffer* radiusBuffer() const; + NS::UInteger radiusBufferOffset() const; + + AttributeFormat radiusFormat() const; + + NS::UInteger radiusStride() const; + + NS::UInteger segmentControlPointCount() const; + + NS::UInteger segmentCount() const; + + void setControlPointBuffer(const MTL::Buffer* controlPointBuffer); + void setControlPointBufferOffset(NS::UInteger controlPointBufferOffset); + + void setControlPointCount(NS::UInteger controlPointCount); + + void setControlPointFormat(MTL::AttributeFormat controlPointFormat); + + void setControlPointStride(NS::UInteger controlPointStride); + + void setCurveBasis(MTL::CurveBasis curveBasis); + + void setCurveEndCaps(MTL::CurveEndCaps curveEndCaps); + + void setCurveType(MTL::CurveType curveType); + + void setIndexBuffer(const MTL::Buffer* indexBuffer); + void setIndexBufferOffset(NS::UInteger indexBufferOffset); + + void setIndexType(MTL::IndexType indexType); + + void setRadiusBuffer(const MTL::Buffer* radiusBuffer); + void setRadiusBufferOffset(NS::UInteger radiusBufferOffset); + + void setRadiusFormat(MTL::AttributeFormat radiusFormat); + + void setRadiusStride(NS::UInteger radiusStride); + + void setSegmentControlPointCount(NS::UInteger segmentControlPointCount); + + void setSegmentCount(NS::UInteger segmentCount); +}; +class AccelerationStructureMotionCurveGeometryDescriptor : public NS::Copying +{ +public: + static AccelerationStructureMotionCurveGeometryDescriptor* alloc(); + + NS::Array* controlPointBuffers() const; + + NS::UInteger controlPointCount() const; + + AttributeFormat controlPointFormat() const; + + NS::UInteger controlPointStride() const; + + CurveBasis curveBasis() const; + + CurveEndCaps curveEndCaps() const; + + CurveType curveType() const; + + static AccelerationStructureMotionCurveGeometryDescriptor* descriptor(); + + Buffer* indexBuffer() const; + NS::UInteger indexBufferOffset() const; + + IndexType indexType() const; + + AccelerationStructureMotionCurveGeometryDescriptor* init(); + + NS::Array* radiusBuffers() const; + + AttributeFormat radiusFormat() const; + + NS::UInteger radiusStride() const; + + NS::UInteger segmentControlPointCount() const; + + NS::UInteger segmentCount() const; + + void setControlPointBuffers(const NS::Array* controlPointBuffers); + + void setControlPointCount(NS::UInteger controlPointCount); + + void setControlPointFormat(MTL::AttributeFormat controlPointFormat); + + void setControlPointStride(NS::UInteger controlPointStride); + + void setCurveBasis(MTL::CurveBasis curveBasis); + + void setCurveEndCaps(MTL::CurveEndCaps curveEndCaps); + + void setCurveType(MTL::CurveType curveType); + + void setIndexBuffer(const MTL::Buffer* indexBuffer); + void setIndexBufferOffset(NS::UInteger indexBufferOffset); + + void setIndexType(MTL::IndexType indexType); + + void setRadiusBuffers(const NS::Array* radiusBuffers); + + void setRadiusFormat(MTL::AttributeFormat radiusFormat); + + void setRadiusStride(NS::UInteger radiusStride); + + void setSegmentControlPointCount(NS::UInteger segmentControlPointCount); + + void setSegmentCount(NS::UInteger segmentCount); +}; +class InstanceAccelerationStructureDescriptor : public NS::Copying +{ +public: + static InstanceAccelerationStructureDescriptor* alloc(); + + static InstanceAccelerationStructureDescriptor* descriptor(); + + InstanceAccelerationStructureDescriptor* init(); + + NS::UInteger instanceCount() const; + + Buffer* instanceDescriptorBuffer() const; + NS::UInteger instanceDescriptorBufferOffset() const; + + NS::UInteger instanceDescriptorStride() const; + + AccelerationStructureInstanceDescriptorType instanceDescriptorType() const; + + MatrixLayout instanceTransformationMatrixLayout() const; + + NS::Array* instancedAccelerationStructures() const; + + Buffer* motionTransformBuffer() const; + NS::UInteger motionTransformBufferOffset() const; + + NS::UInteger motionTransformCount() const; + + NS::UInteger motionTransformStride() const; + + TransformType motionTransformType() const; + + void setInstanceCount(NS::UInteger instanceCount); + + void setInstanceDescriptorBuffer(const MTL::Buffer* instanceDescriptorBuffer); + void setInstanceDescriptorBufferOffset(NS::UInteger instanceDescriptorBufferOffset); + + void setInstanceDescriptorStride(NS::UInteger instanceDescriptorStride); + + void setInstanceDescriptorType(MTL::AccelerationStructureInstanceDescriptorType instanceDescriptorType); + + void setInstanceTransformationMatrixLayout(MTL::MatrixLayout instanceTransformationMatrixLayout); + + void setInstancedAccelerationStructures(const NS::Array* instancedAccelerationStructures); + + void setMotionTransformBuffer(const MTL::Buffer* motionTransformBuffer); + void setMotionTransformBufferOffset(NS::UInteger motionTransformBufferOffset); + + void setMotionTransformCount(NS::UInteger motionTransformCount); + + void setMotionTransformStride(NS::UInteger motionTransformStride); + + void setMotionTransformType(MTL::TransformType motionTransformType); +}; +class IndirectInstanceAccelerationStructureDescriptor : public NS::Copying +{ +public: + static IndirectInstanceAccelerationStructureDescriptor* alloc(); + + static IndirectInstanceAccelerationStructureDescriptor* descriptor(); + + IndirectInstanceAccelerationStructureDescriptor* init(); + + Buffer* instanceCountBuffer() const; + NS::UInteger instanceCountBufferOffset() const; + + Buffer* instanceDescriptorBuffer() const; + NS::UInteger instanceDescriptorBufferOffset() const; + + NS::UInteger instanceDescriptorStride() const; + + AccelerationStructureInstanceDescriptorType instanceDescriptorType() const; + + MatrixLayout instanceTransformationMatrixLayout() const; + + NS::UInteger maxInstanceCount() const; + + NS::UInteger maxMotionTransformCount() const; + + Buffer* motionTransformBuffer() const; + NS::UInteger motionTransformBufferOffset() const; + + Buffer* motionTransformCountBuffer() const; + NS::UInteger motionTransformCountBufferOffset() const; + + NS::UInteger motionTransformStride() const; + + TransformType motionTransformType() const; + + void setInstanceCountBuffer(const MTL::Buffer* instanceCountBuffer); + void setInstanceCountBufferOffset(NS::UInteger instanceCountBufferOffset); + + void setInstanceDescriptorBuffer(const MTL::Buffer* instanceDescriptorBuffer); + void setInstanceDescriptorBufferOffset(NS::UInteger instanceDescriptorBufferOffset); + + void setInstanceDescriptorStride(NS::UInteger instanceDescriptorStride); + + void setInstanceDescriptorType(MTL::AccelerationStructureInstanceDescriptorType instanceDescriptorType); + + void setInstanceTransformationMatrixLayout(MTL::MatrixLayout instanceTransformationMatrixLayout); + + void setMaxInstanceCount(NS::UInteger maxInstanceCount); + + void setMaxMotionTransformCount(NS::UInteger maxMotionTransformCount); + + void setMotionTransformBuffer(const MTL::Buffer* motionTransformBuffer); + void setMotionTransformBufferOffset(NS::UInteger motionTransformBufferOffset); + + void setMotionTransformCountBuffer(const MTL::Buffer* motionTransformCountBuffer); + void setMotionTransformCountBufferOffset(NS::UInteger motionTransformCountBufferOffset); + + void setMotionTransformStride(NS::UInteger motionTransformStride); + + void setMotionTransformType(MTL::TransformType motionTransformType); +}; +class AccelerationStructure : public NS::Referencing +{ +public: + ResourceID gpuResourceID() const; + + NS::UInteger size() const; +}; + +} + +_MTL_INLINE MTL::AccelerationStructureDescriptor* MTL::AccelerationStructureDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLAccelerationStructureDescriptor)); +} + +_MTL_INLINE MTL::AccelerationStructureDescriptor* MTL::AccelerationStructureDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE void MTL::AccelerationStructureDescriptor::setUsage(MTL::AccelerationStructureUsage usage) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setUsage_), usage); +} + +_MTL_INLINE MTL::AccelerationStructureUsage MTL::AccelerationStructureDescriptor::usage() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(usage)); +} + +_MTL_INLINE MTL::AccelerationStructureGeometryDescriptor* MTL::AccelerationStructureGeometryDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLAccelerationStructureGeometryDescriptor)); +} + +_MTL_INLINE bool MTL::AccelerationStructureGeometryDescriptor::allowDuplicateIntersectionFunctionInvocation() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(allowDuplicateIntersectionFunctionInvocation)); +} + +_MTL_INLINE MTL::AccelerationStructureGeometryDescriptor* MTL::AccelerationStructureGeometryDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureGeometryDescriptor::intersectionFunctionTableOffset() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(intersectionFunctionTableOffset)); +} + +_MTL_INLINE NS::String* MTL::AccelerationStructureGeometryDescriptor::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE bool MTL::AccelerationStructureGeometryDescriptor::opaque() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(opaque)); +} + +_MTL_INLINE MTL::Buffer* MTL::AccelerationStructureGeometryDescriptor::primitiveDataBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(primitiveDataBuffer)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureGeometryDescriptor::primitiveDataBufferOffset() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(primitiveDataBufferOffset)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureGeometryDescriptor::primitiveDataElementSize() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(primitiveDataElementSize)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureGeometryDescriptor::primitiveDataStride() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(primitiveDataStride)); +} + +_MTL_INLINE void MTL::AccelerationStructureGeometryDescriptor::setAllowDuplicateIntersectionFunctionInvocation(bool allowDuplicateIntersectionFunctionInvocation) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setAllowDuplicateIntersectionFunctionInvocation_), allowDuplicateIntersectionFunctionInvocation); +} + +_MTL_INLINE void MTL::AccelerationStructureGeometryDescriptor::setIntersectionFunctionTableOffset(NS::UInteger intersectionFunctionTableOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setIntersectionFunctionTableOffset_), intersectionFunctionTableOffset); +} + +_MTL_INLINE void MTL::AccelerationStructureGeometryDescriptor::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +_MTL_INLINE void MTL::AccelerationStructureGeometryDescriptor::setOpaque(bool opaque) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setOpaque_), opaque); +} + +_MTL_INLINE void MTL::AccelerationStructureGeometryDescriptor::setPrimitiveDataBuffer(const MTL::Buffer* primitiveDataBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setPrimitiveDataBuffer_), primitiveDataBuffer); +} + +_MTL_INLINE void MTL::AccelerationStructureGeometryDescriptor::setPrimitiveDataBufferOffset(NS::UInteger primitiveDataBufferOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setPrimitiveDataBufferOffset_), primitiveDataBufferOffset); +} + +_MTL_INLINE void MTL::AccelerationStructureGeometryDescriptor::setPrimitiveDataElementSize(NS::UInteger primitiveDataElementSize) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setPrimitiveDataElementSize_), primitiveDataElementSize); +} + +_MTL_INLINE void MTL::AccelerationStructureGeometryDescriptor::setPrimitiveDataStride(NS::UInteger primitiveDataStride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setPrimitiveDataStride_), primitiveDataStride); +} + +_MTL_INLINE MTL::PrimitiveAccelerationStructureDescriptor* MTL::PrimitiveAccelerationStructureDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLPrimitiveAccelerationStructureDescriptor)); +} + +_MTL_INLINE MTL::PrimitiveAccelerationStructureDescriptor* MTL::PrimitiveAccelerationStructureDescriptor::descriptor() +{ + return Object::sendMessage(_MTL_PRIVATE_CLS(MTLPrimitiveAccelerationStructureDescriptor), _MTL_PRIVATE_SEL(descriptor)); +} + +_MTL_INLINE NS::Array* MTL::PrimitiveAccelerationStructureDescriptor::geometryDescriptors() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(geometryDescriptors)); +} + +_MTL_INLINE MTL::PrimitiveAccelerationStructureDescriptor* MTL::PrimitiveAccelerationStructureDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::MotionBorderMode MTL::PrimitiveAccelerationStructureDescriptor::motionEndBorderMode() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(motionEndBorderMode)); +} + +_MTL_INLINE float MTL::PrimitiveAccelerationStructureDescriptor::motionEndTime() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(motionEndTime)); +} + +_MTL_INLINE NS::UInteger MTL::PrimitiveAccelerationStructureDescriptor::motionKeyframeCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(motionKeyframeCount)); +} + +_MTL_INLINE MTL::MotionBorderMode MTL::PrimitiveAccelerationStructureDescriptor::motionStartBorderMode() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(motionStartBorderMode)); +} + +_MTL_INLINE float MTL::PrimitiveAccelerationStructureDescriptor::motionStartTime() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(motionStartTime)); +} + +_MTL_INLINE void MTL::PrimitiveAccelerationStructureDescriptor::setGeometryDescriptors(const NS::Array* geometryDescriptors) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setGeometryDescriptors_), geometryDescriptors); +} + +_MTL_INLINE void MTL::PrimitiveAccelerationStructureDescriptor::setMotionEndBorderMode(MTL::MotionBorderMode motionEndBorderMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionEndBorderMode_), motionEndBorderMode); +} + +_MTL_INLINE void MTL::PrimitiveAccelerationStructureDescriptor::setMotionEndTime(float motionEndTime) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionEndTime_), motionEndTime); +} + +_MTL_INLINE void MTL::PrimitiveAccelerationStructureDescriptor::setMotionKeyframeCount(NS::UInteger motionKeyframeCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionKeyframeCount_), motionKeyframeCount); +} + +_MTL_INLINE void MTL::PrimitiveAccelerationStructureDescriptor::setMotionStartBorderMode(MTL::MotionBorderMode motionStartBorderMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionStartBorderMode_), motionStartBorderMode); +} + +_MTL_INLINE void MTL::PrimitiveAccelerationStructureDescriptor::setMotionStartTime(float motionStartTime) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionStartTime_), motionStartTime); +} + +_MTL_INLINE MTL::AccelerationStructureTriangleGeometryDescriptor* MTL::AccelerationStructureTriangleGeometryDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLAccelerationStructureTriangleGeometryDescriptor)); +} + +_MTL_INLINE MTL::AccelerationStructureTriangleGeometryDescriptor* MTL::AccelerationStructureTriangleGeometryDescriptor::descriptor() +{ + return Object::sendMessage(_MTL_PRIVATE_CLS(MTLAccelerationStructureTriangleGeometryDescriptor), _MTL_PRIVATE_SEL(descriptor)); +} + +_MTL_INLINE MTL::Buffer* MTL::AccelerationStructureTriangleGeometryDescriptor::indexBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(indexBuffer)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureTriangleGeometryDescriptor::indexBufferOffset() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(indexBufferOffset)); +} + +_MTL_INLINE MTL::IndexType MTL::AccelerationStructureTriangleGeometryDescriptor::indexType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(indexType)); +} + +_MTL_INLINE MTL::AccelerationStructureTriangleGeometryDescriptor* MTL::AccelerationStructureTriangleGeometryDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE void MTL::AccelerationStructureTriangleGeometryDescriptor::setIndexBuffer(const MTL::Buffer* indexBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setIndexBuffer_), indexBuffer); +} + +_MTL_INLINE void MTL::AccelerationStructureTriangleGeometryDescriptor::setIndexBufferOffset(NS::UInteger indexBufferOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setIndexBufferOffset_), indexBufferOffset); +} + +_MTL_INLINE void MTL::AccelerationStructureTriangleGeometryDescriptor::setIndexType(MTL::IndexType indexType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setIndexType_), indexType); +} + +_MTL_INLINE void MTL::AccelerationStructureTriangleGeometryDescriptor::setTransformationMatrixBuffer(const MTL::Buffer* transformationMatrixBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTransformationMatrixBuffer_), transformationMatrixBuffer); +} + +_MTL_INLINE void MTL::AccelerationStructureTriangleGeometryDescriptor::setTransformationMatrixBufferOffset(NS::UInteger transformationMatrixBufferOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTransformationMatrixBufferOffset_), transformationMatrixBufferOffset); +} + +_MTL_INLINE void MTL::AccelerationStructureTriangleGeometryDescriptor::setTransformationMatrixLayout(MTL::MatrixLayout transformationMatrixLayout) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTransformationMatrixLayout_), transformationMatrixLayout); +} + +_MTL_INLINE void MTL::AccelerationStructureTriangleGeometryDescriptor::setTriangleCount(NS::UInteger triangleCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTriangleCount_), triangleCount); +} + +_MTL_INLINE void MTL::AccelerationStructureTriangleGeometryDescriptor::setVertexBuffer(const MTL::Buffer* vertexBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexBuffer_), vertexBuffer); +} + +_MTL_INLINE void MTL::AccelerationStructureTriangleGeometryDescriptor::setVertexBufferOffset(NS::UInteger vertexBufferOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexBufferOffset_), vertexBufferOffset); +} + +_MTL_INLINE void MTL::AccelerationStructureTriangleGeometryDescriptor::setVertexFormat(MTL::AttributeFormat vertexFormat) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexFormat_), vertexFormat); +} + +_MTL_INLINE void MTL::AccelerationStructureTriangleGeometryDescriptor::setVertexStride(NS::UInteger vertexStride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexStride_), vertexStride); +} + +_MTL_INLINE MTL::Buffer* MTL::AccelerationStructureTriangleGeometryDescriptor::transformationMatrixBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(transformationMatrixBuffer)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureTriangleGeometryDescriptor::transformationMatrixBufferOffset() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(transformationMatrixBufferOffset)); +} + +_MTL_INLINE MTL::MatrixLayout MTL::AccelerationStructureTriangleGeometryDescriptor::transformationMatrixLayout() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(transformationMatrixLayout)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureTriangleGeometryDescriptor::triangleCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(triangleCount)); +} + +_MTL_INLINE MTL::Buffer* MTL::AccelerationStructureTriangleGeometryDescriptor::vertexBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(vertexBuffer)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureTriangleGeometryDescriptor::vertexBufferOffset() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(vertexBufferOffset)); +} + +_MTL_INLINE MTL::AttributeFormat MTL::AccelerationStructureTriangleGeometryDescriptor::vertexFormat() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(vertexFormat)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureTriangleGeometryDescriptor::vertexStride() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(vertexStride)); +} + +_MTL_INLINE MTL::AccelerationStructureBoundingBoxGeometryDescriptor* MTL::AccelerationStructureBoundingBoxGeometryDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLAccelerationStructureBoundingBoxGeometryDescriptor)); +} + +_MTL_INLINE MTL::Buffer* MTL::AccelerationStructureBoundingBoxGeometryDescriptor::boundingBoxBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(boundingBoxBuffer)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureBoundingBoxGeometryDescriptor::boundingBoxBufferOffset() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(boundingBoxBufferOffset)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureBoundingBoxGeometryDescriptor::boundingBoxCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(boundingBoxCount)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureBoundingBoxGeometryDescriptor::boundingBoxStride() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(boundingBoxStride)); +} + +_MTL_INLINE MTL::AccelerationStructureBoundingBoxGeometryDescriptor* MTL::AccelerationStructureBoundingBoxGeometryDescriptor::descriptor() +{ + return Object::sendMessage(_MTL_PRIVATE_CLS(MTLAccelerationStructureBoundingBoxGeometryDescriptor), _MTL_PRIVATE_SEL(descriptor)); +} + +_MTL_INLINE MTL::AccelerationStructureBoundingBoxGeometryDescriptor* MTL::AccelerationStructureBoundingBoxGeometryDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE void MTL::AccelerationStructureBoundingBoxGeometryDescriptor::setBoundingBoxBuffer(const MTL::Buffer* boundingBoxBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBoundingBoxBuffer_), boundingBoxBuffer); +} + +_MTL_INLINE void MTL::AccelerationStructureBoundingBoxGeometryDescriptor::setBoundingBoxBufferOffset(NS::UInteger boundingBoxBufferOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBoundingBoxBufferOffset_), boundingBoxBufferOffset); +} + +_MTL_INLINE void MTL::AccelerationStructureBoundingBoxGeometryDescriptor::setBoundingBoxCount(NS::UInteger boundingBoxCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBoundingBoxCount_), boundingBoxCount); +} + +_MTL_INLINE void MTL::AccelerationStructureBoundingBoxGeometryDescriptor::setBoundingBoxStride(NS::UInteger boundingBoxStride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBoundingBoxStride_), boundingBoxStride); +} + +_MTL_INLINE MTL::MotionKeyframeData* MTL::MotionKeyframeData::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLMotionKeyframeData)); +} + +_MTL_INLINE MTL::Buffer* MTL::MotionKeyframeData::buffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(buffer)); +} + +_MTL_INLINE MTL::MotionKeyframeData* MTL::MotionKeyframeData::data() +{ + return Object::sendMessage(_MTL_PRIVATE_CLS(MTLMotionKeyframeData), _MTL_PRIVATE_SEL(data)); +} + +_MTL_INLINE MTL::MotionKeyframeData* MTL::MotionKeyframeData::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::UInteger MTL::MotionKeyframeData::offset() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(offset)); +} + +_MTL_INLINE void MTL::MotionKeyframeData::setBuffer(const MTL::Buffer* buffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBuffer_), buffer); +} + +_MTL_INLINE void MTL::MotionKeyframeData::setOffset(NS::UInteger offset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setOffset_), offset); +} + +_MTL_INLINE MTL::AccelerationStructureMotionTriangleGeometryDescriptor* MTL::AccelerationStructureMotionTriangleGeometryDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLAccelerationStructureMotionTriangleGeometryDescriptor)); +} + +_MTL_INLINE MTL::AccelerationStructureMotionTriangleGeometryDescriptor* MTL::AccelerationStructureMotionTriangleGeometryDescriptor::descriptor() +{ + return Object::sendMessage(_MTL_PRIVATE_CLS(MTLAccelerationStructureMotionTriangleGeometryDescriptor), _MTL_PRIVATE_SEL(descriptor)); +} + +_MTL_INLINE MTL::Buffer* MTL::AccelerationStructureMotionTriangleGeometryDescriptor::indexBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(indexBuffer)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureMotionTriangleGeometryDescriptor::indexBufferOffset() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(indexBufferOffset)); +} + +_MTL_INLINE MTL::IndexType MTL::AccelerationStructureMotionTriangleGeometryDescriptor::indexType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(indexType)); +} + +_MTL_INLINE MTL::AccelerationStructureMotionTriangleGeometryDescriptor* MTL::AccelerationStructureMotionTriangleGeometryDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE void MTL::AccelerationStructureMotionTriangleGeometryDescriptor::setIndexBuffer(const MTL::Buffer* indexBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setIndexBuffer_), indexBuffer); +} + +_MTL_INLINE void MTL::AccelerationStructureMotionTriangleGeometryDescriptor::setIndexBufferOffset(NS::UInteger indexBufferOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setIndexBufferOffset_), indexBufferOffset); +} + +_MTL_INLINE void MTL::AccelerationStructureMotionTriangleGeometryDescriptor::setIndexType(MTL::IndexType indexType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setIndexType_), indexType); +} + +_MTL_INLINE void MTL::AccelerationStructureMotionTriangleGeometryDescriptor::setTransformationMatrixBuffer(const MTL::Buffer* transformationMatrixBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTransformationMatrixBuffer_), transformationMatrixBuffer); +} + +_MTL_INLINE void MTL::AccelerationStructureMotionTriangleGeometryDescriptor::setTransformationMatrixBufferOffset(NS::UInteger transformationMatrixBufferOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTransformationMatrixBufferOffset_), transformationMatrixBufferOffset); +} + +_MTL_INLINE void MTL::AccelerationStructureMotionTriangleGeometryDescriptor::setTransformationMatrixLayout(MTL::MatrixLayout transformationMatrixLayout) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTransformationMatrixLayout_), transformationMatrixLayout); +} + +_MTL_INLINE void MTL::AccelerationStructureMotionTriangleGeometryDescriptor::setTriangleCount(NS::UInteger triangleCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTriangleCount_), triangleCount); +} + +_MTL_INLINE void MTL::AccelerationStructureMotionTriangleGeometryDescriptor::setVertexBuffers(const NS::Array* vertexBuffers) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexBuffers_), vertexBuffers); +} + +_MTL_INLINE void MTL::AccelerationStructureMotionTriangleGeometryDescriptor::setVertexFormat(MTL::AttributeFormat vertexFormat) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexFormat_), vertexFormat); +} + +_MTL_INLINE void MTL::AccelerationStructureMotionTriangleGeometryDescriptor::setVertexStride(NS::UInteger vertexStride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexStride_), vertexStride); +} + +_MTL_INLINE MTL::Buffer* MTL::AccelerationStructureMotionTriangleGeometryDescriptor::transformationMatrixBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(transformationMatrixBuffer)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureMotionTriangleGeometryDescriptor::transformationMatrixBufferOffset() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(transformationMatrixBufferOffset)); +} + +_MTL_INLINE MTL::MatrixLayout MTL::AccelerationStructureMotionTriangleGeometryDescriptor::transformationMatrixLayout() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(transformationMatrixLayout)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureMotionTriangleGeometryDescriptor::triangleCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(triangleCount)); +} + +_MTL_INLINE NS::Array* MTL::AccelerationStructureMotionTriangleGeometryDescriptor::vertexBuffers() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(vertexBuffers)); +} + +_MTL_INLINE MTL::AttributeFormat MTL::AccelerationStructureMotionTriangleGeometryDescriptor::vertexFormat() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(vertexFormat)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureMotionTriangleGeometryDescriptor::vertexStride() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(vertexStride)); +} + +_MTL_INLINE MTL::AccelerationStructureMotionBoundingBoxGeometryDescriptor* MTL::AccelerationStructureMotionBoundingBoxGeometryDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLAccelerationStructureMotionBoundingBoxGeometryDescriptor)); +} + +_MTL_INLINE NS::Array* MTL::AccelerationStructureMotionBoundingBoxGeometryDescriptor::boundingBoxBuffers() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(boundingBoxBuffers)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureMotionBoundingBoxGeometryDescriptor::boundingBoxCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(boundingBoxCount)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureMotionBoundingBoxGeometryDescriptor::boundingBoxStride() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(boundingBoxStride)); +} + +_MTL_INLINE MTL::AccelerationStructureMotionBoundingBoxGeometryDescriptor* MTL::AccelerationStructureMotionBoundingBoxGeometryDescriptor::descriptor() +{ + return Object::sendMessage(_MTL_PRIVATE_CLS(MTLAccelerationStructureMotionBoundingBoxGeometryDescriptor), _MTL_PRIVATE_SEL(descriptor)); +} + +_MTL_INLINE MTL::AccelerationStructureMotionBoundingBoxGeometryDescriptor* MTL::AccelerationStructureMotionBoundingBoxGeometryDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE void MTL::AccelerationStructureMotionBoundingBoxGeometryDescriptor::setBoundingBoxBuffers(const NS::Array* boundingBoxBuffers) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBoundingBoxBuffers_), boundingBoxBuffers); +} + +_MTL_INLINE void MTL::AccelerationStructureMotionBoundingBoxGeometryDescriptor::setBoundingBoxCount(NS::UInteger boundingBoxCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBoundingBoxCount_), boundingBoxCount); +} + +_MTL_INLINE void MTL::AccelerationStructureMotionBoundingBoxGeometryDescriptor::setBoundingBoxStride(NS::UInteger boundingBoxStride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBoundingBoxStride_), boundingBoxStride); +} + +_MTL_INLINE MTL::AccelerationStructureCurveGeometryDescriptor* MTL::AccelerationStructureCurveGeometryDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLAccelerationStructureCurveGeometryDescriptor)); +} + +_MTL_INLINE MTL::Buffer* MTL::AccelerationStructureCurveGeometryDescriptor::controlPointBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(controlPointBuffer)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureCurveGeometryDescriptor::controlPointBufferOffset() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(controlPointBufferOffset)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureCurveGeometryDescriptor::controlPointCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(controlPointCount)); +} + +_MTL_INLINE MTL::AttributeFormat MTL::AccelerationStructureCurveGeometryDescriptor::controlPointFormat() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(controlPointFormat)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureCurveGeometryDescriptor::controlPointStride() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(controlPointStride)); +} + +_MTL_INLINE MTL::CurveBasis MTL::AccelerationStructureCurveGeometryDescriptor::curveBasis() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(curveBasis)); +} + +_MTL_INLINE MTL::CurveEndCaps MTL::AccelerationStructureCurveGeometryDescriptor::curveEndCaps() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(curveEndCaps)); +} + +_MTL_INLINE MTL::CurveType MTL::AccelerationStructureCurveGeometryDescriptor::curveType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(curveType)); +} + +_MTL_INLINE MTL::AccelerationStructureCurveGeometryDescriptor* MTL::AccelerationStructureCurveGeometryDescriptor::descriptor() +{ + return Object::sendMessage(_MTL_PRIVATE_CLS(MTLAccelerationStructureCurveGeometryDescriptor), _MTL_PRIVATE_SEL(descriptor)); +} + +_MTL_INLINE MTL::Buffer* MTL::AccelerationStructureCurveGeometryDescriptor::indexBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(indexBuffer)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureCurveGeometryDescriptor::indexBufferOffset() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(indexBufferOffset)); +} + +_MTL_INLINE MTL::IndexType MTL::AccelerationStructureCurveGeometryDescriptor::indexType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(indexType)); +} + +_MTL_INLINE MTL::AccelerationStructureCurveGeometryDescriptor* MTL::AccelerationStructureCurveGeometryDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::Buffer* MTL::AccelerationStructureCurveGeometryDescriptor::radiusBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(radiusBuffer)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureCurveGeometryDescriptor::radiusBufferOffset() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(radiusBufferOffset)); +} + +_MTL_INLINE MTL::AttributeFormat MTL::AccelerationStructureCurveGeometryDescriptor::radiusFormat() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(radiusFormat)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureCurveGeometryDescriptor::radiusStride() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(radiusStride)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureCurveGeometryDescriptor::segmentControlPointCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(segmentControlPointCount)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureCurveGeometryDescriptor::segmentCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(segmentCount)); +} + +_MTL_INLINE void MTL::AccelerationStructureCurveGeometryDescriptor::setControlPointBuffer(const MTL::Buffer* controlPointBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setControlPointBuffer_), controlPointBuffer); +} + +_MTL_INLINE void MTL::AccelerationStructureCurveGeometryDescriptor::setControlPointBufferOffset(NS::UInteger controlPointBufferOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setControlPointBufferOffset_), controlPointBufferOffset); +} + +_MTL_INLINE void MTL::AccelerationStructureCurveGeometryDescriptor::setControlPointCount(NS::UInteger controlPointCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setControlPointCount_), controlPointCount); +} + +_MTL_INLINE void MTL::AccelerationStructureCurveGeometryDescriptor::setControlPointFormat(MTL::AttributeFormat controlPointFormat) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setControlPointFormat_), controlPointFormat); +} + +_MTL_INLINE void MTL::AccelerationStructureCurveGeometryDescriptor::setControlPointStride(NS::UInteger controlPointStride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setControlPointStride_), controlPointStride); +} + +_MTL_INLINE void MTL::AccelerationStructureCurveGeometryDescriptor::setCurveBasis(MTL::CurveBasis curveBasis) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setCurveBasis_), curveBasis); +} + +_MTL_INLINE void MTL::AccelerationStructureCurveGeometryDescriptor::setCurveEndCaps(MTL::CurveEndCaps curveEndCaps) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setCurveEndCaps_), curveEndCaps); +} + +_MTL_INLINE void MTL::AccelerationStructureCurveGeometryDescriptor::setCurveType(MTL::CurveType curveType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setCurveType_), curveType); +} + +_MTL_INLINE void MTL::AccelerationStructureCurveGeometryDescriptor::setIndexBuffer(const MTL::Buffer* indexBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setIndexBuffer_), indexBuffer); +} + +_MTL_INLINE void MTL::AccelerationStructureCurveGeometryDescriptor::setIndexBufferOffset(NS::UInteger indexBufferOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setIndexBufferOffset_), indexBufferOffset); +} + +_MTL_INLINE void MTL::AccelerationStructureCurveGeometryDescriptor::setIndexType(MTL::IndexType indexType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setIndexType_), indexType); +} + +_MTL_INLINE void MTL::AccelerationStructureCurveGeometryDescriptor::setRadiusBuffer(const MTL::Buffer* radiusBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRadiusBuffer_), radiusBuffer); +} + +_MTL_INLINE void MTL::AccelerationStructureCurveGeometryDescriptor::setRadiusBufferOffset(NS::UInteger radiusBufferOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRadiusBufferOffset_), radiusBufferOffset); +} + +_MTL_INLINE void MTL::AccelerationStructureCurveGeometryDescriptor::setRadiusFormat(MTL::AttributeFormat radiusFormat) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRadiusFormat_), radiusFormat); +} + +_MTL_INLINE void MTL::AccelerationStructureCurveGeometryDescriptor::setRadiusStride(NS::UInteger radiusStride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRadiusStride_), radiusStride); +} + +_MTL_INLINE void MTL::AccelerationStructureCurveGeometryDescriptor::setSegmentControlPointCount(NS::UInteger segmentControlPointCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSegmentControlPointCount_), segmentControlPointCount); +} + +_MTL_INLINE void MTL::AccelerationStructureCurveGeometryDescriptor::setSegmentCount(NS::UInteger segmentCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSegmentCount_), segmentCount); +} + +_MTL_INLINE MTL::AccelerationStructureMotionCurveGeometryDescriptor* MTL::AccelerationStructureMotionCurveGeometryDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLAccelerationStructureMotionCurveGeometryDescriptor)); +} + +_MTL_INLINE NS::Array* MTL::AccelerationStructureMotionCurveGeometryDescriptor::controlPointBuffers() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(controlPointBuffers)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureMotionCurveGeometryDescriptor::controlPointCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(controlPointCount)); +} + +_MTL_INLINE MTL::AttributeFormat MTL::AccelerationStructureMotionCurveGeometryDescriptor::controlPointFormat() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(controlPointFormat)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureMotionCurveGeometryDescriptor::controlPointStride() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(controlPointStride)); +} + +_MTL_INLINE MTL::CurveBasis MTL::AccelerationStructureMotionCurveGeometryDescriptor::curveBasis() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(curveBasis)); +} + +_MTL_INLINE MTL::CurveEndCaps MTL::AccelerationStructureMotionCurveGeometryDescriptor::curveEndCaps() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(curveEndCaps)); +} + +_MTL_INLINE MTL::CurveType MTL::AccelerationStructureMotionCurveGeometryDescriptor::curveType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(curveType)); +} + +_MTL_INLINE MTL::AccelerationStructureMotionCurveGeometryDescriptor* MTL::AccelerationStructureMotionCurveGeometryDescriptor::descriptor() +{ + return Object::sendMessage(_MTL_PRIVATE_CLS(MTLAccelerationStructureMotionCurveGeometryDescriptor), _MTL_PRIVATE_SEL(descriptor)); +} + +_MTL_INLINE MTL::Buffer* MTL::AccelerationStructureMotionCurveGeometryDescriptor::indexBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(indexBuffer)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureMotionCurveGeometryDescriptor::indexBufferOffset() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(indexBufferOffset)); +} + +_MTL_INLINE MTL::IndexType MTL::AccelerationStructureMotionCurveGeometryDescriptor::indexType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(indexType)); +} + +_MTL_INLINE MTL::AccelerationStructureMotionCurveGeometryDescriptor* MTL::AccelerationStructureMotionCurveGeometryDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::Array* MTL::AccelerationStructureMotionCurveGeometryDescriptor::radiusBuffers() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(radiusBuffers)); +} + +_MTL_INLINE MTL::AttributeFormat MTL::AccelerationStructureMotionCurveGeometryDescriptor::radiusFormat() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(radiusFormat)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureMotionCurveGeometryDescriptor::radiusStride() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(radiusStride)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureMotionCurveGeometryDescriptor::segmentControlPointCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(segmentControlPointCount)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureMotionCurveGeometryDescriptor::segmentCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(segmentCount)); +} + +_MTL_INLINE void MTL::AccelerationStructureMotionCurveGeometryDescriptor::setControlPointBuffers(const NS::Array* controlPointBuffers) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setControlPointBuffers_), controlPointBuffers); +} + +_MTL_INLINE void MTL::AccelerationStructureMotionCurveGeometryDescriptor::setControlPointCount(NS::UInteger controlPointCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setControlPointCount_), controlPointCount); +} + +_MTL_INLINE void MTL::AccelerationStructureMotionCurveGeometryDescriptor::setControlPointFormat(MTL::AttributeFormat controlPointFormat) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setControlPointFormat_), controlPointFormat); +} + +_MTL_INLINE void MTL::AccelerationStructureMotionCurveGeometryDescriptor::setControlPointStride(NS::UInteger controlPointStride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setControlPointStride_), controlPointStride); +} + +_MTL_INLINE void MTL::AccelerationStructureMotionCurveGeometryDescriptor::setCurveBasis(MTL::CurveBasis curveBasis) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setCurveBasis_), curveBasis); +} + +_MTL_INLINE void MTL::AccelerationStructureMotionCurveGeometryDescriptor::setCurveEndCaps(MTL::CurveEndCaps curveEndCaps) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setCurveEndCaps_), curveEndCaps); +} + +_MTL_INLINE void MTL::AccelerationStructureMotionCurveGeometryDescriptor::setCurveType(MTL::CurveType curveType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setCurveType_), curveType); +} + +_MTL_INLINE void MTL::AccelerationStructureMotionCurveGeometryDescriptor::setIndexBuffer(const MTL::Buffer* indexBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setIndexBuffer_), indexBuffer); +} + +_MTL_INLINE void MTL::AccelerationStructureMotionCurveGeometryDescriptor::setIndexBufferOffset(NS::UInteger indexBufferOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setIndexBufferOffset_), indexBufferOffset); +} + +_MTL_INLINE void MTL::AccelerationStructureMotionCurveGeometryDescriptor::setIndexType(MTL::IndexType indexType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setIndexType_), indexType); +} + +_MTL_INLINE void MTL::AccelerationStructureMotionCurveGeometryDescriptor::setRadiusBuffers(const NS::Array* radiusBuffers) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRadiusBuffers_), radiusBuffers); +} + +_MTL_INLINE void MTL::AccelerationStructureMotionCurveGeometryDescriptor::setRadiusFormat(MTL::AttributeFormat radiusFormat) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRadiusFormat_), radiusFormat); +} + +_MTL_INLINE void MTL::AccelerationStructureMotionCurveGeometryDescriptor::setRadiusStride(NS::UInteger radiusStride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRadiusStride_), radiusStride); +} + +_MTL_INLINE void MTL::AccelerationStructureMotionCurveGeometryDescriptor::setSegmentControlPointCount(NS::UInteger segmentControlPointCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSegmentControlPointCount_), segmentControlPointCount); +} + +_MTL_INLINE void MTL::AccelerationStructureMotionCurveGeometryDescriptor::setSegmentCount(NS::UInteger segmentCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSegmentCount_), segmentCount); +} + +_MTL_INLINE MTL::InstanceAccelerationStructureDescriptor* MTL::InstanceAccelerationStructureDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLInstanceAccelerationStructureDescriptor)); +} + +_MTL_INLINE MTL::InstanceAccelerationStructureDescriptor* MTL::InstanceAccelerationStructureDescriptor::descriptor() +{ + return Object::sendMessage(_MTL_PRIVATE_CLS(MTLInstanceAccelerationStructureDescriptor), _MTL_PRIVATE_SEL(descriptor)); +} + +_MTL_INLINE MTL::InstanceAccelerationStructureDescriptor* MTL::InstanceAccelerationStructureDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::UInteger MTL::InstanceAccelerationStructureDescriptor::instanceCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(instanceCount)); +} + +_MTL_INLINE MTL::Buffer* MTL::InstanceAccelerationStructureDescriptor::instanceDescriptorBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(instanceDescriptorBuffer)); +} + +_MTL_INLINE NS::UInteger MTL::InstanceAccelerationStructureDescriptor::instanceDescriptorBufferOffset() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(instanceDescriptorBufferOffset)); +} + +_MTL_INLINE NS::UInteger MTL::InstanceAccelerationStructureDescriptor::instanceDescriptorStride() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(instanceDescriptorStride)); +} + +_MTL_INLINE MTL::AccelerationStructureInstanceDescriptorType MTL::InstanceAccelerationStructureDescriptor::instanceDescriptorType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(instanceDescriptorType)); +} + +_MTL_INLINE MTL::MatrixLayout MTL::InstanceAccelerationStructureDescriptor::instanceTransformationMatrixLayout() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(instanceTransformationMatrixLayout)); +} + +_MTL_INLINE NS::Array* MTL::InstanceAccelerationStructureDescriptor::instancedAccelerationStructures() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(instancedAccelerationStructures)); +} + +_MTL_INLINE MTL::Buffer* MTL::InstanceAccelerationStructureDescriptor::motionTransformBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(motionTransformBuffer)); +} + +_MTL_INLINE NS::UInteger MTL::InstanceAccelerationStructureDescriptor::motionTransformBufferOffset() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(motionTransformBufferOffset)); +} + +_MTL_INLINE NS::UInteger MTL::InstanceAccelerationStructureDescriptor::motionTransformCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(motionTransformCount)); +} + +_MTL_INLINE NS::UInteger MTL::InstanceAccelerationStructureDescriptor::motionTransformStride() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(motionTransformStride)); +} + +_MTL_INLINE MTL::TransformType MTL::InstanceAccelerationStructureDescriptor::motionTransformType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(motionTransformType)); +} + +_MTL_INLINE void MTL::InstanceAccelerationStructureDescriptor::setInstanceCount(NS::UInteger instanceCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInstanceCount_), instanceCount); +} + +_MTL_INLINE void MTL::InstanceAccelerationStructureDescriptor::setInstanceDescriptorBuffer(const MTL::Buffer* instanceDescriptorBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInstanceDescriptorBuffer_), instanceDescriptorBuffer); +} + +_MTL_INLINE void MTL::InstanceAccelerationStructureDescriptor::setInstanceDescriptorBufferOffset(NS::UInteger instanceDescriptorBufferOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInstanceDescriptorBufferOffset_), instanceDescriptorBufferOffset); +} + +_MTL_INLINE void MTL::InstanceAccelerationStructureDescriptor::setInstanceDescriptorStride(NS::UInteger instanceDescriptorStride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInstanceDescriptorStride_), instanceDescriptorStride); +} + +_MTL_INLINE void MTL::InstanceAccelerationStructureDescriptor::setInstanceDescriptorType(MTL::AccelerationStructureInstanceDescriptorType instanceDescriptorType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInstanceDescriptorType_), instanceDescriptorType); +} + +_MTL_INLINE void MTL::InstanceAccelerationStructureDescriptor::setInstanceTransformationMatrixLayout(MTL::MatrixLayout instanceTransformationMatrixLayout) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInstanceTransformationMatrixLayout_), instanceTransformationMatrixLayout); +} + +_MTL_INLINE void MTL::InstanceAccelerationStructureDescriptor::setInstancedAccelerationStructures(const NS::Array* instancedAccelerationStructures) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInstancedAccelerationStructures_), instancedAccelerationStructures); +} + +_MTL_INLINE void MTL::InstanceAccelerationStructureDescriptor::setMotionTransformBuffer(const MTL::Buffer* motionTransformBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionTransformBuffer_), motionTransformBuffer); +} + +_MTL_INLINE void MTL::InstanceAccelerationStructureDescriptor::setMotionTransformBufferOffset(NS::UInteger motionTransformBufferOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionTransformBufferOffset_), motionTransformBufferOffset); +} + +_MTL_INLINE void MTL::InstanceAccelerationStructureDescriptor::setMotionTransformCount(NS::UInteger motionTransformCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionTransformCount_), motionTransformCount); +} + +_MTL_INLINE void MTL::InstanceAccelerationStructureDescriptor::setMotionTransformStride(NS::UInteger motionTransformStride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionTransformStride_), motionTransformStride); +} + +_MTL_INLINE void MTL::InstanceAccelerationStructureDescriptor::setMotionTransformType(MTL::TransformType motionTransformType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionTransformType_), motionTransformType); +} + +_MTL_INLINE MTL::IndirectInstanceAccelerationStructureDescriptor* MTL::IndirectInstanceAccelerationStructureDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLIndirectInstanceAccelerationStructureDescriptor)); +} + +_MTL_INLINE MTL::IndirectInstanceAccelerationStructureDescriptor* MTL::IndirectInstanceAccelerationStructureDescriptor::descriptor() +{ + return Object::sendMessage(_MTL_PRIVATE_CLS(MTLIndirectInstanceAccelerationStructureDescriptor), _MTL_PRIVATE_SEL(descriptor)); +} + +_MTL_INLINE MTL::IndirectInstanceAccelerationStructureDescriptor* MTL::IndirectInstanceAccelerationStructureDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::Buffer* MTL::IndirectInstanceAccelerationStructureDescriptor::instanceCountBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(instanceCountBuffer)); +} + +_MTL_INLINE NS::UInteger MTL::IndirectInstanceAccelerationStructureDescriptor::instanceCountBufferOffset() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(instanceCountBufferOffset)); +} + +_MTL_INLINE MTL::Buffer* MTL::IndirectInstanceAccelerationStructureDescriptor::instanceDescriptorBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(instanceDescriptorBuffer)); +} + +_MTL_INLINE NS::UInteger MTL::IndirectInstanceAccelerationStructureDescriptor::instanceDescriptorBufferOffset() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(instanceDescriptorBufferOffset)); +} + +_MTL_INLINE NS::UInteger MTL::IndirectInstanceAccelerationStructureDescriptor::instanceDescriptorStride() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(instanceDescriptorStride)); +} + +_MTL_INLINE MTL::AccelerationStructureInstanceDescriptorType MTL::IndirectInstanceAccelerationStructureDescriptor::instanceDescriptorType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(instanceDescriptorType)); +} + +_MTL_INLINE MTL::MatrixLayout MTL::IndirectInstanceAccelerationStructureDescriptor::instanceTransformationMatrixLayout() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(instanceTransformationMatrixLayout)); +} + +_MTL_INLINE NS::UInteger MTL::IndirectInstanceAccelerationStructureDescriptor::maxInstanceCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxInstanceCount)); +} + +_MTL_INLINE NS::UInteger MTL::IndirectInstanceAccelerationStructureDescriptor::maxMotionTransformCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxMotionTransformCount)); +} + +_MTL_INLINE MTL::Buffer* MTL::IndirectInstanceAccelerationStructureDescriptor::motionTransformBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(motionTransformBuffer)); +} + +_MTL_INLINE NS::UInteger MTL::IndirectInstanceAccelerationStructureDescriptor::motionTransformBufferOffset() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(motionTransformBufferOffset)); +} + +_MTL_INLINE MTL::Buffer* MTL::IndirectInstanceAccelerationStructureDescriptor::motionTransformCountBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(motionTransformCountBuffer)); +} + +_MTL_INLINE NS::UInteger MTL::IndirectInstanceAccelerationStructureDescriptor::motionTransformCountBufferOffset() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(motionTransformCountBufferOffset)); +} + +_MTL_INLINE NS::UInteger MTL::IndirectInstanceAccelerationStructureDescriptor::motionTransformStride() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(motionTransformStride)); +} + +_MTL_INLINE MTL::TransformType MTL::IndirectInstanceAccelerationStructureDescriptor::motionTransformType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(motionTransformType)); +} + +_MTL_INLINE void MTL::IndirectInstanceAccelerationStructureDescriptor::setInstanceCountBuffer(const MTL::Buffer* instanceCountBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInstanceCountBuffer_), instanceCountBuffer); +} + +_MTL_INLINE void MTL::IndirectInstanceAccelerationStructureDescriptor::setInstanceCountBufferOffset(NS::UInteger instanceCountBufferOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInstanceCountBufferOffset_), instanceCountBufferOffset); +} + +_MTL_INLINE void MTL::IndirectInstanceAccelerationStructureDescriptor::setInstanceDescriptorBuffer(const MTL::Buffer* instanceDescriptorBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInstanceDescriptorBuffer_), instanceDescriptorBuffer); +} + +_MTL_INLINE void MTL::IndirectInstanceAccelerationStructureDescriptor::setInstanceDescriptorBufferOffset(NS::UInteger instanceDescriptorBufferOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInstanceDescriptorBufferOffset_), instanceDescriptorBufferOffset); +} + +_MTL_INLINE void MTL::IndirectInstanceAccelerationStructureDescriptor::setInstanceDescriptorStride(NS::UInteger instanceDescriptorStride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInstanceDescriptorStride_), instanceDescriptorStride); +} + +_MTL_INLINE void MTL::IndirectInstanceAccelerationStructureDescriptor::setInstanceDescriptorType(MTL::AccelerationStructureInstanceDescriptorType instanceDescriptorType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInstanceDescriptorType_), instanceDescriptorType); +} + +_MTL_INLINE void MTL::IndirectInstanceAccelerationStructureDescriptor::setInstanceTransformationMatrixLayout(MTL::MatrixLayout instanceTransformationMatrixLayout) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInstanceTransformationMatrixLayout_), instanceTransformationMatrixLayout); +} + +_MTL_INLINE void MTL::IndirectInstanceAccelerationStructureDescriptor::setMaxInstanceCount(NS::UInteger maxInstanceCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxInstanceCount_), maxInstanceCount); +} + +_MTL_INLINE void MTL::IndirectInstanceAccelerationStructureDescriptor::setMaxMotionTransformCount(NS::UInteger maxMotionTransformCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxMotionTransformCount_), maxMotionTransformCount); +} + +_MTL_INLINE void MTL::IndirectInstanceAccelerationStructureDescriptor::setMotionTransformBuffer(const MTL::Buffer* motionTransformBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionTransformBuffer_), motionTransformBuffer); +} + +_MTL_INLINE void MTL::IndirectInstanceAccelerationStructureDescriptor::setMotionTransformBufferOffset(NS::UInteger motionTransformBufferOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionTransformBufferOffset_), motionTransformBufferOffset); +} + +_MTL_INLINE void MTL::IndirectInstanceAccelerationStructureDescriptor::setMotionTransformCountBuffer(const MTL::Buffer* motionTransformCountBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionTransformCountBuffer_), motionTransformCountBuffer); +} + +_MTL_INLINE void MTL::IndirectInstanceAccelerationStructureDescriptor::setMotionTransformCountBufferOffset(NS::UInteger motionTransformCountBufferOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionTransformCountBufferOffset_), motionTransformCountBufferOffset); +} + +_MTL_INLINE void MTL::IndirectInstanceAccelerationStructureDescriptor::setMotionTransformStride(NS::UInteger motionTransformStride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionTransformStride_), motionTransformStride); +} + +_MTL_INLINE void MTL::IndirectInstanceAccelerationStructureDescriptor::setMotionTransformType(MTL::TransformType motionTransformType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionTransformType_), motionTransformType); +} + +_MTL_INLINE MTL::ResourceID MTL::AccelerationStructure::gpuResourceID() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(gpuResourceID)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructure::size() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(size)); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLAccelerationStructureCommandEncoder.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLAccelerationStructureCommandEncoder.hpp new file mode 100644 index 00000000..5f82344a --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLAccelerationStructureCommandEncoder.hpp @@ -0,0 +1,260 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLAccelerationStructureCommandEncoder.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLAccelerationStructure.hpp" +#include "MTLCommandEncoder.hpp" +#include "MTLDataType.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL +{ +class AccelerationStructure; +class AccelerationStructureDescriptor; +class AccelerationStructurePassDescriptor; +class AccelerationStructurePassSampleBufferAttachmentDescriptor; +class AccelerationStructurePassSampleBufferAttachmentDescriptorArray; +class Buffer; +class CounterSampleBuffer; +class Fence; +class Heap; +class Resource; + +class AccelerationStructureCommandEncoder : public NS::Referencing +{ +public: + void buildAccelerationStructure(const MTL::AccelerationStructure* accelerationStructure, const MTL::AccelerationStructureDescriptor* descriptor, const MTL::Buffer* scratchBuffer, NS::UInteger scratchBufferOffset); + + void copyAccelerationStructure(const MTL::AccelerationStructure* sourceAccelerationStructure, const MTL::AccelerationStructure* destinationAccelerationStructure); + + void copyAndCompactAccelerationStructure(const MTL::AccelerationStructure* sourceAccelerationStructure, const MTL::AccelerationStructure* destinationAccelerationStructure); + + void refitAccelerationStructure(const MTL::AccelerationStructure* sourceAccelerationStructure, const MTL::AccelerationStructureDescriptor* descriptor, const MTL::AccelerationStructure* destinationAccelerationStructure, const MTL::Buffer* scratchBuffer, NS::UInteger scratchBufferOffset); + void refitAccelerationStructure(const MTL::AccelerationStructure* sourceAccelerationStructure, const MTL::AccelerationStructureDescriptor* descriptor, const MTL::AccelerationStructure* destinationAccelerationStructure, const MTL::Buffer* scratchBuffer, NS::UInteger scratchBufferOffset, MTL::AccelerationStructureRefitOptions options); + + void sampleCountersInBuffer(const MTL::CounterSampleBuffer* sampleBuffer, NS::UInteger sampleIndex, bool barrier); + + void updateFence(const MTL::Fence* fence); + + void useHeap(const MTL::Heap* heap); + void useHeaps(const MTL::Heap* const heaps[], NS::UInteger count); + + void useResource(const MTL::Resource* resource, MTL::ResourceUsage usage); + void useResources(const MTL::Resource* const resources[], NS::UInteger count, MTL::ResourceUsage usage); + + void waitForFence(const MTL::Fence* fence); + + void writeCompactedAccelerationStructureSize(const MTL::AccelerationStructure* accelerationStructure, const MTL::Buffer* buffer, NS::UInteger offset); + void writeCompactedAccelerationStructureSize(const MTL::AccelerationStructure* accelerationStructure, const MTL::Buffer* buffer, NS::UInteger offset, MTL::DataType sizeDataType); +}; +class AccelerationStructurePassSampleBufferAttachmentDescriptor : public NS::Copying +{ +public: + static AccelerationStructurePassSampleBufferAttachmentDescriptor* alloc(); + + NS::UInteger endOfEncoderSampleIndex() const; + + AccelerationStructurePassSampleBufferAttachmentDescriptor* init(); + + CounterSampleBuffer* sampleBuffer() const; + + void setEndOfEncoderSampleIndex(NS::UInteger endOfEncoderSampleIndex); + + void setSampleBuffer(const MTL::CounterSampleBuffer* sampleBuffer); + + void setStartOfEncoderSampleIndex(NS::UInteger startOfEncoderSampleIndex); + NS::UInteger startOfEncoderSampleIndex() const; +}; +class AccelerationStructurePassSampleBufferAttachmentDescriptorArray : public NS::Referencing +{ +public: + static AccelerationStructurePassSampleBufferAttachmentDescriptorArray* alloc(); + + AccelerationStructurePassSampleBufferAttachmentDescriptorArray* init(); + + AccelerationStructurePassSampleBufferAttachmentDescriptor* object(NS::UInteger attachmentIndex); + void setObject(const MTL::AccelerationStructurePassSampleBufferAttachmentDescriptor* attachment, NS::UInteger attachmentIndex); +}; +class AccelerationStructurePassDescriptor : public NS::Copying +{ +public: + static AccelerationStructurePassDescriptor* accelerationStructurePassDescriptor(); + + static AccelerationStructurePassDescriptor* alloc(); + + AccelerationStructurePassDescriptor* init(); + + AccelerationStructurePassSampleBufferAttachmentDescriptorArray* sampleBufferAttachments() const; +}; + +} +_MTL_INLINE void MTL::AccelerationStructureCommandEncoder::buildAccelerationStructure(const MTL::AccelerationStructure* accelerationStructure, const MTL::AccelerationStructureDescriptor* descriptor, const MTL::Buffer* scratchBuffer, NS::UInteger scratchBufferOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(buildAccelerationStructure_descriptor_scratchBuffer_scratchBufferOffset_), accelerationStructure, descriptor, scratchBuffer, scratchBufferOffset); +} + +_MTL_INLINE void MTL::AccelerationStructureCommandEncoder::copyAccelerationStructure(const MTL::AccelerationStructure* sourceAccelerationStructure, const MTL::AccelerationStructure* destinationAccelerationStructure) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(copyAccelerationStructure_toAccelerationStructure_), sourceAccelerationStructure, destinationAccelerationStructure); +} + +_MTL_INLINE void MTL::AccelerationStructureCommandEncoder::copyAndCompactAccelerationStructure(const MTL::AccelerationStructure* sourceAccelerationStructure, const MTL::AccelerationStructure* destinationAccelerationStructure) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(copyAndCompactAccelerationStructure_toAccelerationStructure_), sourceAccelerationStructure, destinationAccelerationStructure); +} + +_MTL_INLINE void MTL::AccelerationStructureCommandEncoder::refitAccelerationStructure(const MTL::AccelerationStructure* sourceAccelerationStructure, const MTL::AccelerationStructureDescriptor* descriptor, const MTL::AccelerationStructure* destinationAccelerationStructure, const MTL::Buffer* scratchBuffer, NS::UInteger scratchBufferOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(refitAccelerationStructure_descriptor_destination_scratchBuffer_scratchBufferOffset_), sourceAccelerationStructure, descriptor, destinationAccelerationStructure, scratchBuffer, scratchBufferOffset); +} + +_MTL_INLINE void MTL::AccelerationStructureCommandEncoder::refitAccelerationStructure(const MTL::AccelerationStructure* sourceAccelerationStructure, const MTL::AccelerationStructureDescriptor* descriptor, const MTL::AccelerationStructure* destinationAccelerationStructure, const MTL::Buffer* scratchBuffer, NS::UInteger scratchBufferOffset, MTL::AccelerationStructureRefitOptions options) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(refitAccelerationStructure_descriptor_destination_scratchBuffer_scratchBufferOffset_options_), sourceAccelerationStructure, descriptor, destinationAccelerationStructure, scratchBuffer, scratchBufferOffset, options); +} + +_MTL_INLINE void MTL::AccelerationStructureCommandEncoder::sampleCountersInBuffer(const MTL::CounterSampleBuffer* sampleBuffer, NS::UInteger sampleIndex, bool barrier) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(sampleCountersInBuffer_atSampleIndex_withBarrier_), sampleBuffer, sampleIndex, barrier); +} + +_MTL_INLINE void MTL::AccelerationStructureCommandEncoder::updateFence(const MTL::Fence* fence) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(updateFence_), fence); +} + +_MTL_INLINE void MTL::AccelerationStructureCommandEncoder::useHeap(const MTL::Heap* heap) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(useHeap_), heap); +} + +_MTL_INLINE void MTL::AccelerationStructureCommandEncoder::useHeaps(const MTL::Heap* const heaps[], NS::UInteger count) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(useHeaps_count_), heaps, count); +} + +_MTL_INLINE void MTL::AccelerationStructureCommandEncoder::useResource(const MTL::Resource* resource, MTL::ResourceUsage usage) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(useResource_usage_), resource, usage); +} + +_MTL_INLINE void MTL::AccelerationStructureCommandEncoder::useResources(const MTL::Resource* const resources[], NS::UInteger count, MTL::ResourceUsage usage) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(useResources_count_usage_), resources, count, usage); +} + +_MTL_INLINE void MTL::AccelerationStructureCommandEncoder::waitForFence(const MTL::Fence* fence) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(waitForFence_), fence); +} + +_MTL_INLINE void MTL::AccelerationStructureCommandEncoder::writeCompactedAccelerationStructureSize(const MTL::AccelerationStructure* accelerationStructure, const MTL::Buffer* buffer, NS::UInteger offset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(writeCompactedAccelerationStructureSize_toBuffer_offset_), accelerationStructure, buffer, offset); +} + +_MTL_INLINE void MTL::AccelerationStructureCommandEncoder::writeCompactedAccelerationStructureSize(const MTL::AccelerationStructure* accelerationStructure, const MTL::Buffer* buffer, NS::UInteger offset, MTL::DataType sizeDataType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(writeCompactedAccelerationStructureSize_toBuffer_offset_sizeDataType_), accelerationStructure, buffer, offset, sizeDataType); +} + +_MTL_INLINE MTL::AccelerationStructurePassSampleBufferAttachmentDescriptor* MTL::AccelerationStructurePassSampleBufferAttachmentDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLAccelerationStructurePassSampleBufferAttachmentDescriptor)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructurePassSampleBufferAttachmentDescriptor::endOfEncoderSampleIndex() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(endOfEncoderSampleIndex)); +} + +_MTL_INLINE MTL::AccelerationStructurePassSampleBufferAttachmentDescriptor* MTL::AccelerationStructurePassSampleBufferAttachmentDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::CounterSampleBuffer* MTL::AccelerationStructurePassSampleBufferAttachmentDescriptor::sampleBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(sampleBuffer)); +} + +_MTL_INLINE void MTL::AccelerationStructurePassSampleBufferAttachmentDescriptor::setEndOfEncoderSampleIndex(NS::UInteger endOfEncoderSampleIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setEndOfEncoderSampleIndex_), endOfEncoderSampleIndex); +} + +_MTL_INLINE void MTL::AccelerationStructurePassSampleBufferAttachmentDescriptor::setSampleBuffer(const MTL::CounterSampleBuffer* sampleBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSampleBuffer_), sampleBuffer); +} + +_MTL_INLINE void MTL::AccelerationStructurePassSampleBufferAttachmentDescriptor::setStartOfEncoderSampleIndex(NS::UInteger startOfEncoderSampleIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStartOfEncoderSampleIndex_), startOfEncoderSampleIndex); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructurePassSampleBufferAttachmentDescriptor::startOfEncoderSampleIndex() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(startOfEncoderSampleIndex)); +} + +_MTL_INLINE MTL::AccelerationStructurePassSampleBufferAttachmentDescriptorArray* MTL::AccelerationStructurePassSampleBufferAttachmentDescriptorArray::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLAccelerationStructurePassSampleBufferAttachmentDescriptorArray)); +} + +_MTL_INLINE MTL::AccelerationStructurePassSampleBufferAttachmentDescriptorArray* MTL::AccelerationStructurePassSampleBufferAttachmentDescriptorArray::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::AccelerationStructurePassSampleBufferAttachmentDescriptor* MTL::AccelerationStructurePassSampleBufferAttachmentDescriptorArray::object(NS::UInteger attachmentIndex) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(objectAtIndexedSubscript_), attachmentIndex); +} + +_MTL_INLINE void MTL::AccelerationStructurePassSampleBufferAttachmentDescriptorArray::setObject(const MTL::AccelerationStructurePassSampleBufferAttachmentDescriptor* attachment, NS::UInteger attachmentIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObject_atIndexedSubscript_), attachment, attachmentIndex); +} + +_MTL_INLINE MTL::AccelerationStructurePassDescriptor* MTL::AccelerationStructurePassDescriptor::accelerationStructurePassDescriptor() +{ + return Object::sendMessage(_MTL_PRIVATE_CLS(MTLAccelerationStructurePassDescriptor), _MTL_PRIVATE_SEL(accelerationStructurePassDescriptor)); +} + +_MTL_INLINE MTL::AccelerationStructurePassDescriptor* MTL::AccelerationStructurePassDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLAccelerationStructurePassDescriptor)); +} + +_MTL_INLINE MTL::AccelerationStructurePassDescriptor* MTL::AccelerationStructurePassDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::AccelerationStructurePassSampleBufferAttachmentDescriptorArray* MTL::AccelerationStructurePassDescriptor::sampleBufferAttachments() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(sampleBufferAttachments)); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLAccelerationStructureTypes.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLAccelerationStructureTypes.hpp new file mode 100644 index 00000000..a08b1e96 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLAccelerationStructureTypes.hpp @@ -0,0 +1,292 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLAccelerationStructureTypes.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "MTLDefines.hpp" +#include "MTLPrivate.hpp" +#include "MTLResource.hpp" +#include "MTLStageInputOutputDescriptor.hpp" + +#include "../Foundation/Foundation.hpp" +#include + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace MTL +{ + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wnested-anon-types" +struct PackedFloat3 +{ + PackedFloat3(); + PackedFloat3(float x, float y, float z); + + float& operator[](int idx); + float operator[](int idx) const; + + union + { + struct + { + float x; + float y; + float z; + }; + + float elements[3]; + }; +} _MTL_PACKED; +#pragma clang diagnostic pop + +struct PackedFloat4x3 +{ + PackedFloat4x3(); + PackedFloat4x3(const PackedFloat3& col0, const PackedFloat3& col1, const PackedFloat3& col2, const PackedFloat3& col3); + + PackedFloat3& operator[](int idx); + const PackedFloat3& operator[](int idx) const; + + PackedFloat3 columns[4]; +} _MTL_PACKED; + +struct AxisAlignedBoundingBox +{ + AxisAlignedBoundingBox(); + AxisAlignedBoundingBox(PackedFloat3 p); + AxisAlignedBoundingBox(PackedFloat3 min, PackedFloat3 max); + + PackedFloat3 min; + PackedFloat3 max; +} _MTL_PACKED; + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wnested-anon-types" +struct PackedFloatQuaternion +{ + PackedFloatQuaternion(); + PackedFloatQuaternion(float x, float y, float z, float w); + + float& operator[](int idx); + const float& operator[](int idx) const; + + union + { + struct + { + float x; + float y; + float z; + float w; + }; + + float elements[4]; + }; + +} _MTL_PACKED; +#pragma clang diagnostic pop + +struct ComponentTransform +{ + PackedFloat3 scale; + PackedFloat3 shear; + PackedFloat3 pivot; + PackedFloatQuaternion rotation; + PackedFloat3 translation; +} _MTL_PACKED; + +} + +namespace MTL4 +{ + +struct BufferRange +{ + BufferRange() = default; + BufferRange(uint64_t bufferAddress); + BufferRange(uint64_t bufferAddress, uint64_t length); + + static MTL4::BufferRange Make(uint64_t bufferAddress, uint64_t length); + + uint64_t bufferAddress; + uint64_t length; +} _MTL_PACKED; + +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTL_INLINE MTL::PackedFloat3::PackedFloat3() + : x(0.0f) + , y(0.0f) + , z(0.0f) +{ +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTL_INLINE MTL::PackedFloat3::PackedFloat3(float _x, float _y, float _z) + : x(_x) + , y(_y) + , z(_z) +{ +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTL_INLINE float& MTL::PackedFloat3::operator[](int idx) +{ + return elements[idx]; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTL_INLINE float MTL::PackedFloat3::operator[](int idx) const +{ + return elements[idx]; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTL_INLINE MTL::PackedFloat4x3::PackedFloat4x3() +{ + columns[0] = PackedFloat3(0.0f, 0.0f, 0.0f); + columns[1] = PackedFloat3(0.0f, 0.0f, 0.0f); + columns[2] = PackedFloat3(0.0f, 0.0f, 0.0f); + columns[3] = PackedFloat3(0.0f, 0.0f, 0.0f); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTL_INLINE MTL::PackedFloat4x3::PackedFloat4x3(const PackedFloat3& col0, const PackedFloat3& col1, const PackedFloat3& col2, const PackedFloat3& col3) +{ + columns[0] = col0; + columns[1] = col1; + columns[2] = col2; + columns[3] = col3; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTL_INLINE MTL::PackedFloat3& MTL::PackedFloat4x3::operator[](int idx) +{ + return columns[idx]; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTL_INLINE const MTL::PackedFloat3& MTL::PackedFloat4x3::operator[](int idx) const +{ + return columns[idx]; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#if __apple_build_version__ > 16000026 +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wnan-infinity-disabled" +#endif // __apple_build_version__ > 16000026 +_MTL_INLINE MTL::AxisAlignedBoundingBox::AxisAlignedBoundingBox() + : min(INFINITY, INFINITY, INFINITY) + , max(-INFINITY, -INFINITY, -INFINITY) +{ +} +#if __apple_build_version__ > 16000026 +#pragma clang diagnostic pop +#endif // if __apple_build_version__ > 16000026 + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTL_INLINE MTL::AxisAlignedBoundingBox::AxisAlignedBoundingBox(PackedFloat3 p) + : min(p) + , max(p) +{ +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTL_INLINE MTL::AxisAlignedBoundingBox::AxisAlignedBoundingBox(PackedFloat3 _min, PackedFloat3 _max) + : min(_min) + , max(_max) +{ +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTL_INLINE MTL::PackedFloatQuaternion::PackedFloatQuaternion() + : x(0.0f) + , y(0.0f) + , z(0.0f) + , w(0.0f) +{ +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTL_INLINE MTL::PackedFloatQuaternion::PackedFloatQuaternion(float x, float y, float z, float w) + : x(x) + , y(y) + , z(z) + , w(w) +{ +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTL_INLINE float& MTL::PackedFloatQuaternion::operator[](int idx) +{ + return elements[idx]; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTL_INLINE const float& MTL::PackedFloatQuaternion::operator[](int idx) const +{ + return elements[idx]; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTL_INLINE MTL4::BufferRange::BufferRange(uint64_t bufferAddress) +: bufferAddress(bufferAddress) +, length(-1) +{ +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTL_INLINE MTL4::BufferRange::BufferRange(uint64_t bufferAddress, uint64_t length) +: bufferAddress(bufferAddress) +, length(length) +{ +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTL_INLINE MTL4::BufferRange MTL4::BufferRange::Make(uint64_t bufferAddress, uint64_t length) +{ + return MTL4::BufferRange(bufferAddress, length); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLAllocation.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLAllocation.hpp new file mode 100644 index 00000000..ba201058 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLAllocation.hpp @@ -0,0 +1,40 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLAllocation.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL +{ +class Allocation : public NS::Referencing +{ +public: + NS::UInteger allocatedSize() const; +}; + +} +_MTL_INLINE NS::UInteger MTL::Allocation::allocatedSize() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(allocatedSize)); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLArgument.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLArgument.hpp new file mode 100644 index 00000000..f91bd917 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLArgument.hpp @@ -0,0 +1,787 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLArgument.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDataType.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLTensor.hpp" +#include "MTLTexture.hpp" + +namespace MTL +{ +class Argument; +class ArrayType; +class PointerType; +class StructMember; +class StructType; +class TensorExtents; +class TensorReferenceType; +class TextureReferenceType; +class Type; +_MTL_ENUM(NS::UInteger, IndexType) { + IndexTypeUInt16 = 0, + IndexTypeUInt32 = 1, +}; + +_MTL_ENUM(NS::Integer, BindingType) { + BindingTypeBuffer = 0, + BindingTypeThreadgroupMemory = 1, + BindingTypeTexture = 2, + BindingTypeSampler = 3, + BindingTypeImageblockData = 16, + BindingTypeImageblock = 17, + BindingTypeVisibleFunctionTable = 24, + BindingTypePrimitiveAccelerationStructure = 25, + BindingTypeInstanceAccelerationStructure = 26, + BindingTypeIntersectionFunctionTable = 27, + BindingTypeObjectPayload = 34, + BindingTypeTensor = 37, +}; + +_MTL_ENUM(NS::UInteger, ArgumentType) { + ArgumentTypeBuffer = 0, + ArgumentTypeThreadgroupMemory = 1, + ArgumentTypeTexture = 2, + ArgumentTypeSampler = 3, + ArgumentTypeImageblockData = 16, + ArgumentTypeImageblock = 17, + ArgumentTypeVisibleFunctionTable = 24, + ArgumentTypePrimitiveAccelerationStructure = 25, + ArgumentTypeInstanceAccelerationStructure = 26, + ArgumentTypeIntersectionFunctionTable = 27, +}; + +_MTL_ENUM(NS::UInteger, BindingAccess) { + BindingAccessReadOnly = 0, + BindingAccessReadWrite = 1, + BindingAccessWriteOnly = 2, + ArgumentAccessReadOnly = 0, + ArgumentAccessReadWrite = 1, + ArgumentAccessWriteOnly = 2, +}; + +class Type : public NS::Referencing +{ +public: + static Type* alloc(); + + DataType dataType() const; + + Type* init(); +}; +class StructMember : public NS::Referencing +{ +public: + static StructMember* alloc(); + + NS::UInteger argumentIndex() const; + + ArrayType* arrayType(); + + DataType dataType() const; + + StructMember* init(); + + NS::String* name() const; + + NS::UInteger offset() const; + + PointerType* pointerType(); + + StructType* structType(); + + TensorReferenceType* tensorReferenceType(); + + TextureReferenceType* textureReferenceType(); +}; +class StructType : public NS::Referencing +{ +public: + static StructType* alloc(); + + StructType* init(); + + StructMember* memberByName(const NS::String* name); + + NS::Array* members() const; +}; +class ArrayType : public NS::Referencing +{ +public: + static ArrayType* alloc(); + + NS::UInteger argumentIndexStride() const; + + NS::UInteger arrayLength() const; + + ArrayType* elementArrayType(); + + PointerType* elementPointerType(); + + StructType* elementStructType(); + + TensorReferenceType* elementTensorReferenceType(); + + TextureReferenceType* elementTextureReferenceType(); + + DataType elementType() const; + + ArrayType* init(); + + NS::UInteger stride() const; +}; +class PointerType : public NS::Referencing +{ +public: + BindingAccess access() const; + + NS::UInteger alignment() const; + + static PointerType* alloc(); + + NS::UInteger dataSize() const; + + ArrayType* elementArrayType(); + + bool elementIsArgumentBuffer() const; + + StructType* elementStructType(); + + DataType elementType() const; + + PointerType* init(); +}; +class TextureReferenceType : public NS::Referencing +{ +public: + BindingAccess access() const; + + static TextureReferenceType* alloc(); + + TextureReferenceType* init(); + + bool isDepthTexture() const; + + DataType textureDataType() const; + + TextureType textureType() const; +}; +class TensorReferenceType : public NS::Referencing +{ +public: + BindingAccess access() const; + + static TensorReferenceType* alloc(); + + TensorExtents* dimensions() const; + + DataType indexType() const; + + TensorReferenceType* init(); + + TensorDataType tensorDataType() const; +}; +class Argument : public NS::Referencing +{ +public: + BindingAccess access() const; + + [[deprecated("please use isActive instead")]] + bool active() const; + + static Argument* alloc(); + + NS::UInteger arrayLength() const; + + NS::UInteger bufferAlignment() const; + + NS::UInteger bufferDataSize() const; + + DataType bufferDataType() const; + + PointerType* bufferPointerType() const; + + StructType* bufferStructType() const; + + NS::UInteger index() const; + + Argument* init(); + + bool isActive() const; + + bool isDepthTexture() const; + + NS::String* name() const; + + DataType textureDataType() const; + + TextureType textureType() const; + + NS::UInteger threadgroupMemoryAlignment() const; + + NS::UInteger threadgroupMemoryDataSize() const; + + ArgumentType type() const; +}; +class Binding : public NS::Referencing +{ +public: + BindingAccess access() const; + + [[deprecated("please use isArgument instead")]] + bool argument() const; + + NS::UInteger index() const; + + bool isArgument() const; + + bool isUsed() const; + + NS::String* name() const; + + BindingType type() const; + + [[deprecated("please use isUsed instead")]] + bool used() const; +}; +class BufferBinding : public NS::Referencing +{ +public: + NS::UInteger bufferAlignment() const; + + NS::UInteger bufferDataSize() const; + + DataType bufferDataType() const; + + PointerType* bufferPointerType() const; + + StructType* bufferStructType() const; +}; +class ThreadgroupBinding : public NS::Referencing +{ +public: + NS::UInteger threadgroupMemoryAlignment() const; + + NS::UInteger threadgroupMemoryDataSize() const; +}; +class TextureBinding : public NS::Referencing +{ +public: + NS::UInteger arrayLength() const; + + [[deprecated("please use isDepthTexture instead")]] + bool depthTexture() const; + bool isDepthTexture() const; + + DataType textureDataType() const; + + TextureType textureType() const; +}; +class ObjectPayloadBinding : public NS::Referencing +{ +public: + NS::UInteger objectPayloadAlignment() const; + + NS::UInteger objectPayloadDataSize() const; +}; +class TensorBinding : public NS::Referencing +{ +public: + TensorExtents* dimensions() const; + + DataType indexType() const; + + TensorDataType tensorDataType() const; +}; + +} +_MTL_INLINE MTL::Type* MTL::Type::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLType)); +} + +_MTL_INLINE MTL::DataType MTL::Type::dataType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(dataType)); +} + +_MTL_INLINE MTL::Type* MTL::Type::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::StructMember* MTL::StructMember::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLStructMember)); +} + +_MTL_INLINE NS::UInteger MTL::StructMember::argumentIndex() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(argumentIndex)); +} + +_MTL_INLINE MTL::ArrayType* MTL::StructMember::arrayType() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(arrayType)); +} + +_MTL_INLINE MTL::DataType MTL::StructMember::dataType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(dataType)); +} + +_MTL_INLINE MTL::StructMember* MTL::StructMember::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::String* MTL::StructMember::name() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(name)); +} + +_MTL_INLINE NS::UInteger MTL::StructMember::offset() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(offset)); +} + +_MTL_INLINE MTL::PointerType* MTL::StructMember::pointerType() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(pointerType)); +} + +_MTL_INLINE MTL::StructType* MTL::StructMember::structType() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(structType)); +} + +_MTL_INLINE MTL::TensorReferenceType* MTL::StructMember::tensorReferenceType() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(tensorReferenceType)); +} + +_MTL_INLINE MTL::TextureReferenceType* MTL::StructMember::textureReferenceType() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(textureReferenceType)); +} + +_MTL_INLINE MTL::StructType* MTL::StructType::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLStructType)); +} + +_MTL_INLINE MTL::StructType* MTL::StructType::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::StructMember* MTL::StructType::memberByName(const NS::String* name) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(memberByName_), name); +} + +_MTL_INLINE NS::Array* MTL::StructType::members() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(members)); +} + +_MTL_INLINE MTL::ArrayType* MTL::ArrayType::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLArrayType)); +} + +_MTL_INLINE NS::UInteger MTL::ArrayType::argumentIndexStride() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(argumentIndexStride)); +} + +_MTL_INLINE NS::UInteger MTL::ArrayType::arrayLength() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(arrayLength)); +} + +_MTL_INLINE MTL::ArrayType* MTL::ArrayType::elementArrayType() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(elementArrayType)); +} + +_MTL_INLINE MTL::PointerType* MTL::ArrayType::elementPointerType() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(elementPointerType)); +} + +_MTL_INLINE MTL::StructType* MTL::ArrayType::elementStructType() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(elementStructType)); +} + +_MTL_INLINE MTL::TensorReferenceType* MTL::ArrayType::elementTensorReferenceType() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(elementTensorReferenceType)); +} + +_MTL_INLINE MTL::TextureReferenceType* MTL::ArrayType::elementTextureReferenceType() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(elementTextureReferenceType)); +} + +_MTL_INLINE MTL::DataType MTL::ArrayType::elementType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(elementType)); +} + +_MTL_INLINE MTL::ArrayType* MTL::ArrayType::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::UInteger MTL::ArrayType::stride() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(stride)); +} + +_MTL_INLINE MTL::BindingAccess MTL::PointerType::access() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(access)); +} + +_MTL_INLINE NS::UInteger MTL::PointerType::alignment() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(alignment)); +} + +_MTL_INLINE MTL::PointerType* MTL::PointerType::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLPointerType)); +} + +_MTL_INLINE NS::UInteger MTL::PointerType::dataSize() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(dataSize)); +} + +_MTL_INLINE MTL::ArrayType* MTL::PointerType::elementArrayType() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(elementArrayType)); +} + +_MTL_INLINE bool MTL::PointerType::elementIsArgumentBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(elementIsArgumentBuffer)); +} + +_MTL_INLINE MTL::StructType* MTL::PointerType::elementStructType() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(elementStructType)); +} + +_MTL_INLINE MTL::DataType MTL::PointerType::elementType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(elementType)); +} + +_MTL_INLINE MTL::PointerType* MTL::PointerType::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::BindingAccess MTL::TextureReferenceType::access() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(access)); +} + +_MTL_INLINE MTL::TextureReferenceType* MTL::TextureReferenceType::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLTextureReferenceType)); +} + +_MTL_INLINE MTL::TextureReferenceType* MTL::TextureReferenceType::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE bool MTL::TextureReferenceType::isDepthTexture() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isDepthTexture)); +} + +_MTL_INLINE MTL::DataType MTL::TextureReferenceType::textureDataType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(textureDataType)); +} + +_MTL_INLINE MTL::TextureType MTL::TextureReferenceType::textureType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(textureType)); +} + +_MTL_INLINE MTL::BindingAccess MTL::TensorReferenceType::access() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(access)); +} + +_MTL_INLINE MTL::TensorReferenceType* MTL::TensorReferenceType::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLTensorReferenceType)); +} + +_MTL_INLINE MTL::TensorExtents* MTL::TensorReferenceType::dimensions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(dimensions)); +} + +_MTL_INLINE MTL::DataType MTL::TensorReferenceType::indexType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(indexType)); +} + +_MTL_INLINE MTL::TensorReferenceType* MTL::TensorReferenceType::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::TensorDataType MTL::TensorReferenceType::tensorDataType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(tensorDataType)); +} + +_MTL_INLINE MTL::BindingAccess MTL::Argument::access() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(access)); +} + +_MTL_INLINE bool MTL::Argument::active() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isActive)); +} + +_MTL_INLINE MTL::Argument* MTL::Argument::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLArgument)); +} + +_MTL_INLINE NS::UInteger MTL::Argument::arrayLength() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(arrayLength)); +} + +_MTL_INLINE NS::UInteger MTL::Argument::bufferAlignment() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(bufferAlignment)); +} + +_MTL_INLINE NS::UInteger MTL::Argument::bufferDataSize() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(bufferDataSize)); +} + +_MTL_INLINE MTL::DataType MTL::Argument::bufferDataType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(bufferDataType)); +} + +_MTL_INLINE MTL::PointerType* MTL::Argument::bufferPointerType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(bufferPointerType)); +} + +_MTL_INLINE MTL::StructType* MTL::Argument::bufferStructType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(bufferStructType)); +} + +_MTL_INLINE NS::UInteger MTL::Argument::index() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(index)); +} + +_MTL_INLINE MTL::Argument* MTL::Argument::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE bool MTL::Argument::isActive() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isActive)); +} + +_MTL_INLINE bool MTL::Argument::isDepthTexture() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isDepthTexture)); +} + +_MTL_INLINE NS::String* MTL::Argument::name() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(name)); +} + +_MTL_INLINE MTL::DataType MTL::Argument::textureDataType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(textureDataType)); +} + +_MTL_INLINE MTL::TextureType MTL::Argument::textureType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(textureType)); +} + +_MTL_INLINE NS::UInteger MTL::Argument::threadgroupMemoryAlignment() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(threadgroupMemoryAlignment)); +} + +_MTL_INLINE NS::UInteger MTL::Argument::threadgroupMemoryDataSize() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(threadgroupMemoryDataSize)); +} + +_MTL_INLINE MTL::ArgumentType MTL::Argument::type() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(type)); +} + +_MTL_INLINE MTL::BindingAccess MTL::Binding::access() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(access)); +} + +_MTL_INLINE bool MTL::Binding::argument() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isArgument)); +} + +_MTL_INLINE NS::UInteger MTL::Binding::index() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(index)); +} + +_MTL_INLINE bool MTL::Binding::isArgument() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isArgument)); +} + +_MTL_INLINE bool MTL::Binding::isUsed() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isUsed)); +} + +_MTL_INLINE NS::String* MTL::Binding::name() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(name)); +} + +_MTL_INLINE MTL::BindingType MTL::Binding::type() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(type)); +} + +_MTL_INLINE bool MTL::Binding::used() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isUsed)); +} + +_MTL_INLINE NS::UInteger MTL::BufferBinding::bufferAlignment() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(bufferAlignment)); +} + +_MTL_INLINE NS::UInteger MTL::BufferBinding::bufferDataSize() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(bufferDataSize)); +} + +_MTL_INLINE MTL::DataType MTL::BufferBinding::bufferDataType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(bufferDataType)); +} + +_MTL_INLINE MTL::PointerType* MTL::BufferBinding::bufferPointerType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(bufferPointerType)); +} + +_MTL_INLINE MTL::StructType* MTL::BufferBinding::bufferStructType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(bufferStructType)); +} + +_MTL_INLINE NS::UInteger MTL::ThreadgroupBinding::threadgroupMemoryAlignment() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(threadgroupMemoryAlignment)); +} + +_MTL_INLINE NS::UInteger MTL::ThreadgroupBinding::threadgroupMemoryDataSize() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(threadgroupMemoryDataSize)); +} + +_MTL_INLINE NS::UInteger MTL::TextureBinding::arrayLength() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(arrayLength)); +} + +_MTL_INLINE bool MTL::TextureBinding::depthTexture() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isDepthTexture)); +} + +_MTL_INLINE bool MTL::TextureBinding::isDepthTexture() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isDepthTexture)); +} + +_MTL_INLINE MTL::DataType MTL::TextureBinding::textureDataType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(textureDataType)); +} + +_MTL_INLINE MTL::TextureType MTL::TextureBinding::textureType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(textureType)); +} + +_MTL_INLINE NS::UInteger MTL::ObjectPayloadBinding::objectPayloadAlignment() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(objectPayloadAlignment)); +} + +_MTL_INLINE NS::UInteger MTL::ObjectPayloadBinding::objectPayloadDataSize() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(objectPayloadDataSize)); +} + +_MTL_INLINE MTL::TensorExtents* MTL::TensorBinding::dimensions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(dimensions)); +} + +_MTL_INLINE MTL::DataType MTL::TensorBinding::indexType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(indexType)); +} + +_MTL_INLINE MTL::TensorDataType MTL::TensorBinding::tensorDataType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(tensorDataType)); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLArgumentEncoder.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLArgumentEncoder.hpp new file mode 100644 index 00000000..83dbbc20 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLArgumentEncoder.hpp @@ -0,0 +1,235 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLArgumentEncoder.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLDepthStencil.hpp" + +namespace MTL +{ +class AccelerationStructure; +class ArgumentEncoder; +class Buffer; +class ComputePipelineState; +class Device; +class IndirectCommandBuffer; +class IntersectionFunctionTable; +class RenderPipelineState; +class SamplerState; +class Texture; +class VisibleFunctionTable; + +static const NS::UInteger AttributeStrideStatic = NS::UIntegerMax; + +class ArgumentEncoder : public NS::Referencing +{ +public: + NS::UInteger alignment() const; + + void* constantData(NS::UInteger index); + + Device* device() const; + + NS::UInteger encodedLength() const; + + NS::String* label() const; + + ArgumentEncoder* newArgumentEncoder(NS::UInteger index); + + void setAccelerationStructure(const MTL::AccelerationStructure* accelerationStructure, NS::UInteger index); + + void setArgumentBuffer(const MTL::Buffer* argumentBuffer, NS::UInteger offset); + void setArgumentBuffer(const MTL::Buffer* argumentBuffer, NS::UInteger startOffset, NS::UInteger arrayElement); + + void setBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger index); + void setBuffers(const MTL::Buffer* const buffers[], const NS::UInteger offsets[], NS::Range range); + + void setComputePipelineState(const MTL::ComputePipelineState* pipeline, NS::UInteger index); + void setComputePipelineStates(const MTL::ComputePipelineState* const pipelines[], NS::Range range); + + void setDepthStencilState(const MTL::DepthStencilState* depthStencilState, NS::UInteger index); + void setDepthStencilStates(const MTL::DepthStencilState* const depthStencilStates[], NS::Range range); + + void setIndirectCommandBuffer(const MTL::IndirectCommandBuffer* indirectCommandBuffer, NS::UInteger index); + void setIndirectCommandBuffers(const MTL::IndirectCommandBuffer* const buffers[], NS::Range range); + + void setIntersectionFunctionTable(const MTL::IntersectionFunctionTable* intersectionFunctionTable, NS::UInteger index); + void setIntersectionFunctionTables(const MTL::IntersectionFunctionTable* const intersectionFunctionTables[], NS::Range range); + + void setLabel(const NS::String* label); + + void setRenderPipelineState(const MTL::RenderPipelineState* pipeline, NS::UInteger index); + void setRenderPipelineStates(const MTL::RenderPipelineState* const pipelines[], NS::Range range); + + void setSamplerState(const MTL::SamplerState* sampler, NS::UInteger index); + void setSamplerStates(const MTL::SamplerState* const samplers[], NS::Range range); + + void setTexture(const MTL::Texture* texture, NS::UInteger index); + void setTextures(const MTL::Texture* const textures[], NS::Range range); + + void setVisibleFunctionTable(const MTL::VisibleFunctionTable* visibleFunctionTable, NS::UInteger index); + void setVisibleFunctionTables(const MTL::VisibleFunctionTable* const visibleFunctionTables[], NS::Range range); +}; + +} + +_MTL_INLINE NS::UInteger MTL::ArgumentEncoder::alignment() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(alignment)); +} + +_MTL_INLINE void* MTL::ArgumentEncoder::constantData(NS::UInteger index) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(constantDataAtIndex_), index); +} + +_MTL_INLINE MTL::Device* MTL::ArgumentEncoder::device() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(device)); +} + +_MTL_INLINE NS::UInteger MTL::ArgumentEncoder::encodedLength() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(encodedLength)); +} + +_MTL_INLINE NS::String* MTL::ArgumentEncoder::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE MTL::ArgumentEncoder* MTL::ArgumentEncoder::newArgumentEncoder(NS::UInteger index) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newArgumentEncoderForBufferAtIndex_), index); +} + +_MTL_INLINE void MTL::ArgumentEncoder::setAccelerationStructure(const MTL::AccelerationStructure* accelerationStructure, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setAccelerationStructure_atIndex_), accelerationStructure, index); +} + +_MTL_INLINE void MTL::ArgumentEncoder::setArgumentBuffer(const MTL::Buffer* argumentBuffer, NS::UInteger offset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setArgumentBuffer_offset_), argumentBuffer, offset); +} + +_MTL_INLINE void MTL::ArgumentEncoder::setArgumentBuffer(const MTL::Buffer* argumentBuffer, NS::UInteger startOffset, NS::UInteger arrayElement) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setArgumentBuffer_startOffset_arrayElement_), argumentBuffer, startOffset, arrayElement); +} + +_MTL_INLINE void MTL::ArgumentEncoder::setBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBuffer_offset_atIndex_), buffer, offset, index); +} + +_MTL_INLINE void MTL::ArgumentEncoder::setBuffers(const MTL::Buffer* const buffers[], const NS::UInteger offsets[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBuffers_offsets_withRange_), buffers, offsets, range); +} + +_MTL_INLINE void MTL::ArgumentEncoder::setComputePipelineState(const MTL::ComputePipelineState* pipeline, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setComputePipelineState_atIndex_), pipeline, index); +} + +_MTL_INLINE void MTL::ArgumentEncoder::setComputePipelineStates(const MTL::ComputePipelineState* const pipelines[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setComputePipelineStates_withRange_), pipelines, range); +} + +_MTL_INLINE void MTL::ArgumentEncoder::setDepthStencilState(const MTL::DepthStencilState* depthStencilState, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDepthStencilState_atIndex_), depthStencilState, index); +} + +_MTL_INLINE void MTL::ArgumentEncoder::setDepthStencilStates(const MTL::DepthStencilState* const depthStencilStates[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDepthStencilStates_withRange_), depthStencilStates, range); +} + +_MTL_INLINE void MTL::ArgumentEncoder::setIndirectCommandBuffer(const MTL::IndirectCommandBuffer* indirectCommandBuffer, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setIndirectCommandBuffer_atIndex_), indirectCommandBuffer, index); +} + +_MTL_INLINE void MTL::ArgumentEncoder::setIndirectCommandBuffers(const MTL::IndirectCommandBuffer* const buffers[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setIndirectCommandBuffers_withRange_), buffers, range); +} + +_MTL_INLINE void MTL::ArgumentEncoder::setIntersectionFunctionTable(const MTL::IntersectionFunctionTable* intersectionFunctionTable, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setIntersectionFunctionTable_atIndex_), intersectionFunctionTable, index); +} + +_MTL_INLINE void MTL::ArgumentEncoder::setIntersectionFunctionTables(const MTL::IntersectionFunctionTable* const intersectionFunctionTables[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setIntersectionFunctionTables_withRange_), intersectionFunctionTables, range); +} + +_MTL_INLINE void MTL::ArgumentEncoder::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +_MTL_INLINE void MTL::ArgumentEncoder::setRenderPipelineState(const MTL::RenderPipelineState* pipeline, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRenderPipelineState_atIndex_), pipeline, index); +} + +_MTL_INLINE void MTL::ArgumentEncoder::setRenderPipelineStates(const MTL::RenderPipelineState* const pipelines[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRenderPipelineStates_withRange_), pipelines, range); +} + +_MTL_INLINE void MTL::ArgumentEncoder::setSamplerState(const MTL::SamplerState* sampler, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSamplerState_atIndex_), sampler, index); +} + +_MTL_INLINE void MTL::ArgumentEncoder::setSamplerStates(const MTL::SamplerState* const samplers[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSamplerStates_withRange_), samplers, range); +} + +_MTL_INLINE void MTL::ArgumentEncoder::setTexture(const MTL::Texture* texture, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTexture_atIndex_), texture, index); +} + +_MTL_INLINE void MTL::ArgumentEncoder::setTextures(const MTL::Texture* const textures[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTextures_withRange_), textures, range); +} + +_MTL_INLINE void MTL::ArgumentEncoder::setVisibleFunctionTable(const MTL::VisibleFunctionTable* visibleFunctionTable, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVisibleFunctionTable_atIndex_), visibleFunctionTable, index); +} + +_MTL_INLINE void MTL::ArgumentEncoder::setVisibleFunctionTables(const MTL::VisibleFunctionTable* const visibleFunctionTables[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVisibleFunctionTables_withRange_), visibleFunctionTables, range); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLBinaryArchive.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLBinaryArchive.hpp new file mode 100644 index 00000000..c3f16895 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLBinaryArchive.hpp @@ -0,0 +1,152 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLBinaryArchive.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL +{ +class BinaryArchiveDescriptor; +class ComputePipelineDescriptor; +class Device; +class FunctionDescriptor; +class Library; +class MeshRenderPipelineDescriptor; +class RenderPipelineDescriptor; +class StitchedLibraryDescriptor; +class TileRenderPipelineDescriptor; +_MTL_ENUM(NS::UInteger, BinaryArchiveError) { + BinaryArchiveErrorNone = 0, + BinaryArchiveErrorInvalidFile = 1, + BinaryArchiveErrorUnexpectedElement = 2, + BinaryArchiveErrorCompilationFailure = 3, + BinaryArchiveErrorInternalError = 4, +}; + +_MTL_CONST(NS::ErrorDomain, BinaryArchiveDomain); +class BinaryArchiveDescriptor : public NS::Copying +{ +public: + static BinaryArchiveDescriptor* alloc(); + + BinaryArchiveDescriptor* init(); + + void setUrl(const NS::URL* url); + NS::URL* url() const; +}; +class BinaryArchive : public NS::Referencing +{ +public: + bool addComputePipelineFunctions(const MTL::ComputePipelineDescriptor* descriptor, NS::Error** error); + + bool addFunction(const MTL::FunctionDescriptor* descriptor, const MTL::Library* library, NS::Error** error); + + bool addLibrary(const MTL::StitchedLibraryDescriptor* descriptor, NS::Error** error); + + bool addMeshRenderPipelineFunctions(const MTL::MeshRenderPipelineDescriptor* descriptor, NS::Error** error); + + bool addRenderPipelineFunctions(const MTL::RenderPipelineDescriptor* descriptor, NS::Error** error); + + bool addTileRenderPipelineFunctions(const MTL::TileRenderPipelineDescriptor* descriptor, NS::Error** error); + + Device* device() const; + + NS::String* label() const; + + bool serializeToURL(const NS::URL* url, NS::Error** error); + + void setLabel(const NS::String* label); +}; + +} +_MTL_PRIVATE_DEF_CONST(NS::ErrorDomain, BinaryArchiveDomain); +_MTL_INLINE MTL::BinaryArchiveDescriptor* MTL::BinaryArchiveDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLBinaryArchiveDescriptor)); +} + +_MTL_INLINE MTL::BinaryArchiveDescriptor* MTL::BinaryArchiveDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE void MTL::BinaryArchiveDescriptor::setUrl(const NS::URL* url) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setUrl_), url); +} + +_MTL_INLINE NS::URL* MTL::BinaryArchiveDescriptor::url() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(url)); +} + +_MTL_INLINE bool MTL::BinaryArchive::addComputePipelineFunctions(const MTL::ComputePipelineDescriptor* descriptor, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(addComputePipelineFunctionsWithDescriptor_error_), descriptor, error); +} + +_MTL_INLINE bool MTL::BinaryArchive::addFunction(const MTL::FunctionDescriptor* descriptor, const MTL::Library* library, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(addFunctionWithDescriptor_library_error_), descriptor, library, error); +} + +_MTL_INLINE bool MTL::BinaryArchive::addLibrary(const MTL::StitchedLibraryDescriptor* descriptor, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(addLibraryWithDescriptor_error_), descriptor, error); +} + +_MTL_INLINE bool MTL::BinaryArchive::addMeshRenderPipelineFunctions(const MTL::MeshRenderPipelineDescriptor* descriptor, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(addMeshRenderPipelineFunctionsWithDescriptor_error_), descriptor, error); +} + +_MTL_INLINE bool MTL::BinaryArchive::addRenderPipelineFunctions(const MTL::RenderPipelineDescriptor* descriptor, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(addRenderPipelineFunctionsWithDescriptor_error_), descriptor, error); +} + +_MTL_INLINE bool MTL::BinaryArchive::addTileRenderPipelineFunctions(const MTL::TileRenderPipelineDescriptor* descriptor, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(addTileRenderPipelineFunctionsWithDescriptor_error_), descriptor, error); +} + +_MTL_INLINE MTL::Device* MTL::BinaryArchive::device() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(device)); +} + +_MTL_INLINE NS::String* MTL::BinaryArchive::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE bool MTL::BinaryArchive::serializeToURL(const NS::URL* url, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(serializeToURL_error_), url, error); +} + +_MTL_INLINE void MTL::BinaryArchive::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLBlitCommandEncoder.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLBlitCommandEncoder.hpp new file mode 100644 index 00000000..319f05ef --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLBlitCommandEncoder.hpp @@ -0,0 +1,226 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLBlitCommandEncoder.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLCommandEncoder.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLTypes.hpp" +#include + +namespace MTL +{ +class Buffer; +class CounterSampleBuffer; +class Fence; +class IndirectCommandBuffer; +class Resource; +class Tensor; +class TensorExtents; +class Texture; + +_MTL_OPTIONS(NS::UInteger, BlitOption) { + BlitOptionNone = 0, + BlitOptionDepthFromDepthStencil = 1, + BlitOptionStencilFromDepthStencil = 1 << 1, + BlitOptionRowLinearPVRTC = 1 << 2, +}; + +class BlitCommandEncoder : public NS::Referencing +{ +public: + void copyFromBuffer(const MTL::Buffer* sourceBuffer, NS::UInteger sourceOffset, NS::UInteger sourceBytesPerRow, NS::UInteger sourceBytesPerImage, MTL::Size sourceSize, const MTL::Texture* destinationTexture, NS::UInteger destinationSlice, NS::UInteger destinationLevel, MTL::Origin destinationOrigin); + void copyFromBuffer(const MTL::Buffer* sourceBuffer, NS::UInteger sourceOffset, NS::UInteger sourceBytesPerRow, NS::UInteger sourceBytesPerImage, MTL::Size sourceSize, const MTL::Texture* destinationTexture, NS::UInteger destinationSlice, NS::UInteger destinationLevel, MTL::Origin destinationOrigin, MTL::BlitOption options); + void copyFromBuffer(const MTL::Buffer* sourceBuffer, NS::UInteger sourceOffset, const MTL::Buffer* destinationBuffer, NS::UInteger destinationOffset, NS::UInteger size); + + void copyFromTensor(const MTL::Tensor* sourceTensor, const MTL::TensorExtents* sourceOrigin, const MTL::TensorExtents* sourceDimensions, const MTL::Tensor* destinationTensor, const MTL::TensorExtents* destinationOrigin, const MTL::TensorExtents* destinationDimensions); + + void copyFromTexture(const MTL::Texture* sourceTexture, NS::UInteger sourceSlice, NS::UInteger sourceLevel, MTL::Origin sourceOrigin, MTL::Size sourceSize, const MTL::Texture* destinationTexture, NS::UInteger destinationSlice, NS::UInteger destinationLevel, MTL::Origin destinationOrigin); + void copyFromTexture(const MTL::Texture* sourceTexture, NS::UInteger sourceSlice, NS::UInteger sourceLevel, MTL::Origin sourceOrigin, MTL::Size sourceSize, const MTL::Buffer* destinationBuffer, NS::UInteger destinationOffset, NS::UInteger destinationBytesPerRow, NS::UInteger destinationBytesPerImage); + void copyFromTexture(const MTL::Texture* sourceTexture, NS::UInteger sourceSlice, NS::UInteger sourceLevel, MTL::Origin sourceOrigin, MTL::Size sourceSize, const MTL::Buffer* destinationBuffer, NS::UInteger destinationOffset, NS::UInteger destinationBytesPerRow, NS::UInteger destinationBytesPerImage, MTL::BlitOption options); + void copyFromTexture(const MTL::Texture* sourceTexture, NS::UInteger sourceSlice, NS::UInteger sourceLevel, const MTL::Texture* destinationTexture, NS::UInteger destinationSlice, NS::UInteger destinationLevel, NS::UInteger sliceCount, NS::UInteger levelCount); + void copyFromTexture(const MTL::Texture* sourceTexture, const MTL::Texture* destinationTexture); + + void copyIndirectCommandBuffer(const MTL::IndirectCommandBuffer* source, NS::Range sourceRange, const MTL::IndirectCommandBuffer* destination, NS::UInteger destinationIndex); + + void fillBuffer(const MTL::Buffer* buffer, NS::Range range, uint8_t value); + + void generateMipmaps(const MTL::Texture* texture); + + void getTextureAccessCounters(const MTL::Texture* texture, MTL::Region region, NS::UInteger mipLevel, NS::UInteger slice, bool resetCounters, const MTL::Buffer* countersBuffer, NS::UInteger countersBufferOffset); + + void optimizeContentsForCPUAccess(const MTL::Texture* texture); + void optimizeContentsForCPUAccess(const MTL::Texture* texture, NS::UInteger slice, NS::UInteger level); + + void optimizeContentsForGPUAccess(const MTL::Texture* texture); + void optimizeContentsForGPUAccess(const MTL::Texture* texture, NS::UInteger slice, NS::UInteger level); + + void optimizeIndirectCommandBuffer(const MTL::IndirectCommandBuffer* indirectCommandBuffer, NS::Range range); + + void resetCommandsInBuffer(const MTL::IndirectCommandBuffer* buffer, NS::Range range); + + void resetTextureAccessCounters(const MTL::Texture* texture, MTL::Region region, NS::UInteger mipLevel, NS::UInteger slice); + + void resolveCounters(const MTL::CounterSampleBuffer* sampleBuffer, NS::Range range, const MTL::Buffer* destinationBuffer, NS::UInteger destinationOffset); + + void sampleCountersInBuffer(const MTL::CounterSampleBuffer* sampleBuffer, NS::UInteger sampleIndex, bool barrier); + + void synchronizeResource(const MTL::Resource* resource); + + void synchronizeTexture(const MTL::Texture* texture, NS::UInteger slice, NS::UInteger level); + + void updateFence(const MTL::Fence* fence); + + void waitForFence(const MTL::Fence* fence); +}; + +} +_MTL_INLINE void MTL::BlitCommandEncoder::copyFromBuffer(const MTL::Buffer* sourceBuffer, NS::UInteger sourceOffset, NS::UInteger sourceBytesPerRow, NS::UInteger sourceBytesPerImage, MTL::Size sourceSize, const MTL::Texture* destinationTexture, NS::UInteger destinationSlice, NS::UInteger destinationLevel, MTL::Origin destinationOrigin) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(copyFromBuffer_sourceOffset_sourceBytesPerRow_sourceBytesPerImage_sourceSize_toTexture_destinationSlice_destinationLevel_destinationOrigin_), sourceBuffer, sourceOffset, sourceBytesPerRow, sourceBytesPerImage, sourceSize, destinationTexture, destinationSlice, destinationLevel, destinationOrigin); +} + +_MTL_INLINE void MTL::BlitCommandEncoder::copyFromBuffer(const MTL::Buffer* sourceBuffer, NS::UInteger sourceOffset, NS::UInteger sourceBytesPerRow, NS::UInteger sourceBytesPerImage, MTL::Size sourceSize, const MTL::Texture* destinationTexture, NS::UInteger destinationSlice, NS::UInteger destinationLevel, MTL::Origin destinationOrigin, MTL::BlitOption options) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(copyFromBuffer_sourceOffset_sourceBytesPerRow_sourceBytesPerImage_sourceSize_toTexture_destinationSlice_destinationLevel_destinationOrigin_options_), sourceBuffer, sourceOffset, sourceBytesPerRow, sourceBytesPerImage, sourceSize, destinationTexture, destinationSlice, destinationLevel, destinationOrigin, options); +} + +_MTL_INLINE void MTL::BlitCommandEncoder::copyFromBuffer(const MTL::Buffer* sourceBuffer, NS::UInteger sourceOffset, const MTL::Buffer* destinationBuffer, NS::UInteger destinationOffset, NS::UInteger size) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(copyFromBuffer_sourceOffset_toBuffer_destinationOffset_size_), sourceBuffer, sourceOffset, destinationBuffer, destinationOffset, size); +} + +_MTL_INLINE void MTL::BlitCommandEncoder::copyFromTensor(const MTL::Tensor* sourceTensor, const MTL::TensorExtents* sourceOrigin, const MTL::TensorExtents* sourceDimensions, const MTL::Tensor* destinationTensor, const MTL::TensorExtents* destinationOrigin, const MTL::TensorExtents* destinationDimensions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(copyFromTensor_sourceOrigin_sourceDimensions_toTensor_destinationOrigin_destinationDimensions_), sourceTensor, sourceOrigin, sourceDimensions, destinationTensor, destinationOrigin, destinationDimensions); +} + +_MTL_INLINE void MTL::BlitCommandEncoder::copyFromTexture(const MTL::Texture* sourceTexture, NS::UInteger sourceSlice, NS::UInteger sourceLevel, MTL::Origin sourceOrigin, MTL::Size sourceSize, const MTL::Texture* destinationTexture, NS::UInteger destinationSlice, NS::UInteger destinationLevel, MTL::Origin destinationOrigin) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(copyFromTexture_sourceSlice_sourceLevel_sourceOrigin_sourceSize_toTexture_destinationSlice_destinationLevel_destinationOrigin_), sourceTexture, sourceSlice, sourceLevel, sourceOrigin, sourceSize, destinationTexture, destinationSlice, destinationLevel, destinationOrigin); +} + +_MTL_INLINE void MTL::BlitCommandEncoder::copyFromTexture(const MTL::Texture* sourceTexture, NS::UInteger sourceSlice, NS::UInteger sourceLevel, MTL::Origin sourceOrigin, MTL::Size sourceSize, const MTL::Buffer* destinationBuffer, NS::UInteger destinationOffset, NS::UInteger destinationBytesPerRow, NS::UInteger destinationBytesPerImage) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(copyFromTexture_sourceSlice_sourceLevel_sourceOrigin_sourceSize_toBuffer_destinationOffset_destinationBytesPerRow_destinationBytesPerImage_), sourceTexture, sourceSlice, sourceLevel, sourceOrigin, sourceSize, destinationBuffer, destinationOffset, destinationBytesPerRow, destinationBytesPerImage); +} + +_MTL_INLINE void MTL::BlitCommandEncoder::copyFromTexture(const MTL::Texture* sourceTexture, NS::UInteger sourceSlice, NS::UInteger sourceLevel, MTL::Origin sourceOrigin, MTL::Size sourceSize, const MTL::Buffer* destinationBuffer, NS::UInteger destinationOffset, NS::UInteger destinationBytesPerRow, NS::UInteger destinationBytesPerImage, MTL::BlitOption options) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(copyFromTexture_sourceSlice_sourceLevel_sourceOrigin_sourceSize_toBuffer_destinationOffset_destinationBytesPerRow_destinationBytesPerImage_options_), sourceTexture, sourceSlice, sourceLevel, sourceOrigin, sourceSize, destinationBuffer, destinationOffset, destinationBytesPerRow, destinationBytesPerImage, options); +} + +_MTL_INLINE void MTL::BlitCommandEncoder::copyFromTexture(const MTL::Texture* sourceTexture, NS::UInteger sourceSlice, NS::UInteger sourceLevel, const MTL::Texture* destinationTexture, NS::UInteger destinationSlice, NS::UInteger destinationLevel, NS::UInteger sliceCount, NS::UInteger levelCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(copyFromTexture_sourceSlice_sourceLevel_toTexture_destinationSlice_destinationLevel_sliceCount_levelCount_), sourceTexture, sourceSlice, sourceLevel, destinationTexture, destinationSlice, destinationLevel, sliceCount, levelCount); +} + +_MTL_INLINE void MTL::BlitCommandEncoder::copyFromTexture(const MTL::Texture* sourceTexture, const MTL::Texture* destinationTexture) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(copyFromTexture_toTexture_), sourceTexture, destinationTexture); +} + +_MTL_INLINE void MTL::BlitCommandEncoder::copyIndirectCommandBuffer(const MTL::IndirectCommandBuffer* source, NS::Range sourceRange, const MTL::IndirectCommandBuffer* destination, NS::UInteger destinationIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(copyIndirectCommandBuffer_sourceRange_destination_destinationIndex_), source, sourceRange, destination, destinationIndex); +} + +_MTL_INLINE void MTL::BlitCommandEncoder::fillBuffer(const MTL::Buffer* buffer, NS::Range range, uint8_t value) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(fillBuffer_range_value_), buffer, range, value); +} + +_MTL_INLINE void MTL::BlitCommandEncoder::generateMipmaps(const MTL::Texture* texture) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(generateMipmapsForTexture_), texture); +} + +_MTL_INLINE void MTL::BlitCommandEncoder::getTextureAccessCounters(const MTL::Texture* texture, MTL::Region region, NS::UInteger mipLevel, NS::UInteger slice, bool resetCounters, const MTL::Buffer* countersBuffer, NS::UInteger countersBufferOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(getTextureAccessCounters_region_mipLevel_slice_resetCounters_countersBuffer_countersBufferOffset_), texture, region, mipLevel, slice, resetCounters, countersBuffer, countersBufferOffset); +} + +_MTL_INLINE void MTL::BlitCommandEncoder::optimizeContentsForCPUAccess(const MTL::Texture* texture) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(optimizeContentsForCPUAccess_), texture); +} + +_MTL_INLINE void MTL::BlitCommandEncoder::optimizeContentsForCPUAccess(const MTL::Texture* texture, NS::UInteger slice, NS::UInteger level) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(optimizeContentsForCPUAccess_slice_level_), texture, slice, level); +} + +_MTL_INLINE void MTL::BlitCommandEncoder::optimizeContentsForGPUAccess(const MTL::Texture* texture) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(optimizeContentsForGPUAccess_), texture); +} + +_MTL_INLINE void MTL::BlitCommandEncoder::optimizeContentsForGPUAccess(const MTL::Texture* texture, NS::UInteger slice, NS::UInteger level) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(optimizeContentsForGPUAccess_slice_level_), texture, slice, level); +} + +_MTL_INLINE void MTL::BlitCommandEncoder::optimizeIndirectCommandBuffer(const MTL::IndirectCommandBuffer* indirectCommandBuffer, NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(optimizeIndirectCommandBuffer_withRange_), indirectCommandBuffer, range); +} + +_MTL_INLINE void MTL::BlitCommandEncoder::resetCommandsInBuffer(const MTL::IndirectCommandBuffer* buffer, NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(resetCommandsInBuffer_withRange_), buffer, range); +} + +_MTL_INLINE void MTL::BlitCommandEncoder::resetTextureAccessCounters(const MTL::Texture* texture, MTL::Region region, NS::UInteger mipLevel, NS::UInteger slice) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(resetTextureAccessCounters_region_mipLevel_slice_), texture, region, mipLevel, slice); +} + +_MTL_INLINE void MTL::BlitCommandEncoder::resolveCounters(const MTL::CounterSampleBuffer* sampleBuffer, NS::Range range, const MTL::Buffer* destinationBuffer, NS::UInteger destinationOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(resolveCounters_inRange_destinationBuffer_destinationOffset_), sampleBuffer, range, destinationBuffer, destinationOffset); +} + +_MTL_INLINE void MTL::BlitCommandEncoder::sampleCountersInBuffer(const MTL::CounterSampleBuffer* sampleBuffer, NS::UInteger sampleIndex, bool barrier) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(sampleCountersInBuffer_atSampleIndex_withBarrier_), sampleBuffer, sampleIndex, barrier); +} + +_MTL_INLINE void MTL::BlitCommandEncoder::synchronizeResource(const MTL::Resource* resource) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(synchronizeResource_), resource); +} + +_MTL_INLINE void MTL::BlitCommandEncoder::synchronizeTexture(const MTL::Texture* texture, NS::UInteger slice, NS::UInteger level) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(synchronizeTexture_slice_level_), texture, slice, level); +} + +_MTL_INLINE void MTL::BlitCommandEncoder::updateFence(const MTL::Fence* fence) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(updateFence_), fence); +} + +_MTL_INLINE void MTL::BlitCommandEncoder::waitForFence(const MTL::Fence* fence) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(waitForFence_), fence); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLBlitPass.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLBlitPass.hpp new file mode 100644 index 00000000..6b15e0b5 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLBlitPass.hpp @@ -0,0 +1,154 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLBlitPass.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL +{ +class BlitPassDescriptor; +class BlitPassSampleBufferAttachmentDescriptor; +class BlitPassSampleBufferAttachmentDescriptorArray; +class CounterSampleBuffer; + +class BlitPassSampleBufferAttachmentDescriptor : public NS::Copying +{ +public: + static BlitPassSampleBufferAttachmentDescriptor* alloc(); + + NS::UInteger endOfEncoderSampleIndex() const; + + BlitPassSampleBufferAttachmentDescriptor* init(); + + CounterSampleBuffer* sampleBuffer() const; + + void setEndOfEncoderSampleIndex(NS::UInteger endOfEncoderSampleIndex); + + void setSampleBuffer(const MTL::CounterSampleBuffer* sampleBuffer); + + void setStartOfEncoderSampleIndex(NS::UInteger startOfEncoderSampleIndex); + NS::UInteger startOfEncoderSampleIndex() const; +}; +class BlitPassSampleBufferAttachmentDescriptorArray : public NS::Referencing +{ +public: + static BlitPassSampleBufferAttachmentDescriptorArray* alloc(); + + BlitPassSampleBufferAttachmentDescriptorArray* init(); + + BlitPassSampleBufferAttachmentDescriptor* object(NS::UInteger attachmentIndex); + void setObject(const MTL::BlitPassSampleBufferAttachmentDescriptor* attachment, NS::UInteger attachmentIndex); +}; +class BlitPassDescriptor : public NS::Copying +{ +public: + static BlitPassDescriptor* alloc(); + + static BlitPassDescriptor* blitPassDescriptor(); + + BlitPassDescriptor* init(); + + BlitPassSampleBufferAttachmentDescriptorArray* sampleBufferAttachments() const; +}; + +} +_MTL_INLINE MTL::BlitPassSampleBufferAttachmentDescriptor* MTL::BlitPassSampleBufferAttachmentDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLBlitPassSampleBufferAttachmentDescriptor)); +} + +_MTL_INLINE NS::UInteger MTL::BlitPassSampleBufferAttachmentDescriptor::endOfEncoderSampleIndex() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(endOfEncoderSampleIndex)); +} + +_MTL_INLINE MTL::BlitPassSampleBufferAttachmentDescriptor* MTL::BlitPassSampleBufferAttachmentDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::CounterSampleBuffer* MTL::BlitPassSampleBufferAttachmentDescriptor::sampleBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(sampleBuffer)); +} + +_MTL_INLINE void MTL::BlitPassSampleBufferAttachmentDescriptor::setEndOfEncoderSampleIndex(NS::UInteger endOfEncoderSampleIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setEndOfEncoderSampleIndex_), endOfEncoderSampleIndex); +} + +_MTL_INLINE void MTL::BlitPassSampleBufferAttachmentDescriptor::setSampleBuffer(const MTL::CounterSampleBuffer* sampleBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSampleBuffer_), sampleBuffer); +} + +_MTL_INLINE void MTL::BlitPassSampleBufferAttachmentDescriptor::setStartOfEncoderSampleIndex(NS::UInteger startOfEncoderSampleIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStartOfEncoderSampleIndex_), startOfEncoderSampleIndex); +} + +_MTL_INLINE NS::UInteger MTL::BlitPassSampleBufferAttachmentDescriptor::startOfEncoderSampleIndex() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(startOfEncoderSampleIndex)); +} + +_MTL_INLINE MTL::BlitPassSampleBufferAttachmentDescriptorArray* MTL::BlitPassSampleBufferAttachmentDescriptorArray::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLBlitPassSampleBufferAttachmentDescriptorArray)); +} + +_MTL_INLINE MTL::BlitPassSampleBufferAttachmentDescriptorArray* MTL::BlitPassSampleBufferAttachmentDescriptorArray::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::BlitPassSampleBufferAttachmentDescriptor* MTL::BlitPassSampleBufferAttachmentDescriptorArray::object(NS::UInteger attachmentIndex) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(objectAtIndexedSubscript_), attachmentIndex); +} + +_MTL_INLINE void MTL::BlitPassSampleBufferAttachmentDescriptorArray::setObject(const MTL::BlitPassSampleBufferAttachmentDescriptor* attachment, NS::UInteger attachmentIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObject_atIndexedSubscript_), attachment, attachmentIndex); +} + +_MTL_INLINE MTL::BlitPassDescriptor* MTL::BlitPassDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLBlitPassDescriptor)); +} + +_MTL_INLINE MTL::BlitPassDescriptor* MTL::BlitPassDescriptor::blitPassDescriptor() +{ + return Object::sendMessage(_MTL_PRIVATE_CLS(MTLBlitPassDescriptor), _MTL_PRIVATE_SEL(blitPassDescriptor)); +} + +_MTL_INLINE MTL::BlitPassDescriptor* MTL::BlitPassDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::BlitPassSampleBufferAttachmentDescriptorArray* MTL::BlitPassDescriptor::sampleBufferAttachments() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(sampleBufferAttachments)); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLBuffer.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLBuffer.hpp new file mode 100644 index 00000000..a93be1b0 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLBuffer.hpp @@ -0,0 +1,119 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLBuffer.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLGPUAddress.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLResource.hpp" + +namespace MTL +{ +class Buffer; +class Device; +class Tensor; +class TensorDescriptor; +class Texture; +class TextureDescriptor; + +class Buffer : public NS::Referencing +{ +public: + void addDebugMarker(const NS::String* marker, NS::Range range); + + void* contents(); + + void didModifyRange(NS::Range range); + + GPUAddress gpuAddress() const; + + NS::UInteger length() const; + + Buffer* newRemoteBufferViewForDevice(const MTL::Device* device); + + Tensor* newTensor(const MTL::TensorDescriptor* descriptor, NS::UInteger offset, NS::Error** error); + + Texture* newTexture(const MTL::TextureDescriptor* descriptor, NS::UInteger offset, NS::UInteger bytesPerRow); + + Buffer* remoteStorageBuffer() const; + + void removeAllDebugMarkers(); + + BufferSparseTier sparseBufferTier() const; +}; + +} +_MTL_INLINE void MTL::Buffer::addDebugMarker(const NS::String* marker, NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(addDebugMarker_range_), marker, range); +} + +_MTL_INLINE void* MTL::Buffer::contents() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(contents)); +} + +_MTL_INLINE void MTL::Buffer::didModifyRange(NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(didModifyRange_), range); +} + +_MTL_INLINE MTL::GPUAddress MTL::Buffer::gpuAddress() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(gpuAddress)); +} + +_MTL_INLINE NS::UInteger MTL::Buffer::length() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(length)); +} + +_MTL_INLINE MTL::Buffer* MTL::Buffer::newRemoteBufferViewForDevice(const MTL::Device* device) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newRemoteBufferViewForDevice_), device); +} + +_MTL_INLINE MTL::Tensor* MTL::Buffer::newTensor(const MTL::TensorDescriptor* descriptor, NS::UInteger offset, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newTensorWithDescriptor_offset_error_), descriptor, offset, error); +} + +_MTL_INLINE MTL::Texture* MTL::Buffer::newTexture(const MTL::TextureDescriptor* descriptor, NS::UInteger offset, NS::UInteger bytesPerRow) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newTextureWithDescriptor_offset_bytesPerRow_), descriptor, offset, bytesPerRow); +} + +_MTL_INLINE MTL::Buffer* MTL::Buffer::remoteStorageBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(remoteStorageBuffer)); +} + +_MTL_INLINE void MTL::Buffer::removeAllDebugMarkers() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(removeAllDebugMarkers)); +} + +_MTL_INLINE MTL::BufferSparseTier MTL::Buffer::sparseBufferTier() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(sparseBufferTier)); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLCaptureManager.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLCaptureManager.hpp new file mode 100644 index 00000000..a7622418 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLCaptureManager.hpp @@ -0,0 +1,217 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLCaptureManager.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL +{ +class CaptureDescriptor; +class CaptureManager; +class CaptureScope; +class CommandQueue; +class Device; +} + +namespace MTL4 +{ +class CommandQueue; +} + +namespace MTL +{ +_MTL_ENUM(NS::Integer, CaptureError) { + CaptureErrorNotSupported = 1, + CaptureErrorAlreadyCapturing = 2, + CaptureErrorInvalidDescriptor = 3, +}; + +_MTL_ENUM(NS::Integer, CaptureDestination) { + CaptureDestinationDeveloperTools = 1, + CaptureDestinationGPUTraceDocument = 2, +}; + +class CaptureDescriptor : public NS::Copying +{ +public: + static CaptureDescriptor* alloc(); + + NS::Object* captureObject() const; + + CaptureDestination destination() const; + + CaptureDescriptor* init(); + + NS::URL* outputURL() const; + + void setCaptureObject(NS::Object* captureObject); + + void setDestination(MTL::CaptureDestination destination); + + void setOutputURL(const NS::URL* outputURL); +}; +class CaptureManager : public NS::Referencing +{ +public: + static CaptureManager* alloc(); + + CaptureScope* defaultCaptureScope() const; + + CaptureManager* init(); + + bool isCapturing() const; + + CaptureScope* newCaptureScope(const MTL::Device* device); + CaptureScope* newCaptureScope(const MTL::CommandQueue* commandQueue); + CaptureScope* newCaptureScope(const MTL4::CommandQueue* commandQueue); + + void setDefaultCaptureScope(const MTL::CaptureScope* defaultCaptureScope); + + static CaptureManager* sharedCaptureManager(); + + bool startCapture(const MTL::CaptureDescriptor* descriptor, NS::Error** error); + void startCapture(const MTL::Device* device); + void startCapture(const MTL::CommandQueue* commandQueue); + void startCapture(const MTL::CaptureScope* captureScope); + + void stopCapture(); + + bool supportsDestination(MTL::CaptureDestination destination); +}; + +} +_MTL_INLINE MTL::CaptureDescriptor* MTL::CaptureDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLCaptureDescriptor)); +} + +_MTL_INLINE NS::Object* MTL::CaptureDescriptor::captureObject() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(captureObject)); +} + +_MTL_INLINE MTL::CaptureDestination MTL::CaptureDescriptor::destination() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(destination)); +} + +_MTL_INLINE MTL::CaptureDescriptor* MTL::CaptureDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::URL* MTL::CaptureDescriptor::outputURL() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(outputURL)); +} + +_MTL_INLINE void MTL::CaptureDescriptor::setCaptureObject(NS::Object* captureObject) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setCaptureObject_), captureObject); +} + +_MTL_INLINE void MTL::CaptureDescriptor::setDestination(MTL::CaptureDestination destination) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDestination_), destination); +} + +_MTL_INLINE void MTL::CaptureDescriptor::setOutputURL(const NS::URL* outputURL) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setOutputURL_), outputURL); +} + +_MTL_INLINE MTL::CaptureManager* MTL::CaptureManager::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLCaptureManager)); +} + +_MTL_INLINE MTL::CaptureScope* MTL::CaptureManager::defaultCaptureScope() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(defaultCaptureScope)); +} + +_MTL_INLINE MTL::CaptureManager* MTL::CaptureManager::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE bool MTL::CaptureManager::isCapturing() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isCapturing)); +} + +_MTL_INLINE MTL::CaptureScope* MTL::CaptureManager::newCaptureScope(const MTL::Device* device) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newCaptureScopeWithDevice_), device); +} + +_MTL_INLINE MTL::CaptureScope* MTL::CaptureManager::newCaptureScope(const MTL::CommandQueue* commandQueue) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newCaptureScopeWithCommandQueue_), commandQueue); +} + +_MTL_INLINE MTL::CaptureScope* MTL::CaptureManager::newCaptureScope(const MTL4::CommandQueue* commandQueue) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newCaptureScopeWithMTL4CommandQueue_), commandQueue); +} + +_MTL_INLINE void MTL::CaptureManager::setDefaultCaptureScope(const MTL::CaptureScope* defaultCaptureScope) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDefaultCaptureScope_), defaultCaptureScope); +} + +_MTL_INLINE MTL::CaptureManager* MTL::CaptureManager::sharedCaptureManager() +{ + return Object::sendMessage(_MTL_PRIVATE_CLS(MTLCaptureManager), _MTL_PRIVATE_SEL(sharedCaptureManager)); +} + +_MTL_INLINE bool MTL::CaptureManager::startCapture(const MTL::CaptureDescriptor* descriptor, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(startCaptureWithDescriptor_error_), descriptor, error); +} + +_MTL_INLINE void MTL::CaptureManager::startCapture(const MTL::Device* device) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(startCaptureWithDevice_), device); +} + +_MTL_INLINE void MTL::CaptureManager::startCapture(const MTL::CommandQueue* commandQueue) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(startCaptureWithCommandQueue_), commandQueue); +} + +_MTL_INLINE void MTL::CaptureManager::startCapture(const MTL::CaptureScope* captureScope) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(startCaptureWithScope_), captureScope); +} + +_MTL_INLINE void MTL::CaptureManager::stopCapture() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(stopCapture)); +} + +_MTL_INLINE bool MTL::CaptureManager::supportsDestination(MTL::CaptureDestination destination) +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportsDestination_), destination); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLCaptureScope.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLCaptureScope.hpp new file mode 100644 index 00000000..96ade67d --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLCaptureScope.hpp @@ -0,0 +1,91 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLCaptureScope.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "MTLDefines.hpp" +#include "MTLPrivate.hpp" + +#include "../Foundation/Foundation.hpp" + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace MTL +{ +class CaptureScope : public NS::Referencing +{ +public: + class Device* device() const; + + NS::String* label() const; + void setLabel(const NS::String* pLabel); + + class CommandQueue* commandQueue() const; + + void beginScope(); + void endScope(); +}; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTL_INLINE MTL::Device* MTL::CaptureScope::device() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(device)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTL_INLINE NS::String* MTL::CaptureScope::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTL_INLINE void MTL::CaptureScope::setLabel(const NS::String* pLabel) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), pLabel); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTL_INLINE MTL::CommandQueue* MTL::CaptureScope::commandQueue() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(commandQueue)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTL_INLINE void MTL::CaptureScope::beginScope() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(beginScope)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTL_INLINE void MTL::CaptureScope::endScope() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(endScope)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLCommandBuffer.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLCommandBuffer.hpp new file mode 100644 index 00000000..c504573d --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLCommandBuffer.hpp @@ -0,0 +1,464 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLCommandBuffer.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include +#include + +#include + +namespace MTL +{ +class AccelerationStructureCommandEncoder; +class AccelerationStructurePassDescriptor; +class BlitCommandEncoder; +class BlitPassDescriptor; +class CommandBuffer; +class CommandBufferDescriptor; +class CommandQueue; +class ComputeCommandEncoder; +class ComputePassDescriptor; +class Device; +class Drawable; +class Event; +class LogContainer; +class LogState; +class ParallelRenderCommandEncoder; +class RenderCommandEncoder; +class RenderPassDescriptor; +class ResidencySet; +class ResourceStateCommandEncoder; +class ResourceStatePassDescriptor; +_MTL_ENUM(NS::UInteger, CommandBufferStatus) { + CommandBufferStatusNotEnqueued = 0, + CommandBufferStatusEnqueued = 1, + CommandBufferStatusCommitted = 2, + CommandBufferStatusScheduled = 3, + CommandBufferStatusCompleted = 4, + CommandBufferStatusError = 5, +}; + +_MTL_ENUM(NS::UInteger, CommandBufferError) { + CommandBufferErrorNone = 0, + CommandBufferErrorInternal = 1, + CommandBufferErrorTimeout = 2, + CommandBufferErrorPageFault = 3, + CommandBufferErrorBlacklisted = 4, + CommandBufferErrorAccessRevoked = 4, + CommandBufferErrorNotPermitted = 7, + CommandBufferErrorOutOfMemory = 8, + CommandBufferErrorInvalidResource = 9, + CommandBufferErrorMemoryless = 10, + CommandBufferErrorDeviceRemoved = 11, + CommandBufferErrorStackOverflow = 12, +}; + +_MTL_ENUM(NS::Integer, CommandEncoderErrorState) { + CommandEncoderErrorStateUnknown = 0, + CommandEncoderErrorStateCompleted = 1, + CommandEncoderErrorStateAffected = 2, + CommandEncoderErrorStatePending = 3, + CommandEncoderErrorStateFaulted = 4, +}; + +_MTL_ENUM(NS::UInteger, DispatchType) { + DispatchTypeSerial = 0, + DispatchTypeConcurrent = 1, +}; + +_MTL_OPTIONS(NS::UInteger, CommandBufferErrorOption) { + CommandBufferErrorOptionNone = 0, + CommandBufferErrorOptionEncoderExecutionStatus = 1, +}; + +using CommandBufferHandler = void (^)(CommandBuffer*); +using HandlerFunction = std::function; + +class CommandBufferDescriptor : public NS::Copying +{ +public: + static CommandBufferDescriptor* alloc(); + + CommandBufferErrorOption errorOptions() const; + + CommandBufferDescriptor* init(); + + LogState* logState() const; + + bool retainedReferences() const; + + void setErrorOptions(MTL::CommandBufferErrorOption errorOptions); + + void setLogState(const MTL::LogState* logState); + + void setRetainedReferences(bool retainedReferences); +}; +class CommandBufferEncoderInfo : public NS::Referencing +{ +public: + NS::Array* debugSignposts() const; + + CommandEncoderErrorState errorState() const; + + NS::String* label() const; +}; +class CommandBuffer : public NS::Referencing +{ +public: + CFTimeInterval GPUEndTime() const; + + CFTimeInterval GPUStartTime() const; + + AccelerationStructureCommandEncoder* accelerationStructureCommandEncoder(); + AccelerationStructureCommandEncoder* accelerationStructureCommandEncoder(const MTL::AccelerationStructurePassDescriptor* descriptor); + + void addCompletedHandler(const MTL::CommandBufferHandler block); + void addCompletedHandler(const MTL::HandlerFunction& function); + + void addScheduledHandler(const MTL::CommandBufferHandler block); + void addScheduledHandler(const MTL::HandlerFunction& function); + + BlitCommandEncoder* blitCommandEncoder(); + BlitCommandEncoder* blitCommandEncoder(const MTL::BlitPassDescriptor* blitPassDescriptor); + + CommandQueue* commandQueue() const; + + void commit(); + + ComputeCommandEncoder* computeCommandEncoder(const MTL::ComputePassDescriptor* computePassDescriptor); + ComputeCommandEncoder* computeCommandEncoder(); + ComputeCommandEncoder* computeCommandEncoder(MTL::DispatchType dispatchType); + + Device* device() const; + + void encodeSignalEvent(const MTL::Event* event, uint64_t value); + + void encodeWait(const MTL::Event* event, uint64_t value); + + void enqueue(); + + NS::Error* error() const; + CommandBufferErrorOption errorOptions() const; + + CFTimeInterval kernelEndTime() const; + + CFTimeInterval kernelStartTime() const; + + NS::String* label() const; + + LogContainer* logs() const; + + ParallelRenderCommandEncoder* parallelRenderCommandEncoder(const MTL::RenderPassDescriptor* renderPassDescriptor); + + void popDebugGroup(); + + void presentDrawable(const MTL::Drawable* drawable); + void presentDrawableAfterMinimumDuration(const MTL::Drawable* drawable, CFTimeInterval duration); + + void presentDrawableAtTime(const MTL::Drawable* drawable, CFTimeInterval presentationTime); + + void pushDebugGroup(const NS::String* string); + + RenderCommandEncoder* renderCommandEncoder(const MTL::RenderPassDescriptor* renderPassDescriptor); + + ResourceStateCommandEncoder* resourceStateCommandEncoder(); + ResourceStateCommandEncoder* resourceStateCommandEncoder(const MTL::ResourceStatePassDescriptor* resourceStatePassDescriptor); + + bool retainedReferences() const; + + void setLabel(const NS::String* label); + + CommandBufferStatus status() const; + + void useResidencySet(const MTL::ResidencySet* residencySet); + void useResidencySets(const MTL::ResidencySet* const residencySets[], NS::UInteger count); + + void waitUntilCompleted(); + + void waitUntilScheduled(); +}; + +} +_MTL_INLINE MTL::CommandBufferDescriptor* MTL::CommandBufferDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLCommandBufferDescriptor)); +} + +_MTL_INLINE MTL::CommandBufferErrorOption MTL::CommandBufferDescriptor::errorOptions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(errorOptions)); +} + +_MTL_INLINE MTL::CommandBufferDescriptor* MTL::CommandBufferDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::LogState* MTL::CommandBufferDescriptor::logState() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(logState)); +} + +_MTL_INLINE bool MTL::CommandBufferDescriptor::retainedReferences() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(retainedReferences)); +} + +_MTL_INLINE void MTL::CommandBufferDescriptor::setErrorOptions(MTL::CommandBufferErrorOption errorOptions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setErrorOptions_), errorOptions); +} + +_MTL_INLINE void MTL::CommandBufferDescriptor::setLogState(const MTL::LogState* logState) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLogState_), logState); +} + +_MTL_INLINE void MTL::CommandBufferDescriptor::setRetainedReferences(bool retainedReferences) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRetainedReferences_), retainedReferences); +} + +_MTL_INLINE NS::Array* MTL::CommandBufferEncoderInfo::debugSignposts() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(debugSignposts)); +} + +_MTL_INLINE MTL::CommandEncoderErrorState MTL::CommandBufferEncoderInfo::errorState() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(errorState)); +} + +_MTL_INLINE NS::String* MTL::CommandBufferEncoderInfo::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE CFTimeInterval MTL::CommandBuffer::GPUEndTime() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(GPUEndTime)); +} + +_MTL_INLINE CFTimeInterval MTL::CommandBuffer::GPUStartTime() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(GPUStartTime)); +} + +_MTL_INLINE MTL::AccelerationStructureCommandEncoder* MTL::CommandBuffer::accelerationStructureCommandEncoder() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(accelerationStructureCommandEncoder)); +} + +_MTL_INLINE MTL::AccelerationStructureCommandEncoder* MTL::CommandBuffer::accelerationStructureCommandEncoder(const MTL::AccelerationStructurePassDescriptor* descriptor) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(accelerationStructureCommandEncoderWithDescriptor_), descriptor); +} + +_MTL_INLINE void MTL::CommandBuffer::addCompletedHandler(const MTL::CommandBufferHandler block) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(addCompletedHandler_), block); +} + +_MTL_INLINE void MTL::CommandBuffer::addCompletedHandler(const MTL::HandlerFunction& function) +{ + __block HandlerFunction blockFunction = function; + addCompletedHandler(^(MTL::CommandBuffer* pCommandBuffer) { blockFunction(pCommandBuffer); }); +} + +_MTL_INLINE void MTL::CommandBuffer::addScheduledHandler(const MTL::CommandBufferHandler block) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(addScheduledHandler_), block); +} + +_MTL_INLINE void MTL::CommandBuffer::addScheduledHandler(const MTL::HandlerFunction& function) +{ + __block HandlerFunction blockFunction = function; + addScheduledHandler(^(MTL::CommandBuffer* pCommandBuffer) { blockFunction(pCommandBuffer); }); +} + +_MTL_INLINE MTL::BlitCommandEncoder* MTL::CommandBuffer::blitCommandEncoder() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(blitCommandEncoder)); +} + +_MTL_INLINE MTL::BlitCommandEncoder* MTL::CommandBuffer::blitCommandEncoder(const MTL::BlitPassDescriptor* blitPassDescriptor) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(blitCommandEncoderWithDescriptor_), blitPassDescriptor); +} + +_MTL_INLINE MTL::CommandQueue* MTL::CommandBuffer::commandQueue() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(commandQueue)); +} + +_MTL_INLINE void MTL::CommandBuffer::commit() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(commit)); +} + +_MTL_INLINE MTL::ComputeCommandEncoder* MTL::CommandBuffer::computeCommandEncoder(const MTL::ComputePassDescriptor* computePassDescriptor) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(computeCommandEncoderWithDescriptor_), computePassDescriptor); +} + +_MTL_INLINE MTL::ComputeCommandEncoder* MTL::CommandBuffer::computeCommandEncoder() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(computeCommandEncoder)); +} + +_MTL_INLINE MTL::ComputeCommandEncoder* MTL::CommandBuffer::computeCommandEncoder(MTL::DispatchType dispatchType) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(computeCommandEncoderWithDispatchType_), dispatchType); +} + +_MTL_INLINE MTL::Device* MTL::CommandBuffer::device() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(device)); +} + +_MTL_INLINE void MTL::CommandBuffer::encodeSignalEvent(const MTL::Event* event, uint64_t value) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(encodeSignalEvent_value_), event, value); +} + +_MTL_INLINE void MTL::CommandBuffer::encodeWait(const MTL::Event* event, uint64_t value) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(encodeWaitForEvent_value_), event, value); +} + +_MTL_INLINE void MTL::CommandBuffer::enqueue() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(enqueue)); +} + +_MTL_INLINE NS::Error* MTL::CommandBuffer::error() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(error)); +} + +_MTL_INLINE MTL::CommandBufferErrorOption MTL::CommandBuffer::errorOptions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(errorOptions)); +} + +_MTL_INLINE CFTimeInterval MTL::CommandBuffer::kernelEndTime() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(kernelEndTime)); +} + +_MTL_INLINE CFTimeInterval MTL::CommandBuffer::kernelStartTime() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(kernelStartTime)); +} + +_MTL_INLINE NS::String* MTL::CommandBuffer::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE MTL::LogContainer* MTL::CommandBuffer::logs() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(logs)); +} + +_MTL_INLINE MTL::ParallelRenderCommandEncoder* MTL::CommandBuffer::parallelRenderCommandEncoder(const MTL::RenderPassDescriptor* renderPassDescriptor) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(parallelRenderCommandEncoderWithDescriptor_), renderPassDescriptor); +} + +_MTL_INLINE void MTL::CommandBuffer::popDebugGroup() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(popDebugGroup)); +} + +_MTL_INLINE void MTL::CommandBuffer::presentDrawable(const MTL::Drawable* drawable) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(presentDrawable_), drawable); +} + +_MTL_INLINE void MTL::CommandBuffer::presentDrawableAfterMinimumDuration(const MTL::Drawable* drawable, CFTimeInterval duration) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(presentDrawable_afterMinimumDuration_), drawable, duration); +} + +_MTL_INLINE void MTL::CommandBuffer::presentDrawableAtTime(const MTL::Drawable* drawable, CFTimeInterval presentationTime) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(presentDrawable_atTime_), drawable, presentationTime); +} + +_MTL_INLINE void MTL::CommandBuffer::pushDebugGroup(const NS::String* string) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(pushDebugGroup_), string); +} + +_MTL_INLINE MTL::RenderCommandEncoder* MTL::CommandBuffer::renderCommandEncoder(const MTL::RenderPassDescriptor* renderPassDescriptor) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(renderCommandEncoderWithDescriptor_), renderPassDescriptor); +} + +_MTL_INLINE MTL::ResourceStateCommandEncoder* MTL::CommandBuffer::resourceStateCommandEncoder() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(resourceStateCommandEncoder)); +} + +_MTL_INLINE MTL::ResourceStateCommandEncoder* MTL::CommandBuffer::resourceStateCommandEncoder(const MTL::ResourceStatePassDescriptor* resourceStatePassDescriptor) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(resourceStateCommandEncoderWithDescriptor_), resourceStatePassDescriptor); +} + +_MTL_INLINE bool MTL::CommandBuffer::retainedReferences() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(retainedReferences)); +} + +_MTL_INLINE void MTL::CommandBuffer::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +_MTL_INLINE MTL::CommandBufferStatus MTL::CommandBuffer::status() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(status)); +} + +_MTL_INLINE void MTL::CommandBuffer::useResidencySet(const MTL::ResidencySet* residencySet) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(useResidencySet_), residencySet); +} + +_MTL_INLINE void MTL::CommandBuffer::useResidencySets(const MTL::ResidencySet* const residencySets[], NS::UInteger count) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(useResidencySets_count_), residencySets, count); +} + +_MTL_INLINE void MTL::CommandBuffer::waitUntilCompleted() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(waitUntilCompleted)); +} + +_MTL_INLINE void MTL::CommandBuffer::waitUntilScheduled() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(waitUntilScheduled)); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLCommandEncoder.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLCommandEncoder.hpp new file mode 100644 index 00000000..a230ff5d --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLCommandEncoder.hpp @@ -0,0 +1,117 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLCommandEncoder.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL +{ +class Device; + +_MTL_OPTIONS(NS::UInteger, ResourceUsage) { + ResourceUsageRead = 1, + ResourceUsageWrite = 1 << 1, + ResourceUsageSample = 1 << 2, +}; + +_MTL_OPTIONS(NS::UInteger, BarrierScope) { + BarrierScopeBuffers = 1, + BarrierScopeTextures = 1 << 1, + BarrierScopeRenderTargets = 1 << 2, +}; + +_MTL_OPTIONS(NS::UInteger, Stages) { + StageVertex = 1, + StageFragment = 1 << 1, + StageTile = 1 << 2, + StageObject = 1 << 3, + StageMesh = 1 << 4, + StageResourceState = 1 << 26, + StageDispatch = 1 << 27, + StageBlit = 1 << 28, + StageAccelerationStructure = 1 << 29, + StageMachineLearning = 1 << 30, + StageAll = 9223372036854775807, +}; + +class CommandEncoder : public NS::Referencing +{ +public: + void barrierAfterQueueStages(MTL::Stages afterQueueStages, MTL::Stages beforeStages); + + Device* device() const; + + void endEncoding(); + + void insertDebugSignpost(const NS::String* string); + + NS::String* label() const; + + void popDebugGroup(); + + void pushDebugGroup(const NS::String* string); + + void setLabel(const NS::String* label); +}; + +} +_MTL_INLINE void MTL::CommandEncoder::barrierAfterQueueStages(MTL::Stages afterQueueStages, MTL::Stages beforeStages) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(barrierAfterQueueStages_beforeStages_), afterQueueStages, beforeStages); +} + +_MTL_INLINE MTL::Device* MTL::CommandEncoder::device() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(device)); +} + +_MTL_INLINE void MTL::CommandEncoder::endEncoding() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(endEncoding)); +} + +_MTL_INLINE void MTL::CommandEncoder::insertDebugSignpost(const NS::String* string) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(insertDebugSignpost_), string); +} + +_MTL_INLINE NS::String* MTL::CommandEncoder::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE void MTL::CommandEncoder::popDebugGroup() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(popDebugGroup)); +} + +_MTL_INLINE void MTL::CommandEncoder::pushDebugGroup(const NS::String* string) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(pushDebugGroup_), string); +} + +_MTL_INLINE void MTL::CommandEncoder::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLCommandQueue.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLCommandQueue.hpp new file mode 100644 index 00000000..5d3bf164 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLCommandQueue.hpp @@ -0,0 +1,158 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLCommandQueue.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL +{ +class CommandBuffer; +class CommandBufferDescriptor; +class CommandQueueDescriptor; +class Device; +class LogState; +class ResidencySet; + +class CommandQueue : public NS::Referencing +{ +public: + void addResidencySet(const MTL::ResidencySet* residencySet); + void addResidencySets(const MTL::ResidencySet* const residencySets[], NS::UInteger count); + + CommandBuffer* commandBuffer(); + CommandBuffer* commandBuffer(const MTL::CommandBufferDescriptor* descriptor); + CommandBuffer* commandBufferWithUnretainedReferences(); + + Device* device() const; + + void insertDebugCaptureBoundary(); + + NS::String* label() const; + + void removeResidencySet(const MTL::ResidencySet* residencySet); + void removeResidencySets(const MTL::ResidencySet* const residencySets[], NS::UInteger count); + + void setLabel(const NS::String* label); +}; +class CommandQueueDescriptor : public NS::Copying +{ +public: + static CommandQueueDescriptor* alloc(); + + CommandQueueDescriptor* init(); + + LogState* logState() const; + + NS::UInteger maxCommandBufferCount() const; + + void setLogState(const MTL::LogState* logState); + + void setMaxCommandBufferCount(NS::UInteger maxCommandBufferCount); +}; + +} +_MTL_INLINE void MTL::CommandQueue::addResidencySet(const MTL::ResidencySet* residencySet) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(addResidencySet_), residencySet); +} + +_MTL_INLINE void MTL::CommandQueue::addResidencySets(const MTL::ResidencySet* const residencySets[], NS::UInteger count) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(addResidencySets_count_), residencySets, count); +} + +_MTL_INLINE MTL::CommandBuffer* MTL::CommandQueue::commandBuffer() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(commandBuffer)); +} + +_MTL_INLINE MTL::CommandBuffer* MTL::CommandQueue::commandBuffer(const MTL::CommandBufferDescriptor* descriptor) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(commandBufferWithDescriptor_), descriptor); +} + +_MTL_INLINE MTL::CommandBuffer* MTL::CommandQueue::commandBufferWithUnretainedReferences() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(commandBufferWithUnretainedReferences)); +} + +_MTL_INLINE MTL::Device* MTL::CommandQueue::device() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(device)); +} + +_MTL_INLINE void MTL::CommandQueue::insertDebugCaptureBoundary() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(insertDebugCaptureBoundary)); +} + +_MTL_INLINE NS::String* MTL::CommandQueue::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE void MTL::CommandQueue::removeResidencySet(const MTL::ResidencySet* residencySet) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(removeResidencySet_), residencySet); +} + +_MTL_INLINE void MTL::CommandQueue::removeResidencySets(const MTL::ResidencySet* const residencySets[], NS::UInteger count) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(removeResidencySets_count_), residencySets, count); +} + +_MTL_INLINE void MTL::CommandQueue::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +_MTL_INLINE MTL::CommandQueueDescriptor* MTL::CommandQueueDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLCommandQueueDescriptor)); +} + +_MTL_INLINE MTL::CommandQueueDescriptor* MTL::CommandQueueDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::LogState* MTL::CommandQueueDescriptor::logState() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(logState)); +} + +_MTL_INLINE NS::UInteger MTL::CommandQueueDescriptor::maxCommandBufferCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxCommandBufferCount)); +} + +_MTL_INLINE void MTL::CommandQueueDescriptor::setLogState(const MTL::LogState* logState) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLogState_), logState); +} + +_MTL_INLINE void MTL::CommandQueueDescriptor::setMaxCommandBufferCount(NS::UInteger maxCommandBufferCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxCommandBufferCount_), maxCommandBufferCount); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLComputeCommandEncoder.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLComputeCommandEncoder.hpp new file mode 100644 index 00000000..2f555e57 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLComputeCommandEncoder.hpp @@ -0,0 +1,324 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLComputeCommandEncoder.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLCommandBuffer.hpp" +#include "MTLCommandEncoder.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLTypes.hpp" +#include + +namespace MTL +{ +class AccelerationStructure; +class Buffer; +class ComputePipelineState; +class CounterSampleBuffer; +class Fence; +class Heap; +class IndirectCommandBuffer; +class IntersectionFunctionTable; +class Resource; +class SamplerState; +class Texture; +class VisibleFunctionTable; + +struct DispatchThreadgroupsIndirectArguments +{ + uint32_t threadgroupsPerGrid[3]; +} _MTL_PACKED; + +struct DispatchThreadsIndirectArguments +{ + uint32_t threadsPerGrid[3]; + uint32_t threadsPerThreadgroup[3]; +} _MTL_PACKED; + +struct StageInRegionIndirectArguments +{ + uint32_t stageInOrigin[3]; + uint32_t stageInSize[3]; +} _MTL_PACKED; + +class ComputeCommandEncoder : public NS::Referencing +{ +public: + void dispatchThreadgroups(MTL::Size threadgroupsPerGrid, MTL::Size threadsPerThreadgroup); + void dispatchThreadgroups(const MTL::Buffer* indirectBuffer, NS::UInteger indirectBufferOffset, MTL::Size threadsPerThreadgroup); + + void dispatchThreads(MTL::Size threadsPerGrid, MTL::Size threadsPerThreadgroup); + + DispatchType dispatchType() const; + + void executeCommandsInBuffer(const MTL::IndirectCommandBuffer* indirectCommandBuffer, NS::Range executionRange); + void executeCommandsInBuffer(const MTL::IndirectCommandBuffer* indirectCommandbuffer, const MTL::Buffer* indirectRangeBuffer, NS::UInteger indirectBufferOffset); + + void memoryBarrier(MTL::BarrierScope scope); + void memoryBarrier(const MTL::Resource* const resources[], NS::UInteger count); + + void sampleCountersInBuffer(const MTL::CounterSampleBuffer* sampleBuffer, NS::UInteger sampleIndex, bool barrier); + + void setAccelerationStructure(const MTL::AccelerationStructure* accelerationStructure, NS::UInteger bufferIndex); + + void setBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger index); + void setBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger stride, NS::UInteger index); + void setBufferOffset(NS::UInteger offset, NS::UInteger index); + void setBufferOffset(NS::UInteger offset, NS::UInteger stride, NS::UInteger index); + + void setBuffers(const MTL::Buffer* const buffers[], const NS::UInteger offsets[], NS::Range range); + void setBuffers(const MTL::Buffer* const buffers[], const NS::UInteger* offsets, const NS::UInteger* strides, NS::Range range); + + void setBytes(const void* bytes, NS::UInteger length, NS::UInteger index); + void setBytes(const void* bytes, NS::UInteger length, NS::UInteger stride, NS::UInteger index); + + void setComputePipelineState(const MTL::ComputePipelineState* state); + + void setImageblockWidth(NS::UInteger width, NS::UInteger height); + + void setIntersectionFunctionTable(const MTL::IntersectionFunctionTable* intersectionFunctionTable, NS::UInteger bufferIndex); + void setIntersectionFunctionTables(const MTL::IntersectionFunctionTable* const intersectionFunctionTables[], NS::Range range); + + void setSamplerState(const MTL::SamplerState* sampler, NS::UInteger index); + void setSamplerState(const MTL::SamplerState* sampler, float lodMinClamp, float lodMaxClamp, NS::UInteger index); + void setSamplerStates(const MTL::SamplerState* const samplers[], NS::Range range); + void setSamplerStates(const MTL::SamplerState* const samplers[], const float lodMinClamps[], const float lodMaxClamps[], NS::Range range); + + void setStageInRegion(MTL::Region region); + void setStageInRegion(const MTL::Buffer* indirectBuffer, NS::UInteger indirectBufferOffset); + + void setTexture(const MTL::Texture* texture, NS::UInteger index); + void setTextures(const MTL::Texture* const textures[], NS::Range range); + + void setThreadgroupMemoryLength(NS::UInteger length, NS::UInteger index); + + void setVisibleFunctionTable(const MTL::VisibleFunctionTable* visibleFunctionTable, NS::UInteger bufferIndex); + void setVisibleFunctionTables(const MTL::VisibleFunctionTable* const visibleFunctionTables[], NS::Range range); + + void updateFence(const MTL::Fence* fence); + + void useHeap(const MTL::Heap* heap); + void useHeaps(const MTL::Heap* const heaps[], NS::UInteger count); + + void useResource(const MTL::Resource* resource, MTL::ResourceUsage usage); + void useResources(const MTL::Resource* const resources[], NS::UInteger count, MTL::ResourceUsage usage); + + void waitForFence(const MTL::Fence* fence); +}; + +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::dispatchThreadgroups(MTL::Size threadgroupsPerGrid, MTL::Size threadsPerThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(dispatchThreadgroups_threadsPerThreadgroup_), threadgroupsPerGrid, threadsPerThreadgroup); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::dispatchThreadgroups(const MTL::Buffer* indirectBuffer, NS::UInteger indirectBufferOffset, MTL::Size threadsPerThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(dispatchThreadgroupsWithIndirectBuffer_indirectBufferOffset_threadsPerThreadgroup_), indirectBuffer, indirectBufferOffset, threadsPerThreadgroup); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::dispatchThreads(MTL::Size threadsPerGrid, MTL::Size threadsPerThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(dispatchThreads_threadsPerThreadgroup_), threadsPerGrid, threadsPerThreadgroup); +} + +_MTL_INLINE MTL::DispatchType MTL::ComputeCommandEncoder::dispatchType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(dispatchType)); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::executeCommandsInBuffer(const MTL::IndirectCommandBuffer* indirectCommandBuffer, NS::Range executionRange) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(executeCommandsInBuffer_withRange_), indirectCommandBuffer, executionRange); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::executeCommandsInBuffer(const MTL::IndirectCommandBuffer* indirectCommandbuffer, const MTL::Buffer* indirectRangeBuffer, NS::UInteger indirectBufferOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(executeCommandsInBuffer_indirectBuffer_indirectBufferOffset_), indirectCommandbuffer, indirectRangeBuffer, indirectBufferOffset); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::memoryBarrier(MTL::BarrierScope scope) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(memoryBarrierWithScope_), scope); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::memoryBarrier(const MTL::Resource* const resources[], NS::UInteger count) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(memoryBarrierWithResources_count_), resources, count); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::sampleCountersInBuffer(const MTL::CounterSampleBuffer* sampleBuffer, NS::UInteger sampleIndex, bool barrier) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(sampleCountersInBuffer_atSampleIndex_withBarrier_), sampleBuffer, sampleIndex, barrier); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::setAccelerationStructure(const MTL::AccelerationStructure* accelerationStructure, NS::UInteger bufferIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setAccelerationStructure_atBufferIndex_), accelerationStructure, bufferIndex); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::setBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBuffer_offset_atIndex_), buffer, offset, index); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::setBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger stride, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBuffer_offset_attributeStride_atIndex_), buffer, offset, stride, index); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::setBufferOffset(NS::UInteger offset, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBufferOffset_atIndex_), offset, index); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::setBufferOffset(NS::UInteger offset, NS::UInteger stride, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBufferOffset_attributeStride_atIndex_), offset, stride, index); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::setBuffers(const MTL::Buffer* const buffers[], const NS::UInteger offsets[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBuffers_offsets_withRange_), buffers, offsets, range); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::setBuffers(const MTL::Buffer* const buffers[], const NS::UInteger* offsets, const NS::UInteger* strides, NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBuffers_offsets_attributeStrides_withRange_), buffers, offsets, strides, range); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::setBytes(const void* bytes, NS::UInteger length, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBytes_length_atIndex_), bytes, length, index); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::setBytes(const void* bytes, NS::UInteger length, NS::UInteger stride, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBytes_length_attributeStride_atIndex_), bytes, length, stride, index); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::setComputePipelineState(const MTL::ComputePipelineState* state) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setComputePipelineState_), state); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::setImageblockWidth(NS::UInteger width, NS::UInteger height) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setImageblockWidth_height_), width, height); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::setIntersectionFunctionTable(const MTL::IntersectionFunctionTable* intersectionFunctionTable, NS::UInteger bufferIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setIntersectionFunctionTable_atBufferIndex_), intersectionFunctionTable, bufferIndex); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::setIntersectionFunctionTables(const MTL::IntersectionFunctionTable* const intersectionFunctionTables[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setIntersectionFunctionTables_withBufferRange_), intersectionFunctionTables, range); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::setSamplerState(const MTL::SamplerState* sampler, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSamplerState_atIndex_), sampler, index); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::setSamplerState(const MTL::SamplerState* sampler, float lodMinClamp, float lodMaxClamp, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSamplerState_lodMinClamp_lodMaxClamp_atIndex_), sampler, lodMinClamp, lodMaxClamp, index); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::setSamplerStates(const MTL::SamplerState* const samplers[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSamplerStates_withRange_), samplers, range); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::setSamplerStates(const MTL::SamplerState* const samplers[], const float lodMinClamps[], const float lodMaxClamps[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSamplerStates_lodMinClamps_lodMaxClamps_withRange_), samplers, lodMinClamps, lodMaxClamps, range); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::setStageInRegion(MTL::Region region) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStageInRegion_), region); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::setStageInRegion(const MTL::Buffer* indirectBuffer, NS::UInteger indirectBufferOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStageInRegionWithIndirectBuffer_indirectBufferOffset_), indirectBuffer, indirectBufferOffset); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::setTexture(const MTL::Texture* texture, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTexture_atIndex_), texture, index); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::setTextures(const MTL::Texture* const textures[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTextures_withRange_), textures, range); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::setThreadgroupMemoryLength(NS::UInteger length, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setThreadgroupMemoryLength_atIndex_), length, index); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::setVisibleFunctionTable(const MTL::VisibleFunctionTable* visibleFunctionTable, NS::UInteger bufferIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVisibleFunctionTable_atBufferIndex_), visibleFunctionTable, bufferIndex); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::setVisibleFunctionTables(const MTL::VisibleFunctionTable* const visibleFunctionTables[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVisibleFunctionTables_withBufferRange_), visibleFunctionTables, range); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::updateFence(const MTL::Fence* fence) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(updateFence_), fence); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::useHeap(const MTL::Heap* heap) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(useHeap_), heap); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::useHeaps(const MTL::Heap* const heaps[], NS::UInteger count) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(useHeaps_count_), heaps, count); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::useResource(const MTL::Resource* resource, MTL::ResourceUsage usage) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(useResource_usage_), resource, usage); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::useResources(const MTL::Resource* const resources[], NS::UInteger count, MTL::ResourceUsage usage) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(useResources_count_usage_), resources, count, usage); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::waitForFence(const MTL::Fence* fence) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(waitForFence_), fence); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLComputePass.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLComputePass.hpp new file mode 100644 index 00000000..fb34f7d8 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLComputePass.hpp @@ -0,0 +1,169 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLComputePass.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLCommandBuffer.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL +{ +class ComputePassDescriptor; +class ComputePassSampleBufferAttachmentDescriptor; +class ComputePassSampleBufferAttachmentDescriptorArray; +class CounterSampleBuffer; + +class ComputePassSampleBufferAttachmentDescriptor : public NS::Copying +{ +public: + static ComputePassSampleBufferAttachmentDescriptor* alloc(); + + NS::UInteger endOfEncoderSampleIndex() const; + + ComputePassSampleBufferAttachmentDescriptor* init(); + + CounterSampleBuffer* sampleBuffer() const; + + void setEndOfEncoderSampleIndex(NS::UInteger endOfEncoderSampleIndex); + + void setSampleBuffer(const MTL::CounterSampleBuffer* sampleBuffer); + + void setStartOfEncoderSampleIndex(NS::UInteger startOfEncoderSampleIndex); + NS::UInteger startOfEncoderSampleIndex() const; +}; +class ComputePassSampleBufferAttachmentDescriptorArray : public NS::Referencing +{ +public: + static ComputePassSampleBufferAttachmentDescriptorArray* alloc(); + + ComputePassSampleBufferAttachmentDescriptorArray* init(); + + ComputePassSampleBufferAttachmentDescriptor* object(NS::UInteger attachmentIndex); + void setObject(const MTL::ComputePassSampleBufferAttachmentDescriptor* attachment, NS::UInteger attachmentIndex); +}; +class ComputePassDescriptor : public NS::Copying +{ +public: + static ComputePassDescriptor* alloc(); + + static ComputePassDescriptor* computePassDescriptor(); + + DispatchType dispatchType() const; + + ComputePassDescriptor* init(); + + ComputePassSampleBufferAttachmentDescriptorArray* sampleBufferAttachments() const; + + void setDispatchType(MTL::DispatchType dispatchType); +}; + +} +_MTL_INLINE MTL::ComputePassSampleBufferAttachmentDescriptor* MTL::ComputePassSampleBufferAttachmentDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLComputePassSampleBufferAttachmentDescriptor)); +} + +_MTL_INLINE NS::UInteger MTL::ComputePassSampleBufferAttachmentDescriptor::endOfEncoderSampleIndex() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(endOfEncoderSampleIndex)); +} + +_MTL_INLINE MTL::ComputePassSampleBufferAttachmentDescriptor* MTL::ComputePassSampleBufferAttachmentDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::CounterSampleBuffer* MTL::ComputePassSampleBufferAttachmentDescriptor::sampleBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(sampleBuffer)); +} + +_MTL_INLINE void MTL::ComputePassSampleBufferAttachmentDescriptor::setEndOfEncoderSampleIndex(NS::UInteger endOfEncoderSampleIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setEndOfEncoderSampleIndex_), endOfEncoderSampleIndex); +} + +_MTL_INLINE void MTL::ComputePassSampleBufferAttachmentDescriptor::setSampleBuffer(const MTL::CounterSampleBuffer* sampleBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSampleBuffer_), sampleBuffer); +} + +_MTL_INLINE void MTL::ComputePassSampleBufferAttachmentDescriptor::setStartOfEncoderSampleIndex(NS::UInteger startOfEncoderSampleIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStartOfEncoderSampleIndex_), startOfEncoderSampleIndex); +} + +_MTL_INLINE NS::UInteger MTL::ComputePassSampleBufferAttachmentDescriptor::startOfEncoderSampleIndex() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(startOfEncoderSampleIndex)); +} + +_MTL_INLINE MTL::ComputePassSampleBufferAttachmentDescriptorArray* MTL::ComputePassSampleBufferAttachmentDescriptorArray::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLComputePassSampleBufferAttachmentDescriptorArray)); +} + +_MTL_INLINE MTL::ComputePassSampleBufferAttachmentDescriptorArray* MTL::ComputePassSampleBufferAttachmentDescriptorArray::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::ComputePassSampleBufferAttachmentDescriptor* MTL::ComputePassSampleBufferAttachmentDescriptorArray::object(NS::UInteger attachmentIndex) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(objectAtIndexedSubscript_), attachmentIndex); +} + +_MTL_INLINE void MTL::ComputePassSampleBufferAttachmentDescriptorArray::setObject(const MTL::ComputePassSampleBufferAttachmentDescriptor* attachment, NS::UInteger attachmentIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObject_atIndexedSubscript_), attachment, attachmentIndex); +} + +_MTL_INLINE MTL::ComputePassDescriptor* MTL::ComputePassDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLComputePassDescriptor)); +} + +_MTL_INLINE MTL::ComputePassDescriptor* MTL::ComputePassDescriptor::computePassDescriptor() +{ + return Object::sendMessage(_MTL_PRIVATE_CLS(MTLComputePassDescriptor), _MTL_PRIVATE_SEL(computePassDescriptor)); +} + +_MTL_INLINE MTL::DispatchType MTL::ComputePassDescriptor::dispatchType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(dispatchType)); +} + +_MTL_INLINE MTL::ComputePassDescriptor* MTL::ComputePassDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::ComputePassSampleBufferAttachmentDescriptorArray* MTL::ComputePassDescriptor::sampleBufferAttachments() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(sampleBufferAttachments)); +} + +_MTL_INLINE void MTL::ComputePassDescriptor::setDispatchType(MTL::DispatchType dispatchType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDispatchType_), dispatchType); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLComputePipeline.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLComputePipeline.hpp new file mode 100644 index 00000000..d200af75 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLComputePipeline.hpp @@ -0,0 +1,439 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLComputePipeline.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLAllocation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPipeline.hpp" +#include "MTLPrivate.hpp" +#include "MTLTypes.hpp" + +namespace MTL +{ +class ComputePipelineDescriptor; +class ComputePipelineReflection; +class ComputePipelineState; +class Device; +class Function; +class FunctionHandle; +class IntersectionFunctionTable; +class IntersectionFunctionTableDescriptor; +class LinkedFunctions; +class PipelineBufferDescriptorArray; +class StageInputOutputDescriptor; +class VisibleFunctionTable; +class VisibleFunctionTableDescriptor; + +} +namespace MTL4 +{ +class BinaryFunction; + +} +namespace MTL +{ +class ComputePipelineReflection : public NS::Referencing +{ +public: + static ComputePipelineReflection* alloc(); + + NS::Array* arguments() const; + + NS::Array* bindings() const; + + ComputePipelineReflection* init(); +}; +class ComputePipelineDescriptor : public NS::Copying +{ +public: + static ComputePipelineDescriptor* alloc(); + + NS::Array* binaryArchives() const; + + PipelineBufferDescriptorArray* buffers() const; + + Function* computeFunction() const; + + ComputePipelineDescriptor* init(); + + NS::Array* insertLibraries() const; + + NS::String* label() const; + + LinkedFunctions* linkedFunctions() const; + + NS::UInteger maxCallStackDepth() const; + + NS::UInteger maxTotalThreadsPerThreadgroup() const; + + NS::Array* preloadedLibraries() const; + + Size requiredThreadsPerThreadgroup() const; + + void reset(); + + void setBinaryArchives(const NS::Array* binaryArchives); + + void setComputeFunction(const MTL::Function* computeFunction); + + void setInsertLibraries(const NS::Array* insertLibraries); + + void setLabel(const NS::String* label); + + void setLinkedFunctions(const MTL::LinkedFunctions* linkedFunctions); + + void setMaxCallStackDepth(NS::UInteger maxCallStackDepth); + + void setMaxTotalThreadsPerThreadgroup(NS::UInteger maxTotalThreadsPerThreadgroup); + + void setPreloadedLibraries(const NS::Array* preloadedLibraries); + + void setRequiredThreadsPerThreadgroup(MTL::Size requiredThreadsPerThreadgroup); + + void setShaderValidation(MTL::ShaderValidation shaderValidation); + + void setStageInputDescriptor(const MTL::StageInputOutputDescriptor* stageInputDescriptor); + + void setSupportAddingBinaryFunctions(bool supportAddingBinaryFunctions); + + void setSupportIndirectCommandBuffers(bool supportIndirectCommandBuffers); + + void setThreadGroupSizeIsMultipleOfThreadExecutionWidth(bool threadGroupSizeIsMultipleOfThreadExecutionWidth); + + ShaderValidation shaderValidation() const; + + StageInputOutputDescriptor* stageInputDescriptor() const; + + bool supportAddingBinaryFunctions() const; + + bool supportIndirectCommandBuffers() const; + + bool threadGroupSizeIsMultipleOfThreadExecutionWidth() const; +}; +class ComputePipelineState : public NS::Referencing +{ +public: + Device* device() const; + + FunctionHandle* functionHandle(const NS::String* name); + FunctionHandle* functionHandle(const MTL4::BinaryFunction* function); + FunctionHandle* functionHandle(const MTL::Function* function); + + ResourceID gpuResourceID() const; + + NS::UInteger imageblockMemoryLength(MTL::Size imageblockDimensions); + + NS::String* label() const; + + NS::UInteger maxTotalThreadsPerThreadgroup() const; + + ComputePipelineState* newComputePipelineStateWithBinaryFunctions(const NS::Array* additionalBinaryFunctions, NS::Error** error); + ComputePipelineState* newComputePipelineState(const NS::Array* functions, NS::Error** error); + + IntersectionFunctionTable* newIntersectionFunctionTable(const MTL::IntersectionFunctionTableDescriptor* descriptor); + + VisibleFunctionTable* newVisibleFunctionTable(const MTL::VisibleFunctionTableDescriptor* descriptor); + + ComputePipelineReflection* reflection() const; + + Size requiredThreadsPerThreadgroup() const; + + ShaderValidation shaderValidation() const; + + NS::UInteger staticThreadgroupMemoryLength() const; + + bool supportIndirectCommandBuffers() const; + + NS::UInteger threadExecutionWidth() const; +}; + +} +_MTL_INLINE MTL::ComputePipelineReflection* MTL::ComputePipelineReflection::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLComputePipelineReflection)); +} + +_MTL_INLINE NS::Array* MTL::ComputePipelineReflection::arguments() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(arguments)); +} + +_MTL_INLINE NS::Array* MTL::ComputePipelineReflection::bindings() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(bindings)); +} + +_MTL_INLINE MTL::ComputePipelineReflection* MTL::ComputePipelineReflection::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::ComputePipelineDescriptor* MTL::ComputePipelineDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLComputePipelineDescriptor)); +} + +_MTL_INLINE NS::Array* MTL::ComputePipelineDescriptor::binaryArchives() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(binaryArchives)); +} + +_MTL_INLINE MTL::PipelineBufferDescriptorArray* MTL::ComputePipelineDescriptor::buffers() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(buffers)); +} + +_MTL_INLINE MTL::Function* MTL::ComputePipelineDescriptor::computeFunction() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(computeFunction)); +} + +_MTL_INLINE MTL::ComputePipelineDescriptor* MTL::ComputePipelineDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::Array* MTL::ComputePipelineDescriptor::insertLibraries() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(insertLibraries)); +} + +_MTL_INLINE NS::String* MTL::ComputePipelineDescriptor::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE MTL::LinkedFunctions* MTL::ComputePipelineDescriptor::linkedFunctions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(linkedFunctions)); +} + +_MTL_INLINE NS::UInteger MTL::ComputePipelineDescriptor::maxCallStackDepth() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxCallStackDepth)); +} + +_MTL_INLINE NS::UInteger MTL::ComputePipelineDescriptor::maxTotalThreadsPerThreadgroup() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxTotalThreadsPerThreadgroup)); +} + +_MTL_INLINE NS::Array* MTL::ComputePipelineDescriptor::preloadedLibraries() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(preloadedLibraries)); +} + +_MTL_INLINE MTL::Size MTL::ComputePipelineDescriptor::requiredThreadsPerThreadgroup() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(requiredThreadsPerThreadgroup)); +} + +_MTL_INLINE void MTL::ComputePipelineDescriptor::reset() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(reset)); +} + +_MTL_INLINE void MTL::ComputePipelineDescriptor::setBinaryArchives(const NS::Array* binaryArchives) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBinaryArchives_), binaryArchives); +} + +_MTL_INLINE void MTL::ComputePipelineDescriptor::setComputeFunction(const MTL::Function* computeFunction) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setComputeFunction_), computeFunction); +} + +_MTL_INLINE void MTL::ComputePipelineDescriptor::setInsertLibraries(const NS::Array* insertLibraries) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInsertLibraries_), insertLibraries); +} + +_MTL_INLINE void MTL::ComputePipelineDescriptor::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +_MTL_INLINE void MTL::ComputePipelineDescriptor::setLinkedFunctions(const MTL::LinkedFunctions* linkedFunctions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLinkedFunctions_), linkedFunctions); +} + +_MTL_INLINE void MTL::ComputePipelineDescriptor::setMaxCallStackDepth(NS::UInteger maxCallStackDepth) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxCallStackDepth_), maxCallStackDepth); +} + +_MTL_INLINE void MTL::ComputePipelineDescriptor::setMaxTotalThreadsPerThreadgroup(NS::UInteger maxTotalThreadsPerThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxTotalThreadsPerThreadgroup_), maxTotalThreadsPerThreadgroup); +} + +_MTL_INLINE void MTL::ComputePipelineDescriptor::setPreloadedLibraries(const NS::Array* preloadedLibraries) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setPreloadedLibraries_), preloadedLibraries); +} + +_MTL_INLINE void MTL::ComputePipelineDescriptor::setRequiredThreadsPerThreadgroup(MTL::Size requiredThreadsPerThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRequiredThreadsPerThreadgroup_), requiredThreadsPerThreadgroup); +} + +_MTL_INLINE void MTL::ComputePipelineDescriptor::setShaderValidation(MTL::ShaderValidation shaderValidation) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setShaderValidation_), shaderValidation); +} + +_MTL_INLINE void MTL::ComputePipelineDescriptor::setStageInputDescriptor(const MTL::StageInputOutputDescriptor* stageInputDescriptor) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStageInputDescriptor_), stageInputDescriptor); +} + +_MTL_INLINE void MTL::ComputePipelineDescriptor::setSupportAddingBinaryFunctions(bool supportAddingBinaryFunctions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSupportAddingBinaryFunctions_), supportAddingBinaryFunctions); +} + +_MTL_INLINE void MTL::ComputePipelineDescriptor::setSupportIndirectCommandBuffers(bool supportIndirectCommandBuffers) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSupportIndirectCommandBuffers_), supportIndirectCommandBuffers); +} + +_MTL_INLINE void MTL::ComputePipelineDescriptor::setThreadGroupSizeIsMultipleOfThreadExecutionWidth(bool threadGroupSizeIsMultipleOfThreadExecutionWidth) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setThreadGroupSizeIsMultipleOfThreadExecutionWidth_), threadGroupSizeIsMultipleOfThreadExecutionWidth); +} + +_MTL_INLINE MTL::ShaderValidation MTL::ComputePipelineDescriptor::shaderValidation() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(shaderValidation)); +} + +_MTL_INLINE MTL::StageInputOutputDescriptor* MTL::ComputePipelineDescriptor::stageInputDescriptor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(stageInputDescriptor)); +} + +_MTL_INLINE bool MTL::ComputePipelineDescriptor::supportAddingBinaryFunctions() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportAddingBinaryFunctions)); +} + +_MTL_INLINE bool MTL::ComputePipelineDescriptor::supportIndirectCommandBuffers() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportIndirectCommandBuffers)); +} + +_MTL_INLINE bool MTL::ComputePipelineDescriptor::threadGroupSizeIsMultipleOfThreadExecutionWidth() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(threadGroupSizeIsMultipleOfThreadExecutionWidth)); +} + +_MTL_INLINE MTL::Device* MTL::ComputePipelineState::device() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(device)); +} + +_MTL_INLINE MTL::FunctionHandle* MTL::ComputePipelineState::functionHandle(const NS::String* name) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(functionHandleWithName_), name); +} + +_MTL_INLINE MTL::FunctionHandle* MTL::ComputePipelineState::functionHandle(const MTL4::BinaryFunction* function) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(functionHandleWithBinaryFunction_), function); +} + +_MTL_INLINE MTL::FunctionHandle* MTL::ComputePipelineState::functionHandle(const MTL::Function* function) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(functionHandleWithFunction_), function); +} + +_MTL_INLINE MTL::ResourceID MTL::ComputePipelineState::gpuResourceID() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(gpuResourceID)); +} + +_MTL_INLINE NS::UInteger MTL::ComputePipelineState::imageblockMemoryLength(MTL::Size imageblockDimensions) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(imageblockMemoryLengthForDimensions_), imageblockDimensions); +} + +_MTL_INLINE NS::String* MTL::ComputePipelineState::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE NS::UInteger MTL::ComputePipelineState::maxTotalThreadsPerThreadgroup() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxTotalThreadsPerThreadgroup)); +} + +_MTL_INLINE MTL::ComputePipelineState* MTL::ComputePipelineState::newComputePipelineStateWithBinaryFunctions(const NS::Array* additionalBinaryFunctions, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newComputePipelineStateWithBinaryFunctions_error_), additionalBinaryFunctions, error); +} + +_MTL_INLINE MTL::ComputePipelineState* MTL::ComputePipelineState::newComputePipelineState(const NS::Array* functions, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newComputePipelineStateWithAdditionalBinaryFunctions_error_), functions, error); +} + +_MTL_INLINE MTL::IntersectionFunctionTable* MTL::ComputePipelineState::newIntersectionFunctionTable(const MTL::IntersectionFunctionTableDescriptor* descriptor) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newIntersectionFunctionTableWithDescriptor_), descriptor); +} + +_MTL_INLINE MTL::VisibleFunctionTable* MTL::ComputePipelineState::newVisibleFunctionTable(const MTL::VisibleFunctionTableDescriptor* descriptor) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newVisibleFunctionTableWithDescriptor_), descriptor); +} + +_MTL_INLINE MTL::ComputePipelineReflection* MTL::ComputePipelineState::reflection() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(reflection)); +} + +_MTL_INLINE MTL::Size MTL::ComputePipelineState::requiredThreadsPerThreadgroup() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(requiredThreadsPerThreadgroup)); +} + +_MTL_INLINE MTL::ShaderValidation MTL::ComputePipelineState::shaderValidation() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(shaderValidation)); +} + +_MTL_INLINE NS::UInteger MTL::ComputePipelineState::staticThreadgroupMemoryLength() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(staticThreadgroupMemoryLength)); +} + +_MTL_INLINE bool MTL::ComputePipelineState::supportIndirectCommandBuffers() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportIndirectCommandBuffers)); +} + +_MTL_INLINE NS::UInteger MTL::ComputePipelineState::threadExecutionWidth() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(threadExecutionWidth)); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLCounters.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLCounters.hpp new file mode 100644 index 00000000..6d655f15 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLCounters.hpp @@ -0,0 +1,243 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLCounters.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLResource.hpp" +#include + +namespace MTL +{ +class CounterSampleBufferDescriptor; +class CounterSet; +class Device; +_MTL_ENUM(NS::Integer, CounterSampleBufferError) { + CounterSampleBufferErrorOutOfMemory = 0, + CounterSampleBufferErrorInvalid = 1, + CounterSampleBufferErrorInternal = 2, +}; + +using CommonCounter = NS::String*; +using CommonCounterSet = NS::String*; + +static const NS::UInteger CounterErrorValue = static_cast(~0ULL); +static const NS::UInteger CounterDontSample = static_cast(-1); +_MTL_CONST(NS::ErrorDomain, CounterErrorDomain); +_MTL_CONST(CommonCounter, CommonCounterTimestamp); +_MTL_CONST(CommonCounter, CommonCounterTessellationInputPatches); +_MTL_CONST(CommonCounter, CommonCounterVertexInvocations); +_MTL_CONST(CommonCounter, CommonCounterPostTessellationVertexInvocations); +_MTL_CONST(CommonCounter, CommonCounterClipperInvocations); +_MTL_CONST(CommonCounter, CommonCounterClipperPrimitivesOut); +_MTL_CONST(CommonCounter, CommonCounterFragmentInvocations); +_MTL_CONST(CommonCounter, CommonCounterFragmentsPassed); +_MTL_CONST(CommonCounter, CommonCounterComputeKernelInvocations); +_MTL_CONST(CommonCounter, CommonCounterTotalCycles); +_MTL_CONST(CommonCounter, CommonCounterVertexCycles); +_MTL_CONST(CommonCounter, CommonCounterTessellationCycles); +_MTL_CONST(CommonCounter, CommonCounterPostTessellationVertexCycles); +_MTL_CONST(CommonCounter, CommonCounterFragmentCycles); +_MTL_CONST(CommonCounter, CommonCounterRenderTargetWriteCycles); +_MTL_CONST(CommonCounterSet, CommonCounterSetTimestamp); +_MTL_CONST(CommonCounterSet, CommonCounterSetStageUtilization); +_MTL_CONST(CommonCounterSet, CommonCounterSetStatistic); +struct CounterResultTimestamp +{ + uint64_t timestamp; +} _MTL_PACKED; + +struct CounterResultStageUtilization +{ + uint64_t totalCycles; + uint64_t vertexCycles; + uint64_t tessellationCycles; + uint64_t postTessellationVertexCycles; + uint64_t fragmentCycles; + uint64_t renderTargetCycles; +} _MTL_PACKED; + +struct CounterResultStatistic +{ + uint64_t tessellationInputPatches; + uint64_t vertexInvocations; + uint64_t postTessellationVertexInvocations; + uint64_t clipperInvocations; + uint64_t clipperPrimitivesOut; + uint64_t fragmentInvocations; + uint64_t fragmentsPassed; + uint64_t computeKernelInvocations; +} _MTL_PACKED; + +class Counter : public NS::Referencing +{ +public: + NS::String* name() const; +}; +class CounterSet : public NS::Referencing +{ +public: + NS::Array* counters() const; + + NS::String* name() const; +}; +class CounterSampleBufferDescriptor : public NS::Copying +{ +public: + static CounterSampleBufferDescriptor* alloc(); + + CounterSet* counterSet() const; + + CounterSampleBufferDescriptor* init(); + + NS::String* label() const; + + NS::UInteger sampleCount() const; + + void setCounterSet(const MTL::CounterSet* counterSet); + + void setLabel(const NS::String* label); + + void setSampleCount(NS::UInteger sampleCount); + + void setStorageMode(MTL::StorageMode storageMode); + StorageMode storageMode() const; +}; +class CounterSampleBuffer : public NS::Referencing +{ +public: + Device* device() const; + + NS::String* label() const; + + NS::Data* resolveCounterRange(NS::Range range); + + NS::UInteger sampleCount() const; +}; + +} + +_MTL_PRIVATE_DEF_CONST(NS::ErrorDomain, CounterErrorDomain); +_MTL_PRIVATE_DEF_CONST(MTL::CommonCounter, CommonCounterTimestamp); +_MTL_PRIVATE_DEF_CONST(MTL::CommonCounter, CommonCounterTessellationInputPatches); +_MTL_PRIVATE_DEF_CONST(MTL::CommonCounter, CommonCounterVertexInvocations); +_MTL_PRIVATE_DEF_CONST(MTL::CommonCounter, CommonCounterPostTessellationVertexInvocations); +_MTL_PRIVATE_DEF_CONST(MTL::CommonCounter, CommonCounterClipperInvocations); +_MTL_PRIVATE_DEF_CONST(MTL::CommonCounter, CommonCounterClipperPrimitivesOut); +_MTL_PRIVATE_DEF_CONST(MTL::CommonCounter, CommonCounterFragmentInvocations); +_MTL_PRIVATE_DEF_CONST(MTL::CommonCounter, CommonCounterFragmentsPassed); +_MTL_PRIVATE_DEF_CONST(MTL::CommonCounter, CommonCounterComputeKernelInvocations); +_MTL_PRIVATE_DEF_CONST(MTL::CommonCounter, CommonCounterTotalCycles); +_MTL_PRIVATE_DEF_CONST(MTL::CommonCounter, CommonCounterVertexCycles); +_MTL_PRIVATE_DEF_CONST(MTL::CommonCounter, CommonCounterTessellationCycles); +_MTL_PRIVATE_DEF_CONST(MTL::CommonCounter, CommonCounterPostTessellationVertexCycles); +_MTL_PRIVATE_DEF_CONST(MTL::CommonCounter, CommonCounterFragmentCycles); +_MTL_PRIVATE_DEF_CONST(MTL::CommonCounter, CommonCounterRenderTargetWriteCycles); +_MTL_PRIVATE_DEF_CONST(MTL::CommonCounterSet, CommonCounterSetTimestamp); +_MTL_PRIVATE_DEF_CONST(MTL::CommonCounterSet, CommonCounterSetStageUtilization); +_MTL_PRIVATE_DEF_CONST(MTL::CommonCounterSet, CommonCounterSetStatistic); + +_MTL_INLINE NS::String* MTL::Counter::name() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(name)); +} + +_MTL_INLINE NS::Array* MTL::CounterSet::counters() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(counters)); +} + +_MTL_INLINE NS::String* MTL::CounterSet::name() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(name)); +} + +_MTL_INLINE MTL::CounterSampleBufferDescriptor* MTL::CounterSampleBufferDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLCounterSampleBufferDescriptor)); +} + +_MTL_INLINE MTL::CounterSet* MTL::CounterSampleBufferDescriptor::counterSet() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(counterSet)); +} + +_MTL_INLINE MTL::CounterSampleBufferDescriptor* MTL::CounterSampleBufferDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::String* MTL::CounterSampleBufferDescriptor::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE NS::UInteger MTL::CounterSampleBufferDescriptor::sampleCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(sampleCount)); +} + +_MTL_INLINE void MTL::CounterSampleBufferDescriptor::setCounterSet(const MTL::CounterSet* counterSet) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setCounterSet_), counterSet); +} + +_MTL_INLINE void MTL::CounterSampleBufferDescriptor::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +_MTL_INLINE void MTL::CounterSampleBufferDescriptor::setSampleCount(NS::UInteger sampleCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSampleCount_), sampleCount); +} + +_MTL_INLINE void MTL::CounterSampleBufferDescriptor::setStorageMode(MTL::StorageMode storageMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStorageMode_), storageMode); +} + +_MTL_INLINE MTL::StorageMode MTL::CounterSampleBufferDescriptor::storageMode() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(storageMode)); +} + +_MTL_INLINE MTL::Device* MTL::CounterSampleBuffer::device() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(device)); +} + +_MTL_INLINE NS::String* MTL::CounterSampleBuffer::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE NS::Data* MTL::CounterSampleBuffer::resolveCounterRange(NS::Range range) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(resolveCounterRange_), range); +} + +_MTL_INLINE NS::UInteger MTL::CounterSampleBuffer::sampleCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(sampleCount)); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLDataType.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLDataType.hpp new file mode 100644 index 00000000..f0e9b25a --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLDataType.hpp @@ -0,0 +1,129 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLDataType.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL +{ +_MTL_ENUM(NS::UInteger, DataType) { + DataTypeNone = 0, + DataTypeStruct = 1, + DataTypeArray = 2, + DataTypeFloat = 3, + DataTypeFloat2 = 4, + DataTypeFloat3 = 5, + DataTypeFloat4 = 6, + DataTypeFloat2x2 = 7, + DataTypeFloat2x3 = 8, + DataTypeFloat2x4 = 9, + DataTypeFloat3x2 = 10, + DataTypeFloat3x3 = 11, + DataTypeFloat3x4 = 12, + DataTypeFloat4x2 = 13, + DataTypeFloat4x3 = 14, + DataTypeFloat4x4 = 15, + DataTypeHalf = 16, + DataTypeHalf2 = 17, + DataTypeHalf3 = 18, + DataTypeHalf4 = 19, + DataTypeHalf2x2 = 20, + DataTypeHalf2x3 = 21, + DataTypeHalf2x4 = 22, + DataTypeHalf3x2 = 23, + DataTypeHalf3x3 = 24, + DataTypeHalf3x4 = 25, + DataTypeHalf4x2 = 26, + DataTypeHalf4x3 = 27, + DataTypeHalf4x4 = 28, + DataTypeInt = 29, + DataTypeInt2 = 30, + DataTypeInt3 = 31, + DataTypeInt4 = 32, + DataTypeUInt = 33, + DataTypeUInt2 = 34, + DataTypeUInt3 = 35, + DataTypeUInt4 = 36, + DataTypeShort = 37, + DataTypeShort2 = 38, + DataTypeShort3 = 39, + DataTypeShort4 = 40, + DataTypeUShort = 41, + DataTypeUShort2 = 42, + DataTypeUShort3 = 43, + DataTypeUShort4 = 44, + DataTypeChar = 45, + DataTypeChar2 = 46, + DataTypeChar3 = 47, + DataTypeChar4 = 48, + DataTypeUChar = 49, + DataTypeUChar2 = 50, + DataTypeUChar3 = 51, + DataTypeUChar4 = 52, + DataTypeBool = 53, + DataTypeBool2 = 54, + DataTypeBool3 = 55, + DataTypeBool4 = 56, + DataTypeTexture = 58, + DataTypeSampler = 59, + DataTypePointer = 60, + DataTypeR8Unorm = 62, + DataTypeR8Snorm = 63, + DataTypeR16Unorm = 64, + DataTypeR16Snorm = 65, + DataTypeRG8Unorm = 66, + DataTypeRG8Snorm = 67, + DataTypeRG16Unorm = 68, + DataTypeRG16Snorm = 69, + DataTypeRGBA8Unorm = 70, + DataTypeRGBA8Unorm_sRGB = 71, + DataTypeRGBA8Snorm = 72, + DataTypeRGBA16Unorm = 73, + DataTypeRGBA16Snorm = 74, + DataTypeRGB10A2Unorm = 75, + DataTypeRG11B10Float = 76, + DataTypeRGB9E5Float = 77, + DataTypeRenderPipeline = 78, + DataTypeComputePipeline = 79, + DataTypeIndirectCommandBuffer = 80, + DataTypeLong = 81, + DataTypeLong2 = 82, + DataTypeLong3 = 83, + DataTypeLong4 = 84, + DataTypeULong = 85, + DataTypeULong2 = 86, + DataTypeULong3 = 87, + DataTypeULong4 = 88, + DataTypeVisibleFunctionTable = 115, + DataTypeIntersectionFunctionTable = 116, + DataTypePrimitiveAccelerationStructure = 117, + DataTypeInstanceAccelerationStructure = 118, + DataTypeBFloat = 121, + DataTypeBFloat2 = 122, + DataTypeBFloat3 = 123, + DataTypeBFloat4 = 124, + DataTypeDepthStencilState = 139, + DataTypeTensor = 140, +}; + +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLDefines.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLDefines.hpp new file mode 100644 index 00000000..4260a2b1 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLDefines.hpp @@ -0,0 +1,41 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLDefines.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "../Foundation/NSDefines.hpp" + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#define _MTL_EXPORT _NS_EXPORT +#define _MTL_EXTERN _NS_EXTERN +#define _MTL_INLINE _NS_INLINE +#define _MTL_PACKED _NS_PACKED + +#define _MTL_CONST(type, name) _NS_CONST(type, name) +#define _MTL_ENUM(type, name) _NS_ENUM(type, name) +#define _MTL_OPTIONS(type, name) _NS_OPTIONS(type, name) + +#define _MTL_VALIDATE_SIZE(ns, name) _NS_VALIDATE_SIZE(ns, name) +#define _MTL_VALIDATE_ENUM(ns, name) _NS_VALIDATE_ENUM(ns, name) + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLDepthStencil.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLDepthStencil.hpp new file mode 100644 index 00000000..e1116175 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLDepthStencil.hpp @@ -0,0 +1,277 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLDepthStencil.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLTypes.hpp" +#include + +namespace MTL +{ +class DepthStencilDescriptor; +class Device; +class StencilDescriptor; +_MTL_ENUM(NS::UInteger, CompareFunction) { + CompareFunctionNever = 0, + CompareFunctionLess = 1, + CompareFunctionEqual = 2, + CompareFunctionLessEqual = 3, + CompareFunctionGreater = 4, + CompareFunctionNotEqual = 5, + CompareFunctionGreaterEqual = 6, + CompareFunctionAlways = 7, +}; + +_MTL_ENUM(NS::UInteger, StencilOperation) { + StencilOperationKeep = 0, + StencilOperationZero = 1, + StencilOperationReplace = 2, + StencilOperationIncrementClamp = 3, + StencilOperationDecrementClamp = 4, + StencilOperationInvert = 5, + StencilOperationIncrementWrap = 6, + StencilOperationDecrementWrap = 7, +}; + +class StencilDescriptor : public NS::Copying +{ +public: + static StencilDescriptor* alloc(); + + StencilOperation depthFailureOperation() const; + + StencilOperation depthStencilPassOperation() const; + + StencilDescriptor* init(); + + uint32_t readMask() const; + + void setDepthFailureOperation(MTL::StencilOperation depthFailureOperation); + + void setDepthStencilPassOperation(MTL::StencilOperation depthStencilPassOperation); + + void setReadMask(uint32_t readMask); + + void setStencilCompareFunction(MTL::CompareFunction stencilCompareFunction); + + void setStencilFailureOperation(MTL::StencilOperation stencilFailureOperation); + + void setWriteMask(uint32_t writeMask); + + CompareFunction stencilCompareFunction() const; + + StencilOperation stencilFailureOperation() const; + + uint32_t writeMask() const; +}; +class DepthStencilDescriptor : public NS::Copying +{ +public: + static DepthStencilDescriptor* alloc(); + + StencilDescriptor* backFaceStencil() const; + + CompareFunction depthCompareFunction() const; + + [[deprecated("please use isDepthWriteEnabled instead")]] + bool depthWriteEnabled() const; + + StencilDescriptor* frontFaceStencil() const; + + DepthStencilDescriptor* init(); + + bool isDepthWriteEnabled() const; + + NS::String* label() const; + + void setBackFaceStencil(const MTL::StencilDescriptor* backFaceStencil); + + void setDepthCompareFunction(MTL::CompareFunction depthCompareFunction); + + void setDepthWriteEnabled(bool depthWriteEnabled); + + void setFrontFaceStencil(const MTL::StencilDescriptor* frontFaceStencil); + + void setLabel(const NS::String* label); +}; +class DepthStencilState : public NS::Referencing +{ +public: + Device* device() const; + + ResourceID gpuResourceID() const; + + NS::String* label() const; +}; + +} +_MTL_INLINE MTL::StencilDescriptor* MTL::StencilDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLStencilDescriptor)); +} + +_MTL_INLINE MTL::StencilOperation MTL::StencilDescriptor::depthFailureOperation() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(depthFailureOperation)); +} + +_MTL_INLINE MTL::StencilOperation MTL::StencilDescriptor::depthStencilPassOperation() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(depthStencilPassOperation)); +} + +_MTL_INLINE MTL::StencilDescriptor* MTL::StencilDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE uint32_t MTL::StencilDescriptor::readMask() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(readMask)); +} + +_MTL_INLINE void MTL::StencilDescriptor::setDepthFailureOperation(MTL::StencilOperation depthFailureOperation) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDepthFailureOperation_), depthFailureOperation); +} + +_MTL_INLINE void MTL::StencilDescriptor::setDepthStencilPassOperation(MTL::StencilOperation depthStencilPassOperation) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDepthStencilPassOperation_), depthStencilPassOperation); +} + +_MTL_INLINE void MTL::StencilDescriptor::setReadMask(uint32_t readMask) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setReadMask_), readMask); +} + +_MTL_INLINE void MTL::StencilDescriptor::setStencilCompareFunction(MTL::CompareFunction stencilCompareFunction) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStencilCompareFunction_), stencilCompareFunction); +} + +_MTL_INLINE void MTL::StencilDescriptor::setStencilFailureOperation(MTL::StencilOperation stencilFailureOperation) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStencilFailureOperation_), stencilFailureOperation); +} + +_MTL_INLINE void MTL::StencilDescriptor::setWriteMask(uint32_t writeMask) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setWriteMask_), writeMask); +} + +_MTL_INLINE MTL::CompareFunction MTL::StencilDescriptor::stencilCompareFunction() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(stencilCompareFunction)); +} + +_MTL_INLINE MTL::StencilOperation MTL::StencilDescriptor::stencilFailureOperation() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(stencilFailureOperation)); +} + +_MTL_INLINE uint32_t MTL::StencilDescriptor::writeMask() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(writeMask)); +} + +_MTL_INLINE MTL::DepthStencilDescriptor* MTL::DepthStencilDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLDepthStencilDescriptor)); +} + +_MTL_INLINE MTL::StencilDescriptor* MTL::DepthStencilDescriptor::backFaceStencil() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(backFaceStencil)); +} + +_MTL_INLINE MTL::CompareFunction MTL::DepthStencilDescriptor::depthCompareFunction() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(depthCompareFunction)); +} + +_MTL_INLINE bool MTL::DepthStencilDescriptor::depthWriteEnabled() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isDepthWriteEnabled)); +} + +_MTL_INLINE MTL::StencilDescriptor* MTL::DepthStencilDescriptor::frontFaceStencil() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(frontFaceStencil)); +} + +_MTL_INLINE MTL::DepthStencilDescriptor* MTL::DepthStencilDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE bool MTL::DepthStencilDescriptor::isDepthWriteEnabled() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isDepthWriteEnabled)); +} + +_MTL_INLINE NS::String* MTL::DepthStencilDescriptor::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE void MTL::DepthStencilDescriptor::setBackFaceStencil(const MTL::StencilDescriptor* backFaceStencil) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBackFaceStencil_), backFaceStencil); +} + +_MTL_INLINE void MTL::DepthStencilDescriptor::setDepthCompareFunction(MTL::CompareFunction depthCompareFunction) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDepthCompareFunction_), depthCompareFunction); +} + +_MTL_INLINE void MTL::DepthStencilDescriptor::setDepthWriteEnabled(bool depthWriteEnabled) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDepthWriteEnabled_), depthWriteEnabled); +} + +_MTL_INLINE void MTL::DepthStencilDescriptor::setFrontFaceStencil(const MTL::StencilDescriptor* frontFaceStencil) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFrontFaceStencil_), frontFaceStencil); +} + +_MTL_INLINE void MTL::DepthStencilDescriptor::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +_MTL_INLINE MTL::Device* MTL::DepthStencilState::device() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(device)); +} + +_MTL_INLINE MTL::ResourceID MTL::DepthStencilState::gpuResourceID() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(gpuResourceID)); +} + +_MTL_INLINE NS::String* MTL::DepthStencilState::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLDevice.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLDevice.hpp new file mode 100644 index 00000000..0e867397 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLDevice.hpp @@ -0,0 +1,1493 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLDevice.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTL4Counters.hpp" +#include "MTLArgument.hpp" +#include "MTLDataType.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPixelFormat.hpp" +#include "MTLPrivate.hpp" +#include "MTLResource.hpp" +#include "MTLTexture.hpp" +#include "MTLTypes.hpp" +#include +#include +#include + +#include +#include +#include + +namespace MTL +{ +class AccelerationStructure; +class AccelerationStructureDescriptor; +class Architecture; +class ArgumentDescriptor; +class ArgumentEncoder; +class BinaryArchive; +class BinaryArchiveDescriptor; +class Buffer; +class BufferBinding; +class CommandQueue; +class CommandQueueDescriptor; +class CompileOptions; +class ComputePipelineDescriptor; +class ComputePipelineReflection; +class ComputePipelineState; +class CounterSampleBuffer; +class CounterSampleBufferDescriptor; +class DepthStencilDescriptor; +class DepthStencilState; +class Device; +class DynamicLibrary; +class Event; +class Fence; +class Function; +class FunctionConstantValues; +class FunctionHandle; +class Heap; +class HeapDescriptor; +class IOCommandQueue; +class IOCommandQueueDescriptor; +class IOFileHandle; +class IndirectCommandBuffer; +class IndirectCommandBufferDescriptor; +class Library; +class LogState; +class LogStateDescriptor; +class MeshRenderPipelineDescriptor; +class RasterizationRateMap; +class RasterizationRateMapDescriptor; +struct Region; +class RenderPipelineDescriptor; +class RenderPipelineReflection; +class RenderPipelineState; +class ResidencySet; +class ResidencySetDescriptor; +class ResourceViewPoolDescriptor; +struct SamplePosition; +class SamplerDescriptor; +class SamplerState; +class SharedEvent; +class SharedEventHandle; +class SharedTextureHandle; +class StitchedLibraryDescriptor; +class Tensor; +class TensorDescriptor; +class Texture; +class TextureDescriptor; +class TextureViewPool; +class TileRenderPipelineDescriptor; + +} +namespace MTL4 +{ +class Archive; +class ArgumentTable; +class ArgumentTableDescriptor; +class BinaryFunction; +class CommandAllocator; +class CommandAllocatorDescriptor; +class CommandBuffer; +class CommandQueue; +class CommandQueueDescriptor; +class Compiler; +class CompilerDescriptor; +class CounterHeap; +class CounterHeapDescriptor; +class PipelineDataSetSerializer; +class PipelineDataSetSerializerDescriptor; + +} +namespace MTL +{ +_MTL_ENUM(NS::Integer, IOCompressionMethod) { + IOCompressionMethodZlib = 0, + IOCompressionMethodLZFSE = 1, + IOCompressionMethodLZ4 = 2, + IOCompressionMethodLZMA = 3, + IOCompressionMethodLZBitmap = 4, +}; + +_MTL_ENUM(NS::UInteger, FeatureSet) { + FeatureSet_iOS_GPUFamily1_v1 = 0, + FeatureSet_iOS_GPUFamily2_v1 = 1, + FeatureSet_iOS_GPUFamily1_v2 = 2, + FeatureSet_iOS_GPUFamily2_v2 = 3, + FeatureSet_iOS_GPUFamily3_v1 = 4, + FeatureSet_iOS_GPUFamily1_v3 = 5, + FeatureSet_iOS_GPUFamily2_v3 = 6, + FeatureSet_iOS_GPUFamily3_v2 = 7, + FeatureSet_iOS_GPUFamily1_v4 = 8, + FeatureSet_iOS_GPUFamily2_v4 = 9, + FeatureSet_iOS_GPUFamily3_v3 = 10, + FeatureSet_iOS_GPUFamily4_v1 = 11, + FeatureSet_iOS_GPUFamily1_v5 = 12, + FeatureSet_iOS_GPUFamily2_v5 = 13, + FeatureSet_iOS_GPUFamily3_v4 = 14, + FeatureSet_iOS_GPUFamily4_v2 = 15, + FeatureSet_iOS_GPUFamily5_v1 = 16, + FeatureSet_macOS_GPUFamily1_v1 = 10000, + FeatureSet_OSX_GPUFamily1_v1 = 10000, + FeatureSet_macOS_GPUFamily1_v2 = 10001, + FeatureSet_OSX_GPUFamily1_v2 = 10001, + FeatureSet_macOS_ReadWriteTextureTier2 = 10002, + FeatureSet_OSX_ReadWriteTextureTier2 = 10002, + FeatureSet_macOS_GPUFamily1_v3 = 10003, + FeatureSet_macOS_GPUFamily1_v4 = 10004, + FeatureSet_macOS_GPUFamily2_v1 = 10005, + FeatureSet_watchOS_GPUFamily1_v1 = 20000, + FeatureSet_WatchOS_GPUFamily1_v1 = 20000, + FeatureSet_watchOS_GPUFamily2_v1 = 20001, + FeatureSet_WatchOS_GPUFamily2_v1 = 20001, + FeatureSet_tvOS_GPUFamily1_v1 = 30000, + FeatureSet_TVOS_GPUFamily1_v1 = 30000, + FeatureSet_tvOS_GPUFamily1_v2 = 30001, + FeatureSet_tvOS_GPUFamily1_v3 = 30002, + FeatureSet_tvOS_GPUFamily2_v1 = 30003, + FeatureSet_tvOS_GPUFamily1_v4 = 30004, + FeatureSet_tvOS_GPUFamily2_v2 = 30005, +}; + +_MTL_ENUM(NS::Integer, GPUFamily) { + GPUFamilyApple1 = 1001, + GPUFamilyApple2 = 1002, + GPUFamilyApple3 = 1003, + GPUFamilyApple4 = 1004, + GPUFamilyApple5 = 1005, + GPUFamilyApple6 = 1006, + GPUFamilyApple7 = 1007, + GPUFamilyApple8 = 1008, + GPUFamilyApple9 = 1009, + GPUFamilyApple10 = 1010, + GPUFamilyMac1 = 2001, + GPUFamilyMac2 = 2002, + GPUFamilyCommon1 = 3001, + GPUFamilyCommon2 = 3002, + GPUFamilyCommon3 = 3003, + GPUFamilyMacCatalyst1 = 4001, + GPUFamilyMacCatalyst2 = 4002, + GPUFamilyMetal3 = 5001, + GPUFamilyMetal4 = 5002, +}; + +_MTL_ENUM(NS::UInteger, DeviceLocation) { + DeviceLocationBuiltIn = 0, + DeviceLocationSlot = 1, + DeviceLocationExternal = 2, + DeviceLocationUnspecified = NS::UIntegerMax, +}; + +_MTL_ENUM(NS::UInteger, ReadWriteTextureTier) { + ReadWriteTextureTierNone = 0, + ReadWriteTextureTier1 = 1, + ReadWriteTextureTier2 = 2, +}; + +_MTL_ENUM(NS::UInteger, ArgumentBuffersTier) { + ArgumentBuffersTier1 = 0, + ArgumentBuffersTier2 = 1, +}; + +_MTL_ENUM(NS::UInteger, SparseTextureRegionAlignmentMode) { + SparseTextureRegionAlignmentModeOutward = 0, + SparseTextureRegionAlignmentModeInward = 1, +}; + +_MTL_ENUM(NS::UInteger, CounterSamplingPoint) { + CounterSamplingPointAtStageBoundary = 0, + CounterSamplingPointAtDrawBoundary = 1, + CounterSamplingPointAtDispatchBoundary = 2, + CounterSamplingPointAtTileDispatchBoundary = 3, + CounterSamplingPointAtBlitBoundary = 4, +}; + +_MTL_OPTIONS(NS::UInteger, PipelineOption) { + PipelineOptionNone = 0, + PipelineOptionArgumentInfo = 1, + PipelineOptionBindingInfo = 1, + PipelineOptionBufferTypeInfo = 1 << 1, + PipelineOptionFailOnBinaryArchiveMiss = 1 << 2, +}; + +using DeviceNotificationName = NS::String*; +using DeviceNotificationHandlerBlock = void (^)(MTL::Device* pDevice, MTL::DeviceNotificationName notifyName); +using DeviceNotificationHandlerFunction = std::function; +using AutoreleasedComputePipelineReflection = MTL::ComputePipelineReflection*; +using AutoreleasedRenderPipelineReflection = MTL::RenderPipelineReflection*; +using NewLibraryCompletionHandler = void (^)(MTL::Library*, NS::Error*); +using NewLibraryCompletionHandlerFunction = std::function; +using NewRenderPipelineStateCompletionHandler = void (^)(MTL::RenderPipelineState*, NS::Error*); +using NewRenderPipelineStateCompletionHandlerFunction = std::function; +using NewRenderPipelineStateWithReflectionCompletionHandler = void (^)(MTL::RenderPipelineState*, MTL::RenderPipelineReflection*, NS::Error*); +using NewRenderPipelineStateWithReflectionCompletionHandlerFunction = std::function; +using NewComputePipelineStateCompletionHandler = void (^)(MTL::ComputePipelineState*, NS::Error*); +using NewComputePipelineStateCompletionHandlerFunction = std::function; +using NewComputePipelineStateWithReflectionCompletionHandler = void (^)(MTL::ComputePipelineState*, MTL::ComputePipelineReflection*, NS::Error*); +using NewComputePipelineStateWithReflectionCompletionHandlerFunction = std::function; +using Timestamp = std::uint64_t; + +_MTL_CONST(DeviceNotificationName, DeviceWasAddedNotification); +_MTL_CONST(DeviceNotificationName, DeviceRemovalRequestedNotification); +_MTL_CONST(DeviceNotificationName, DeviceWasRemovedNotification); +_MTL_CONST(NS::ErrorUserInfoKey, CommandBufferEncoderInfoErrorKey); +Device* CreateSystemDefaultDevice(); +NS::Array* CopyAllDevices(); +NS::Array* CopyAllDevicesWithObserver(NS::Object** pOutObserver, MTL::DeviceNotificationHandlerBlock handler); +NS::Array* CopyAllDevicesWithObserver(NS::Object** pOutObserver, const MTL::DeviceNotificationHandlerFunction& handler); +void RemoveDeviceObserver(const NS::Object* pObserver); +struct AccelerationStructureSizes +{ + NS::UInteger accelerationStructureSize; + NS::UInteger buildScratchBufferSize; + NS::UInteger refitScratchBufferSize; +} _MTL_PACKED; + +struct SizeAndAlign +{ + NS::UInteger size; + NS::UInteger align; +} _MTL_PACKED; + +class ArgumentDescriptor : public NS::Copying +{ +public: + BindingAccess access() const; + + static ArgumentDescriptor* alloc(); + + static ArgumentDescriptor* argumentDescriptor(); + + NS::UInteger arrayLength() const; + + NS::UInteger constantBlockAlignment() const; + + DataType dataType() const; + + NS::UInteger index() const; + + ArgumentDescriptor* init(); + + void setAccess(MTL::BindingAccess access); + + void setArrayLength(NS::UInteger arrayLength); + + void setConstantBlockAlignment(NS::UInteger constantBlockAlignment); + + void setDataType(MTL::DataType dataType); + + void setIndex(NS::UInteger index); + + void setTextureType(MTL::TextureType textureType); + TextureType textureType() const; +}; +class Architecture : public NS::Copying +{ +public: + static Architecture* alloc(); + + Architecture* init(); + + NS::String* name() const; +}; +class Device : public NS::Referencing +{ +public: + AccelerationStructureSizes accelerationStructureSizes(const MTL::AccelerationStructureDescriptor* descriptor); + + Architecture* architecture() const; + + bool areBarycentricCoordsSupported() const; + + bool areProgrammableSamplePositionsSupported() const; + + bool areRasterOrderGroupsSupported() const; + + ArgumentBuffersTier argumentBuffersSupport() const; + + [[deprecated("please use areBarycentricCoordsSupported instead")]] + bool barycentricCoordsSupported() const; + + void convertSparsePixelRegions(const MTL::Region* pixelRegions, MTL::Region* tileRegions, MTL::Size tileSize, MTL::SparseTextureRegionAlignmentMode mode, NS::UInteger numRegions); + + void convertSparseTileRegions(const MTL::Region* tileRegions, MTL::Region* pixelRegions, MTL::Size tileSize, NS::UInteger numRegions); + + NS::Array* counterSets() const; + + NS::UInteger currentAllocatedSize() const; + + [[deprecated("please use isDepth24Stencil8PixelFormatSupported instead")]] + bool depth24Stencil8PixelFormatSupported() const; + + FunctionHandle* functionHandle(const MTL::Function* function); + FunctionHandle* functionHandle(const MTL4::BinaryFunction* function); + + void getDefaultSamplePositions(MTL::SamplePosition* positions, NS::UInteger count); + + bool hasUnifiedMemory() const; + + [[deprecated("please use isHeadless instead")]] + bool headless() const; + + SizeAndAlign heapAccelerationStructureSizeAndAlign(NS::UInteger size); + SizeAndAlign heapAccelerationStructureSizeAndAlign(const MTL::AccelerationStructureDescriptor* descriptor); + + SizeAndAlign heapBufferSizeAndAlign(NS::UInteger length, MTL::ResourceOptions options); + + SizeAndAlign heapTextureSizeAndAlign(const MTL::TextureDescriptor* desc); + + bool isDepth24Stencil8PixelFormatSupported() const; + + bool isHeadless() const; + + bool isLowPower() const; + + bool isRemovable() const; + + DeviceLocation location() const; + NS::UInteger locationNumber() const; + + [[deprecated("please use isLowPower instead")]] + bool lowPower() const; + + NS::UInteger maxArgumentBufferSamplerCount() const; + + NS::UInteger maxBufferLength() const; + + NS::UInteger maxThreadgroupMemoryLength() const; + + Size maxThreadsPerThreadgroup() const; + + uint64_t maxTransferRate() const; + + NS::UInteger maximumConcurrentCompilationTaskCount() const; + + NS::UInteger minimumLinearTextureAlignmentForPixelFormat(MTL::PixelFormat format); + + NS::UInteger minimumTextureBufferAlignmentForPixelFormat(MTL::PixelFormat format); + + NS::String* name() const; + + AccelerationStructure* newAccelerationStructure(NS::UInteger size); + AccelerationStructure* newAccelerationStructure(const MTL::AccelerationStructureDescriptor* descriptor); + + MTL4::Archive* newArchive(const NS::URL* url, NS::Error** error); + + ArgumentEncoder* newArgumentEncoder(const NS::Array* arguments); + ArgumentEncoder* newArgumentEncoder(const MTL::BufferBinding* bufferBinding); + + MTL4::ArgumentTable* newArgumentTable(const MTL4::ArgumentTableDescriptor* descriptor, NS::Error** error); + + BinaryArchive* newBinaryArchive(const MTL::BinaryArchiveDescriptor* descriptor, NS::Error** error); + + Buffer* newBuffer(NS::UInteger length, MTL::ResourceOptions options); + Buffer* newBuffer(const void* pointer, NS::UInteger length, MTL::ResourceOptions options); + Buffer* newBuffer(const void* pointer, NS::UInteger length, MTL::ResourceOptions options, void (^deallocator)(void*, NS::UInteger)); + Buffer* newBuffer(NS::UInteger length, MTL::ResourceOptions options, MTL::SparsePageSize placementSparsePageSize); + + MTL4::CommandAllocator* newCommandAllocator(); + MTL4::CommandAllocator* newCommandAllocator(const MTL4::CommandAllocatorDescriptor* descriptor, NS::Error** error); + + MTL4::CommandBuffer* newCommandBuffer(); + + CommandQueue* newCommandQueue(); + CommandQueue* newCommandQueue(NS::UInteger maxCommandBufferCount); + CommandQueue* newCommandQueue(const MTL::CommandQueueDescriptor* descriptor); + + MTL4::Compiler* newCompiler(const MTL4::CompilerDescriptor* descriptor, NS::Error** error); + + ComputePipelineState* newComputePipelineState(const MTL::Function* computeFunction, NS::Error** error); + ComputePipelineState* newComputePipelineState(const MTL::Function* computeFunction, MTL::PipelineOption options, const MTL::AutoreleasedComputePipelineReflection* reflection, NS::Error** error); + void newComputePipelineState(const MTL::Function* computeFunction, const MTL::NewComputePipelineStateCompletionHandler completionHandler); + void newComputePipelineState(const MTL::Function* computeFunction, MTL::PipelineOption options, const MTL::NewComputePipelineStateWithReflectionCompletionHandler completionHandler); + ComputePipelineState* newComputePipelineState(const MTL::ComputePipelineDescriptor* descriptor, MTL::PipelineOption options, const MTL::AutoreleasedComputePipelineReflection* reflection, NS::Error** error); + void newComputePipelineState(const MTL::ComputePipelineDescriptor* descriptor, MTL::PipelineOption options, const MTL::NewComputePipelineStateWithReflectionCompletionHandler completionHandler); + void newComputePipelineState(const MTL::Function* pFunction, const MTL::NewComputePipelineStateCompletionHandlerFunction& completionHandler); + void newComputePipelineState(const MTL::Function* pFunction, MTL::PipelineOption options, const MTL::NewComputePipelineStateWithReflectionCompletionHandlerFunction& completionHandler); + void newComputePipelineState(const MTL::ComputePipelineDescriptor* pDescriptor, MTL::PipelineOption options, const MTL::NewComputePipelineStateWithReflectionCompletionHandlerFunction& completionHandler); + + MTL4::CounterHeap* newCounterHeap(const MTL4::CounterHeapDescriptor* descriptor, NS::Error** error); + + CounterSampleBuffer* newCounterSampleBuffer(const MTL::CounterSampleBufferDescriptor* descriptor, NS::Error** error); + + Library* newDefaultLibrary(); + Library* newDefaultLibrary(const NS::Bundle* bundle, NS::Error** error); + + DepthStencilState* newDepthStencilState(const MTL::DepthStencilDescriptor* descriptor); + + DynamicLibrary* newDynamicLibrary(const MTL::Library* library, NS::Error** error); + DynamicLibrary* newDynamicLibrary(const NS::URL* url, NS::Error** error); + + Event* newEvent(); + + Fence* newFence(); + + Heap* newHeap(const MTL::HeapDescriptor* descriptor); + + IOCommandQueue* newIOCommandQueue(const MTL::IOCommandQueueDescriptor* descriptor, NS::Error** error); + + IOFileHandle* newIOFileHandle(const NS::URL* url, NS::Error** error); + IOFileHandle* newIOFileHandle(const NS::URL* url, MTL::IOCompressionMethod compressionMethod, NS::Error** error); + + IOFileHandle* newIOHandle(const NS::URL* url, NS::Error** error); + IOFileHandle* newIOHandle(const NS::URL* url, MTL::IOCompressionMethod compressionMethod, NS::Error** error); + + IndirectCommandBuffer* newIndirectCommandBuffer(const MTL::IndirectCommandBufferDescriptor* descriptor, NS::UInteger maxCount, MTL::ResourceOptions options); + + Library* newLibrary(const NS::String* filepath, NS::Error** error); + Library* newLibrary(const NS::URL* url, NS::Error** error); + Library* newLibrary(const dispatch_data_t data, NS::Error** error); + Library* newLibrary(const NS::String* source, const MTL::CompileOptions* options, NS::Error** error); + void newLibrary(const NS::String* source, const MTL::CompileOptions* options, const MTL::NewLibraryCompletionHandler completionHandler); + Library* newLibrary(const MTL::StitchedLibraryDescriptor* descriptor, NS::Error** error); + void newLibrary(const MTL::StitchedLibraryDescriptor* descriptor, const MTL::NewLibraryCompletionHandler completionHandler); + void newLibrary(const NS::String* pSource, const MTL::CompileOptions* pOptions, const MTL::NewLibraryCompletionHandlerFunction& completionHandler); + void newLibrary(const MTL::StitchedLibraryDescriptor* pDescriptor, const MTL::NewLibraryCompletionHandlerFunction& completionHandler); + + LogState* newLogState(const MTL::LogStateDescriptor* descriptor, NS::Error** error); + + MTL4::CommandQueue* newMTL4CommandQueue(); + MTL4::CommandQueue* newMTL4CommandQueue(const MTL4::CommandQueueDescriptor* descriptor, NS::Error** error); + + MTL4::PipelineDataSetSerializer* newPipelineDataSetSerializer(const MTL4::PipelineDataSetSerializerDescriptor* descriptor); + + RasterizationRateMap* newRasterizationRateMap(const MTL::RasterizationRateMapDescriptor* descriptor); + + RenderPipelineState* newRenderPipelineState(const MTL::RenderPipelineDescriptor* descriptor, NS::Error** error); + RenderPipelineState* newRenderPipelineState(const MTL::RenderPipelineDescriptor* descriptor, MTL::PipelineOption options, const MTL::AutoreleasedRenderPipelineReflection* reflection, NS::Error** error); + void newRenderPipelineState(const MTL::RenderPipelineDescriptor* descriptor, const MTL::NewRenderPipelineStateCompletionHandler completionHandler); + void newRenderPipelineState(const MTL::RenderPipelineDescriptor* descriptor, MTL::PipelineOption options, const MTL::NewRenderPipelineStateWithReflectionCompletionHandler completionHandler); + RenderPipelineState* newRenderPipelineState(const MTL::TileRenderPipelineDescriptor* descriptor, MTL::PipelineOption options, const MTL::AutoreleasedRenderPipelineReflection* reflection, NS::Error** error); + void newRenderPipelineState(const MTL::TileRenderPipelineDescriptor* descriptor, MTL::PipelineOption options, const MTL::NewRenderPipelineStateWithReflectionCompletionHandler completionHandler); + RenderPipelineState* newRenderPipelineState(const MTL::MeshRenderPipelineDescriptor* descriptor, MTL::PipelineOption options, const MTL::AutoreleasedRenderPipelineReflection* reflection, NS::Error** error); + void newRenderPipelineState(const MTL::MeshRenderPipelineDescriptor* descriptor, MTL::PipelineOption options, const MTL::NewRenderPipelineStateWithReflectionCompletionHandler completionHandler); + void newRenderPipelineState(const MTL::RenderPipelineDescriptor* pDescriptor, const MTL::NewRenderPipelineStateCompletionHandlerFunction& completionHandler); + void newRenderPipelineState(const MTL::RenderPipelineDescriptor* pDescriptor, MTL::PipelineOption options, const MTL::NewRenderPipelineStateWithReflectionCompletionHandlerFunction& completionHandler); + void newRenderPipelineState(const MTL::TileRenderPipelineDescriptor* pDescriptor, MTL::PipelineOption options, const MTL::NewRenderPipelineStateWithReflectionCompletionHandlerFunction& completionHandler); + + ResidencySet* newResidencySet(const MTL::ResidencySetDescriptor* desc, NS::Error** error); + + SamplerState* newSamplerState(const MTL::SamplerDescriptor* descriptor); + + SharedEvent* newSharedEvent(); + SharedEvent* newSharedEvent(const MTL::SharedEventHandle* sharedEventHandle); + + Texture* newSharedTexture(const MTL::TextureDescriptor* descriptor); + Texture* newSharedTexture(const MTL::SharedTextureHandle* sharedHandle); + + Tensor* newTensor(const MTL::TensorDescriptor* descriptor, NS::Error** error); + + Texture* newTexture(const MTL::TextureDescriptor* descriptor); + Texture* newTexture(const MTL::TextureDescriptor* descriptor, const IOSurfaceRef iosurface, NS::UInteger plane); + TextureViewPool* newTextureViewPool(const MTL::ResourceViewPoolDescriptor* descriptor, NS::Error** error); + + uint32_t peerCount() const; + + uint64_t peerGroupID() const; + + uint32_t peerIndex() const; + + [[deprecated("please use areProgrammableSamplePositionsSupported instead")]] + bool programmableSamplePositionsSupported() const; + + uint64_t queryTimestampFrequency(); + + [[deprecated("please use areRasterOrderGroupsSupported instead")]] + bool rasterOrderGroupsSupported() const; + + ReadWriteTextureTier readWriteTextureSupport() const; + + uint64_t recommendedMaxWorkingSetSize() const; + + uint64_t registryID() const; + + [[deprecated("please use isRemovable instead")]] + bool removable() const; + + void sampleTimestamps(MTL::Timestamp* cpuTimestamp, MTL::Timestamp* gpuTimestamp); + + void setShouldMaximizeConcurrentCompilation(bool shouldMaximizeConcurrentCompilation); + bool shouldMaximizeConcurrentCompilation() const; + + NS::UInteger sizeOfCounterHeapEntry(MTL4::CounterHeapType type); + + Size sparseTileSize(MTL::TextureType textureType, MTL::PixelFormat pixelFormat, NS::UInteger sampleCount); + Size sparseTileSize(MTL::TextureType textureType, MTL::PixelFormat pixelFormat, NS::UInteger sampleCount, MTL::SparsePageSize sparsePageSize); + NS::UInteger sparseTileSizeInBytes() const; + NS::UInteger sparseTileSizeInBytes(MTL::SparsePageSize sparsePageSize); + + bool supports32BitFloatFiltering() const; + + bool supports32BitMSAA() const; + + bool supportsBCTextureCompression() const; + + bool supportsCounterSampling(MTL::CounterSamplingPoint samplingPoint); + + bool supportsDynamicLibraries() const; + + bool supportsFamily(MTL::GPUFamily gpuFamily); + + bool supportsFeatureSet(MTL::FeatureSet featureSet); + + bool supportsFunctionPointers() const; + bool supportsFunctionPointersFromRender() const; + + bool supportsPrimitiveMotionBlur() const; + + bool supportsPullModelInterpolation() const; + + bool supportsQueryTextureLOD() const; + + bool supportsRasterizationRateMap(NS::UInteger layerCount); + + bool supportsRaytracing() const; + bool supportsRaytracingFromRender() const; + + bool supportsRenderDynamicLibraries() const; + + bool supportsShaderBarycentricCoordinates() const; + + bool supportsTextureSampleCount(NS::UInteger sampleCount); + + bool supportsVertexAmplificationCount(NS::UInteger count); + + SizeAndAlign tensorSizeAndAlign(const MTL::TensorDescriptor* descriptor); +}; + +} + +#if defined(MTL_PRIVATE_IMPLEMENTATION) +extern "C" MTL::Device* MTLCreateSystemDefaultDevice(); +extern "C" NS::Array* MTLCopyAllDevices(); +extern "C" NS::Array* MTLCopyAllDevicesWithObserver(NS::Object**, MTL::DeviceNotificationHandlerBlock); +extern "C" void MTLRemoveDeviceObserver(const NS::Object*); +_MTL_PRIVATE_DEF_WEAK_CONST(MTL::DeviceNotificationName, DeviceWasAddedNotification); +_MTL_PRIVATE_DEF_WEAK_CONST(MTL::DeviceNotificationName, DeviceRemovalRequestedNotification); +_MTL_PRIVATE_DEF_WEAK_CONST(MTL::DeviceNotificationName, DeviceWasRemovedNotification); +_MTL_PRIVATE_DEF_CONST(NS::ErrorUserInfoKey, CommandBufferEncoderInfoErrorKey); +_NS_EXPORT MTL::Device* MTL::CreateSystemDefaultDevice() +{ + return ::MTLCreateSystemDefaultDevice(); +} + +_NS_EXPORT NS::Array* MTL::CopyAllDevices() +{ +#if (__IPHONE_OS_VERSION_MIN_REQUIRED >= 180000) || (__MAC_OS_X_VERSION_MIN_REQUIRED >= 101100) + return ::MTLCopyAllDevices(); +#else + return nullptr; +#endif +} + +_NS_EXPORT NS::Array* MTL::CopyAllDevicesWithObserver(NS::Object** pOutObserver, MTL::DeviceNotificationHandlerBlock handler) +{ +#if TARGET_OS_OSX + return ::MTLCopyAllDevicesWithObserver(pOutObserver, handler); +#else + (void)pOutObserver; + (void)handler; + return nullptr; +#endif // TARGET_OS_OSX +} + +_NS_EXPORT NS::Array* MTL::CopyAllDevicesWithObserver(NS::Object** pOutObserver, const MTL::DeviceNotificationHandlerFunction& handler) +{ + __block DeviceNotificationHandlerFunction function = handler; + return CopyAllDevicesWithObserver(pOutObserver, ^(Device* pDevice, DeviceNotificationName pNotificationName) { function(pDevice, pNotificationName); }); +} + +_NS_EXPORT void MTL::RemoveDeviceObserver(const NS::Object* pObserver) +{ + (void)pObserver; +#if TARGET_OS_OSX + ::MTLRemoveDeviceObserver(pObserver); +#endif // TARGET_OS_OSX +} + +#endif // MTL_PRIVATE_IMPLEMENTATION + +_MTL_INLINE MTL::BindingAccess MTL::ArgumentDescriptor::access() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(access)); +} + +_MTL_INLINE MTL::ArgumentDescriptor* MTL::ArgumentDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLArgumentDescriptor)); +} + +_MTL_INLINE MTL::ArgumentDescriptor* MTL::ArgumentDescriptor::argumentDescriptor() +{ + return Object::sendMessage(_MTL_PRIVATE_CLS(MTLArgumentDescriptor), _MTL_PRIVATE_SEL(argumentDescriptor)); +} + +_MTL_INLINE NS::UInteger MTL::ArgumentDescriptor::arrayLength() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(arrayLength)); +} + +_MTL_INLINE NS::UInteger MTL::ArgumentDescriptor::constantBlockAlignment() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(constantBlockAlignment)); +} + +_MTL_INLINE MTL::DataType MTL::ArgumentDescriptor::dataType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(dataType)); +} + +_MTL_INLINE NS::UInteger MTL::ArgumentDescriptor::index() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(index)); +} + +_MTL_INLINE MTL::ArgumentDescriptor* MTL::ArgumentDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE void MTL::ArgumentDescriptor::setAccess(MTL::BindingAccess access) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setAccess_), access); +} + +_MTL_INLINE void MTL::ArgumentDescriptor::setArrayLength(NS::UInteger arrayLength) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setArrayLength_), arrayLength); +} + +_MTL_INLINE void MTL::ArgumentDescriptor::setConstantBlockAlignment(NS::UInteger constantBlockAlignment) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setConstantBlockAlignment_), constantBlockAlignment); +} + +_MTL_INLINE void MTL::ArgumentDescriptor::setDataType(MTL::DataType dataType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDataType_), dataType); +} + +_MTL_INLINE void MTL::ArgumentDescriptor::setIndex(NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setIndex_), index); +} + +_MTL_INLINE void MTL::ArgumentDescriptor::setTextureType(MTL::TextureType textureType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTextureType_), textureType); +} + +_MTL_INLINE MTL::TextureType MTL::ArgumentDescriptor::textureType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(textureType)); +} + +_MTL_INLINE MTL::Architecture* MTL::Architecture::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLArchitecture)); +} + +_MTL_INLINE MTL::Architecture* MTL::Architecture::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::String* MTL::Architecture::name() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(name)); +} + +_MTL_INLINE MTL::AccelerationStructureSizes MTL::Device::accelerationStructureSizes(const MTL::AccelerationStructureDescriptor* descriptor) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(accelerationStructureSizesWithDescriptor_), descriptor); +} + +_MTL_INLINE MTL::Architecture* MTL::Device::architecture() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(architecture)); +} + +_MTL_INLINE bool MTL::Device::areBarycentricCoordsSupported() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(areBarycentricCoordsSupported)); +} + +_MTL_INLINE bool MTL::Device::areProgrammableSamplePositionsSupported() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(areProgrammableSamplePositionsSupported)); +} + +_MTL_INLINE bool MTL::Device::areRasterOrderGroupsSupported() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(areRasterOrderGroupsSupported)); +} + +_MTL_INLINE MTL::ArgumentBuffersTier MTL::Device::argumentBuffersSupport() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(argumentBuffersSupport)); +} + +_MTL_INLINE bool MTL::Device::barycentricCoordsSupported() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(areBarycentricCoordsSupported)); +} + +_MTL_INLINE void MTL::Device::convertSparsePixelRegions(const MTL::Region* pixelRegions, MTL::Region* tileRegions, MTL::Size tileSize, MTL::SparseTextureRegionAlignmentMode mode, NS::UInteger numRegions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(convertSparsePixelRegions_toTileRegions_withTileSize_alignmentMode_numRegions_), pixelRegions, tileRegions, tileSize, mode, numRegions); +} + +_MTL_INLINE void MTL::Device::convertSparseTileRegions(const MTL::Region* tileRegions, MTL::Region* pixelRegions, MTL::Size tileSize, NS::UInteger numRegions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(convertSparseTileRegions_toPixelRegions_withTileSize_numRegions_), tileRegions, pixelRegions, tileSize, numRegions); +} + +_MTL_INLINE NS::Array* MTL::Device::counterSets() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(counterSets)); +} + +_MTL_INLINE NS::UInteger MTL::Device::currentAllocatedSize() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(currentAllocatedSize)); +} + +_MTL_INLINE bool MTL::Device::depth24Stencil8PixelFormatSupported() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(isDepth24Stencil8PixelFormatSupported)); +} + +_MTL_INLINE MTL::FunctionHandle* MTL::Device::functionHandle(const MTL::Function* function) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(functionHandleWithFunction_), function); +} + +_MTL_INLINE MTL::FunctionHandle* MTL::Device::functionHandle(const MTL4::BinaryFunction* function) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(functionHandleWithBinaryFunction_), function); +} + +_MTL_INLINE void MTL::Device::getDefaultSamplePositions(MTL::SamplePosition* positions, NS::UInteger count) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(getDefaultSamplePositions_count_), positions, count); +} + +_MTL_INLINE bool MTL::Device::hasUnifiedMemory() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(hasUnifiedMemory)); +} + +_MTL_INLINE bool MTL::Device::headless() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isHeadless)); +} + +_MTL_INLINE MTL::SizeAndAlign MTL::Device::heapAccelerationStructureSizeAndAlign(NS::UInteger size) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(heapAccelerationStructureSizeAndAlignWithSize_), size); +} + +_MTL_INLINE MTL::SizeAndAlign MTL::Device::heapAccelerationStructureSizeAndAlign(const MTL::AccelerationStructureDescriptor* descriptor) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(heapAccelerationStructureSizeAndAlignWithDescriptor_), descriptor); +} + +_MTL_INLINE MTL::SizeAndAlign MTL::Device::heapBufferSizeAndAlign(NS::UInteger length, MTL::ResourceOptions options) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(heapBufferSizeAndAlignWithLength_options_), length, options); +} + +_MTL_INLINE MTL::SizeAndAlign MTL::Device::heapTextureSizeAndAlign(const MTL::TextureDescriptor* desc) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(heapTextureSizeAndAlignWithDescriptor_), desc); +} + +_MTL_INLINE bool MTL::Device::isDepth24Stencil8PixelFormatSupported() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(isDepth24Stencil8PixelFormatSupported)); +} + +_MTL_INLINE bool MTL::Device::isHeadless() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isHeadless)); +} + +_MTL_INLINE bool MTL::Device::isLowPower() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isLowPower)); +} + +_MTL_INLINE bool MTL::Device::isRemovable() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isRemovable)); +} + +_MTL_INLINE MTL::DeviceLocation MTL::Device::location() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(location)); +} + +_MTL_INLINE NS::UInteger MTL::Device::locationNumber() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(locationNumber)); +} + +_MTL_INLINE bool MTL::Device::lowPower() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isLowPower)); +} + +_MTL_INLINE NS::UInteger MTL::Device::maxArgumentBufferSamplerCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxArgumentBufferSamplerCount)); +} + +_MTL_INLINE NS::UInteger MTL::Device::maxBufferLength() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxBufferLength)); +} + +_MTL_INLINE NS::UInteger MTL::Device::maxThreadgroupMemoryLength() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxThreadgroupMemoryLength)); +} + +_MTL_INLINE MTL::Size MTL::Device::maxThreadsPerThreadgroup() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxThreadsPerThreadgroup)); +} + +_MTL_INLINE uint64_t MTL::Device::maxTransferRate() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxTransferRate)); +} + +_MTL_INLINE NS::UInteger MTL::Device::maximumConcurrentCompilationTaskCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maximumConcurrentCompilationTaskCount)); +} + +_MTL_INLINE NS::UInteger MTL::Device::minimumLinearTextureAlignmentForPixelFormat(MTL::PixelFormat format) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(minimumLinearTextureAlignmentForPixelFormat_), format); +} + +_MTL_INLINE NS::UInteger MTL::Device::minimumTextureBufferAlignmentForPixelFormat(MTL::PixelFormat format) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(minimumTextureBufferAlignmentForPixelFormat_), format); +} + +_MTL_INLINE NS::String* MTL::Device::name() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(name)); +} + +_MTL_INLINE MTL::AccelerationStructure* MTL::Device::newAccelerationStructure(NS::UInteger size) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newAccelerationStructureWithSize_), size); +} + +_MTL_INLINE MTL::AccelerationStructure* MTL::Device::newAccelerationStructure(const MTL::AccelerationStructureDescriptor* descriptor) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newAccelerationStructureWithDescriptor_), descriptor); +} + +_MTL_INLINE MTL4::Archive* MTL::Device::newArchive(const NS::URL* url, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newArchiveWithURL_error_), url, error); +} + +_MTL_INLINE MTL::ArgumentEncoder* MTL::Device::newArgumentEncoder(const NS::Array* arguments) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newArgumentEncoderWithArguments_), arguments); +} + +_MTL_INLINE MTL::ArgumentEncoder* MTL::Device::newArgumentEncoder(const MTL::BufferBinding* bufferBinding) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newArgumentEncoderWithBufferBinding_), bufferBinding); +} + +_MTL_INLINE MTL4::ArgumentTable* MTL::Device::newArgumentTable(const MTL4::ArgumentTableDescriptor* descriptor, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newArgumentTableWithDescriptor_error_), descriptor, error); +} + +_MTL_INLINE MTL::BinaryArchive* MTL::Device::newBinaryArchive(const MTL::BinaryArchiveDescriptor* descriptor, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newBinaryArchiveWithDescriptor_error_), descriptor, error); +} + +_MTL_INLINE MTL::Buffer* MTL::Device::newBuffer(NS::UInteger length, MTL::ResourceOptions options) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newBufferWithLength_options_), length, options); +} + +_MTL_INLINE MTL::Buffer* MTL::Device::newBuffer(const void* pointer, NS::UInteger length, MTL::ResourceOptions options) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newBufferWithBytes_length_options_), pointer, length, options); +} + +_MTL_INLINE MTL::Buffer* MTL::Device::newBuffer(const void* pointer, NS::UInteger length, MTL::ResourceOptions options, void (^deallocator)(void*, NS::UInteger)) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newBufferWithBytesNoCopy_length_options_deallocator_), pointer, length, options, deallocator); +} + +_MTL_INLINE MTL::Buffer* MTL::Device::newBuffer(NS::UInteger length, MTL::ResourceOptions options, MTL::SparsePageSize placementSparsePageSize) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newBufferWithLength_options_placementSparsePageSize_), length, options, placementSparsePageSize); +} + +_MTL_INLINE MTL4::CommandAllocator* MTL::Device::newCommandAllocator() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newCommandAllocator)); +} + +_MTL_INLINE MTL4::CommandAllocator* MTL::Device::newCommandAllocator(const MTL4::CommandAllocatorDescriptor* descriptor, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newCommandAllocatorWithDescriptor_error_), descriptor, error); +} + +_MTL_INLINE MTL4::CommandBuffer* MTL::Device::newCommandBuffer() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newCommandBuffer)); +} + +_MTL_INLINE MTL::CommandQueue* MTL::Device::newCommandQueue() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newCommandQueue)); +} + +_MTL_INLINE MTL::CommandQueue* MTL::Device::newCommandQueue(NS::UInteger maxCommandBufferCount) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newCommandQueueWithMaxCommandBufferCount_), maxCommandBufferCount); +} + +_MTL_INLINE MTL::CommandQueue* MTL::Device::newCommandQueue(const MTL::CommandQueueDescriptor* descriptor) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newCommandQueueWithDescriptor_), descriptor); +} + +_MTL_INLINE MTL4::Compiler* MTL::Device::newCompiler(const MTL4::CompilerDescriptor* descriptor, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newCompilerWithDescriptor_error_), descriptor, error); +} + +_MTL_INLINE MTL::ComputePipelineState* MTL::Device::newComputePipelineState(const MTL::Function* computeFunction, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newComputePipelineStateWithFunction_error_), computeFunction, error); +} + +_MTL_INLINE MTL::ComputePipelineState* MTL::Device::newComputePipelineState(const MTL::Function* computeFunction, MTL::PipelineOption options, const MTL::AutoreleasedComputePipelineReflection* reflection, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newComputePipelineStateWithFunction_options_reflection_error_), computeFunction, options, reflection, error); +} + +_MTL_INLINE void MTL::Device::newComputePipelineState(const MTL::Function* computeFunction, const MTL::NewComputePipelineStateCompletionHandler completionHandler) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(newComputePipelineStateWithFunction_completionHandler_), computeFunction, completionHandler); +} + +_MTL_INLINE void MTL::Device::newComputePipelineState(const MTL::Function* computeFunction, MTL::PipelineOption options, const MTL::NewComputePipelineStateWithReflectionCompletionHandler completionHandler) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(newComputePipelineStateWithFunction_options_completionHandler_), computeFunction, options, completionHandler); +} + +_MTL_INLINE MTL::ComputePipelineState* MTL::Device::newComputePipelineState(const MTL::ComputePipelineDescriptor* descriptor, MTL::PipelineOption options, const MTL::AutoreleasedComputePipelineReflection* reflection, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newComputePipelineStateWithDescriptor_options_reflection_error_), descriptor, options, reflection, error); +} + +_MTL_INLINE void MTL::Device::newComputePipelineState(const MTL::ComputePipelineDescriptor* descriptor, MTL::PipelineOption options, const MTL::NewComputePipelineStateWithReflectionCompletionHandler completionHandler) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(newComputePipelineStateWithDescriptor_options_completionHandler_), descriptor, options, completionHandler); +} + +_MTL_INLINE void MTL::Device::newComputePipelineState(const MTL::Function* pFunction, const MTL::NewComputePipelineStateCompletionHandlerFunction& completionHandler) +{ + __block MTL::NewComputePipelineStateCompletionHandlerFunction blockCompletionHandler = completionHandler; + newComputePipelineState(pFunction, ^(MTL::ComputePipelineState* pPipelineState, NS::Error* pError) { blockCompletionHandler(pPipelineState, pError); }); +} + +_MTL_INLINE void MTL::Device::newComputePipelineState(const MTL::Function* pFunction, MTL::PipelineOption options, const MTL::NewComputePipelineStateWithReflectionCompletionHandlerFunction& completionHandler) +{ + __block MTL::NewComputePipelineStateWithReflectionCompletionHandlerFunction blockCompletionHandler = completionHandler; + newComputePipelineState(pFunction, options, ^(MTL::ComputePipelineState* pPipelineState, MTL::ComputePipelineReflection* pReflection, NS::Error* pError) { blockCompletionHandler(pPipelineState, pReflection, pError); }); +} + +_MTL_INLINE void MTL::Device::newComputePipelineState(const MTL::ComputePipelineDescriptor* pDescriptor, MTL::PipelineOption options, const MTL::NewComputePipelineStateWithReflectionCompletionHandlerFunction& completionHandler) +{ + __block NewComputePipelineStateWithReflectionCompletionHandlerFunction blockCompletionHandler = completionHandler; + newComputePipelineState(pDescriptor, options, ^(ComputePipelineState* pPipelineState, ComputePipelineReflection* pReflection, NS::Error* pError) { blockCompletionHandler(pPipelineState, pReflection, pError); }); +} + +_MTL_INLINE MTL4::CounterHeap* MTL::Device::newCounterHeap(const MTL4::CounterHeapDescriptor* descriptor, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newCounterHeapWithDescriptor_error_), descriptor, error); +} + +_MTL_INLINE MTL::CounterSampleBuffer* MTL::Device::newCounterSampleBuffer(const MTL::CounterSampleBufferDescriptor* descriptor, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newCounterSampleBufferWithDescriptor_error_), descriptor, error); +} + +_MTL_INLINE MTL::Library* MTL::Device::newDefaultLibrary() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newDefaultLibrary)); +} + +_MTL_INLINE MTL::Library* MTL::Device::newDefaultLibrary(const NS::Bundle* bundle, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newDefaultLibraryWithBundle_error_), bundle, error); +} + +_MTL_INLINE MTL::DepthStencilState* MTL::Device::newDepthStencilState(const MTL::DepthStencilDescriptor* descriptor) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newDepthStencilStateWithDescriptor_), descriptor); +} + +_MTL_INLINE MTL::DynamicLibrary* MTL::Device::newDynamicLibrary(const MTL::Library* library, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newDynamicLibrary_error_), library, error); +} + +_MTL_INLINE MTL::DynamicLibrary* MTL::Device::newDynamicLibrary(const NS::URL* url, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newDynamicLibraryWithURL_error_), url, error); +} + +_MTL_INLINE MTL::Event* MTL::Device::newEvent() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newEvent)); +} + +_MTL_INLINE MTL::Fence* MTL::Device::newFence() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newFence)); +} + +_MTL_INLINE MTL::Heap* MTL::Device::newHeap(const MTL::HeapDescriptor* descriptor) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newHeapWithDescriptor_), descriptor); +} + +_MTL_INLINE MTL::IOCommandQueue* MTL::Device::newIOCommandQueue(const MTL::IOCommandQueueDescriptor* descriptor, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newIOCommandQueueWithDescriptor_error_), descriptor, error); +} + +_MTL_INLINE MTL::IOFileHandle* MTL::Device::newIOFileHandle(const NS::URL* url, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newIOFileHandleWithURL_error_), url, error); +} + +_MTL_INLINE MTL::IOFileHandle* MTL::Device::newIOFileHandle(const NS::URL* url, MTL::IOCompressionMethod compressionMethod, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newIOFileHandleWithURL_compressionMethod_error_), url, compressionMethod, error); +} + +_MTL_INLINE MTL::IOFileHandle* MTL::Device::newIOHandle(const NS::URL* url, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newIOHandleWithURL_error_), url, error); +} + +_MTL_INLINE MTL::IOFileHandle* MTL::Device::newIOHandle(const NS::URL* url, MTL::IOCompressionMethod compressionMethod, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newIOHandleWithURL_compressionMethod_error_), url, compressionMethod, error); +} + +_MTL_INLINE MTL::IndirectCommandBuffer* MTL::Device::newIndirectCommandBuffer(const MTL::IndirectCommandBufferDescriptor* descriptor, NS::UInteger maxCount, MTL::ResourceOptions options) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newIndirectCommandBufferWithDescriptor_maxCommandCount_options_), descriptor, maxCount, options); +} + +_MTL_INLINE MTL::Library* MTL::Device::newLibrary(const NS::String* filepath, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newLibraryWithFile_error_), filepath, error); +} + +_MTL_INLINE MTL::Library* MTL::Device::newLibrary(const NS::URL* url, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newLibraryWithURL_error_), url, error); +} + +_MTL_INLINE MTL::Library* MTL::Device::newLibrary(const dispatch_data_t data, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newLibraryWithData_error_), data, error); +} + +_MTL_INLINE MTL::Library* MTL::Device::newLibrary(const NS::String* source, const MTL::CompileOptions* options, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newLibraryWithSource_options_error_), source, options, error); +} + +_MTL_INLINE void MTL::Device::newLibrary(const NS::String* source, const MTL::CompileOptions* options, const MTL::NewLibraryCompletionHandler completionHandler) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(newLibraryWithSource_options_completionHandler_), source, options, completionHandler); +} + +_MTL_INLINE MTL::Library* MTL::Device::newLibrary(const MTL::StitchedLibraryDescriptor* descriptor, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newLibraryWithStitchedDescriptor_error_), descriptor, error); +} + +_MTL_INLINE void MTL::Device::newLibrary(const MTL::StitchedLibraryDescriptor* descriptor, const MTL::NewLibraryCompletionHandler completionHandler) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(newLibraryWithStitchedDescriptor_completionHandler_), descriptor, completionHandler); +} + +_MTL_INLINE void MTL::Device::newLibrary(const NS::String* pSource, const MTL::CompileOptions* pOptions, const MTL::NewLibraryCompletionHandlerFunction& completionHandler) +{ + __block MTL::NewLibraryCompletionHandlerFunction blockCompletionHandler = completionHandler; + newLibrary(pSource, pOptions, ^(MTL::Library* pLibrary, NS::Error* pError) { blockCompletionHandler(pLibrary, pError); }); +} + +_MTL_INLINE void MTL::Device::newLibrary(const MTL::StitchedLibraryDescriptor* pDescriptor, const MTL::NewLibraryCompletionHandlerFunction& completionHandler) +{ + __block MTL::NewLibraryCompletionHandlerFunction blockCompletionHandler = completionHandler; + newLibrary(pDescriptor, ^(MTL::Library* pLibrary, NS::Error* pError) { blockCompletionHandler(pLibrary, pError); }); +} + +_MTL_INLINE MTL::LogState* MTL::Device::newLogState(const MTL::LogStateDescriptor* descriptor, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newLogStateWithDescriptor_error_), descriptor, error); +} + +_MTL_INLINE MTL4::CommandQueue* MTL::Device::newMTL4CommandQueue() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newMTL4CommandQueue)); +} + +_MTL_INLINE MTL4::CommandQueue* MTL::Device::newMTL4CommandQueue(const MTL4::CommandQueueDescriptor* descriptor, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newMTL4CommandQueueWithDescriptor_error_), descriptor, error); +} + +_MTL_INLINE MTL4::PipelineDataSetSerializer* MTL::Device::newPipelineDataSetSerializer(const MTL4::PipelineDataSetSerializerDescriptor* descriptor) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newPipelineDataSetSerializerWithDescriptor_), descriptor); +} + +_MTL_INLINE MTL::RasterizationRateMap* MTL::Device::newRasterizationRateMap(const MTL::RasterizationRateMapDescriptor* descriptor) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newRasterizationRateMapWithDescriptor_), descriptor); +} + +_MTL_INLINE MTL::RenderPipelineState* MTL::Device::newRenderPipelineState(const MTL::RenderPipelineDescriptor* descriptor, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newRenderPipelineStateWithDescriptor_error_), descriptor, error); +} + +_MTL_INLINE MTL::RenderPipelineState* MTL::Device::newRenderPipelineState(const MTL::RenderPipelineDescriptor* descriptor, MTL::PipelineOption options, const MTL::AutoreleasedRenderPipelineReflection* reflection, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newRenderPipelineStateWithDescriptor_options_reflection_error_), descriptor, options, reflection, error); +} + +_MTL_INLINE void MTL::Device::newRenderPipelineState(const MTL::RenderPipelineDescriptor* descriptor, const MTL::NewRenderPipelineStateCompletionHandler completionHandler) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(newRenderPipelineStateWithDescriptor_completionHandler_), descriptor, completionHandler); +} + +_MTL_INLINE void MTL::Device::newRenderPipelineState(const MTL::RenderPipelineDescriptor* descriptor, MTL::PipelineOption options, const MTL::NewRenderPipelineStateWithReflectionCompletionHandler completionHandler) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(newRenderPipelineStateWithDescriptor_options_completionHandler_), descriptor, options, completionHandler); +} + +_MTL_INLINE MTL::RenderPipelineState* MTL::Device::newRenderPipelineState(const MTL::TileRenderPipelineDescriptor* descriptor, MTL::PipelineOption options, const MTL::AutoreleasedRenderPipelineReflection* reflection, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newRenderPipelineStateWithTileDescriptor_options_reflection_error_), descriptor, options, reflection, error); +} + +_MTL_INLINE void MTL::Device::newRenderPipelineState(const MTL::TileRenderPipelineDescriptor* descriptor, MTL::PipelineOption options, const MTL::NewRenderPipelineStateWithReflectionCompletionHandler completionHandler) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(newRenderPipelineStateWithTileDescriptor_options_completionHandler_), descriptor, options, completionHandler); +} + +_MTL_INLINE MTL::RenderPipelineState* MTL::Device::newRenderPipelineState(const MTL::MeshRenderPipelineDescriptor* descriptor, MTL::PipelineOption options, const MTL::AutoreleasedRenderPipelineReflection* reflection, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newRenderPipelineStateWithMeshDescriptor_options_reflection_error_), descriptor, options, reflection, error); +} + +_MTL_INLINE void MTL::Device::newRenderPipelineState(const MTL::MeshRenderPipelineDescriptor* descriptor, MTL::PipelineOption options, const MTL::NewRenderPipelineStateWithReflectionCompletionHandler completionHandler) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(newRenderPipelineStateWithMeshDescriptor_options_completionHandler_), descriptor, options, completionHandler); +} + +_MTL_INLINE void MTL::Device::newRenderPipelineState(const MTL::RenderPipelineDescriptor* pDescriptor, const MTL::NewRenderPipelineStateCompletionHandlerFunction& completionHandler) +{ + __block MTL::NewRenderPipelineStateCompletionHandlerFunction blockCompletionHandler = completionHandler; + newRenderPipelineState(pDescriptor, ^(MTL::RenderPipelineState* pPipelineState, NS::Error* pError) { blockCompletionHandler(pPipelineState, pError); }); +} + +_MTL_INLINE void MTL::Device::newRenderPipelineState(const MTL::RenderPipelineDescriptor* pDescriptor, MTL::PipelineOption options, const MTL::NewRenderPipelineStateWithReflectionCompletionHandlerFunction& completionHandler) +{ + __block MTL::NewRenderPipelineStateWithReflectionCompletionHandlerFunction blockCompletionHandler = completionHandler; + newRenderPipelineState(pDescriptor, options, ^(MTL::RenderPipelineState* pPipelineState, MTL::RenderPipelineReflection* pReflection, NS::Error* pError) { blockCompletionHandler(pPipelineState, pReflection, pError); }); +} + +_MTL_INLINE void MTL::Device::newRenderPipelineState(const MTL::TileRenderPipelineDescriptor* pDescriptor, MTL::PipelineOption options, const MTL::NewRenderPipelineStateWithReflectionCompletionHandlerFunction& completionHandler) +{ + __block MTL::NewRenderPipelineStateWithReflectionCompletionHandlerFunction blockCompletionHandler = completionHandler; + newRenderPipelineState(pDescriptor, options, ^(MTL::RenderPipelineState* pPipelineState, MTL::RenderPipelineReflection* pReflection, NS::Error* pError) { blockCompletionHandler(pPipelineState, pReflection, pError); }); +} + +_MTL_INLINE MTL::ResidencySet* MTL::Device::newResidencySet(const MTL::ResidencySetDescriptor* desc, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newResidencySetWithDescriptor_error_), desc, error); +} + +_MTL_INLINE MTL::SamplerState* MTL::Device::newSamplerState(const MTL::SamplerDescriptor* descriptor) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newSamplerStateWithDescriptor_), descriptor); +} + +_MTL_INLINE MTL::SharedEvent* MTL::Device::newSharedEvent() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newSharedEvent)); +} + +_MTL_INLINE MTL::SharedEvent* MTL::Device::newSharedEvent(const MTL::SharedEventHandle* sharedEventHandle) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newSharedEventWithHandle_), sharedEventHandle); +} + +_MTL_INLINE MTL::Texture* MTL::Device::newSharedTexture(const MTL::TextureDescriptor* descriptor) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newSharedTextureWithDescriptor_), descriptor); +} + +_MTL_INLINE MTL::Texture* MTL::Device::newSharedTexture(const MTL::SharedTextureHandle* sharedHandle) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newSharedTextureWithHandle_), sharedHandle); +} + +_MTL_INLINE MTL::Tensor* MTL::Device::newTensor(const MTL::TensorDescriptor* descriptor, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newTensorWithDescriptor_error_), descriptor, error); +} + +_MTL_INLINE MTL::Texture* MTL::Device::newTexture(const MTL::TextureDescriptor* descriptor) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newTextureWithDescriptor_), descriptor); +} + +_MTL_INLINE MTL::Texture* MTL::Device::newTexture(const MTL::TextureDescriptor* descriptor, const IOSurfaceRef iosurface, NS::UInteger plane) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newTextureWithDescriptor_iosurface_plane_), descriptor, iosurface, plane); +} + +_MTL_INLINE MTL::TextureViewPool* MTL::Device::newTextureViewPool(const MTL::ResourceViewPoolDescriptor* descriptor, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newTextureViewPoolWithDescriptor_error_), descriptor, error); +} + +_MTL_INLINE uint32_t MTL::Device::peerCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(peerCount)); +} + +_MTL_INLINE uint64_t MTL::Device::peerGroupID() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(peerGroupID)); +} + +_MTL_INLINE uint32_t MTL::Device::peerIndex() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(peerIndex)); +} + +_MTL_INLINE bool MTL::Device::programmableSamplePositionsSupported() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(areProgrammableSamplePositionsSupported)); +} + +_MTL_INLINE uint64_t MTL::Device::queryTimestampFrequency() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(queryTimestampFrequency)); +} + +_MTL_INLINE bool MTL::Device::rasterOrderGroupsSupported() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(areRasterOrderGroupsSupported)); +} + +_MTL_INLINE MTL::ReadWriteTextureTier MTL::Device::readWriteTextureSupport() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(readWriteTextureSupport)); +} + +_MTL_INLINE uint64_t MTL::Device::recommendedMaxWorkingSetSize() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(recommendedMaxWorkingSetSize)); +} + +_MTL_INLINE uint64_t MTL::Device::registryID() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(registryID)); +} + +_MTL_INLINE bool MTL::Device::removable() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isRemovable)); +} + +_MTL_INLINE void MTL::Device::sampleTimestamps(MTL::Timestamp* cpuTimestamp, MTL::Timestamp* gpuTimestamp) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(sampleTimestamps_gpuTimestamp_), cpuTimestamp, gpuTimestamp); +} + +_MTL_INLINE void MTL::Device::setShouldMaximizeConcurrentCompilation(bool shouldMaximizeConcurrentCompilation) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setShouldMaximizeConcurrentCompilation_), shouldMaximizeConcurrentCompilation); +} + +_MTL_INLINE bool MTL::Device::shouldMaximizeConcurrentCompilation() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(shouldMaximizeConcurrentCompilation)); +} + +_MTL_INLINE NS::UInteger MTL::Device::sizeOfCounterHeapEntry(MTL4::CounterHeapType type) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(sizeOfCounterHeapEntry_), type); +} + +_MTL_INLINE MTL::Size MTL::Device::sparseTileSize(MTL::TextureType textureType, MTL::PixelFormat pixelFormat, NS::UInteger sampleCount) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(sparseTileSizeWithTextureType_pixelFormat_sampleCount_), textureType, pixelFormat, sampleCount); +} + +_MTL_INLINE MTL::Size MTL::Device::sparseTileSize(MTL::TextureType textureType, MTL::PixelFormat pixelFormat, NS::UInteger sampleCount, MTL::SparsePageSize sparsePageSize) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(sparseTileSizeWithTextureType_pixelFormat_sampleCount_sparsePageSize_), textureType, pixelFormat, sampleCount, sparsePageSize); +} + +_MTL_INLINE NS::UInteger MTL::Device::sparseTileSizeInBytes() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(sparseTileSizeInBytes)); +} + +_MTL_INLINE NS::UInteger MTL::Device::sparseTileSizeInBytes(MTL::SparsePageSize sparsePageSize) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(sparseTileSizeInBytesForSparsePageSize_), sparsePageSize); +} + +_MTL_INLINE bool MTL::Device::supports32BitFloatFiltering() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supports32BitFloatFiltering)); +} + +_MTL_INLINE bool MTL::Device::supports32BitMSAA() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supports32BitMSAA)); +} + +_MTL_INLINE bool MTL::Device::supportsBCTextureCompression() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportsBCTextureCompression)); +} + +_MTL_INLINE bool MTL::Device::supportsCounterSampling(MTL::CounterSamplingPoint samplingPoint) +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportsCounterSampling_), samplingPoint); +} + +_MTL_INLINE bool MTL::Device::supportsDynamicLibraries() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportsDynamicLibraries)); +} + +_MTL_INLINE bool MTL::Device::supportsFamily(MTL::GPUFamily gpuFamily) +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportsFamily_), gpuFamily); +} + +_MTL_INLINE bool MTL::Device::supportsFeatureSet(MTL::FeatureSet featureSet) +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportsFeatureSet_), featureSet); +} + +_MTL_INLINE bool MTL::Device::supportsFunctionPointers() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportsFunctionPointers)); +} + +_MTL_INLINE bool MTL::Device::supportsFunctionPointersFromRender() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportsFunctionPointersFromRender)); +} + +_MTL_INLINE bool MTL::Device::supportsPrimitiveMotionBlur() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportsPrimitiveMotionBlur)); +} + +_MTL_INLINE bool MTL::Device::supportsPullModelInterpolation() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportsPullModelInterpolation)); +} + +_MTL_INLINE bool MTL::Device::supportsQueryTextureLOD() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportsQueryTextureLOD)); +} + +_MTL_INLINE bool MTL::Device::supportsRasterizationRateMap(NS::UInteger layerCount) +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportsRasterizationRateMapWithLayerCount_), layerCount); +} + +_MTL_INLINE bool MTL::Device::supportsRaytracing() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportsRaytracing)); +} + +_MTL_INLINE bool MTL::Device::supportsRaytracingFromRender() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportsRaytracingFromRender)); +} + +_MTL_INLINE bool MTL::Device::supportsRenderDynamicLibraries() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportsRenderDynamicLibraries)); +} + +_MTL_INLINE bool MTL::Device::supportsShaderBarycentricCoordinates() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportsShaderBarycentricCoordinates)); +} + +_MTL_INLINE bool MTL::Device::supportsTextureSampleCount(NS::UInteger sampleCount) +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportsTextureSampleCount_), sampleCount); +} + +_MTL_INLINE bool MTL::Device::supportsVertexAmplificationCount(NS::UInteger count) +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportsVertexAmplificationCount_), count); +} + +_MTL_INLINE MTL::SizeAndAlign MTL::Device::tensorSizeAndAlign(const MTL::TensorDescriptor* descriptor) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(tensorSizeAndAlignWithDescriptor_), descriptor); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLDrawable.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLDrawable.hpp new file mode 100644 index 00000000..fad4feda --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLDrawable.hpp @@ -0,0 +1,90 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLDrawable.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include + +#include +#include + +namespace MTL +{ +class Drawable; + +using DrawablePresentedHandler = void (^)(MTL::Drawable*); +using DrawablePresentedHandlerFunction = std::function; + +class Drawable : public NS::Referencing +{ +public: + void addPresentedHandler(const MTL::DrawablePresentedHandler block); + void addPresentedHandler(const MTL::DrawablePresentedHandlerFunction& function); + + NS::UInteger drawableID() const; + + void present(); + void presentAfterMinimumDuration(CFTimeInterval duration); + + void presentAtTime(CFTimeInterval presentationTime); + + CFTimeInterval presentedTime() const; +}; + +} +_MTL_INLINE void MTL::Drawable::addPresentedHandler(const MTL::DrawablePresentedHandler block) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(addPresentedHandler_), block); +} + +_MTL_INLINE void MTL::Drawable::addPresentedHandler(const MTL::DrawablePresentedHandlerFunction& function) +{ + __block DrawablePresentedHandlerFunction blockFunction = function; + addPresentedHandler(^(Drawable* pDrawable) { blockFunction(pDrawable); }); +} + +_MTL_INLINE NS::UInteger MTL::Drawable::drawableID() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(drawableID)); +} + +_MTL_INLINE void MTL::Drawable::present() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(present)); +} + +_MTL_INLINE void MTL::Drawable::presentAfterMinimumDuration(CFTimeInterval duration) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(presentAfterMinimumDuration_), duration); +} + +_MTL_INLINE void MTL::Drawable::presentAtTime(CFTimeInterval presentationTime) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(presentAtTime_), presentationTime); +} + +_MTL_INLINE CFTimeInterval MTL::Drawable::presentedTime() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(presentedTime)); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLDynamicLibrary.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLDynamicLibrary.hpp new file mode 100644 index 00000000..0726acc1 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLDynamicLibrary.hpp @@ -0,0 +1,78 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLDynamicLibrary.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL +{ +class Device; +_MTL_ENUM(NS::UInteger, DynamicLibraryError) { + DynamicLibraryErrorNone = 0, + DynamicLibraryErrorInvalidFile = 1, + DynamicLibraryErrorCompilationFailure = 2, + DynamicLibraryErrorUnresolvedInstallName = 3, + DynamicLibraryErrorDependencyLoadFailure = 4, + DynamicLibraryErrorUnsupported = 5, +}; + +class DynamicLibrary : public NS::Referencing +{ +public: + Device* device() const; + + NS::String* installName() const; + + NS::String* label() const; + + bool serializeToURL(const NS::URL* url, NS::Error** error); + + void setLabel(const NS::String* label); +}; + +} +_MTL_INLINE MTL::Device* MTL::DynamicLibrary::device() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(device)); +} + +_MTL_INLINE NS::String* MTL::DynamicLibrary::installName() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(installName)); +} + +_MTL_INLINE NS::String* MTL::DynamicLibrary::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE bool MTL::DynamicLibrary::serializeToURL(const NS::URL* url, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(serializeToURL_error_), url, error); +} + +_MTL_INLINE void MTL::DynamicLibrary::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLEvent.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLEvent.hpp new file mode 100644 index 00000000..d06b9693 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLEvent.hpp @@ -0,0 +1,170 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLEvent.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include +#include + +#include +#include + +namespace MTL +{ +class Device; +class SharedEvent; +class SharedEventHandle; +class SharedEventListener; + +using SharedEventNotificationBlock = void (^)(SharedEvent* pEvent, std::uint64_t value); +using SharedEventNotificationFunction = std::function; + +class Event : public NS::Referencing +{ +public: + Device* device() const; + + NS::String* label() const; + void setLabel(const NS::String* label); +}; +class SharedEventListener : public NS::Referencing +{ +public: + static SharedEventListener* alloc(); + + dispatch_queue_t dispatchQueue() const; + + SharedEventListener* init(); + SharedEventListener* init(const dispatch_queue_t dispatchQueue); + + static SharedEventListener* sharedListener(); +}; +class SharedEvent : public NS::Referencing +{ +public: + SharedEventHandle* newSharedEventHandle(); + + void notifyListener(const MTL::SharedEventListener* listener, uint64_t value, const MTL::SharedEventNotificationBlock block); + void notifyListener(const MTL::SharedEventListener* listener, uint64_t value, const MTL::SharedEventNotificationFunction& function); + + void setSignaledValue(uint64_t signaledValue); + uint64_t signaledValue() const; + bool waitUntilSignaledValue(uint64_t value, uint64_t milliseconds); +}; +class SharedEventHandle : public NS::SecureCoding +{ +public: + static SharedEventHandle* alloc(); + + SharedEventHandle* init(); + + NS::String* label() const; +}; + +} +_MTL_INLINE MTL::Device* MTL::Event::device() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(device)); +} + +_MTL_INLINE NS::String* MTL::Event::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE void MTL::Event::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +_MTL_INLINE MTL::SharedEventListener* MTL::SharedEventListener::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLSharedEventListener)); +} + +_MTL_INLINE dispatch_queue_t MTL::SharedEventListener::dispatchQueue() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(dispatchQueue)); +} + +_MTL_INLINE MTL::SharedEventListener* MTL::SharedEventListener::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::SharedEventListener* MTL::SharedEventListener::init(const dispatch_queue_t dispatchQueue) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(initWithDispatchQueue_), dispatchQueue); +} + +_MTL_INLINE MTL::SharedEventListener* MTL::SharedEventListener::sharedListener() +{ + return Object::sendMessage(_MTL_PRIVATE_CLS(MTLSharedEventListener), _MTL_PRIVATE_SEL(sharedListener)); +} + +_MTL_INLINE MTL::SharedEventHandle* MTL::SharedEvent::newSharedEventHandle() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newSharedEventHandle)); +} + +_MTL_INLINE void MTL::SharedEvent::notifyListener(const MTL::SharedEventListener* listener, uint64_t value, const MTL::SharedEventNotificationBlock block) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(notifyListener_atValue_block_), listener, value, block); +} + +_MTL_INLINE void MTL::SharedEvent::notifyListener(const MTL::SharedEventListener* listener, uint64_t value, const MTL::SharedEventNotificationFunction& function) +{ + __block MTL::SharedEventNotificationFunction callback = function; + notifyListener(listener, value, ^void(SharedEvent* pEvent, std::uint64_t innerValue) { callback(pEvent, innerValue); }); +} + +_MTL_INLINE void MTL::SharedEvent::setSignaledValue(uint64_t signaledValue) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSignaledValue_), signaledValue); +} + +_MTL_INLINE uint64_t MTL::SharedEvent::signaledValue() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(signaledValue)); +} + +_MTL_INLINE bool MTL::SharedEvent::waitUntilSignaledValue(uint64_t value, uint64_t milliseconds) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(waitUntilSignaledValue_timeoutMS_), value, milliseconds); +} + +_MTL_INLINE MTL::SharedEventHandle* MTL::SharedEventHandle::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLSharedEventHandle)); +} + +_MTL_INLINE MTL::SharedEventHandle* MTL::SharedEventHandle::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::String* MTL::SharedEventHandle::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLFence.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLFence.hpp new file mode 100644 index 00000000..f31df4ce --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLFence.hpp @@ -0,0 +1,55 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLFence.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL +{ +class Device; + +class Fence : public NS::Referencing +{ +public: + Device* device() const; + + NS::String* label() const; + void setLabel(const NS::String* label); +}; + +} +_MTL_INLINE MTL::Device* MTL::Fence::device() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(device)); +} + +_MTL_INLINE NS::String* MTL::Fence::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE void MTL::Fence::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLFunctionConstantValues.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLFunctionConstantValues.hpp new file mode 100644 index 00000000..dce89d15 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLFunctionConstantValues.hpp @@ -0,0 +1,76 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLFunctionConstantValues.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDataType.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL +{ +class FunctionConstantValues; + +class FunctionConstantValues : public NS::Copying +{ +public: + static FunctionConstantValues* alloc(); + + FunctionConstantValues* init(); + + void reset(); + + void setConstantValue(const void* value, MTL::DataType type, NS::UInteger index); + void setConstantValue(const void* value, MTL::DataType type, const NS::String* name); + void setConstantValues(const void* values, MTL::DataType type, NS::Range range); +}; + +} +_MTL_INLINE MTL::FunctionConstantValues* MTL::FunctionConstantValues::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLFunctionConstantValues)); +} + +_MTL_INLINE MTL::FunctionConstantValues* MTL::FunctionConstantValues::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE void MTL::FunctionConstantValues::reset() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(reset)); +} + +_MTL_INLINE void MTL::FunctionConstantValues::setConstantValue(const void* value, MTL::DataType type, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setConstantValue_type_atIndex_), value, type, index); +} + +_MTL_INLINE void MTL::FunctionConstantValues::setConstantValue(const void* value, MTL::DataType type, const NS::String* name) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setConstantValue_type_withName_), value, type, name); +} + +_MTL_INLINE void MTL::FunctionConstantValues::setConstantValues(const void* values, MTL::DataType type, NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setConstantValues_type_withRange_), values, type, range); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLFunctionDescriptor.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLFunctionDescriptor.hpp new file mode 100644 index 00000000..aa296b5c --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLFunctionDescriptor.hpp @@ -0,0 +1,153 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLFunctionDescriptor.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL +{ +class FunctionConstantValues; +class FunctionDescriptor; +class IntersectionFunctionDescriptor; + +_MTL_OPTIONS(NS::UInteger, FunctionOptions) { + FunctionOptionNone = 0, + FunctionOptionCompileToBinary = 1, + FunctionOptionStoreFunctionInMetalPipelinesScript = 1 << 1, + FunctionOptionStoreFunctionInMetalScript = 1 << 1, + FunctionOptionFailOnBinaryArchiveMiss = 1 << 2, + FunctionOptionPipelineIndependent = 1 << 3, +}; + +class FunctionDescriptor : public NS::Copying +{ +public: + static FunctionDescriptor* alloc(); + + NS::Array* binaryArchives() const; + + FunctionConstantValues* constantValues() const; + + static FunctionDescriptor* functionDescriptor(); + + FunctionDescriptor* init(); + + NS::String* name() const; + + FunctionOptions options() const; + + void setBinaryArchives(const NS::Array* binaryArchives); + + void setConstantValues(const MTL::FunctionConstantValues* constantValues); + + void setName(const NS::String* name); + + void setOptions(MTL::FunctionOptions options); + + void setSpecializedName(const NS::String* specializedName); + NS::String* specializedName() const; +}; +class IntersectionFunctionDescriptor : public NS::Copying +{ +public: + static IntersectionFunctionDescriptor* alloc(); + + IntersectionFunctionDescriptor* init(); +}; + +} +_MTL_INLINE MTL::FunctionDescriptor* MTL::FunctionDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLFunctionDescriptor)); +} + +_MTL_INLINE NS::Array* MTL::FunctionDescriptor::binaryArchives() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(binaryArchives)); +} + +_MTL_INLINE MTL::FunctionConstantValues* MTL::FunctionDescriptor::constantValues() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(constantValues)); +} + +_MTL_INLINE MTL::FunctionDescriptor* MTL::FunctionDescriptor::functionDescriptor() +{ + return Object::sendMessage(_MTL_PRIVATE_CLS(MTLFunctionDescriptor), _MTL_PRIVATE_SEL(functionDescriptor)); +} + +_MTL_INLINE MTL::FunctionDescriptor* MTL::FunctionDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::String* MTL::FunctionDescriptor::name() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(name)); +} + +_MTL_INLINE MTL::FunctionOptions MTL::FunctionDescriptor::options() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(options)); +} + +_MTL_INLINE void MTL::FunctionDescriptor::setBinaryArchives(const NS::Array* binaryArchives) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBinaryArchives_), binaryArchives); +} + +_MTL_INLINE void MTL::FunctionDescriptor::setConstantValues(const MTL::FunctionConstantValues* constantValues) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setConstantValues_), constantValues); +} + +_MTL_INLINE void MTL::FunctionDescriptor::setName(const NS::String* name) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setName_), name); +} + +_MTL_INLINE void MTL::FunctionDescriptor::setOptions(MTL::FunctionOptions options) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setOptions_), options); +} + +_MTL_INLINE void MTL::FunctionDescriptor::setSpecializedName(const NS::String* specializedName) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSpecializedName_), specializedName); +} + +_MTL_INLINE NS::String* MTL::FunctionDescriptor::specializedName() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(specializedName)); +} + +_MTL_INLINE MTL::IntersectionFunctionDescriptor* MTL::IntersectionFunctionDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLIntersectionFunctionDescriptor)); +} + +_MTL_INLINE MTL::IntersectionFunctionDescriptor* MTL::IntersectionFunctionDescriptor::init() +{ + return NS::Object::init(); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLFunctionHandle.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLFunctionHandle.hpp new file mode 100644 index 00000000..7a3ff95d --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLFunctionHandle.hpp @@ -0,0 +1,65 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLFunctionHandle.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLLibrary.hpp" +#include "MTLPrivate.hpp" +#include "MTLTypes.hpp" + +namespace MTL +{ +class Device; + +class FunctionHandle : public NS::Referencing +{ +public: + Device* device() const; + + FunctionType functionType() const; + + ResourceID gpuResourceID() const; + + NS::String* name() const; +}; + +} +_MTL_INLINE MTL::Device* MTL::FunctionHandle::device() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(device)); +} + +_MTL_INLINE MTL::FunctionType MTL::FunctionHandle::functionType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(functionType)); +} + +_MTL_INLINE MTL::ResourceID MTL::FunctionHandle::gpuResourceID() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(gpuResourceID)); +} + +_MTL_INLINE NS::String* MTL::FunctionHandle::name() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(name)); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLFunctionLog.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLFunctionLog.hpp new file mode 100644 index 00000000..454e6058 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLFunctionLog.hpp @@ -0,0 +1,101 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLFunctionLog.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL +{ +class Function; +class FunctionLogDebugLocation; +_MTL_ENUM(NS::UInteger, FunctionLogType) { + FunctionLogTypeValidation = 0, +}; + +class LogContainer : public NS::Referencing +{ +}; +class FunctionLogDebugLocation : public NS::Referencing +{ +public: + NS::URL* URL() const; + + NS::UInteger column() const; + + NS::String* functionName() const; + + NS::UInteger line() const; +}; +class FunctionLog : public NS::Referencing +{ +public: + FunctionLogDebugLocation* debugLocation() const; + + NS::String* encoderLabel() const; + + Function* function() const; + + FunctionLogType type() const; +}; + +} +_MTL_INLINE NS::URL* MTL::FunctionLogDebugLocation::URL() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(URL)); +} + +_MTL_INLINE NS::UInteger MTL::FunctionLogDebugLocation::column() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(column)); +} + +_MTL_INLINE NS::String* MTL::FunctionLogDebugLocation::functionName() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(functionName)); +} + +_MTL_INLINE NS::UInteger MTL::FunctionLogDebugLocation::line() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(line)); +} + +_MTL_INLINE MTL::FunctionLogDebugLocation* MTL::FunctionLog::debugLocation() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(debugLocation)); +} + +_MTL_INLINE NS::String* MTL::FunctionLog::encoderLabel() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(encoderLabel)); +} + +_MTL_INLINE MTL::Function* MTL::FunctionLog::function() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(function)); +} + +_MTL_INLINE MTL::FunctionLogType MTL::FunctionLog::type() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(type)); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLFunctionStitching.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLFunctionStitching.hpp new file mode 100644 index 00000000..8dd5fd29 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLFunctionStitching.hpp @@ -0,0 +1,319 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLFunctionStitching.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL +{ +class FunctionStitchingAttributeAlwaysInline; +class FunctionStitchingFunctionNode; +class FunctionStitchingGraph; +class FunctionStitchingInputNode; +class StitchedLibraryDescriptor; + +_MTL_OPTIONS(NS::UInteger, StitchedLibraryOptions) { + StitchedLibraryOptionNone = 0, + StitchedLibraryOptionFailOnBinaryArchiveMiss = 1, + StitchedLibraryOptionStoreLibraryInMetalPipelinesScript = 1 << 1, +}; + +class FunctionStitchingAttribute : public NS::Referencing +{ +}; +class FunctionStitchingAttributeAlwaysInline : public NS::Referencing +{ +public: + static FunctionStitchingAttributeAlwaysInline* alloc(); + + FunctionStitchingAttributeAlwaysInline* init(); +}; +class FunctionStitchingNode : public NS::Copying +{ +}; +class FunctionStitchingInputNode : public NS::Referencing +{ +public: + static FunctionStitchingInputNode* alloc(); + + NS::UInteger argumentIndex() const; + + FunctionStitchingInputNode* init(); + FunctionStitchingInputNode* init(NS::UInteger argument); + + void setArgumentIndex(NS::UInteger argumentIndex); +}; +class FunctionStitchingFunctionNode : public NS::Referencing +{ +public: + static FunctionStitchingFunctionNode* alloc(); + + NS::Array* arguments() const; + + NS::Array* controlDependencies() const; + + FunctionStitchingFunctionNode* init(); + FunctionStitchingFunctionNode* init(const NS::String* name, const NS::Array* arguments, const NS::Array* controlDependencies); + + NS::String* name() const; + + void setArguments(const NS::Array* arguments); + + void setControlDependencies(const NS::Array* controlDependencies); + + void setName(const NS::String* name); +}; +class FunctionStitchingGraph : public NS::Copying +{ +public: + static FunctionStitchingGraph* alloc(); + + NS::Array* attributes() const; + + NS::String* functionName() const; + + FunctionStitchingGraph* init(); + FunctionStitchingGraph* init(const NS::String* functionName, const NS::Array* nodes, const MTL::FunctionStitchingFunctionNode* outputNode, const NS::Array* attributes); + + NS::Array* nodes() const; + + FunctionStitchingFunctionNode* outputNode() const; + + void setAttributes(const NS::Array* attributes); + + void setFunctionName(const NS::String* functionName); + + void setNodes(const NS::Array* nodes); + + void setOutputNode(const MTL::FunctionStitchingFunctionNode* outputNode); +}; +class StitchedLibraryDescriptor : public NS::Copying +{ +public: + static StitchedLibraryDescriptor* alloc(); + + NS::Array* binaryArchives() const; + + NS::Array* functionGraphs() const; + + NS::Array* functions() const; + + StitchedLibraryDescriptor* init(); + + StitchedLibraryOptions options() const; + + void setBinaryArchives(const NS::Array* binaryArchives); + + void setFunctionGraphs(const NS::Array* functionGraphs); + + void setFunctions(const NS::Array* functions); + + void setOptions(MTL::StitchedLibraryOptions options); +}; + +} +_MTL_INLINE MTL::FunctionStitchingAttributeAlwaysInline* MTL::FunctionStitchingAttributeAlwaysInline::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLFunctionStitchingAttributeAlwaysInline)); +} + +_MTL_INLINE MTL::FunctionStitchingAttributeAlwaysInline* MTL::FunctionStitchingAttributeAlwaysInline::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::FunctionStitchingInputNode* MTL::FunctionStitchingInputNode::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLFunctionStitchingInputNode)); +} + +_MTL_INLINE NS::UInteger MTL::FunctionStitchingInputNode::argumentIndex() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(argumentIndex)); +} + +_MTL_INLINE MTL::FunctionStitchingInputNode* MTL::FunctionStitchingInputNode::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::FunctionStitchingInputNode* MTL::FunctionStitchingInputNode::init(NS::UInteger argument) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(initWithArgumentIndex_), argument); +} + +_MTL_INLINE void MTL::FunctionStitchingInputNode::setArgumentIndex(NS::UInteger argumentIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setArgumentIndex_), argumentIndex); +} + +_MTL_INLINE MTL::FunctionStitchingFunctionNode* MTL::FunctionStitchingFunctionNode::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLFunctionStitchingFunctionNode)); +} + +_MTL_INLINE NS::Array* MTL::FunctionStitchingFunctionNode::arguments() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(arguments)); +} + +_MTL_INLINE NS::Array* MTL::FunctionStitchingFunctionNode::controlDependencies() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(controlDependencies)); +} + +_MTL_INLINE MTL::FunctionStitchingFunctionNode* MTL::FunctionStitchingFunctionNode::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::FunctionStitchingFunctionNode* MTL::FunctionStitchingFunctionNode::init(const NS::String* name, const NS::Array* arguments, const NS::Array* controlDependencies) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(initWithName_arguments_controlDependencies_), name, arguments, controlDependencies); +} + +_MTL_INLINE NS::String* MTL::FunctionStitchingFunctionNode::name() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(name)); +} + +_MTL_INLINE void MTL::FunctionStitchingFunctionNode::setArguments(const NS::Array* arguments) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setArguments_), arguments); +} + +_MTL_INLINE void MTL::FunctionStitchingFunctionNode::setControlDependencies(const NS::Array* controlDependencies) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setControlDependencies_), controlDependencies); +} + +_MTL_INLINE void MTL::FunctionStitchingFunctionNode::setName(const NS::String* name) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setName_), name); +} + +_MTL_INLINE MTL::FunctionStitchingGraph* MTL::FunctionStitchingGraph::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLFunctionStitchingGraph)); +} + +_MTL_INLINE NS::Array* MTL::FunctionStitchingGraph::attributes() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(attributes)); +} + +_MTL_INLINE NS::String* MTL::FunctionStitchingGraph::functionName() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(functionName)); +} + +_MTL_INLINE MTL::FunctionStitchingGraph* MTL::FunctionStitchingGraph::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::FunctionStitchingGraph* MTL::FunctionStitchingGraph::init(const NS::String* functionName, const NS::Array* nodes, const MTL::FunctionStitchingFunctionNode* outputNode, const NS::Array* attributes) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(initWithFunctionName_nodes_outputNode_attributes_), functionName, nodes, outputNode, attributes); +} + +_MTL_INLINE NS::Array* MTL::FunctionStitchingGraph::nodes() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(nodes)); +} + +_MTL_INLINE MTL::FunctionStitchingFunctionNode* MTL::FunctionStitchingGraph::outputNode() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(outputNode)); +} + +_MTL_INLINE void MTL::FunctionStitchingGraph::setAttributes(const NS::Array* attributes) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setAttributes_), attributes); +} + +_MTL_INLINE void MTL::FunctionStitchingGraph::setFunctionName(const NS::String* functionName) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFunctionName_), functionName); +} + +_MTL_INLINE void MTL::FunctionStitchingGraph::setNodes(const NS::Array* nodes) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setNodes_), nodes); +} + +_MTL_INLINE void MTL::FunctionStitchingGraph::setOutputNode(const MTL::FunctionStitchingFunctionNode* outputNode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setOutputNode_), outputNode); +} + +_MTL_INLINE MTL::StitchedLibraryDescriptor* MTL::StitchedLibraryDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLStitchedLibraryDescriptor)); +} + +_MTL_INLINE NS::Array* MTL::StitchedLibraryDescriptor::binaryArchives() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(binaryArchives)); +} + +_MTL_INLINE NS::Array* MTL::StitchedLibraryDescriptor::functionGraphs() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(functionGraphs)); +} + +_MTL_INLINE NS::Array* MTL::StitchedLibraryDescriptor::functions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(functions)); +} + +_MTL_INLINE MTL::StitchedLibraryDescriptor* MTL::StitchedLibraryDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::StitchedLibraryOptions MTL::StitchedLibraryDescriptor::options() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(options)); +} + +_MTL_INLINE void MTL::StitchedLibraryDescriptor::setBinaryArchives(const NS::Array* binaryArchives) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBinaryArchives_), binaryArchives); +} + +_MTL_INLINE void MTL::StitchedLibraryDescriptor::setFunctionGraphs(const NS::Array* functionGraphs) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFunctionGraphs_), functionGraphs); +} + +_MTL_INLINE void MTL::StitchedLibraryDescriptor::setFunctions(const NS::Array* functions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFunctions_), functions); +} + +_MTL_INLINE void MTL::StitchedLibraryDescriptor::setOptions(MTL::StitchedLibraryOptions options) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setOptions_), options); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLGPUAddress.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLGPUAddress.hpp new file mode 100644 index 00000000..fb9d61d5 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLGPUAddress.hpp @@ -0,0 +1,36 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLGPUAddress.hpp +// +// Copyright 2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#ifdef __METAL_VERSION__ + +#include + +#else + +#include + +#endif // __METAL_VERSION__ + +namespace MTL +{ + using GPUAddress = uint64_t; +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLHeaderBridge.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLHeaderBridge.hpp new file mode 100644 index 00000000..6a3a1422 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLHeaderBridge.hpp @@ -0,0 +1,3120 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLHeaderBridge.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once +#include "MTLPrivate.hpp" + +namespace MTL::Private::Class +{ + +_MTL_PRIVATE_DEF_CLS(MTL4AccelerationStructureBoundingBoxGeometryDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4AccelerationStructureCurveGeometryDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4AccelerationStructureDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4AccelerationStructureGeometryDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4AccelerationStructureMotionBoundingBoxGeometryDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4AccelerationStructureMotionCurveGeometryDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4AccelerationStructureMotionTriangleGeometryDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4AccelerationStructureTriangleGeometryDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4ArgumentTableDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4BinaryFunctionDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4CommandAllocatorDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4CommandBufferOptions); +_MTL_PRIVATE_DEF_CLS(MTL4CommandQueueDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4CommitOptions); +_MTL_PRIVATE_DEF_CLS(MTL4CompilerDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4CompilerTaskOptions); +_MTL_PRIVATE_DEF_CLS(MTL4ComputePipelineDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4CounterHeapDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4FunctionDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4IndirectInstanceAccelerationStructureDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4InstanceAccelerationStructureDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4LibraryDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4LibraryFunctionDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4MachineLearningPipelineDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4MachineLearningPipelineReflection); +_MTL_PRIVATE_DEF_CLS(MTL4MeshRenderPipelineDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4PipelineDataSetSerializerDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4PipelineDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4PipelineOptions); +_MTL_PRIVATE_DEF_CLS(MTL4PipelineStageDynamicLinkingDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4PrimitiveAccelerationStructureDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4RenderPassDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4RenderPipelineBinaryFunctionsDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4RenderPipelineColorAttachmentDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4RenderPipelineColorAttachmentDescriptorArray); +_MTL_PRIVATE_DEF_CLS(MTL4RenderPipelineDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4RenderPipelineDynamicLinkingDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4SpecializedFunctionDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4StaticLinkingDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4StitchedFunctionDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4TileRenderPipelineDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLAccelerationStructureBoundingBoxGeometryDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLAccelerationStructureCurveGeometryDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLAccelerationStructureDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLAccelerationStructureGeometryDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLAccelerationStructureMotionBoundingBoxGeometryDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLAccelerationStructureMotionCurveGeometryDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLAccelerationStructureMotionTriangleGeometryDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLAccelerationStructurePassDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLAccelerationStructurePassSampleBufferAttachmentDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLAccelerationStructurePassSampleBufferAttachmentDescriptorArray); +_MTL_PRIVATE_DEF_CLS(MTLAccelerationStructureTriangleGeometryDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLArchitecture); +_MTL_PRIVATE_DEF_CLS(MTLArgument); +_MTL_PRIVATE_DEF_CLS(MTLArgumentDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLArrayType); +_MTL_PRIVATE_DEF_CLS(MTLAttribute); +_MTL_PRIVATE_DEF_CLS(MTLAttributeDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLAttributeDescriptorArray); +_MTL_PRIVATE_DEF_CLS(MTLBinaryArchiveDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLBlitPassDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLBlitPassSampleBufferAttachmentDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLBlitPassSampleBufferAttachmentDescriptorArray); +_MTL_PRIVATE_DEF_CLS(MTLBufferLayoutDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLBufferLayoutDescriptorArray); +_MTL_PRIVATE_DEF_CLS(MTLCaptureDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLCaptureManager); +_MTL_PRIVATE_DEF_CLS(MTLCommandBufferDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLCommandQueueDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLCompileOptions); +_MTL_PRIVATE_DEF_CLS(MTLComputePassDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLComputePassSampleBufferAttachmentDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLComputePassSampleBufferAttachmentDescriptorArray); +_MTL_PRIVATE_DEF_CLS(MTLComputePipelineDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLComputePipelineReflection); +_MTL_PRIVATE_DEF_CLS(MTLCounterSampleBufferDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLDepthStencilDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLFunctionConstant); +_MTL_PRIVATE_DEF_CLS(MTLFunctionConstantValues); +_MTL_PRIVATE_DEF_CLS(MTLFunctionDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLFunctionReflection); +_MTL_PRIVATE_DEF_CLS(MTLFunctionStitchingAttributeAlwaysInline); +_MTL_PRIVATE_DEF_CLS(MTLFunctionStitchingFunctionNode); +_MTL_PRIVATE_DEF_CLS(MTLFunctionStitchingGraph); +_MTL_PRIVATE_DEF_CLS(MTLFunctionStitchingInputNode); +_MTL_PRIVATE_DEF_CLS(MTLHeapDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLIOCommandQueueDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLIndirectCommandBufferDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLIndirectInstanceAccelerationStructureDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLInstanceAccelerationStructureDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLIntersectionFunctionDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLIntersectionFunctionTableDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLLinkedFunctions); +_MTL_PRIVATE_DEF_CLS(MTLLogStateDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLLogicalToPhysicalColorAttachmentMap); +_MTL_PRIVATE_DEF_CLS(MTLMeshRenderPipelineDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLMotionKeyframeData); +_MTL_PRIVATE_DEF_CLS(MTLPipelineBufferDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLPipelineBufferDescriptorArray); +_MTL_PRIVATE_DEF_CLS(MTLPointerType); +_MTL_PRIVATE_DEF_CLS(MTLPrimitiveAccelerationStructureDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLRasterizationRateLayerArray); +_MTL_PRIVATE_DEF_CLS(MTLRasterizationRateLayerDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLRasterizationRateMapDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLRasterizationRateSampleArray); +_MTL_PRIVATE_DEF_CLS(MTLRenderPassAttachmentDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLRenderPassColorAttachmentDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLRenderPassColorAttachmentDescriptorArray); +_MTL_PRIVATE_DEF_CLS(MTLRenderPassDepthAttachmentDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLRenderPassDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLRenderPassSampleBufferAttachmentDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLRenderPassSampleBufferAttachmentDescriptorArray); +_MTL_PRIVATE_DEF_CLS(MTLRenderPassStencilAttachmentDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLRenderPipelineColorAttachmentDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLRenderPipelineColorAttachmentDescriptorArray); +_MTL_PRIVATE_DEF_CLS(MTLRenderPipelineDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLRenderPipelineFunctionsDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLRenderPipelineReflection); +_MTL_PRIVATE_DEF_CLS(MTLResidencySetDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLResourceStatePassDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLResourceStatePassSampleBufferAttachmentDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLResourceStatePassSampleBufferAttachmentDescriptorArray); +_MTL_PRIVATE_DEF_CLS(MTLResourceViewPoolDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLSamplerDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLSharedEventHandle); +_MTL_PRIVATE_DEF_CLS(MTLSharedEventListener); +_MTL_PRIVATE_DEF_CLS(MTLSharedTextureHandle); +_MTL_PRIVATE_DEF_CLS(MTLStageInputOutputDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLStencilDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLStitchedLibraryDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLStructMember); +_MTL_PRIVATE_DEF_CLS(MTLStructType); +_MTL_PRIVATE_DEF_CLS(MTLTensorDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLTensorExtents); +_MTL_PRIVATE_DEF_CLS(MTLTensorReferenceType); +_MTL_PRIVATE_DEF_CLS(MTLTextureDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLTextureReferenceType); +_MTL_PRIVATE_DEF_CLS(MTLTextureViewDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLTileRenderPipelineColorAttachmentDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLTileRenderPipelineColorAttachmentDescriptorArray); +_MTL_PRIVATE_DEF_CLS(MTLTileRenderPipelineDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLType); +_MTL_PRIVATE_DEF_CLS(MTLVertexAttribute); +_MTL_PRIVATE_DEF_CLS(MTLVertexAttributeDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLVertexAttributeDescriptorArray); +_MTL_PRIVATE_DEF_CLS(MTLVertexBufferLayoutDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLVertexBufferLayoutDescriptorArray); +_MTL_PRIVATE_DEF_CLS(MTLVertexDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLVisibleFunctionTableDescriptor); + +} + +namespace MTL::Private::Protocol +{ + +_MTL_PRIVATE_DEF_PRO(MTL4Archive); +_MTL_PRIVATE_DEF_PRO(MTL4ArgumentTable); +_MTL_PRIVATE_DEF_PRO(MTL4BinaryFunction); +_MTL_PRIVATE_DEF_PRO(MTL4CommandAllocator); +_MTL_PRIVATE_DEF_PRO(MTL4CommandBuffer); +_MTL_PRIVATE_DEF_PRO(MTL4CommandEncoder); +_MTL_PRIVATE_DEF_PRO(MTL4CommandQueue); +_MTL_PRIVATE_DEF_PRO(MTL4CommitFeedback); +_MTL_PRIVATE_DEF_PRO(MTL4Compiler); +_MTL_PRIVATE_DEF_PRO(MTL4CompilerTask); +_MTL_PRIVATE_DEF_PRO(MTL4ComputeCommandEncoder); +_MTL_PRIVATE_DEF_PRO(MTL4CounterHeap); +_MTL_PRIVATE_DEF_PRO(MTL4MachineLearningCommandEncoder); +_MTL_PRIVATE_DEF_PRO(MTL4MachineLearningPipelineState); +_MTL_PRIVATE_DEF_PRO(MTL4PipelineDataSetSerializer); +_MTL_PRIVATE_DEF_PRO(MTL4RenderCommandEncoder); +_MTL_PRIVATE_DEF_PRO(MTLAccelerationStructure); +_MTL_PRIVATE_DEF_PRO(MTLAccelerationStructureCommandEncoder); +_MTL_PRIVATE_DEF_PRO(MTLAllocation); +_MTL_PRIVATE_DEF_PRO(MTLArgumentEncoder); +_MTL_PRIVATE_DEF_PRO(MTLBinaryArchive); +_MTL_PRIVATE_DEF_PRO(MTLBinding); +_MTL_PRIVATE_DEF_PRO(MTLBlitCommandEncoder); +_MTL_PRIVATE_DEF_PRO(MTLBuffer); +_MTL_PRIVATE_DEF_PRO(MTLBufferBinding); +_MTL_PRIVATE_DEF_PRO(MTLCommandBuffer); +_MTL_PRIVATE_DEF_PRO(MTLCommandBufferEncoderInfo); +_MTL_PRIVATE_DEF_PRO(MTLCommandEncoder); +_MTL_PRIVATE_DEF_PRO(MTLCommandQueue); +_MTL_PRIVATE_DEF_PRO(MTLComputeCommandEncoder); +_MTL_PRIVATE_DEF_PRO(MTLComputePipelineState); +_MTL_PRIVATE_DEF_PRO(MTLCounter); +_MTL_PRIVATE_DEF_PRO(MTLCounterSampleBuffer); +_MTL_PRIVATE_DEF_PRO(MTLCounterSet); +_MTL_PRIVATE_DEF_PRO(MTLDepthStencilState); +_MTL_PRIVATE_DEF_PRO(MTLDevice); +_MTL_PRIVATE_DEF_PRO(MTLDrawable); +_MTL_PRIVATE_DEF_PRO(MTLDynamicLibrary); +_MTL_PRIVATE_DEF_PRO(MTLEvent); +_MTL_PRIVATE_DEF_PRO(MTLFence); +_MTL_PRIVATE_DEF_PRO(MTLFunction); +_MTL_PRIVATE_DEF_PRO(MTLFunctionHandle); +_MTL_PRIVATE_DEF_PRO(MTLFunctionLog); +_MTL_PRIVATE_DEF_PRO(MTLFunctionLogDebugLocation); +_MTL_PRIVATE_DEF_PRO(MTLFunctionStitchingAttribute); +_MTL_PRIVATE_DEF_PRO(MTLFunctionStitchingNode); +_MTL_PRIVATE_DEF_PRO(MTLHeap); +_MTL_PRIVATE_DEF_PRO(MTLIOCommandBuffer); +_MTL_PRIVATE_DEF_PRO(MTLIOCommandQueue); +_MTL_PRIVATE_DEF_PRO(MTLIOFileHandle); +_MTL_PRIVATE_DEF_PRO(MTLIOScratchBuffer); +_MTL_PRIVATE_DEF_PRO(MTLIOScratchBufferAllocator); +_MTL_PRIVATE_DEF_PRO(MTLIndirectCommandBuffer); +_MTL_PRIVATE_DEF_PRO(MTLIndirectComputeCommand); +_MTL_PRIVATE_DEF_PRO(MTLIndirectRenderCommand); +_MTL_PRIVATE_DEF_PRO(MTLIntersectionFunctionTable); +_MTL_PRIVATE_DEF_PRO(MTLLibrary); +_MTL_PRIVATE_DEF_PRO(MTLLogContainer); +_MTL_PRIVATE_DEF_PRO(MTLLogState); +_MTL_PRIVATE_DEF_PRO(MTLObjectPayloadBinding); +_MTL_PRIVATE_DEF_PRO(MTLParallelRenderCommandEncoder); +_MTL_PRIVATE_DEF_PRO(MTLRasterizationRateMap); +_MTL_PRIVATE_DEF_PRO(MTLRenderCommandEncoder); +_MTL_PRIVATE_DEF_PRO(MTLRenderPipelineState); +_MTL_PRIVATE_DEF_PRO(MTLResidencySet); +_MTL_PRIVATE_DEF_PRO(MTLResource); +_MTL_PRIVATE_DEF_PRO(MTLResourceStateCommandEncoder); +_MTL_PRIVATE_DEF_PRO(MTLResourceViewPool); +_MTL_PRIVATE_DEF_PRO(MTLSamplerState); +_MTL_PRIVATE_DEF_PRO(MTLSharedEvent); +_MTL_PRIVATE_DEF_PRO(MTLTensor); +_MTL_PRIVATE_DEF_PRO(MTLTensorBinding); +_MTL_PRIVATE_DEF_PRO(MTLTexture); +_MTL_PRIVATE_DEF_PRO(MTLTextureBinding); +_MTL_PRIVATE_DEF_PRO(MTLTextureViewPool); +_MTL_PRIVATE_DEF_PRO(MTLThreadgroupBinding); +_MTL_PRIVATE_DEF_PRO(MTLVisibleFunctionTable); + +} + +namespace MTL::Private::Selector +{ + +_MTL_PRIVATE_DEF_SEL(GPUEndTime, + "GPUEndTime"); +_MTL_PRIVATE_DEF_SEL(GPUStartTime, + "GPUStartTime"); +_MTL_PRIVATE_DEF_SEL(URL, + "URL"); +_MTL_PRIVATE_DEF_SEL(accelerationStructureCommandEncoder, + "accelerationStructureCommandEncoder"); +_MTL_PRIVATE_DEF_SEL(accelerationStructureCommandEncoderWithDescriptor_, + "accelerationStructureCommandEncoderWithDescriptor:"); +_MTL_PRIVATE_DEF_SEL(accelerationStructurePassDescriptor, + "accelerationStructurePassDescriptor"); +_MTL_PRIVATE_DEF_SEL(accelerationStructureSizesWithDescriptor_, + "accelerationStructureSizesWithDescriptor:"); +_MTL_PRIVATE_DEF_SEL(access, + "access"); +_MTL_PRIVATE_DEF_SEL(addAllocation_, + "addAllocation:"); +_MTL_PRIVATE_DEF_SEL(addAllocations_count_, + "addAllocations:count:"); +_MTL_PRIVATE_DEF_SEL(addBarrier, + "addBarrier"); +_MTL_PRIVATE_DEF_SEL(addCompletedHandler_, + "addCompletedHandler:"); +_MTL_PRIVATE_DEF_SEL(addComputePipelineFunctionsWithDescriptor_error_, + "addComputePipelineFunctionsWithDescriptor:error:"); +_MTL_PRIVATE_DEF_SEL(addDebugMarker_range_, + "addDebugMarker:range:"); +_MTL_PRIVATE_DEF_SEL(addFeedbackHandler_, + "addFeedbackHandler:"); +_MTL_PRIVATE_DEF_SEL(addFunctionWithDescriptor_library_error_, + "addFunctionWithDescriptor:library:error:"); +_MTL_PRIVATE_DEF_SEL(addLibraryWithDescriptor_error_, + "addLibraryWithDescriptor:error:"); +_MTL_PRIVATE_DEF_SEL(addLogHandler_, + "addLogHandler:"); +_MTL_PRIVATE_DEF_SEL(addMeshRenderPipelineFunctionsWithDescriptor_error_, + "addMeshRenderPipelineFunctionsWithDescriptor:error:"); +_MTL_PRIVATE_DEF_SEL(addPresentedHandler_, + "addPresentedHandler:"); +_MTL_PRIVATE_DEF_SEL(addRenderPipelineFunctionsWithDescriptor_error_, + "addRenderPipelineFunctionsWithDescriptor:error:"); +_MTL_PRIVATE_DEF_SEL(addResidencySet_, + "addResidencySet:"); +_MTL_PRIVATE_DEF_SEL(addResidencySets_count_, + "addResidencySets:count:"); +_MTL_PRIVATE_DEF_SEL(addScheduledHandler_, + "addScheduledHandler:"); +_MTL_PRIVATE_DEF_SEL(addTileRenderPipelineFunctionsWithDescriptor_error_, + "addTileRenderPipelineFunctionsWithDescriptor:error:"); +_MTL_PRIVATE_DEF_SEL(alignment, + "alignment"); +_MTL_PRIVATE_DEF_SEL(allAllocations, + "allAllocations"); +_MTL_PRIVATE_DEF_SEL(allocatedSize, + "allocatedSize"); +_MTL_PRIVATE_DEF_SEL(allocationCount, + "allocationCount"); +_MTL_PRIVATE_DEF_SEL(allowDuplicateIntersectionFunctionInvocation, + "allowDuplicateIntersectionFunctionInvocation"); +_MTL_PRIVATE_DEF_SEL(allowGPUOptimizedContents, + "allowGPUOptimizedContents"); +_MTL_PRIVATE_DEF_SEL(allowReferencingUndefinedSymbols, + "allowReferencingUndefinedSymbols"); +_MTL_PRIVATE_DEF_SEL(alphaBlendOperation, + "alphaBlendOperation"); +_MTL_PRIVATE_DEF_SEL(alphaToCoverageState, + "alphaToCoverageState"); +_MTL_PRIVATE_DEF_SEL(alphaToOneState, + "alphaToOneState"); +_MTL_PRIVATE_DEF_SEL(architecture, + "architecture"); +_MTL_PRIVATE_DEF_SEL(areBarycentricCoordsSupported, + "areBarycentricCoordsSupported"); +_MTL_PRIVATE_DEF_SEL(areProgrammableSamplePositionsSupported, + "areProgrammableSamplePositionsSupported"); +_MTL_PRIVATE_DEF_SEL(areRasterOrderGroupsSupported, + "areRasterOrderGroupsSupported"); +_MTL_PRIVATE_DEF_SEL(argumentBuffersSupport, + "argumentBuffersSupport"); +_MTL_PRIVATE_DEF_SEL(argumentDescriptor, + "argumentDescriptor"); +_MTL_PRIVATE_DEF_SEL(argumentIndex, + "argumentIndex"); +_MTL_PRIVATE_DEF_SEL(argumentIndexStride, + "argumentIndexStride"); +_MTL_PRIVATE_DEF_SEL(arguments, + "arguments"); +_MTL_PRIVATE_DEF_SEL(arrayLength, + "arrayLength"); +_MTL_PRIVATE_DEF_SEL(arrayType, + "arrayType"); +_MTL_PRIVATE_DEF_SEL(attributeIndex, + "attributeIndex"); +_MTL_PRIVATE_DEF_SEL(attributeType, + "attributeType"); +_MTL_PRIVATE_DEF_SEL(attributes, + "attributes"); +_MTL_PRIVATE_DEF_SEL(backFaceStencil, + "backFaceStencil"); +_MTL_PRIVATE_DEF_SEL(barrierAfterEncoderStages_beforeEncoderStages_visibilityOptions_, + "barrierAfterEncoderStages:beforeEncoderStages:visibilityOptions:"); +_MTL_PRIVATE_DEF_SEL(barrierAfterQueueStages_beforeStages_, + "barrierAfterQueueStages:beforeStages:"); +_MTL_PRIVATE_DEF_SEL(barrierAfterQueueStages_beforeStages_visibilityOptions_, + "barrierAfterQueueStages:beforeStages:visibilityOptions:"); +_MTL_PRIVATE_DEF_SEL(barrierAfterStages_beforeQueueStages_visibilityOptions_, + "barrierAfterStages:beforeQueueStages:visibilityOptions:"); +_MTL_PRIVATE_DEF_SEL(baseResourceID, + "baseResourceID"); +_MTL_PRIVATE_DEF_SEL(beginCommandBufferWithAllocator_, + "beginCommandBufferWithAllocator:"); +_MTL_PRIVATE_DEF_SEL(beginCommandBufferWithAllocator_options_, + "beginCommandBufferWithAllocator:options:"); +_MTL_PRIVATE_DEF_SEL(binaryArchives, + "binaryArchives"); +_MTL_PRIVATE_DEF_SEL(binaryFunctions, + "binaryFunctions"); +_MTL_PRIVATE_DEF_SEL(binaryLinkedFunctions, + "binaryLinkedFunctions"); +_MTL_PRIVATE_DEF_SEL(bindings, + "bindings"); +_MTL_PRIVATE_DEF_SEL(blendingState, + "blendingState"); +_MTL_PRIVATE_DEF_SEL(blitCommandEncoder, + "blitCommandEncoder"); +_MTL_PRIVATE_DEF_SEL(blitCommandEncoderWithDescriptor_, + "blitCommandEncoderWithDescriptor:"); +_MTL_PRIVATE_DEF_SEL(blitPassDescriptor, + "blitPassDescriptor"); +_MTL_PRIVATE_DEF_SEL(borderColor, + "borderColor"); +_MTL_PRIVATE_DEF_SEL(boundingBoxBuffer, + "boundingBoxBuffer"); +_MTL_PRIVATE_DEF_SEL(boundingBoxBufferOffset, + "boundingBoxBufferOffset"); +_MTL_PRIVATE_DEF_SEL(boundingBoxBuffers, + "boundingBoxBuffers"); +_MTL_PRIVATE_DEF_SEL(boundingBoxCount, + "boundingBoxCount"); +_MTL_PRIVATE_DEF_SEL(boundingBoxStride, + "boundingBoxStride"); +_MTL_PRIVATE_DEF_SEL(buffer, + "buffer"); +_MTL_PRIVATE_DEF_SEL(bufferAlignment, + "bufferAlignment"); +_MTL_PRIVATE_DEF_SEL(bufferBytesPerRow, + "bufferBytesPerRow"); +_MTL_PRIVATE_DEF_SEL(bufferDataSize, + "bufferDataSize"); +_MTL_PRIVATE_DEF_SEL(bufferDataType, + "bufferDataType"); +_MTL_PRIVATE_DEF_SEL(bufferIndex, + "bufferIndex"); +_MTL_PRIVATE_DEF_SEL(bufferOffset, + "bufferOffset"); +_MTL_PRIVATE_DEF_SEL(bufferPointerType, + "bufferPointerType"); +_MTL_PRIVATE_DEF_SEL(bufferSize, + "bufferSize"); +_MTL_PRIVATE_DEF_SEL(bufferStructType, + "bufferStructType"); +_MTL_PRIVATE_DEF_SEL(buffers, + "buffers"); +_MTL_PRIVATE_DEF_SEL(buildAccelerationStructure_descriptor_scratchBuffer_, + "buildAccelerationStructure:descriptor:scratchBuffer:"); +_MTL_PRIVATE_DEF_SEL(buildAccelerationStructure_descriptor_scratchBuffer_scratchBufferOffset_, + "buildAccelerationStructure:descriptor:scratchBuffer:scratchBufferOffset:"); +_MTL_PRIVATE_DEF_SEL(captureObject, + "captureObject"); +_MTL_PRIVATE_DEF_SEL(clearBarrier, + "clearBarrier"); +_MTL_PRIVATE_DEF_SEL(clearColor, + "clearColor"); +_MTL_PRIVATE_DEF_SEL(clearDepth, + "clearDepth"); +_MTL_PRIVATE_DEF_SEL(clearStencil, + "clearStencil"); +_MTL_PRIVATE_DEF_SEL(colorAttachmentMappingState, + "colorAttachmentMappingState"); +_MTL_PRIVATE_DEF_SEL(colorAttachments, + "colorAttachments"); +_MTL_PRIVATE_DEF_SEL(column, + "column"); +_MTL_PRIVATE_DEF_SEL(commandBuffer, + "commandBuffer"); +_MTL_PRIVATE_DEF_SEL(commandBufferWithDescriptor_, + "commandBufferWithDescriptor:"); +_MTL_PRIVATE_DEF_SEL(commandBufferWithUnretainedReferences, + "commandBufferWithUnretainedReferences"); +_MTL_PRIVATE_DEF_SEL(commandQueue, + "commandQueue"); +_MTL_PRIVATE_DEF_SEL(commandTypes, + "commandTypes"); +_MTL_PRIVATE_DEF_SEL(commit, + "commit"); +_MTL_PRIVATE_DEF_SEL(commit_count_, + "commit:count:"); +_MTL_PRIVATE_DEF_SEL(commit_count_options_, + "commit:count:options:"); +_MTL_PRIVATE_DEF_SEL(compareFunction, + "compareFunction"); +_MTL_PRIVATE_DEF_SEL(compileSymbolVisibility, + "compileSymbolVisibility"); +_MTL_PRIVATE_DEF_SEL(compiler, + "compiler"); +_MTL_PRIVATE_DEF_SEL(compressionType, + "compressionType"); +_MTL_PRIVATE_DEF_SEL(computeCommandEncoder, + "computeCommandEncoder"); +_MTL_PRIVATE_DEF_SEL(computeCommandEncoderWithDescriptor_, + "computeCommandEncoderWithDescriptor:"); +_MTL_PRIVATE_DEF_SEL(computeCommandEncoderWithDispatchType_, + "computeCommandEncoderWithDispatchType:"); +_MTL_PRIVATE_DEF_SEL(computeFunction, + "computeFunction"); +_MTL_PRIVATE_DEF_SEL(computeFunctionDescriptor, + "computeFunctionDescriptor"); +_MTL_PRIVATE_DEF_SEL(computePassDescriptor, + "computePassDescriptor"); +_MTL_PRIVATE_DEF_SEL(concurrentDispatchThreadgroups_threadsPerThreadgroup_, + "concurrentDispatchThreadgroups:threadsPerThreadgroup:"); +_MTL_PRIVATE_DEF_SEL(concurrentDispatchThreads_threadsPerThreadgroup_, + "concurrentDispatchThreads:threadsPerThreadgroup:"); +_MTL_PRIVATE_DEF_SEL(configuration, + "configuration"); +_MTL_PRIVATE_DEF_SEL(constantBlockAlignment, + "constantBlockAlignment"); +_MTL_PRIVATE_DEF_SEL(constantDataAtIndex_, + "constantDataAtIndex:"); +_MTL_PRIVATE_DEF_SEL(constantValues, + "constantValues"); +_MTL_PRIVATE_DEF_SEL(containsAllocation_, + "containsAllocation:"); +_MTL_PRIVATE_DEF_SEL(contents, + "contents"); +_MTL_PRIVATE_DEF_SEL(controlDependencies, + "controlDependencies"); +_MTL_PRIVATE_DEF_SEL(controlPointBuffer, + "controlPointBuffer"); +_MTL_PRIVATE_DEF_SEL(controlPointBufferOffset, + "controlPointBufferOffset"); +_MTL_PRIVATE_DEF_SEL(controlPointBuffers, + "controlPointBuffers"); +_MTL_PRIVATE_DEF_SEL(controlPointCount, + "controlPointCount"); +_MTL_PRIVATE_DEF_SEL(controlPointFormat, + "controlPointFormat"); +_MTL_PRIVATE_DEF_SEL(controlPointStride, + "controlPointStride"); +_MTL_PRIVATE_DEF_SEL(convertSparsePixelRegions_toTileRegions_withTileSize_alignmentMode_numRegions_, + "convertSparsePixelRegions:toTileRegions:withTileSize:alignmentMode:numRegions:"); +_MTL_PRIVATE_DEF_SEL(convertSparseTileRegions_toPixelRegions_withTileSize_numRegions_, + "convertSparseTileRegions:toPixelRegions:withTileSize:numRegions:"); +_MTL_PRIVATE_DEF_SEL(copyAccelerationStructure_toAccelerationStructure_, + "copyAccelerationStructure:toAccelerationStructure:"); +_MTL_PRIVATE_DEF_SEL(copyAndCompactAccelerationStructure_toAccelerationStructure_, + "copyAndCompactAccelerationStructure:toAccelerationStructure:"); +_MTL_PRIVATE_DEF_SEL(copyBufferMappingsFromBuffer_toBuffer_operations_count_, + "copyBufferMappingsFromBuffer:toBuffer:operations:count:"); +_MTL_PRIVATE_DEF_SEL(copyFromBuffer_sourceOffset_sourceBytesPerRow_sourceBytesPerImage_sourceSize_toTexture_destinationSlice_destinationLevel_destinationOrigin_, + "copyFromBuffer:sourceOffset:sourceBytesPerRow:sourceBytesPerImage:sourceSize:toTexture:destinationSlice:destinationLevel:destinationOrigin:"); +_MTL_PRIVATE_DEF_SEL(copyFromBuffer_sourceOffset_sourceBytesPerRow_sourceBytesPerImage_sourceSize_toTexture_destinationSlice_destinationLevel_destinationOrigin_options_, + "copyFromBuffer:sourceOffset:sourceBytesPerRow:sourceBytesPerImage:sourceSize:toTexture:destinationSlice:destinationLevel:destinationOrigin:options:"); +_MTL_PRIVATE_DEF_SEL(copyFromBuffer_sourceOffset_toBuffer_destinationOffset_size_, + "copyFromBuffer:sourceOffset:toBuffer:destinationOffset:size:"); +_MTL_PRIVATE_DEF_SEL(copyFromTensor_sourceOrigin_sourceDimensions_toTensor_destinationOrigin_destinationDimensions_, + "copyFromTensor:sourceOrigin:sourceDimensions:toTensor:destinationOrigin:destinationDimensions:"); +_MTL_PRIVATE_DEF_SEL(copyFromTexture_sourceSlice_sourceLevel_sourceOrigin_sourceSize_toBuffer_destinationOffset_destinationBytesPerRow_destinationBytesPerImage_, + "copyFromTexture:sourceSlice:sourceLevel:sourceOrigin:sourceSize:toBuffer:destinationOffset:destinationBytesPerRow:destinationBytesPerImage:"); +_MTL_PRIVATE_DEF_SEL(copyFromTexture_sourceSlice_sourceLevel_sourceOrigin_sourceSize_toBuffer_destinationOffset_destinationBytesPerRow_destinationBytesPerImage_options_, + "copyFromTexture:sourceSlice:sourceLevel:sourceOrigin:sourceSize:toBuffer:destinationOffset:destinationBytesPerRow:destinationBytesPerImage:options:"); +_MTL_PRIVATE_DEF_SEL(copyFromTexture_sourceSlice_sourceLevel_sourceOrigin_sourceSize_toTexture_destinationSlice_destinationLevel_destinationOrigin_, + "copyFromTexture:sourceSlice:sourceLevel:sourceOrigin:sourceSize:toTexture:destinationSlice:destinationLevel:destinationOrigin:"); +_MTL_PRIVATE_DEF_SEL(copyFromTexture_sourceSlice_sourceLevel_toTexture_destinationSlice_destinationLevel_sliceCount_levelCount_, + "copyFromTexture:sourceSlice:sourceLevel:toTexture:destinationSlice:destinationLevel:sliceCount:levelCount:"); +_MTL_PRIVATE_DEF_SEL(copyFromTexture_toTexture_, + "copyFromTexture:toTexture:"); +_MTL_PRIVATE_DEF_SEL(copyIndirectCommandBuffer_sourceRange_destination_destinationIndex_, + "copyIndirectCommandBuffer:sourceRange:destination:destinationIndex:"); +_MTL_PRIVATE_DEF_SEL(copyParameterDataToBuffer_offset_, + "copyParameterDataToBuffer:offset:"); +_MTL_PRIVATE_DEF_SEL(copyResourceViewsFromPool_sourceRange_destinationIndex_, + "copyResourceViewsFromPool:sourceRange:destinationIndex:"); +_MTL_PRIVATE_DEF_SEL(copyStatusToBuffer_offset_, + "copyStatusToBuffer:offset:"); +_MTL_PRIVATE_DEF_SEL(copyTextureMappingsFromTexture_toTexture_operations_count_, + "copyTextureMappingsFromTexture:toTexture:operations:count:"); +_MTL_PRIVATE_DEF_SEL(count, + "count"); +_MTL_PRIVATE_DEF_SEL(counterSet, + "counterSet"); +_MTL_PRIVATE_DEF_SEL(counterSets, + "counterSets"); +_MTL_PRIVATE_DEF_SEL(counters, + "counters"); +_MTL_PRIVATE_DEF_SEL(cpuCacheMode, + "cpuCacheMode"); +_MTL_PRIVATE_DEF_SEL(currentAllocatedSize, + "currentAllocatedSize"); +_MTL_PRIVATE_DEF_SEL(curveBasis, + "curveBasis"); +_MTL_PRIVATE_DEF_SEL(curveEndCaps, + "curveEndCaps"); +_MTL_PRIVATE_DEF_SEL(curveType, + "curveType"); +_MTL_PRIVATE_DEF_SEL(data, + "data"); +_MTL_PRIVATE_DEF_SEL(dataSize, + "dataSize"); +_MTL_PRIVATE_DEF_SEL(dataType, + "dataType"); +_MTL_PRIVATE_DEF_SEL(dealloc, + "dealloc"); +_MTL_PRIVATE_DEF_SEL(debugLocation, + "debugLocation"); +_MTL_PRIVATE_DEF_SEL(debugSignposts, + "debugSignposts"); +_MTL_PRIVATE_DEF_SEL(defaultCaptureScope, + "defaultCaptureScope"); +_MTL_PRIVATE_DEF_SEL(defaultRasterSampleCount, + "defaultRasterSampleCount"); +_MTL_PRIVATE_DEF_SEL(depth, + "depth"); +_MTL_PRIVATE_DEF_SEL(depthAttachment, + "depthAttachment"); +_MTL_PRIVATE_DEF_SEL(depthAttachmentPixelFormat, + "depthAttachmentPixelFormat"); +_MTL_PRIVATE_DEF_SEL(depthCompareFunction, + "depthCompareFunction"); +_MTL_PRIVATE_DEF_SEL(depthFailureOperation, + "depthFailureOperation"); +_MTL_PRIVATE_DEF_SEL(depthPlane, + "depthPlane"); +_MTL_PRIVATE_DEF_SEL(depthResolveFilter, + "depthResolveFilter"); +_MTL_PRIVATE_DEF_SEL(depthStencilPassOperation, + "depthStencilPassOperation"); +_MTL_PRIVATE_DEF_SEL(descriptor, + "descriptor"); +_MTL_PRIVATE_DEF_SEL(destination, + "destination"); +_MTL_PRIVATE_DEF_SEL(destinationAlphaBlendFactor, + "destinationAlphaBlendFactor"); +_MTL_PRIVATE_DEF_SEL(destinationRGBBlendFactor, + "destinationRGBBlendFactor"); +_MTL_PRIVATE_DEF_SEL(device, + "device"); +_MTL_PRIVATE_DEF_SEL(didModifyRange_, + "didModifyRange:"); +_MTL_PRIVATE_DEF_SEL(dimensions, + "dimensions"); +_MTL_PRIVATE_DEF_SEL(dispatchNetworkWithIntermediatesHeap_, + "dispatchNetworkWithIntermediatesHeap:"); +_MTL_PRIVATE_DEF_SEL(dispatchQueue, + "dispatchQueue"); +_MTL_PRIVATE_DEF_SEL(dispatchThreadgroups_threadsPerThreadgroup_, + "dispatchThreadgroups:threadsPerThreadgroup:"); +_MTL_PRIVATE_DEF_SEL(dispatchThreadgroupsWithIndirectBuffer_indirectBufferOffset_threadsPerThreadgroup_, + "dispatchThreadgroupsWithIndirectBuffer:indirectBufferOffset:threadsPerThreadgroup:"); +_MTL_PRIVATE_DEF_SEL(dispatchThreadgroupsWithIndirectBuffer_threadsPerThreadgroup_, + "dispatchThreadgroupsWithIndirectBuffer:threadsPerThreadgroup:"); +_MTL_PRIVATE_DEF_SEL(dispatchThreads_threadsPerThreadgroup_, + "dispatchThreads:threadsPerThreadgroup:"); +_MTL_PRIVATE_DEF_SEL(dispatchThreadsPerTile_, + "dispatchThreadsPerTile:"); +_MTL_PRIVATE_DEF_SEL(dispatchThreadsWithIndirectBuffer_, + "dispatchThreadsWithIndirectBuffer:"); +_MTL_PRIVATE_DEF_SEL(dispatchType, + "dispatchType"); +_MTL_PRIVATE_DEF_SEL(drawIndexedPatches_patchIndexBuffer_patchIndexBufferOffset_controlPointIndexBuffer_controlPointIndexBufferOffset_indirectBuffer_indirectBufferOffset_, + "drawIndexedPatches:patchIndexBuffer:patchIndexBufferOffset:controlPointIndexBuffer:controlPointIndexBufferOffset:indirectBuffer:indirectBufferOffset:"); +_MTL_PRIVATE_DEF_SEL(drawIndexedPatches_patchStart_patchCount_patchIndexBuffer_patchIndexBufferOffset_controlPointIndexBuffer_controlPointIndexBufferOffset_instanceCount_baseInstance_, + "drawIndexedPatches:patchStart:patchCount:patchIndexBuffer:patchIndexBufferOffset:controlPointIndexBuffer:controlPointIndexBufferOffset:instanceCount:baseInstance:"); +_MTL_PRIVATE_DEF_SEL(drawIndexedPatches_patchStart_patchCount_patchIndexBuffer_patchIndexBufferOffset_controlPointIndexBuffer_controlPointIndexBufferOffset_instanceCount_baseInstance_tessellationFactorBuffer_tessellationFactorBufferOffset_tessellationFactorBufferInstanceStride_, + "drawIndexedPatches:patchStart:patchCount:patchIndexBuffer:patchIndexBufferOffset:controlPointIndexBuffer:controlPointIndexBufferOffset:instanceCount:baseInstance:tessellationFactorBuffer:tessellationFactorBufferOffset:tessellationFactorBufferInstanceStride:"); +_MTL_PRIVATE_DEF_SEL(drawIndexedPrimitives_indexCount_indexType_indexBuffer_indexBufferLength_, + "drawIndexedPrimitives:indexCount:indexType:indexBuffer:indexBufferLength:"); +_MTL_PRIVATE_DEF_SEL(drawIndexedPrimitives_indexCount_indexType_indexBuffer_indexBufferLength_instanceCount_, + "drawIndexedPrimitives:indexCount:indexType:indexBuffer:indexBufferLength:instanceCount:"); +_MTL_PRIVATE_DEF_SEL(drawIndexedPrimitives_indexCount_indexType_indexBuffer_indexBufferLength_instanceCount_baseVertex_baseInstance_, + "drawIndexedPrimitives:indexCount:indexType:indexBuffer:indexBufferLength:instanceCount:baseVertex:baseInstance:"); +_MTL_PRIVATE_DEF_SEL(drawIndexedPrimitives_indexCount_indexType_indexBuffer_indexBufferOffset_, + "drawIndexedPrimitives:indexCount:indexType:indexBuffer:indexBufferOffset:"); +_MTL_PRIVATE_DEF_SEL(drawIndexedPrimitives_indexCount_indexType_indexBuffer_indexBufferOffset_instanceCount_, + "drawIndexedPrimitives:indexCount:indexType:indexBuffer:indexBufferOffset:instanceCount:"); +_MTL_PRIVATE_DEF_SEL(drawIndexedPrimitives_indexCount_indexType_indexBuffer_indexBufferOffset_instanceCount_baseVertex_baseInstance_, + "drawIndexedPrimitives:indexCount:indexType:indexBuffer:indexBufferOffset:instanceCount:baseVertex:baseInstance:"); +_MTL_PRIVATE_DEF_SEL(drawIndexedPrimitives_indexType_indexBuffer_indexBufferLength_indirectBuffer_, + "drawIndexedPrimitives:indexType:indexBuffer:indexBufferLength:indirectBuffer:"); +_MTL_PRIVATE_DEF_SEL(drawIndexedPrimitives_indexType_indexBuffer_indexBufferOffset_indirectBuffer_indirectBufferOffset_, + "drawIndexedPrimitives:indexType:indexBuffer:indexBufferOffset:indirectBuffer:indirectBufferOffset:"); +_MTL_PRIVATE_DEF_SEL(drawMeshThreadgroups_threadsPerObjectThreadgroup_threadsPerMeshThreadgroup_, + "drawMeshThreadgroups:threadsPerObjectThreadgroup:threadsPerMeshThreadgroup:"); +_MTL_PRIVATE_DEF_SEL(drawMeshThreadgroupsWithIndirectBuffer_indirectBufferOffset_threadsPerObjectThreadgroup_threadsPerMeshThreadgroup_, + "drawMeshThreadgroupsWithIndirectBuffer:indirectBufferOffset:threadsPerObjectThreadgroup:threadsPerMeshThreadgroup:"); +_MTL_PRIVATE_DEF_SEL(drawMeshThreadgroupsWithIndirectBuffer_threadsPerObjectThreadgroup_threadsPerMeshThreadgroup_, + "drawMeshThreadgroupsWithIndirectBuffer:threadsPerObjectThreadgroup:threadsPerMeshThreadgroup:"); +_MTL_PRIVATE_DEF_SEL(drawMeshThreads_threadsPerObjectThreadgroup_threadsPerMeshThreadgroup_, + "drawMeshThreads:threadsPerObjectThreadgroup:threadsPerMeshThreadgroup:"); +_MTL_PRIVATE_DEF_SEL(drawPatches_patchIndexBuffer_patchIndexBufferOffset_indirectBuffer_indirectBufferOffset_, + "drawPatches:patchIndexBuffer:patchIndexBufferOffset:indirectBuffer:indirectBufferOffset:"); +_MTL_PRIVATE_DEF_SEL(drawPatches_patchStart_patchCount_patchIndexBuffer_patchIndexBufferOffset_instanceCount_baseInstance_, + "drawPatches:patchStart:patchCount:patchIndexBuffer:patchIndexBufferOffset:instanceCount:baseInstance:"); +_MTL_PRIVATE_DEF_SEL(drawPatches_patchStart_patchCount_patchIndexBuffer_patchIndexBufferOffset_instanceCount_baseInstance_tessellationFactorBuffer_tessellationFactorBufferOffset_tessellationFactorBufferInstanceStride_, + "drawPatches:patchStart:patchCount:patchIndexBuffer:patchIndexBufferOffset:instanceCount:baseInstance:tessellationFactorBuffer:tessellationFactorBufferOffset:tessellationFactorBufferInstanceStride:"); +_MTL_PRIVATE_DEF_SEL(drawPrimitives_indirectBuffer_, + "drawPrimitives:indirectBuffer:"); +_MTL_PRIVATE_DEF_SEL(drawPrimitives_indirectBuffer_indirectBufferOffset_, + "drawPrimitives:indirectBuffer:indirectBufferOffset:"); +_MTL_PRIVATE_DEF_SEL(drawPrimitives_vertexStart_vertexCount_, + "drawPrimitives:vertexStart:vertexCount:"); +_MTL_PRIVATE_DEF_SEL(drawPrimitives_vertexStart_vertexCount_instanceCount_, + "drawPrimitives:vertexStart:vertexCount:instanceCount:"); +_MTL_PRIVATE_DEF_SEL(drawPrimitives_vertexStart_vertexCount_instanceCount_baseInstance_, + "drawPrimitives:vertexStart:vertexCount:instanceCount:baseInstance:"); +_MTL_PRIVATE_DEF_SEL(drawableID, + "drawableID"); +_MTL_PRIVATE_DEF_SEL(elementArrayType, + "elementArrayType"); +_MTL_PRIVATE_DEF_SEL(elementIsArgumentBuffer, + "elementIsArgumentBuffer"); +_MTL_PRIVATE_DEF_SEL(elementPointerType, + "elementPointerType"); +_MTL_PRIVATE_DEF_SEL(elementStructType, + "elementStructType"); +_MTL_PRIVATE_DEF_SEL(elementTensorReferenceType, + "elementTensorReferenceType"); +_MTL_PRIVATE_DEF_SEL(elementTextureReferenceType, + "elementTextureReferenceType"); +_MTL_PRIVATE_DEF_SEL(elementType, + "elementType"); +_MTL_PRIVATE_DEF_SEL(enableLogging, + "enableLogging"); +_MTL_PRIVATE_DEF_SEL(encodeSignalEvent_value_, + "encodeSignalEvent:value:"); +_MTL_PRIVATE_DEF_SEL(encodeWaitForEvent_value_, + "encodeWaitForEvent:value:"); +_MTL_PRIVATE_DEF_SEL(encodedLength, + "encodedLength"); +_MTL_PRIVATE_DEF_SEL(encoderLabel, + "encoderLabel"); +_MTL_PRIVATE_DEF_SEL(endCommandBuffer, + "endCommandBuffer"); +_MTL_PRIVATE_DEF_SEL(endEncoding, + "endEncoding"); +_MTL_PRIVATE_DEF_SEL(endOfEncoderSampleIndex, + "endOfEncoderSampleIndex"); +_MTL_PRIVATE_DEF_SEL(endOfFragmentSampleIndex, + "endOfFragmentSampleIndex"); +_MTL_PRIVATE_DEF_SEL(endOfVertexSampleIndex, + "endOfVertexSampleIndex"); +_MTL_PRIVATE_DEF_SEL(endResidency, + "endResidency"); +_MTL_PRIVATE_DEF_SEL(enqueue, + "enqueue"); +_MTL_PRIVATE_DEF_SEL(enqueueBarrier, + "enqueueBarrier"); +_MTL_PRIVATE_DEF_SEL(error, + "error"); +_MTL_PRIVATE_DEF_SEL(errorOptions, + "errorOptions"); +_MTL_PRIVATE_DEF_SEL(errorState, + "errorState"); +_MTL_PRIVATE_DEF_SEL(executeCommandsInBuffer_indirectBuffer_, + "executeCommandsInBuffer:indirectBuffer:"); +_MTL_PRIVATE_DEF_SEL(executeCommandsInBuffer_indirectBuffer_indirectBufferOffset_, + "executeCommandsInBuffer:indirectBuffer:indirectBufferOffset:"); +_MTL_PRIVATE_DEF_SEL(executeCommandsInBuffer_withRange_, + "executeCommandsInBuffer:withRange:"); +_MTL_PRIVATE_DEF_SEL(extentAtDimensionIndex_, + "extentAtDimensionIndex:"); +_MTL_PRIVATE_DEF_SEL(fastMathEnabled, + "fastMathEnabled"); +_MTL_PRIVATE_DEF_SEL(feedbackQueue, + "feedbackQueue"); +_MTL_PRIVATE_DEF_SEL(fillBuffer_range_value_, + "fillBuffer:range:value:"); +_MTL_PRIVATE_DEF_SEL(firstMipmapInTail, + "firstMipmapInTail"); +_MTL_PRIVATE_DEF_SEL(format, + "format"); +_MTL_PRIVATE_DEF_SEL(fragmentAdditionalBinaryFunctions, + "fragmentAdditionalBinaryFunctions"); +_MTL_PRIVATE_DEF_SEL(fragmentArguments, + "fragmentArguments"); +_MTL_PRIVATE_DEF_SEL(fragmentBindings, + "fragmentBindings"); +_MTL_PRIVATE_DEF_SEL(fragmentBuffers, + "fragmentBuffers"); +_MTL_PRIVATE_DEF_SEL(fragmentFunction, + "fragmentFunction"); +_MTL_PRIVATE_DEF_SEL(fragmentFunctionDescriptor, + "fragmentFunctionDescriptor"); +_MTL_PRIVATE_DEF_SEL(fragmentLinkedFunctions, + "fragmentLinkedFunctions"); +_MTL_PRIVATE_DEF_SEL(fragmentLinkingDescriptor, + "fragmentLinkingDescriptor"); +_MTL_PRIVATE_DEF_SEL(fragmentPreloadedLibraries, + "fragmentPreloadedLibraries"); +_MTL_PRIVATE_DEF_SEL(fragmentStaticLinkingDescriptor, + "fragmentStaticLinkingDescriptor"); +_MTL_PRIVATE_DEF_SEL(frontFaceStencil, + "frontFaceStencil"); +_MTL_PRIVATE_DEF_SEL(function, + "function"); +_MTL_PRIVATE_DEF_SEL(functionConstantsDictionary, + "functionConstantsDictionary"); +_MTL_PRIVATE_DEF_SEL(functionCount, + "functionCount"); +_MTL_PRIVATE_DEF_SEL(functionDescriptor, + "functionDescriptor"); +_MTL_PRIVATE_DEF_SEL(functionDescriptors, + "functionDescriptors"); +_MTL_PRIVATE_DEF_SEL(functionGraph, + "functionGraph"); +_MTL_PRIVATE_DEF_SEL(functionGraphs, + "functionGraphs"); +_MTL_PRIVATE_DEF_SEL(functionHandleWithBinaryFunction_, + "functionHandleWithBinaryFunction:"); +_MTL_PRIVATE_DEF_SEL(functionHandleWithBinaryFunction_stage_, + "functionHandleWithBinaryFunction:stage:"); +_MTL_PRIVATE_DEF_SEL(functionHandleWithFunction_, + "functionHandleWithFunction:"); +_MTL_PRIVATE_DEF_SEL(functionHandleWithFunction_stage_, + "functionHandleWithFunction:stage:"); +_MTL_PRIVATE_DEF_SEL(functionHandleWithName_, + "functionHandleWithName:"); +_MTL_PRIVATE_DEF_SEL(functionHandleWithName_stage_, + "functionHandleWithName:stage:"); +_MTL_PRIVATE_DEF_SEL(functionName, + "functionName"); +_MTL_PRIVATE_DEF_SEL(functionNames, + "functionNames"); +_MTL_PRIVATE_DEF_SEL(functionType, + "functionType"); +_MTL_PRIVATE_DEF_SEL(functions, + "functions"); +_MTL_PRIVATE_DEF_SEL(generateMipmapsForTexture_, + "generateMipmapsForTexture:"); +_MTL_PRIVATE_DEF_SEL(geometryDescriptors, + "geometryDescriptors"); +_MTL_PRIVATE_DEF_SEL(getBytes_bytesPerRow_bytesPerImage_fromRegion_mipmapLevel_slice_, + "getBytes:bytesPerRow:bytesPerImage:fromRegion:mipmapLevel:slice:"); +_MTL_PRIVATE_DEF_SEL(getBytes_bytesPerRow_fromRegion_mipmapLevel_, + "getBytes:bytesPerRow:fromRegion:mipmapLevel:"); +_MTL_PRIVATE_DEF_SEL(getBytes_strides_fromSliceOrigin_sliceDimensions_, + "getBytes:strides:fromSliceOrigin:sliceDimensions:"); +_MTL_PRIVATE_DEF_SEL(getDefaultSamplePositions_count_, + "getDefaultSamplePositions:count:"); +_MTL_PRIVATE_DEF_SEL(getPhysicalIndexForLogicalIndex_, + "getPhysicalIndexForLogicalIndex:"); +_MTL_PRIVATE_DEF_SEL(getSamplePositions_count_, + "getSamplePositions:count:"); +_MTL_PRIVATE_DEF_SEL(getTextureAccessCounters_region_mipLevel_slice_resetCounters_countersBuffer_countersBufferOffset_, + "getTextureAccessCounters:region:mipLevel:slice:resetCounters:countersBuffer:countersBufferOffset:"); +_MTL_PRIVATE_DEF_SEL(gpuAddress, + "gpuAddress"); +_MTL_PRIVATE_DEF_SEL(gpuResourceID, + "gpuResourceID"); +_MTL_PRIVATE_DEF_SEL(groups, + "groups"); +_MTL_PRIVATE_DEF_SEL(hasUnifiedMemory, + "hasUnifiedMemory"); +_MTL_PRIVATE_DEF_SEL(hazardTrackingMode, + "hazardTrackingMode"); +_MTL_PRIVATE_DEF_SEL(heap, + "heap"); +_MTL_PRIVATE_DEF_SEL(heapAccelerationStructureSizeAndAlignWithDescriptor_, + "heapAccelerationStructureSizeAndAlignWithDescriptor:"); +_MTL_PRIVATE_DEF_SEL(heapAccelerationStructureSizeAndAlignWithSize_, + "heapAccelerationStructureSizeAndAlignWithSize:"); +_MTL_PRIVATE_DEF_SEL(heapBufferSizeAndAlignWithLength_options_, + "heapBufferSizeAndAlignWithLength:options:"); +_MTL_PRIVATE_DEF_SEL(heapOffset, + "heapOffset"); +_MTL_PRIVATE_DEF_SEL(heapTextureSizeAndAlignWithDescriptor_, + "heapTextureSizeAndAlignWithDescriptor:"); +_MTL_PRIVATE_DEF_SEL(height, + "height"); +_MTL_PRIVATE_DEF_SEL(horizontal, + "horizontal"); +_MTL_PRIVATE_DEF_SEL(horizontalSampleStorage, + "horizontalSampleStorage"); +_MTL_PRIVATE_DEF_SEL(imageblockMemoryLengthForDimensions_, + "imageblockMemoryLengthForDimensions:"); +_MTL_PRIVATE_DEF_SEL(imageblockSampleLength, + "imageblockSampleLength"); +_MTL_PRIVATE_DEF_SEL(index, + "index"); +_MTL_PRIVATE_DEF_SEL(indexBuffer, + "indexBuffer"); +_MTL_PRIVATE_DEF_SEL(indexBufferIndex, + "indexBufferIndex"); +_MTL_PRIVATE_DEF_SEL(indexBufferOffset, + "indexBufferOffset"); +_MTL_PRIVATE_DEF_SEL(indexType, + "indexType"); +_MTL_PRIVATE_DEF_SEL(indirectComputeCommandAtIndex_, + "indirectComputeCommandAtIndex:"); +_MTL_PRIVATE_DEF_SEL(indirectRenderCommandAtIndex_, + "indirectRenderCommandAtIndex:"); +_MTL_PRIVATE_DEF_SEL(inheritBuffers, + "inheritBuffers"); +_MTL_PRIVATE_DEF_SEL(inheritCullMode, + "inheritCullMode"); +_MTL_PRIVATE_DEF_SEL(inheritDepthBias, + "inheritDepthBias"); +_MTL_PRIVATE_DEF_SEL(inheritDepthClipMode, + "inheritDepthClipMode"); +_MTL_PRIVATE_DEF_SEL(inheritDepthStencilState, + "inheritDepthStencilState"); +_MTL_PRIVATE_DEF_SEL(inheritFrontFacingWinding, + "inheritFrontFacingWinding"); +_MTL_PRIVATE_DEF_SEL(inheritPipelineState, + "inheritPipelineState"); +_MTL_PRIVATE_DEF_SEL(inheritTriangleFillMode, + "inheritTriangleFillMode"); +_MTL_PRIVATE_DEF_SEL(init, + "init"); +_MTL_PRIVATE_DEF_SEL(initWithArgumentIndex_, + "initWithArgumentIndex:"); +_MTL_PRIVATE_DEF_SEL(initWithDispatchQueue_, + "initWithDispatchQueue:"); +_MTL_PRIVATE_DEF_SEL(initWithFunctionName_nodes_outputNode_attributes_, + "initWithFunctionName:nodes:outputNode:attributes:"); +_MTL_PRIVATE_DEF_SEL(initWithName_arguments_controlDependencies_, + "initWithName:arguments:controlDependencies:"); +_MTL_PRIVATE_DEF_SEL(initWithRank_values_, + "initWithRank:values:"); +_MTL_PRIVATE_DEF_SEL(initWithSampleCount_, + "initWithSampleCount:"); +_MTL_PRIVATE_DEF_SEL(initWithSampleCount_horizontal_vertical_, + "initWithSampleCount:horizontal:vertical:"); +_MTL_PRIVATE_DEF_SEL(initialCapacity, + "initialCapacity"); +_MTL_PRIVATE_DEF_SEL(initializeBindings, + "initializeBindings"); +_MTL_PRIVATE_DEF_SEL(inputDimensionsAtBufferIndex_, + "inputDimensionsAtBufferIndex:"); +_MTL_PRIVATE_DEF_SEL(inputPrimitiveTopology, + "inputPrimitiveTopology"); +_MTL_PRIVATE_DEF_SEL(insertDebugCaptureBoundary, + "insertDebugCaptureBoundary"); +_MTL_PRIVATE_DEF_SEL(insertDebugSignpost_, + "insertDebugSignpost:"); +_MTL_PRIVATE_DEF_SEL(insertLibraries, + "insertLibraries"); +_MTL_PRIVATE_DEF_SEL(installName, + "installName"); +_MTL_PRIVATE_DEF_SEL(instanceCount, + "instanceCount"); +_MTL_PRIVATE_DEF_SEL(instanceCountBuffer, + "instanceCountBuffer"); +_MTL_PRIVATE_DEF_SEL(instanceCountBufferOffset, + "instanceCountBufferOffset"); +_MTL_PRIVATE_DEF_SEL(instanceDescriptorBuffer, + "instanceDescriptorBuffer"); +_MTL_PRIVATE_DEF_SEL(instanceDescriptorBufferOffset, + "instanceDescriptorBufferOffset"); +_MTL_PRIVATE_DEF_SEL(instanceDescriptorStride, + "instanceDescriptorStride"); +_MTL_PRIVATE_DEF_SEL(instanceDescriptorType, + "instanceDescriptorType"); +_MTL_PRIVATE_DEF_SEL(instanceTransformationMatrixLayout, + "instanceTransformationMatrixLayout"); +_MTL_PRIVATE_DEF_SEL(instancedAccelerationStructures, + "instancedAccelerationStructures"); +_MTL_PRIVATE_DEF_SEL(intermediatesHeapSize, + "intermediatesHeapSize"); +_MTL_PRIVATE_DEF_SEL(intersectionFunctionTableDescriptor, + "intersectionFunctionTableDescriptor"); +_MTL_PRIVATE_DEF_SEL(intersectionFunctionTableOffset, + "intersectionFunctionTableOffset"); +_MTL_PRIVATE_DEF_SEL(invalidateCounterRange_, + "invalidateCounterRange:"); +_MTL_PRIVATE_DEF_SEL(iosurface, + "iosurface"); +_MTL_PRIVATE_DEF_SEL(iosurfacePlane, + "iosurfacePlane"); +_MTL_PRIVATE_DEF_SEL(isActive, + "isActive"); +_MTL_PRIVATE_DEF_SEL(isAliasable, + "isAliasable"); +_MTL_PRIVATE_DEF_SEL(isAlphaToCoverageEnabled, + "isAlphaToCoverageEnabled"); +_MTL_PRIVATE_DEF_SEL(isAlphaToOneEnabled, + "isAlphaToOneEnabled"); +_MTL_PRIVATE_DEF_SEL(isArgument, + "isArgument"); +_MTL_PRIVATE_DEF_SEL(isBlendingEnabled, + "isBlendingEnabled"); +_MTL_PRIVATE_DEF_SEL(isCapturing, + "isCapturing"); +_MTL_PRIVATE_DEF_SEL(isDepth24Stencil8PixelFormatSupported, + "isDepth24Stencil8PixelFormatSupported"); +_MTL_PRIVATE_DEF_SEL(isDepthTexture, + "isDepthTexture"); +_MTL_PRIVATE_DEF_SEL(isDepthWriteEnabled, + "isDepthWriteEnabled"); +_MTL_PRIVATE_DEF_SEL(isFramebufferOnly, + "isFramebufferOnly"); +_MTL_PRIVATE_DEF_SEL(isHeadless, + "isHeadless"); +_MTL_PRIVATE_DEF_SEL(isLowPower, + "isLowPower"); +_MTL_PRIVATE_DEF_SEL(isPatchControlPointData, + "isPatchControlPointData"); +_MTL_PRIVATE_DEF_SEL(isPatchData, + "isPatchData"); +_MTL_PRIVATE_DEF_SEL(isRasterizationEnabled, + "isRasterizationEnabled"); +_MTL_PRIVATE_DEF_SEL(isRemovable, + "isRemovable"); +_MTL_PRIVATE_DEF_SEL(isShareable, + "isShareable"); +_MTL_PRIVATE_DEF_SEL(isSparse, + "isSparse"); +_MTL_PRIVATE_DEF_SEL(isTessellationFactorScaleEnabled, + "isTessellationFactorScaleEnabled"); +_MTL_PRIVATE_DEF_SEL(isUsed, + "isUsed"); +_MTL_PRIVATE_DEF_SEL(kernelEndTime, + "kernelEndTime"); +_MTL_PRIVATE_DEF_SEL(kernelStartTime, + "kernelStartTime"); +_MTL_PRIVATE_DEF_SEL(label, + "label"); +_MTL_PRIVATE_DEF_SEL(languageVersion, + "languageVersion"); +_MTL_PRIVATE_DEF_SEL(layerAtIndex_, + "layerAtIndex:"); +_MTL_PRIVATE_DEF_SEL(layerCount, + "layerCount"); +_MTL_PRIVATE_DEF_SEL(layers, + "layers"); +_MTL_PRIVATE_DEF_SEL(layouts, + "layouts"); +_MTL_PRIVATE_DEF_SEL(length, + "length"); +_MTL_PRIVATE_DEF_SEL(level, + "level"); +_MTL_PRIVATE_DEF_SEL(levelRange, + "levelRange"); +_MTL_PRIVATE_DEF_SEL(libraries, + "libraries"); +_MTL_PRIVATE_DEF_SEL(library, + "library"); +_MTL_PRIVATE_DEF_SEL(libraryType, + "libraryType"); +_MTL_PRIVATE_DEF_SEL(line, + "line"); +_MTL_PRIVATE_DEF_SEL(linkedFunctions, + "linkedFunctions"); +_MTL_PRIVATE_DEF_SEL(loadAction, + "loadAction"); +_MTL_PRIVATE_DEF_SEL(loadBuffer_offset_size_sourceHandle_sourceHandleOffset_, + "loadBuffer:offset:size:sourceHandle:sourceHandleOffset:"); +_MTL_PRIVATE_DEF_SEL(loadBytes_size_sourceHandle_sourceHandleOffset_, + "loadBytes:size:sourceHandle:sourceHandleOffset:"); +_MTL_PRIVATE_DEF_SEL(loadTexture_slice_level_size_sourceBytesPerRow_sourceBytesPerImage_destinationOrigin_sourceHandle_sourceHandleOffset_, + "loadTexture:slice:level:size:sourceBytesPerRow:sourceBytesPerImage:destinationOrigin:sourceHandle:sourceHandleOffset:"); +_MTL_PRIVATE_DEF_SEL(location, + "location"); +_MTL_PRIVATE_DEF_SEL(locationNumber, + "locationNumber"); +_MTL_PRIVATE_DEF_SEL(lodAverage, + "lodAverage"); +_MTL_PRIVATE_DEF_SEL(lodBias, + "lodBias"); +_MTL_PRIVATE_DEF_SEL(lodMaxClamp, + "lodMaxClamp"); +_MTL_PRIVATE_DEF_SEL(lodMinClamp, + "lodMinClamp"); +_MTL_PRIVATE_DEF_SEL(logState, + "logState"); +_MTL_PRIVATE_DEF_SEL(logs, + "logs"); +_MTL_PRIVATE_DEF_SEL(lookupArchives, + "lookupArchives"); +_MTL_PRIVATE_DEF_SEL(machineLearningCommandEncoder, + "machineLearningCommandEncoder"); +_MTL_PRIVATE_DEF_SEL(machineLearningFunctionDescriptor, + "machineLearningFunctionDescriptor"); +_MTL_PRIVATE_DEF_SEL(magFilter, + "magFilter"); +_MTL_PRIVATE_DEF_SEL(makeAliasable, + "makeAliasable"); +_MTL_PRIVATE_DEF_SEL(mapPhysicalToScreenCoordinates_forLayer_, + "mapPhysicalToScreenCoordinates:forLayer:"); +_MTL_PRIVATE_DEF_SEL(mapScreenToPhysicalCoordinates_forLayer_, + "mapScreenToPhysicalCoordinates:forLayer:"); +_MTL_PRIVATE_DEF_SEL(mathFloatingPointFunctions, + "mathFloatingPointFunctions"); +_MTL_PRIVATE_DEF_SEL(mathMode, + "mathMode"); +_MTL_PRIVATE_DEF_SEL(maxAnisotropy, + "maxAnisotropy"); +_MTL_PRIVATE_DEF_SEL(maxArgumentBufferSamplerCount, + "maxArgumentBufferSamplerCount"); +_MTL_PRIVATE_DEF_SEL(maxAvailableSizeWithAlignment_, + "maxAvailableSizeWithAlignment:"); +_MTL_PRIVATE_DEF_SEL(maxBufferBindCount, + "maxBufferBindCount"); +_MTL_PRIVATE_DEF_SEL(maxBufferLength, + "maxBufferLength"); +_MTL_PRIVATE_DEF_SEL(maxCallStackDepth, + "maxCallStackDepth"); +_MTL_PRIVATE_DEF_SEL(maxCommandBufferCount, + "maxCommandBufferCount"); +_MTL_PRIVATE_DEF_SEL(maxCommandsInFlight, + "maxCommandsInFlight"); +_MTL_PRIVATE_DEF_SEL(maxCompatiblePlacementSparsePageSize, + "maxCompatiblePlacementSparsePageSize"); +_MTL_PRIVATE_DEF_SEL(maxFragmentBufferBindCount, + "maxFragmentBufferBindCount"); +_MTL_PRIVATE_DEF_SEL(maxFragmentCallStackDepth, + "maxFragmentCallStackDepth"); +_MTL_PRIVATE_DEF_SEL(maxInstanceCount, + "maxInstanceCount"); +_MTL_PRIVATE_DEF_SEL(maxKernelBufferBindCount, + "maxKernelBufferBindCount"); +_MTL_PRIVATE_DEF_SEL(maxKernelThreadgroupMemoryBindCount, + "maxKernelThreadgroupMemoryBindCount"); +_MTL_PRIVATE_DEF_SEL(maxMeshBufferBindCount, + "maxMeshBufferBindCount"); +_MTL_PRIVATE_DEF_SEL(maxMotionTransformCount, + "maxMotionTransformCount"); +_MTL_PRIVATE_DEF_SEL(maxObjectBufferBindCount, + "maxObjectBufferBindCount"); +_MTL_PRIVATE_DEF_SEL(maxObjectThreadgroupMemoryBindCount, + "maxObjectThreadgroupMemoryBindCount"); +_MTL_PRIVATE_DEF_SEL(maxSampleCount, + "maxSampleCount"); +_MTL_PRIVATE_DEF_SEL(maxSamplerStateBindCount, + "maxSamplerStateBindCount"); +_MTL_PRIVATE_DEF_SEL(maxTessellationFactor, + "maxTessellationFactor"); +_MTL_PRIVATE_DEF_SEL(maxTextureBindCount, + "maxTextureBindCount"); +_MTL_PRIVATE_DEF_SEL(maxThreadgroupMemoryLength, + "maxThreadgroupMemoryLength"); +_MTL_PRIVATE_DEF_SEL(maxThreadsPerThreadgroup, + "maxThreadsPerThreadgroup"); +_MTL_PRIVATE_DEF_SEL(maxTotalThreadgroupsPerMeshGrid, + "maxTotalThreadgroupsPerMeshGrid"); +_MTL_PRIVATE_DEF_SEL(maxTotalThreadsPerMeshThreadgroup, + "maxTotalThreadsPerMeshThreadgroup"); +_MTL_PRIVATE_DEF_SEL(maxTotalThreadsPerObjectThreadgroup, + "maxTotalThreadsPerObjectThreadgroup"); +_MTL_PRIVATE_DEF_SEL(maxTotalThreadsPerThreadgroup, + "maxTotalThreadsPerThreadgroup"); +_MTL_PRIVATE_DEF_SEL(maxTransferRate, + "maxTransferRate"); +_MTL_PRIVATE_DEF_SEL(maxVertexAmplificationCount, + "maxVertexAmplificationCount"); +_MTL_PRIVATE_DEF_SEL(maxVertexBufferBindCount, + "maxVertexBufferBindCount"); +_MTL_PRIVATE_DEF_SEL(maxVertexCallStackDepth, + "maxVertexCallStackDepth"); +_MTL_PRIVATE_DEF_SEL(maximumConcurrentCompilationTaskCount, + "maximumConcurrentCompilationTaskCount"); +_MTL_PRIVATE_DEF_SEL(memberByName_, + "memberByName:"); +_MTL_PRIVATE_DEF_SEL(members, + "members"); +_MTL_PRIVATE_DEF_SEL(memoryBarrierWithResources_count_, + "memoryBarrierWithResources:count:"); +_MTL_PRIVATE_DEF_SEL(memoryBarrierWithResources_count_afterStages_beforeStages_, + "memoryBarrierWithResources:count:afterStages:beforeStages:"); +_MTL_PRIVATE_DEF_SEL(memoryBarrierWithScope_, + "memoryBarrierWithScope:"); +_MTL_PRIVATE_DEF_SEL(memoryBarrierWithScope_afterStages_beforeStages_, + "memoryBarrierWithScope:afterStages:beforeStages:"); +_MTL_PRIVATE_DEF_SEL(meshAdditionalBinaryFunctions, + "meshAdditionalBinaryFunctions"); +_MTL_PRIVATE_DEF_SEL(meshBindings, + "meshBindings"); +_MTL_PRIVATE_DEF_SEL(meshBuffers, + "meshBuffers"); +_MTL_PRIVATE_DEF_SEL(meshFunction, + "meshFunction"); +_MTL_PRIVATE_DEF_SEL(meshFunctionDescriptor, + "meshFunctionDescriptor"); +_MTL_PRIVATE_DEF_SEL(meshLinkedFunctions, + "meshLinkedFunctions"); +_MTL_PRIVATE_DEF_SEL(meshLinkingDescriptor, + "meshLinkingDescriptor"); +_MTL_PRIVATE_DEF_SEL(meshStaticLinkingDescriptor, + "meshStaticLinkingDescriptor"); +_MTL_PRIVATE_DEF_SEL(meshThreadExecutionWidth, + "meshThreadExecutionWidth"); +_MTL_PRIVATE_DEF_SEL(meshThreadgroupSizeIsMultipleOfThreadExecutionWidth, + "meshThreadgroupSizeIsMultipleOfThreadExecutionWidth"); +_MTL_PRIVATE_DEF_SEL(minFilter, + "minFilter"); +_MTL_PRIVATE_DEF_SEL(minimumLinearTextureAlignmentForPixelFormat_, + "minimumLinearTextureAlignmentForPixelFormat:"); +_MTL_PRIVATE_DEF_SEL(minimumTextureBufferAlignmentForPixelFormat_, + "minimumTextureBufferAlignmentForPixelFormat:"); +_MTL_PRIVATE_DEF_SEL(mipFilter, + "mipFilter"); +_MTL_PRIVATE_DEF_SEL(mipmapLevelCount, + "mipmapLevelCount"); +_MTL_PRIVATE_DEF_SEL(motionEndBorderMode, + "motionEndBorderMode"); +_MTL_PRIVATE_DEF_SEL(motionEndTime, + "motionEndTime"); +_MTL_PRIVATE_DEF_SEL(motionKeyframeCount, + "motionKeyframeCount"); +_MTL_PRIVATE_DEF_SEL(motionStartBorderMode, + "motionStartBorderMode"); +_MTL_PRIVATE_DEF_SEL(motionStartTime, + "motionStartTime"); +_MTL_PRIVATE_DEF_SEL(motionTransformBuffer, + "motionTransformBuffer"); +_MTL_PRIVATE_DEF_SEL(motionTransformBufferOffset, + "motionTransformBufferOffset"); +_MTL_PRIVATE_DEF_SEL(motionTransformCount, + "motionTransformCount"); +_MTL_PRIVATE_DEF_SEL(motionTransformCountBuffer, + "motionTransformCountBuffer"); +_MTL_PRIVATE_DEF_SEL(motionTransformCountBufferOffset, + "motionTransformCountBufferOffset"); +_MTL_PRIVATE_DEF_SEL(motionTransformStride, + "motionTransformStride"); +_MTL_PRIVATE_DEF_SEL(motionTransformType, + "motionTransformType"); +_MTL_PRIVATE_DEF_SEL(moveTextureMappingsFromTexture_sourceSlice_sourceLevel_sourceOrigin_sourceSize_toTexture_destinationSlice_destinationLevel_destinationOrigin_, + "moveTextureMappingsFromTexture:sourceSlice:sourceLevel:sourceOrigin:sourceSize:toTexture:destinationSlice:destinationLevel:destinationOrigin:"); +_MTL_PRIVATE_DEF_SEL(mutability, + "mutability"); +_MTL_PRIVATE_DEF_SEL(name, + "name"); +_MTL_PRIVATE_DEF_SEL(newAccelerationStructureWithDescriptor_, + "newAccelerationStructureWithDescriptor:"); +_MTL_PRIVATE_DEF_SEL(newAccelerationStructureWithDescriptor_offset_, + "newAccelerationStructureWithDescriptor:offset:"); +_MTL_PRIVATE_DEF_SEL(newAccelerationStructureWithSize_, + "newAccelerationStructureWithSize:"); +_MTL_PRIVATE_DEF_SEL(newAccelerationStructureWithSize_offset_, + "newAccelerationStructureWithSize:offset:"); +_MTL_PRIVATE_DEF_SEL(newArchiveWithURL_error_, + "newArchiveWithURL:error:"); +_MTL_PRIVATE_DEF_SEL(newArgumentEncoderForBufferAtIndex_, + "newArgumentEncoderForBufferAtIndex:"); +_MTL_PRIVATE_DEF_SEL(newArgumentEncoderWithArguments_, + "newArgumentEncoderWithArguments:"); +_MTL_PRIVATE_DEF_SEL(newArgumentEncoderWithBufferBinding_, + "newArgumentEncoderWithBufferBinding:"); +_MTL_PRIVATE_DEF_SEL(newArgumentEncoderWithBufferIndex_, + "newArgumentEncoderWithBufferIndex:"); +_MTL_PRIVATE_DEF_SEL(newArgumentEncoderWithBufferIndex_reflection_, + "newArgumentEncoderWithBufferIndex:reflection:"); +_MTL_PRIVATE_DEF_SEL(newArgumentTableWithDescriptor_error_, + "newArgumentTableWithDescriptor:error:"); +_MTL_PRIVATE_DEF_SEL(newBinaryArchiveWithDescriptor_error_, + "newBinaryArchiveWithDescriptor:error:"); +_MTL_PRIVATE_DEF_SEL(newBinaryFunctionWithDescriptor_compilerTaskOptions_completionHandler_, + "newBinaryFunctionWithDescriptor:compilerTaskOptions:completionHandler:"); +_MTL_PRIVATE_DEF_SEL(newBinaryFunctionWithDescriptor_compilerTaskOptions_error_, + "newBinaryFunctionWithDescriptor:compilerTaskOptions:error:"); +_MTL_PRIVATE_DEF_SEL(newBinaryFunctionWithDescriptor_error_, + "newBinaryFunctionWithDescriptor:error:"); +_MTL_PRIVATE_DEF_SEL(newBufferWithBytes_length_options_, + "newBufferWithBytes:length:options:"); +_MTL_PRIVATE_DEF_SEL(newBufferWithBytesNoCopy_length_options_deallocator_, + "newBufferWithBytesNoCopy:length:options:deallocator:"); +_MTL_PRIVATE_DEF_SEL(newBufferWithLength_options_, + "newBufferWithLength:options:"); +_MTL_PRIVATE_DEF_SEL(newBufferWithLength_options_offset_, + "newBufferWithLength:options:offset:"); +_MTL_PRIVATE_DEF_SEL(newBufferWithLength_options_placementSparsePageSize_, + "newBufferWithLength:options:placementSparsePageSize:"); +_MTL_PRIVATE_DEF_SEL(newCaptureScopeWithCommandQueue_, + "newCaptureScopeWithCommandQueue:"); +_MTL_PRIVATE_DEF_SEL(newCaptureScopeWithDevice_, + "newCaptureScopeWithDevice:"); +_MTL_PRIVATE_DEF_SEL(newCaptureScopeWithMTL4CommandQueue_, + "newCaptureScopeWithMTL4CommandQueue:"); +_MTL_PRIVATE_DEF_SEL(newCommandAllocator, + "newCommandAllocator"); +_MTL_PRIVATE_DEF_SEL(newCommandAllocatorWithDescriptor_error_, + "newCommandAllocatorWithDescriptor:error:"); +_MTL_PRIVATE_DEF_SEL(newCommandBuffer, + "newCommandBuffer"); +_MTL_PRIVATE_DEF_SEL(newCommandQueue, + "newCommandQueue"); +_MTL_PRIVATE_DEF_SEL(newCommandQueueWithDescriptor_, + "newCommandQueueWithDescriptor:"); +_MTL_PRIVATE_DEF_SEL(newCommandQueueWithMaxCommandBufferCount_, + "newCommandQueueWithMaxCommandBufferCount:"); +_MTL_PRIVATE_DEF_SEL(newCompilerWithDescriptor_error_, + "newCompilerWithDescriptor:error:"); +_MTL_PRIVATE_DEF_SEL(newComputePipelineStateWithAdditionalBinaryFunctions_error_, + "newComputePipelineStateWithAdditionalBinaryFunctions:error:"); +_MTL_PRIVATE_DEF_SEL(newComputePipelineStateWithBinaryFunctions_error_, + "newComputePipelineStateWithBinaryFunctions:error:"); +_MTL_PRIVATE_DEF_SEL(newComputePipelineStateWithDescriptor_compilerTaskOptions_completionHandler_, + "newComputePipelineStateWithDescriptor:compilerTaskOptions:completionHandler:"); +_MTL_PRIVATE_DEF_SEL(newComputePipelineStateWithDescriptor_compilerTaskOptions_error_, + "newComputePipelineStateWithDescriptor:compilerTaskOptions:error:"); +_MTL_PRIVATE_DEF_SEL(newComputePipelineStateWithDescriptor_dynamicLinkingDescriptor_compilerTaskOptions_completionHandler_, + "newComputePipelineStateWithDescriptor:dynamicLinkingDescriptor:compilerTaskOptions:completionHandler:"); +_MTL_PRIVATE_DEF_SEL(newComputePipelineStateWithDescriptor_dynamicLinkingDescriptor_compilerTaskOptions_error_, + "newComputePipelineStateWithDescriptor:dynamicLinkingDescriptor:compilerTaskOptions:error:"); +_MTL_PRIVATE_DEF_SEL(newComputePipelineStateWithDescriptor_dynamicLinkingDescriptor_error_, + "newComputePipelineStateWithDescriptor:dynamicLinkingDescriptor:error:"); +_MTL_PRIVATE_DEF_SEL(newComputePipelineStateWithDescriptor_error_, + "newComputePipelineStateWithDescriptor:error:"); +_MTL_PRIVATE_DEF_SEL(newComputePipelineStateWithDescriptor_options_completionHandler_, + "newComputePipelineStateWithDescriptor:options:completionHandler:"); +_MTL_PRIVATE_DEF_SEL(newComputePipelineStateWithDescriptor_options_reflection_error_, + "newComputePipelineStateWithDescriptor:options:reflection:error:"); +_MTL_PRIVATE_DEF_SEL(newComputePipelineStateWithFunction_completionHandler_, + "newComputePipelineStateWithFunction:completionHandler:"); +_MTL_PRIVATE_DEF_SEL(newComputePipelineStateWithFunction_error_, + "newComputePipelineStateWithFunction:error:"); +_MTL_PRIVATE_DEF_SEL(newComputePipelineStateWithFunction_options_completionHandler_, + "newComputePipelineStateWithFunction:options:completionHandler:"); +_MTL_PRIVATE_DEF_SEL(newComputePipelineStateWithFunction_options_reflection_error_, + "newComputePipelineStateWithFunction:options:reflection:error:"); +_MTL_PRIVATE_DEF_SEL(newCounterHeapWithDescriptor_error_, + "newCounterHeapWithDescriptor:error:"); +_MTL_PRIVATE_DEF_SEL(newCounterSampleBufferWithDescriptor_error_, + "newCounterSampleBufferWithDescriptor:error:"); +_MTL_PRIVATE_DEF_SEL(newDefaultLibrary, + "newDefaultLibrary"); +_MTL_PRIVATE_DEF_SEL(newDefaultLibraryWithBundle_error_, + "newDefaultLibraryWithBundle:error:"); +_MTL_PRIVATE_DEF_SEL(newDepthStencilStateWithDescriptor_, + "newDepthStencilStateWithDescriptor:"); +_MTL_PRIVATE_DEF_SEL(newDynamicLibrary_completionHandler_, + "newDynamicLibrary:completionHandler:"); +_MTL_PRIVATE_DEF_SEL(newDynamicLibrary_error_, + "newDynamicLibrary:error:"); +_MTL_PRIVATE_DEF_SEL(newDynamicLibraryWithURL_completionHandler_, + "newDynamicLibraryWithURL:completionHandler:"); +_MTL_PRIVATE_DEF_SEL(newDynamicLibraryWithURL_error_, + "newDynamicLibraryWithURL:error:"); +_MTL_PRIVATE_DEF_SEL(newEvent, + "newEvent"); +_MTL_PRIVATE_DEF_SEL(newFence, + "newFence"); +_MTL_PRIVATE_DEF_SEL(newFunctionWithDescriptor_completionHandler_, + "newFunctionWithDescriptor:completionHandler:"); +_MTL_PRIVATE_DEF_SEL(newFunctionWithDescriptor_error_, + "newFunctionWithDescriptor:error:"); +_MTL_PRIVATE_DEF_SEL(newFunctionWithName_, + "newFunctionWithName:"); +_MTL_PRIVATE_DEF_SEL(newFunctionWithName_constantValues_completionHandler_, + "newFunctionWithName:constantValues:completionHandler:"); +_MTL_PRIVATE_DEF_SEL(newFunctionWithName_constantValues_error_, + "newFunctionWithName:constantValues:error:"); +_MTL_PRIVATE_DEF_SEL(newHeapWithDescriptor_, + "newHeapWithDescriptor:"); +_MTL_PRIVATE_DEF_SEL(newIOCommandQueueWithDescriptor_error_, + "newIOCommandQueueWithDescriptor:error:"); +_MTL_PRIVATE_DEF_SEL(newIOFileHandleWithURL_compressionMethod_error_, + "newIOFileHandleWithURL:compressionMethod:error:"); +_MTL_PRIVATE_DEF_SEL(newIOFileHandleWithURL_error_, + "newIOFileHandleWithURL:error:"); +_MTL_PRIVATE_DEF_SEL(newIOHandleWithURL_compressionMethod_error_, + "newIOHandleWithURL:compressionMethod:error:"); +_MTL_PRIVATE_DEF_SEL(newIOHandleWithURL_error_, + "newIOHandleWithURL:error:"); +_MTL_PRIVATE_DEF_SEL(newIndirectCommandBufferWithDescriptor_maxCommandCount_options_, + "newIndirectCommandBufferWithDescriptor:maxCommandCount:options:"); +_MTL_PRIVATE_DEF_SEL(newIntersectionFunctionTableWithDescriptor_, + "newIntersectionFunctionTableWithDescriptor:"); +_MTL_PRIVATE_DEF_SEL(newIntersectionFunctionTableWithDescriptor_stage_, + "newIntersectionFunctionTableWithDescriptor:stage:"); +_MTL_PRIVATE_DEF_SEL(newIntersectionFunctionWithDescriptor_completionHandler_, + "newIntersectionFunctionWithDescriptor:completionHandler:"); +_MTL_PRIVATE_DEF_SEL(newIntersectionFunctionWithDescriptor_error_, + "newIntersectionFunctionWithDescriptor:error:"); +_MTL_PRIVATE_DEF_SEL(newLibraryWithData_error_, + "newLibraryWithData:error:"); +_MTL_PRIVATE_DEF_SEL(newLibraryWithDescriptor_completionHandler_, + "newLibraryWithDescriptor:completionHandler:"); +_MTL_PRIVATE_DEF_SEL(newLibraryWithDescriptor_error_, + "newLibraryWithDescriptor:error:"); +_MTL_PRIVATE_DEF_SEL(newLibraryWithFile_error_, + "newLibraryWithFile:error:"); +_MTL_PRIVATE_DEF_SEL(newLibraryWithSource_options_completionHandler_, + "newLibraryWithSource:options:completionHandler:"); +_MTL_PRIVATE_DEF_SEL(newLibraryWithSource_options_error_, + "newLibraryWithSource:options:error:"); +_MTL_PRIVATE_DEF_SEL(newLibraryWithStitchedDescriptor_completionHandler_, + "newLibraryWithStitchedDescriptor:completionHandler:"); +_MTL_PRIVATE_DEF_SEL(newLibraryWithStitchedDescriptor_error_, + "newLibraryWithStitchedDescriptor:error:"); +_MTL_PRIVATE_DEF_SEL(newLibraryWithURL_error_, + "newLibraryWithURL:error:"); +_MTL_PRIVATE_DEF_SEL(newLogStateWithDescriptor_error_, + "newLogStateWithDescriptor:error:"); +_MTL_PRIVATE_DEF_SEL(newMTL4CommandQueue, + "newMTL4CommandQueue"); +_MTL_PRIVATE_DEF_SEL(newMTL4CommandQueueWithDescriptor_error_, + "newMTL4CommandQueueWithDescriptor:error:"); +_MTL_PRIVATE_DEF_SEL(newMachineLearningPipelineStateWithDescriptor_completionHandler_, + "newMachineLearningPipelineStateWithDescriptor:completionHandler:"); +_MTL_PRIVATE_DEF_SEL(newMachineLearningPipelineStateWithDescriptor_error_, + "newMachineLearningPipelineStateWithDescriptor:error:"); +_MTL_PRIVATE_DEF_SEL(newPipelineDataSetSerializerWithDescriptor_, + "newPipelineDataSetSerializerWithDescriptor:"); +_MTL_PRIVATE_DEF_SEL(newRasterizationRateMapWithDescriptor_, + "newRasterizationRateMapWithDescriptor:"); +_MTL_PRIVATE_DEF_SEL(newRemoteBufferViewForDevice_, + "newRemoteBufferViewForDevice:"); +_MTL_PRIVATE_DEF_SEL(newRemoteTextureViewForDevice_, + "newRemoteTextureViewForDevice:"); +_MTL_PRIVATE_DEF_SEL(newRenderPipelineDescriptorForSpecialization, + "newRenderPipelineDescriptorForSpecialization"); +_MTL_PRIVATE_DEF_SEL(newRenderPipelineStateBySpecializationWithDescriptor_pipeline_completionHandler_, + "newRenderPipelineStateBySpecializationWithDescriptor:pipeline:completionHandler:"); +_MTL_PRIVATE_DEF_SEL(newRenderPipelineStateBySpecializationWithDescriptor_pipeline_error_, + "newRenderPipelineStateBySpecializationWithDescriptor:pipeline:error:"); +_MTL_PRIVATE_DEF_SEL(newRenderPipelineStateWithAdditionalBinaryFunctions_error_, + "newRenderPipelineStateWithAdditionalBinaryFunctions:error:"); +_MTL_PRIVATE_DEF_SEL(newRenderPipelineStateWithBinaryFunctions_error_, + "newRenderPipelineStateWithBinaryFunctions:error:"); +_MTL_PRIVATE_DEF_SEL(newRenderPipelineStateWithDescriptor_compilerTaskOptions_completionHandler_, + "newRenderPipelineStateWithDescriptor:compilerTaskOptions:completionHandler:"); +_MTL_PRIVATE_DEF_SEL(newRenderPipelineStateWithDescriptor_compilerTaskOptions_error_, + "newRenderPipelineStateWithDescriptor:compilerTaskOptions:error:"); +_MTL_PRIVATE_DEF_SEL(newRenderPipelineStateWithDescriptor_completionHandler_, + "newRenderPipelineStateWithDescriptor:completionHandler:"); +_MTL_PRIVATE_DEF_SEL(newRenderPipelineStateWithDescriptor_dynamicLinkingDescriptor_compilerTaskOptions_completionHandler_, + "newRenderPipelineStateWithDescriptor:dynamicLinkingDescriptor:compilerTaskOptions:completionHandler:"); +_MTL_PRIVATE_DEF_SEL(newRenderPipelineStateWithDescriptor_dynamicLinkingDescriptor_compilerTaskOptions_error_, + "newRenderPipelineStateWithDescriptor:dynamicLinkingDescriptor:compilerTaskOptions:error:"); +_MTL_PRIVATE_DEF_SEL(newRenderPipelineStateWithDescriptor_dynamicLinkingDescriptor_error_, + "newRenderPipelineStateWithDescriptor:dynamicLinkingDescriptor:error:"); +_MTL_PRIVATE_DEF_SEL(newRenderPipelineStateWithDescriptor_error_, + "newRenderPipelineStateWithDescriptor:error:"); +_MTL_PRIVATE_DEF_SEL(newRenderPipelineStateWithDescriptor_options_completionHandler_, + "newRenderPipelineStateWithDescriptor:options:completionHandler:"); +_MTL_PRIVATE_DEF_SEL(newRenderPipelineStateWithDescriptor_options_reflection_error_, + "newRenderPipelineStateWithDescriptor:options:reflection:error:"); +_MTL_PRIVATE_DEF_SEL(newRenderPipelineStateWithMeshDescriptor_options_completionHandler_, + "newRenderPipelineStateWithMeshDescriptor:options:completionHandler:"); +_MTL_PRIVATE_DEF_SEL(newRenderPipelineStateWithMeshDescriptor_options_reflection_error_, + "newRenderPipelineStateWithMeshDescriptor:options:reflection:error:"); +_MTL_PRIVATE_DEF_SEL(newRenderPipelineStateWithTileDescriptor_options_completionHandler_, + "newRenderPipelineStateWithTileDescriptor:options:completionHandler:"); +_MTL_PRIVATE_DEF_SEL(newRenderPipelineStateWithTileDescriptor_options_reflection_error_, + "newRenderPipelineStateWithTileDescriptor:options:reflection:error:"); +_MTL_PRIVATE_DEF_SEL(newResidencySetWithDescriptor_error_, + "newResidencySetWithDescriptor:error:"); +_MTL_PRIVATE_DEF_SEL(newSamplerStateWithDescriptor_, + "newSamplerStateWithDescriptor:"); +_MTL_PRIVATE_DEF_SEL(newScratchBufferWithMinimumSize_, + "newScratchBufferWithMinimumSize:"); +_MTL_PRIVATE_DEF_SEL(newSharedEvent, + "newSharedEvent"); +_MTL_PRIVATE_DEF_SEL(newSharedEventHandle, + "newSharedEventHandle"); +_MTL_PRIVATE_DEF_SEL(newSharedEventWithHandle_, + "newSharedEventWithHandle:"); +_MTL_PRIVATE_DEF_SEL(newSharedTextureHandle, + "newSharedTextureHandle"); +_MTL_PRIVATE_DEF_SEL(newSharedTextureWithDescriptor_, + "newSharedTextureWithDescriptor:"); +_MTL_PRIVATE_DEF_SEL(newSharedTextureWithHandle_, + "newSharedTextureWithHandle:"); +_MTL_PRIVATE_DEF_SEL(newTensorWithDescriptor_error_, + "newTensorWithDescriptor:error:"); +_MTL_PRIVATE_DEF_SEL(newTensorWithDescriptor_offset_error_, + "newTensorWithDescriptor:offset:error:"); +_MTL_PRIVATE_DEF_SEL(newTextureViewPoolWithDescriptor_error_, + "newTextureViewPoolWithDescriptor:error:"); +_MTL_PRIVATE_DEF_SEL(newTextureViewWithDescriptor_, + "newTextureViewWithDescriptor:"); +_MTL_PRIVATE_DEF_SEL(newTextureViewWithPixelFormat_, + "newTextureViewWithPixelFormat:"); +_MTL_PRIVATE_DEF_SEL(newTextureViewWithPixelFormat_textureType_levels_slices_, + "newTextureViewWithPixelFormat:textureType:levels:slices:"); +_MTL_PRIVATE_DEF_SEL(newTextureViewWithPixelFormat_textureType_levels_slices_swizzle_, + "newTextureViewWithPixelFormat:textureType:levels:slices:swizzle:"); +_MTL_PRIVATE_DEF_SEL(newTextureWithDescriptor_, + "newTextureWithDescriptor:"); +_MTL_PRIVATE_DEF_SEL(newTextureWithDescriptor_iosurface_plane_, + "newTextureWithDescriptor:iosurface:plane:"); +_MTL_PRIVATE_DEF_SEL(newTextureWithDescriptor_offset_, + "newTextureWithDescriptor:offset:"); +_MTL_PRIVATE_DEF_SEL(newTextureWithDescriptor_offset_bytesPerRow_, + "newTextureWithDescriptor:offset:bytesPerRow:"); +_MTL_PRIVATE_DEF_SEL(newVisibleFunctionTableWithDescriptor_, + "newVisibleFunctionTableWithDescriptor:"); +_MTL_PRIVATE_DEF_SEL(newVisibleFunctionTableWithDescriptor_stage_, + "newVisibleFunctionTableWithDescriptor:stage:"); +_MTL_PRIVATE_DEF_SEL(nodes, + "nodes"); +_MTL_PRIVATE_DEF_SEL(normalizedCoordinates, + "normalizedCoordinates"); +_MTL_PRIVATE_DEF_SEL(notifyListener_atValue_block_, + "notifyListener:atValue:block:"); +_MTL_PRIVATE_DEF_SEL(objectAdditionalBinaryFunctions, + "objectAdditionalBinaryFunctions"); +_MTL_PRIVATE_DEF_SEL(objectAtIndexedSubscript_, + "objectAtIndexedSubscript:"); +_MTL_PRIVATE_DEF_SEL(objectBindings, + "objectBindings"); +_MTL_PRIVATE_DEF_SEL(objectBuffers, + "objectBuffers"); +_MTL_PRIVATE_DEF_SEL(objectFunction, + "objectFunction"); +_MTL_PRIVATE_DEF_SEL(objectFunctionDescriptor, + "objectFunctionDescriptor"); +_MTL_PRIVATE_DEF_SEL(objectLinkedFunctions, + "objectLinkedFunctions"); +_MTL_PRIVATE_DEF_SEL(objectLinkingDescriptor, + "objectLinkingDescriptor"); +_MTL_PRIVATE_DEF_SEL(objectPayloadAlignment, + "objectPayloadAlignment"); +_MTL_PRIVATE_DEF_SEL(objectPayloadDataSize, + "objectPayloadDataSize"); +_MTL_PRIVATE_DEF_SEL(objectStaticLinkingDescriptor, + "objectStaticLinkingDescriptor"); +_MTL_PRIVATE_DEF_SEL(objectThreadExecutionWidth, + "objectThreadExecutionWidth"); +_MTL_PRIVATE_DEF_SEL(objectThreadgroupSizeIsMultipleOfThreadExecutionWidth, + "objectThreadgroupSizeIsMultipleOfThreadExecutionWidth"); +_MTL_PRIVATE_DEF_SEL(offset, + "offset"); +_MTL_PRIVATE_DEF_SEL(opaque, + "opaque"); +_MTL_PRIVATE_DEF_SEL(optimizationLevel, + "optimizationLevel"); +_MTL_PRIVATE_DEF_SEL(optimizeContentsForCPUAccess_, + "optimizeContentsForCPUAccess:"); +_MTL_PRIVATE_DEF_SEL(optimizeContentsForCPUAccess_slice_level_, + "optimizeContentsForCPUAccess:slice:level:"); +_MTL_PRIVATE_DEF_SEL(optimizeContentsForGPUAccess_, + "optimizeContentsForGPUAccess:"); +_MTL_PRIVATE_DEF_SEL(optimizeContentsForGPUAccess_slice_level_, + "optimizeContentsForGPUAccess:slice:level:"); +_MTL_PRIVATE_DEF_SEL(optimizeIndirectCommandBuffer_withRange_, + "optimizeIndirectCommandBuffer:withRange:"); +_MTL_PRIVATE_DEF_SEL(options, + "options"); +_MTL_PRIVATE_DEF_SEL(outputNode, + "outputNode"); +_MTL_PRIVATE_DEF_SEL(outputURL, + "outputURL"); +_MTL_PRIVATE_DEF_SEL(parallelRenderCommandEncoderWithDescriptor_, + "parallelRenderCommandEncoderWithDescriptor:"); +_MTL_PRIVATE_DEF_SEL(parameterBufferSizeAndAlign, + "parameterBufferSizeAndAlign"); +_MTL_PRIVATE_DEF_SEL(parentRelativeLevel, + "parentRelativeLevel"); +_MTL_PRIVATE_DEF_SEL(parentRelativeSlice, + "parentRelativeSlice"); +_MTL_PRIVATE_DEF_SEL(parentTexture, + "parentTexture"); +_MTL_PRIVATE_DEF_SEL(patchControlPointCount, + "patchControlPointCount"); +_MTL_PRIVATE_DEF_SEL(patchType, + "patchType"); +_MTL_PRIVATE_DEF_SEL(payloadMemoryLength, + "payloadMemoryLength"); +_MTL_PRIVATE_DEF_SEL(peerCount, + "peerCount"); +_MTL_PRIVATE_DEF_SEL(peerGroupID, + "peerGroupID"); +_MTL_PRIVATE_DEF_SEL(peerIndex, + "peerIndex"); +_MTL_PRIVATE_DEF_SEL(physicalGranularity, + "physicalGranularity"); +_MTL_PRIVATE_DEF_SEL(physicalSizeForLayer_, + "physicalSizeForLayer:"); +_MTL_PRIVATE_DEF_SEL(pipelineDataSetSerializer, + "pipelineDataSetSerializer"); +_MTL_PRIVATE_DEF_SEL(pixelFormat, + "pixelFormat"); +_MTL_PRIVATE_DEF_SEL(placementSparsePageSize, + "placementSparsePageSize"); +_MTL_PRIVATE_DEF_SEL(pointerType, + "pointerType"); +_MTL_PRIVATE_DEF_SEL(popDebugGroup, + "popDebugGroup"); +_MTL_PRIVATE_DEF_SEL(preloadedLibraries, + "preloadedLibraries"); +_MTL_PRIVATE_DEF_SEL(preprocessorMacros, + "preprocessorMacros"); +_MTL_PRIVATE_DEF_SEL(present, + "present"); +_MTL_PRIVATE_DEF_SEL(presentAfterMinimumDuration_, + "presentAfterMinimumDuration:"); +_MTL_PRIVATE_DEF_SEL(presentAtTime_, + "presentAtTime:"); +_MTL_PRIVATE_DEF_SEL(presentDrawable_, + "presentDrawable:"); +_MTL_PRIVATE_DEF_SEL(presentDrawable_afterMinimumDuration_, + "presentDrawable:afterMinimumDuration:"); +_MTL_PRIVATE_DEF_SEL(presentDrawable_atTime_, + "presentDrawable:atTime:"); +_MTL_PRIVATE_DEF_SEL(presentedTime, + "presentedTime"); +_MTL_PRIVATE_DEF_SEL(preserveInvariance, + "preserveInvariance"); +_MTL_PRIVATE_DEF_SEL(primitiveDataBuffer, + "primitiveDataBuffer"); +_MTL_PRIVATE_DEF_SEL(primitiveDataBufferOffset, + "primitiveDataBufferOffset"); +_MTL_PRIVATE_DEF_SEL(primitiveDataElementSize, + "primitiveDataElementSize"); +_MTL_PRIVATE_DEF_SEL(primitiveDataStride, + "primitiveDataStride"); +_MTL_PRIVATE_DEF_SEL(priority, + "priority"); +_MTL_PRIVATE_DEF_SEL(privateFunctionDescriptors, + "privateFunctionDescriptors"); +_MTL_PRIVATE_DEF_SEL(privateFunctions, + "privateFunctions"); +_MTL_PRIVATE_DEF_SEL(pushDebugGroup_, + "pushDebugGroup:"); +_MTL_PRIVATE_DEF_SEL(queryTimestampFrequency, + "queryTimestampFrequency"); +_MTL_PRIVATE_DEF_SEL(rAddressMode, + "rAddressMode"); +_MTL_PRIVATE_DEF_SEL(radiusBuffer, + "radiusBuffer"); +_MTL_PRIVATE_DEF_SEL(radiusBufferOffset, + "radiusBufferOffset"); +_MTL_PRIVATE_DEF_SEL(radiusBuffers, + "radiusBuffers"); +_MTL_PRIVATE_DEF_SEL(radiusFormat, + "radiusFormat"); +_MTL_PRIVATE_DEF_SEL(radiusStride, + "radiusStride"); +_MTL_PRIVATE_DEF_SEL(rank, + "rank"); +_MTL_PRIVATE_DEF_SEL(rasterSampleCount, + "rasterSampleCount"); +_MTL_PRIVATE_DEF_SEL(rasterizationRateMap, + "rasterizationRateMap"); +_MTL_PRIVATE_DEF_SEL(rasterizationRateMapDescriptorWithScreenSize_, + "rasterizationRateMapDescriptorWithScreenSize:"); +_MTL_PRIVATE_DEF_SEL(rasterizationRateMapDescriptorWithScreenSize_layer_, + "rasterizationRateMapDescriptorWithScreenSize:layer:"); +_MTL_PRIVATE_DEF_SEL(rasterizationRateMapDescriptorWithScreenSize_layerCount_layers_, + "rasterizationRateMapDescriptorWithScreenSize:layerCount:layers:"); +_MTL_PRIVATE_DEF_SEL(readMask, + "readMask"); +_MTL_PRIVATE_DEF_SEL(readWriteTextureSupport, + "readWriteTextureSupport"); +_MTL_PRIVATE_DEF_SEL(recommendedMaxWorkingSetSize, + "recommendedMaxWorkingSetSize"); +_MTL_PRIVATE_DEF_SEL(reductionMode, + "reductionMode"); +_MTL_PRIVATE_DEF_SEL(refitAccelerationStructure_descriptor_destination_scratchBuffer_, + "refitAccelerationStructure:descriptor:destination:scratchBuffer:"); +_MTL_PRIVATE_DEF_SEL(refitAccelerationStructure_descriptor_destination_scratchBuffer_options_, + "refitAccelerationStructure:descriptor:destination:scratchBuffer:options:"); +_MTL_PRIVATE_DEF_SEL(refitAccelerationStructure_descriptor_destination_scratchBuffer_scratchBufferOffset_, + "refitAccelerationStructure:descriptor:destination:scratchBuffer:scratchBufferOffset:"); +_MTL_PRIVATE_DEF_SEL(refitAccelerationStructure_descriptor_destination_scratchBuffer_scratchBufferOffset_options_, + "refitAccelerationStructure:descriptor:destination:scratchBuffer:scratchBufferOffset:options:"); +_MTL_PRIVATE_DEF_SEL(reflection, + "reflection"); +_MTL_PRIVATE_DEF_SEL(reflectionForFunctionWithName_, + "reflectionForFunctionWithName:"); +_MTL_PRIVATE_DEF_SEL(registryID, + "registryID"); +_MTL_PRIVATE_DEF_SEL(remoteStorageBuffer, + "remoteStorageBuffer"); +_MTL_PRIVATE_DEF_SEL(remoteStorageTexture, + "remoteStorageTexture"); +_MTL_PRIVATE_DEF_SEL(removeAllAllocations, + "removeAllAllocations"); +_MTL_PRIVATE_DEF_SEL(removeAllDebugMarkers, + "removeAllDebugMarkers"); +_MTL_PRIVATE_DEF_SEL(removeAllocation_, + "removeAllocation:"); +_MTL_PRIVATE_DEF_SEL(removeAllocations_count_, + "removeAllocations:count:"); +_MTL_PRIVATE_DEF_SEL(removeResidencySet_, + "removeResidencySet:"); +_MTL_PRIVATE_DEF_SEL(removeResidencySets_count_, + "removeResidencySets:count:"); +_MTL_PRIVATE_DEF_SEL(renderCommandEncoder, + "renderCommandEncoder"); +_MTL_PRIVATE_DEF_SEL(renderCommandEncoderWithDescriptor_, + "renderCommandEncoderWithDescriptor:"); +_MTL_PRIVATE_DEF_SEL(renderCommandEncoderWithDescriptor_options_, + "renderCommandEncoderWithDescriptor:options:"); +_MTL_PRIVATE_DEF_SEL(renderPassDescriptor, + "renderPassDescriptor"); +_MTL_PRIVATE_DEF_SEL(renderTargetArrayLength, + "renderTargetArrayLength"); +_MTL_PRIVATE_DEF_SEL(renderTargetHeight, + "renderTargetHeight"); +_MTL_PRIVATE_DEF_SEL(renderTargetWidth, + "renderTargetWidth"); +_MTL_PRIVATE_DEF_SEL(replaceRegion_mipmapLevel_slice_withBytes_bytesPerRow_bytesPerImage_, + "replaceRegion:mipmapLevel:slice:withBytes:bytesPerRow:bytesPerImage:"); +_MTL_PRIVATE_DEF_SEL(replaceRegion_mipmapLevel_withBytes_bytesPerRow_, + "replaceRegion:mipmapLevel:withBytes:bytesPerRow:"); +_MTL_PRIVATE_DEF_SEL(replaceSliceOrigin_sliceDimensions_withBytes_strides_, + "replaceSliceOrigin:sliceDimensions:withBytes:strides:"); +_MTL_PRIVATE_DEF_SEL(requestResidency, + "requestResidency"); +_MTL_PRIVATE_DEF_SEL(required, + "required"); +_MTL_PRIVATE_DEF_SEL(requiredThreadsPerMeshThreadgroup, + "requiredThreadsPerMeshThreadgroup"); +_MTL_PRIVATE_DEF_SEL(requiredThreadsPerObjectThreadgroup, + "requiredThreadsPerObjectThreadgroup"); +_MTL_PRIVATE_DEF_SEL(requiredThreadsPerThreadgroup, + "requiredThreadsPerThreadgroup"); +_MTL_PRIVATE_DEF_SEL(requiredThreadsPerTileThreadgroup, + "requiredThreadsPerTileThreadgroup"); +_MTL_PRIVATE_DEF_SEL(reset, + "reset"); +_MTL_PRIVATE_DEF_SEL(resetCommandsInBuffer_withRange_, + "resetCommandsInBuffer:withRange:"); +_MTL_PRIVATE_DEF_SEL(resetTextureAccessCounters_region_mipLevel_slice_, + "resetTextureAccessCounters:region:mipLevel:slice:"); +_MTL_PRIVATE_DEF_SEL(resetWithRange_, + "resetWithRange:"); +_MTL_PRIVATE_DEF_SEL(resolveCounterHeap_withRange_intoBuffer_waitFence_updateFence_, + "resolveCounterHeap:withRange:intoBuffer:waitFence:updateFence:"); +_MTL_PRIVATE_DEF_SEL(resolveCounterRange_, + "resolveCounterRange:"); +_MTL_PRIVATE_DEF_SEL(resolveCounters_inRange_destinationBuffer_destinationOffset_, + "resolveCounters:inRange:destinationBuffer:destinationOffset:"); +_MTL_PRIVATE_DEF_SEL(resolveDepthPlane, + "resolveDepthPlane"); +_MTL_PRIVATE_DEF_SEL(resolveLevel, + "resolveLevel"); +_MTL_PRIVATE_DEF_SEL(resolveSlice, + "resolveSlice"); +_MTL_PRIVATE_DEF_SEL(resolveTexture, + "resolveTexture"); +_MTL_PRIVATE_DEF_SEL(resourceOptions, + "resourceOptions"); +_MTL_PRIVATE_DEF_SEL(resourceStateCommandEncoder, + "resourceStateCommandEncoder"); +_MTL_PRIVATE_DEF_SEL(resourceStateCommandEncoderWithDescriptor_, + "resourceStateCommandEncoderWithDescriptor:"); +_MTL_PRIVATE_DEF_SEL(resourceStatePassDescriptor, + "resourceStatePassDescriptor"); +_MTL_PRIVATE_DEF_SEL(resourceViewCount, + "resourceViewCount"); +_MTL_PRIVATE_DEF_SEL(retainedReferences, + "retainedReferences"); +_MTL_PRIVATE_DEF_SEL(rgbBlendOperation, + "rgbBlendOperation"); +_MTL_PRIVATE_DEF_SEL(rootResource, + "rootResource"); +_MTL_PRIVATE_DEF_SEL(sAddressMode, + "sAddressMode"); +_MTL_PRIVATE_DEF_SEL(sampleBuffer, + "sampleBuffer"); +_MTL_PRIVATE_DEF_SEL(sampleBufferAttachments, + "sampleBufferAttachments"); +_MTL_PRIVATE_DEF_SEL(sampleCount, + "sampleCount"); +_MTL_PRIVATE_DEF_SEL(sampleCountersInBuffer_atSampleIndex_withBarrier_, + "sampleCountersInBuffer:atSampleIndex:withBarrier:"); +_MTL_PRIVATE_DEF_SEL(sampleTimestamps_gpuTimestamp_, + "sampleTimestamps:gpuTimestamp:"); +_MTL_PRIVATE_DEF_SEL(scratchBufferAllocator, + "scratchBufferAllocator"); +_MTL_PRIVATE_DEF_SEL(screenSize, + "screenSize"); +_MTL_PRIVATE_DEF_SEL(segmentControlPointCount, + "segmentControlPointCount"); +_MTL_PRIVATE_DEF_SEL(segmentCount, + "segmentCount"); +_MTL_PRIVATE_DEF_SEL(serializeAsArchiveAndFlushToURL_error_, + "serializeAsArchiveAndFlushToURL:error:"); +_MTL_PRIVATE_DEF_SEL(serializeAsPipelinesScriptWithError_, + "serializeAsPipelinesScriptWithError:"); +_MTL_PRIVATE_DEF_SEL(serializeToURL_error_, + "serializeToURL:error:"); +_MTL_PRIVATE_DEF_SEL(setAccelerationStructure_atBufferIndex_, + "setAccelerationStructure:atBufferIndex:"); +_MTL_PRIVATE_DEF_SEL(setAccelerationStructure_atIndex_, + "setAccelerationStructure:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setAccess_, + "setAccess:"); +_MTL_PRIVATE_DEF_SEL(setAddress_atIndex_, + "setAddress:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setAddress_attributeStride_atIndex_, + "setAddress:attributeStride:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setAllowDuplicateIntersectionFunctionInvocation_, + "setAllowDuplicateIntersectionFunctionInvocation:"); +_MTL_PRIVATE_DEF_SEL(setAllowGPUOptimizedContents_, + "setAllowGPUOptimizedContents:"); +_MTL_PRIVATE_DEF_SEL(setAllowReferencingUndefinedSymbols_, + "setAllowReferencingUndefinedSymbols:"); +_MTL_PRIVATE_DEF_SEL(setAlphaBlendOperation_, + "setAlphaBlendOperation:"); +_MTL_PRIVATE_DEF_SEL(setAlphaToCoverageEnabled_, + "setAlphaToCoverageEnabled:"); +_MTL_PRIVATE_DEF_SEL(setAlphaToCoverageState_, + "setAlphaToCoverageState:"); +_MTL_PRIVATE_DEF_SEL(setAlphaToOneEnabled_, + "setAlphaToOneEnabled:"); +_MTL_PRIVATE_DEF_SEL(setAlphaToOneState_, + "setAlphaToOneState:"); +_MTL_PRIVATE_DEF_SEL(setArgumentBuffer_offset_, + "setArgumentBuffer:offset:"); +_MTL_PRIVATE_DEF_SEL(setArgumentBuffer_startOffset_arrayElement_, + "setArgumentBuffer:startOffset:arrayElement:"); +_MTL_PRIVATE_DEF_SEL(setArgumentIndex_, + "setArgumentIndex:"); +_MTL_PRIVATE_DEF_SEL(setArgumentTable_, + "setArgumentTable:"); +_MTL_PRIVATE_DEF_SEL(setArgumentTable_atStages_, + "setArgumentTable:atStages:"); +_MTL_PRIVATE_DEF_SEL(setArguments_, + "setArguments:"); +_MTL_PRIVATE_DEF_SEL(setArrayLength_, + "setArrayLength:"); +_MTL_PRIVATE_DEF_SEL(setAttributes_, + "setAttributes:"); +_MTL_PRIVATE_DEF_SEL(setBackFaceStencil_, + "setBackFaceStencil:"); +_MTL_PRIVATE_DEF_SEL(setBarrier, + "setBarrier"); +_MTL_PRIVATE_DEF_SEL(setBinaryArchives_, + "setBinaryArchives:"); +_MTL_PRIVATE_DEF_SEL(setBinaryFunctions_, + "setBinaryFunctions:"); +_MTL_PRIVATE_DEF_SEL(setBinaryLinkedFunctions_, + "setBinaryLinkedFunctions:"); +_MTL_PRIVATE_DEF_SEL(setBlendColorRed_green_blue_alpha_, + "setBlendColorRed:green:blue:alpha:"); +_MTL_PRIVATE_DEF_SEL(setBlendingEnabled_, + "setBlendingEnabled:"); +_MTL_PRIVATE_DEF_SEL(setBlendingState_, + "setBlendingState:"); +_MTL_PRIVATE_DEF_SEL(setBorderColor_, + "setBorderColor:"); +_MTL_PRIVATE_DEF_SEL(setBoundingBoxBuffer_, + "setBoundingBoxBuffer:"); +_MTL_PRIVATE_DEF_SEL(setBoundingBoxBufferOffset_, + "setBoundingBoxBufferOffset:"); +_MTL_PRIVATE_DEF_SEL(setBoundingBoxBuffers_, + "setBoundingBoxBuffers:"); +_MTL_PRIVATE_DEF_SEL(setBoundingBoxCount_, + "setBoundingBoxCount:"); +_MTL_PRIVATE_DEF_SEL(setBoundingBoxStride_, + "setBoundingBoxStride:"); +_MTL_PRIVATE_DEF_SEL(setBuffer_, + "setBuffer:"); +_MTL_PRIVATE_DEF_SEL(setBuffer_offset_atIndex_, + "setBuffer:offset:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setBuffer_offset_attributeStride_atIndex_, + "setBuffer:offset:attributeStride:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setBufferIndex_, + "setBufferIndex:"); +_MTL_PRIVATE_DEF_SEL(setBufferOffset_atIndex_, + "setBufferOffset:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setBufferOffset_attributeStride_atIndex_, + "setBufferOffset:attributeStride:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setBufferSize_, + "setBufferSize:"); +_MTL_PRIVATE_DEF_SEL(setBuffers_offsets_attributeStrides_withRange_, + "setBuffers:offsets:attributeStrides:withRange:"); +_MTL_PRIVATE_DEF_SEL(setBuffers_offsets_withRange_, + "setBuffers:offsets:withRange:"); +_MTL_PRIVATE_DEF_SEL(setBytes_length_atIndex_, + "setBytes:length:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setBytes_length_attributeStride_atIndex_, + "setBytes:length:attributeStride:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setCaptureObject_, + "setCaptureObject:"); +_MTL_PRIVATE_DEF_SEL(setClearColor_, + "setClearColor:"); +_MTL_PRIVATE_DEF_SEL(setClearDepth_, + "setClearDepth:"); +_MTL_PRIVATE_DEF_SEL(setClearStencil_, + "setClearStencil:"); +_MTL_PRIVATE_DEF_SEL(setColorAttachmentMap_, + "setColorAttachmentMap:"); +_MTL_PRIVATE_DEF_SEL(setColorAttachmentMappingState_, + "setColorAttachmentMappingState:"); +_MTL_PRIVATE_DEF_SEL(setColorStoreAction_atIndex_, + "setColorStoreAction:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setColorStoreActionOptions_atIndex_, + "setColorStoreActionOptions:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setCommandTypes_, + "setCommandTypes:"); +_MTL_PRIVATE_DEF_SEL(setCompareFunction_, + "setCompareFunction:"); +_MTL_PRIVATE_DEF_SEL(setCompileSymbolVisibility_, + "setCompileSymbolVisibility:"); +_MTL_PRIVATE_DEF_SEL(setCompressionType_, + "setCompressionType:"); +_MTL_PRIVATE_DEF_SEL(setComputeFunction_, + "setComputeFunction:"); +_MTL_PRIVATE_DEF_SEL(setComputeFunctionDescriptor_, + "setComputeFunctionDescriptor:"); +_MTL_PRIVATE_DEF_SEL(setComputePipelineState_, + "setComputePipelineState:"); +_MTL_PRIVATE_DEF_SEL(setComputePipelineState_atIndex_, + "setComputePipelineState:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setComputePipelineStates_withRange_, + "setComputePipelineStates:withRange:"); +_MTL_PRIVATE_DEF_SEL(setConfiguration_, + "setConfiguration:"); +_MTL_PRIVATE_DEF_SEL(setConstantBlockAlignment_, + "setConstantBlockAlignment:"); +_MTL_PRIVATE_DEF_SEL(setConstantValue_type_atIndex_, + "setConstantValue:type:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setConstantValue_type_withName_, + "setConstantValue:type:withName:"); +_MTL_PRIVATE_DEF_SEL(setConstantValues_, + "setConstantValues:"); +_MTL_PRIVATE_DEF_SEL(setConstantValues_type_withRange_, + "setConstantValues:type:withRange:"); +_MTL_PRIVATE_DEF_SEL(setControlDependencies_, + "setControlDependencies:"); +_MTL_PRIVATE_DEF_SEL(setControlPointBuffer_, + "setControlPointBuffer:"); +_MTL_PRIVATE_DEF_SEL(setControlPointBufferOffset_, + "setControlPointBufferOffset:"); +_MTL_PRIVATE_DEF_SEL(setControlPointBuffers_, + "setControlPointBuffers:"); +_MTL_PRIVATE_DEF_SEL(setControlPointCount_, + "setControlPointCount:"); +_MTL_PRIVATE_DEF_SEL(setControlPointFormat_, + "setControlPointFormat:"); +_MTL_PRIVATE_DEF_SEL(setControlPointStride_, + "setControlPointStride:"); +_MTL_PRIVATE_DEF_SEL(setCount_, + "setCount:"); +_MTL_PRIVATE_DEF_SEL(setCounterSet_, + "setCounterSet:"); +_MTL_PRIVATE_DEF_SEL(setCpuCacheMode_, + "setCpuCacheMode:"); +_MTL_PRIVATE_DEF_SEL(setCullMode_, + "setCullMode:"); +_MTL_PRIVATE_DEF_SEL(setCurveBasis_, + "setCurveBasis:"); +_MTL_PRIVATE_DEF_SEL(setCurveEndCaps_, + "setCurveEndCaps:"); +_MTL_PRIVATE_DEF_SEL(setCurveType_, + "setCurveType:"); +_MTL_PRIVATE_DEF_SEL(setDataType_, + "setDataType:"); +_MTL_PRIVATE_DEF_SEL(setDefaultCaptureScope_, + "setDefaultCaptureScope:"); +_MTL_PRIVATE_DEF_SEL(setDefaultRasterSampleCount_, + "setDefaultRasterSampleCount:"); +_MTL_PRIVATE_DEF_SEL(setDepth_, + "setDepth:"); +_MTL_PRIVATE_DEF_SEL(setDepthAttachment_, + "setDepthAttachment:"); +_MTL_PRIVATE_DEF_SEL(setDepthAttachmentPixelFormat_, + "setDepthAttachmentPixelFormat:"); +_MTL_PRIVATE_DEF_SEL(setDepthBias_slopeScale_clamp_, + "setDepthBias:slopeScale:clamp:"); +_MTL_PRIVATE_DEF_SEL(setDepthClipMode_, + "setDepthClipMode:"); +_MTL_PRIVATE_DEF_SEL(setDepthCompareFunction_, + "setDepthCompareFunction:"); +_MTL_PRIVATE_DEF_SEL(setDepthFailureOperation_, + "setDepthFailureOperation:"); +_MTL_PRIVATE_DEF_SEL(setDepthPlane_, + "setDepthPlane:"); +_MTL_PRIVATE_DEF_SEL(setDepthResolveFilter_, + "setDepthResolveFilter:"); +_MTL_PRIVATE_DEF_SEL(setDepthStencilPassOperation_, + "setDepthStencilPassOperation:"); +_MTL_PRIVATE_DEF_SEL(setDepthStencilState_, + "setDepthStencilState:"); +_MTL_PRIVATE_DEF_SEL(setDepthStencilState_atIndex_, + "setDepthStencilState:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setDepthStencilStates_withRange_, + "setDepthStencilStates:withRange:"); +_MTL_PRIVATE_DEF_SEL(setDepthStoreAction_, + "setDepthStoreAction:"); +_MTL_PRIVATE_DEF_SEL(setDepthStoreActionOptions_, + "setDepthStoreActionOptions:"); +_MTL_PRIVATE_DEF_SEL(setDepthTestMinBound_maxBound_, + "setDepthTestMinBound:maxBound:"); +_MTL_PRIVATE_DEF_SEL(setDepthWriteEnabled_, + "setDepthWriteEnabled:"); +_MTL_PRIVATE_DEF_SEL(setDestination_, + "setDestination:"); +_MTL_PRIVATE_DEF_SEL(setDestinationAlphaBlendFactor_, + "setDestinationAlphaBlendFactor:"); +_MTL_PRIVATE_DEF_SEL(setDestinationRGBBlendFactor_, + "setDestinationRGBBlendFactor:"); +_MTL_PRIVATE_DEF_SEL(setDimensions_, + "setDimensions:"); +_MTL_PRIVATE_DEF_SEL(setDispatchType_, + "setDispatchType:"); +_MTL_PRIVATE_DEF_SEL(setEnableLogging_, + "setEnableLogging:"); +_MTL_PRIVATE_DEF_SEL(setEndOfEncoderSampleIndex_, + "setEndOfEncoderSampleIndex:"); +_MTL_PRIVATE_DEF_SEL(setEndOfFragmentSampleIndex_, + "setEndOfFragmentSampleIndex:"); +_MTL_PRIVATE_DEF_SEL(setEndOfVertexSampleIndex_, + "setEndOfVertexSampleIndex:"); +_MTL_PRIVATE_DEF_SEL(setErrorOptions_, + "setErrorOptions:"); +_MTL_PRIVATE_DEF_SEL(setFastMathEnabled_, + "setFastMathEnabled:"); +_MTL_PRIVATE_DEF_SEL(setFeedbackQueue_, + "setFeedbackQueue:"); +_MTL_PRIVATE_DEF_SEL(setFormat_, + "setFormat:"); +_MTL_PRIVATE_DEF_SEL(setFragmentAccelerationStructure_atBufferIndex_, + "setFragmentAccelerationStructure:atBufferIndex:"); +_MTL_PRIVATE_DEF_SEL(setFragmentAdditionalBinaryFunctions_, + "setFragmentAdditionalBinaryFunctions:"); +_MTL_PRIVATE_DEF_SEL(setFragmentBuffer_offset_atIndex_, + "setFragmentBuffer:offset:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setFragmentBufferOffset_atIndex_, + "setFragmentBufferOffset:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setFragmentBuffers_offsets_withRange_, + "setFragmentBuffers:offsets:withRange:"); +_MTL_PRIVATE_DEF_SEL(setFragmentBytes_length_atIndex_, + "setFragmentBytes:length:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setFragmentFunction_, + "setFragmentFunction:"); +_MTL_PRIVATE_DEF_SEL(setFragmentFunctionDescriptor_, + "setFragmentFunctionDescriptor:"); +_MTL_PRIVATE_DEF_SEL(setFragmentIntersectionFunctionTable_atBufferIndex_, + "setFragmentIntersectionFunctionTable:atBufferIndex:"); +_MTL_PRIVATE_DEF_SEL(setFragmentIntersectionFunctionTables_withBufferRange_, + "setFragmentIntersectionFunctionTables:withBufferRange:"); +_MTL_PRIVATE_DEF_SEL(setFragmentLinkedFunctions_, + "setFragmentLinkedFunctions:"); +_MTL_PRIVATE_DEF_SEL(setFragmentPreloadedLibraries_, + "setFragmentPreloadedLibraries:"); +_MTL_PRIVATE_DEF_SEL(setFragmentSamplerState_atIndex_, + "setFragmentSamplerState:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setFragmentSamplerState_lodMinClamp_lodMaxClamp_atIndex_, + "setFragmentSamplerState:lodMinClamp:lodMaxClamp:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setFragmentSamplerStates_lodMinClamps_lodMaxClamps_withRange_, + "setFragmentSamplerStates:lodMinClamps:lodMaxClamps:withRange:"); +_MTL_PRIVATE_DEF_SEL(setFragmentSamplerStates_withRange_, + "setFragmentSamplerStates:withRange:"); +_MTL_PRIVATE_DEF_SEL(setFragmentStaticLinkingDescriptor_, + "setFragmentStaticLinkingDescriptor:"); +_MTL_PRIVATE_DEF_SEL(setFragmentTexture_atIndex_, + "setFragmentTexture:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setFragmentTextures_withRange_, + "setFragmentTextures:withRange:"); +_MTL_PRIVATE_DEF_SEL(setFragmentVisibleFunctionTable_atBufferIndex_, + "setFragmentVisibleFunctionTable:atBufferIndex:"); +_MTL_PRIVATE_DEF_SEL(setFragmentVisibleFunctionTables_withBufferRange_, + "setFragmentVisibleFunctionTables:withBufferRange:"); +_MTL_PRIVATE_DEF_SEL(setFrontFaceStencil_, + "setFrontFaceStencil:"); +_MTL_PRIVATE_DEF_SEL(setFrontFacingWinding_, + "setFrontFacingWinding:"); +_MTL_PRIVATE_DEF_SEL(setFunction_atIndex_, + "setFunction:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setFunctionCount_, + "setFunctionCount:"); +_MTL_PRIVATE_DEF_SEL(setFunctionDescriptor_, + "setFunctionDescriptor:"); +_MTL_PRIVATE_DEF_SEL(setFunctionDescriptors_, + "setFunctionDescriptors:"); +_MTL_PRIVATE_DEF_SEL(setFunctionGraph_, + "setFunctionGraph:"); +_MTL_PRIVATE_DEF_SEL(setFunctionGraphs_, + "setFunctionGraphs:"); +_MTL_PRIVATE_DEF_SEL(setFunctionName_, + "setFunctionName:"); +_MTL_PRIVATE_DEF_SEL(setFunctions_, + "setFunctions:"); +_MTL_PRIVATE_DEF_SEL(setFunctions_withRange_, + "setFunctions:withRange:"); +_MTL_PRIVATE_DEF_SEL(setGeometryDescriptors_, + "setGeometryDescriptors:"); +_MTL_PRIVATE_DEF_SEL(setGroups_, + "setGroups:"); +_MTL_PRIVATE_DEF_SEL(setHazardTrackingMode_, + "setHazardTrackingMode:"); +_MTL_PRIVATE_DEF_SEL(setHeight_, + "setHeight:"); +_MTL_PRIVATE_DEF_SEL(setImageblockSampleLength_, + "setImageblockSampleLength:"); +_MTL_PRIVATE_DEF_SEL(setImageblockWidth_height_, + "setImageblockWidth:height:"); +_MTL_PRIVATE_DEF_SEL(setIndex_, + "setIndex:"); +_MTL_PRIVATE_DEF_SEL(setIndexBuffer_, + "setIndexBuffer:"); +_MTL_PRIVATE_DEF_SEL(setIndexBufferIndex_, + "setIndexBufferIndex:"); +_MTL_PRIVATE_DEF_SEL(setIndexBufferOffset_, + "setIndexBufferOffset:"); +_MTL_PRIVATE_DEF_SEL(setIndexType_, + "setIndexType:"); +_MTL_PRIVATE_DEF_SEL(setIndirectCommandBuffer_atIndex_, + "setIndirectCommandBuffer:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setIndirectCommandBuffers_withRange_, + "setIndirectCommandBuffers:withRange:"); +_MTL_PRIVATE_DEF_SEL(setInheritBuffers_, + "setInheritBuffers:"); +_MTL_PRIVATE_DEF_SEL(setInheritCullMode_, + "setInheritCullMode:"); +_MTL_PRIVATE_DEF_SEL(setInheritDepthBias_, + "setInheritDepthBias:"); +_MTL_PRIVATE_DEF_SEL(setInheritDepthClipMode_, + "setInheritDepthClipMode:"); +_MTL_PRIVATE_DEF_SEL(setInheritDepthStencilState_, + "setInheritDepthStencilState:"); +_MTL_PRIVATE_DEF_SEL(setInheritFrontFacingWinding_, + "setInheritFrontFacingWinding:"); +_MTL_PRIVATE_DEF_SEL(setInheritPipelineState_, + "setInheritPipelineState:"); +_MTL_PRIVATE_DEF_SEL(setInheritTriangleFillMode_, + "setInheritTriangleFillMode:"); +_MTL_PRIVATE_DEF_SEL(setInitialCapacity_, + "setInitialCapacity:"); +_MTL_PRIVATE_DEF_SEL(setInitializeBindings_, + "setInitializeBindings:"); +_MTL_PRIVATE_DEF_SEL(setInputDimensions_atBufferIndex_, + "setInputDimensions:atBufferIndex:"); +_MTL_PRIVATE_DEF_SEL(setInputDimensions_withRange_, + "setInputDimensions:withRange:"); +_MTL_PRIVATE_DEF_SEL(setInputPrimitiveTopology_, + "setInputPrimitiveTopology:"); +_MTL_PRIVATE_DEF_SEL(setInsertLibraries_, + "setInsertLibraries:"); +_MTL_PRIVATE_DEF_SEL(setInstallName_, + "setInstallName:"); +_MTL_PRIVATE_DEF_SEL(setInstanceCount_, + "setInstanceCount:"); +_MTL_PRIVATE_DEF_SEL(setInstanceCountBuffer_, + "setInstanceCountBuffer:"); +_MTL_PRIVATE_DEF_SEL(setInstanceCountBufferOffset_, + "setInstanceCountBufferOffset:"); +_MTL_PRIVATE_DEF_SEL(setInstanceDescriptorBuffer_, + "setInstanceDescriptorBuffer:"); +_MTL_PRIVATE_DEF_SEL(setInstanceDescriptorBufferOffset_, + "setInstanceDescriptorBufferOffset:"); +_MTL_PRIVATE_DEF_SEL(setInstanceDescriptorStride_, + "setInstanceDescriptorStride:"); +_MTL_PRIVATE_DEF_SEL(setInstanceDescriptorType_, + "setInstanceDescriptorType:"); +_MTL_PRIVATE_DEF_SEL(setInstanceTransformationMatrixLayout_, + "setInstanceTransformationMatrixLayout:"); +_MTL_PRIVATE_DEF_SEL(setInstancedAccelerationStructures_, + "setInstancedAccelerationStructures:"); +_MTL_PRIVATE_DEF_SEL(setIntersectionFunctionTable_atBufferIndex_, + "setIntersectionFunctionTable:atBufferIndex:"); +_MTL_PRIVATE_DEF_SEL(setIntersectionFunctionTable_atIndex_, + "setIntersectionFunctionTable:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setIntersectionFunctionTableOffset_, + "setIntersectionFunctionTableOffset:"); +_MTL_PRIVATE_DEF_SEL(setIntersectionFunctionTables_withBufferRange_, + "setIntersectionFunctionTables:withBufferRange:"); +_MTL_PRIVATE_DEF_SEL(setIntersectionFunctionTables_withRange_, + "setIntersectionFunctionTables:withRange:"); +_MTL_PRIVATE_DEF_SEL(setKernelBuffer_offset_atIndex_, + "setKernelBuffer:offset:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setKernelBuffer_offset_attributeStride_atIndex_, + "setKernelBuffer:offset:attributeStride:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setLabel_, + "setLabel:"); +_MTL_PRIVATE_DEF_SEL(setLanguageVersion_, + "setLanguageVersion:"); +_MTL_PRIVATE_DEF_SEL(setLayer_atIndex_, + "setLayer:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setLevel_, + "setLevel:"); +_MTL_PRIVATE_DEF_SEL(setLevelRange_, + "setLevelRange:"); +_MTL_PRIVATE_DEF_SEL(setLibraries_, + "setLibraries:"); +_MTL_PRIVATE_DEF_SEL(setLibrary_, + "setLibrary:"); +_MTL_PRIVATE_DEF_SEL(setLibraryType_, + "setLibraryType:"); +_MTL_PRIVATE_DEF_SEL(setLinkedFunctions_, + "setLinkedFunctions:"); +_MTL_PRIVATE_DEF_SEL(setLoadAction_, + "setLoadAction:"); +_MTL_PRIVATE_DEF_SEL(setLodAverage_, + "setLodAverage:"); +_MTL_PRIVATE_DEF_SEL(setLodBias_, + "setLodBias:"); +_MTL_PRIVATE_DEF_SEL(setLodMaxClamp_, + "setLodMaxClamp:"); +_MTL_PRIVATE_DEF_SEL(setLodMinClamp_, + "setLodMinClamp:"); +_MTL_PRIVATE_DEF_SEL(setLogState_, + "setLogState:"); +_MTL_PRIVATE_DEF_SEL(setLookupArchives_, + "setLookupArchives:"); +_MTL_PRIVATE_DEF_SEL(setMachineLearningFunctionDescriptor_, + "setMachineLearningFunctionDescriptor:"); +_MTL_PRIVATE_DEF_SEL(setMagFilter_, + "setMagFilter:"); +_MTL_PRIVATE_DEF_SEL(setMathFloatingPointFunctions_, + "setMathFloatingPointFunctions:"); +_MTL_PRIVATE_DEF_SEL(setMathMode_, + "setMathMode:"); +_MTL_PRIVATE_DEF_SEL(setMaxAnisotropy_, + "setMaxAnisotropy:"); +_MTL_PRIVATE_DEF_SEL(setMaxBufferBindCount_, + "setMaxBufferBindCount:"); +_MTL_PRIVATE_DEF_SEL(setMaxCallStackDepth_, + "setMaxCallStackDepth:"); +_MTL_PRIVATE_DEF_SEL(setMaxCommandBufferCount_, + "setMaxCommandBufferCount:"); +_MTL_PRIVATE_DEF_SEL(setMaxCommandsInFlight_, + "setMaxCommandsInFlight:"); +_MTL_PRIVATE_DEF_SEL(setMaxCompatiblePlacementSparsePageSize_, + "setMaxCompatiblePlacementSparsePageSize:"); +_MTL_PRIVATE_DEF_SEL(setMaxFragmentBufferBindCount_, + "setMaxFragmentBufferBindCount:"); +_MTL_PRIVATE_DEF_SEL(setMaxFragmentCallStackDepth_, + "setMaxFragmentCallStackDepth:"); +_MTL_PRIVATE_DEF_SEL(setMaxInstanceCount_, + "setMaxInstanceCount:"); +_MTL_PRIVATE_DEF_SEL(setMaxKernelBufferBindCount_, + "setMaxKernelBufferBindCount:"); +_MTL_PRIVATE_DEF_SEL(setMaxKernelThreadgroupMemoryBindCount_, + "setMaxKernelThreadgroupMemoryBindCount:"); +_MTL_PRIVATE_DEF_SEL(setMaxMeshBufferBindCount_, + "setMaxMeshBufferBindCount:"); +_MTL_PRIVATE_DEF_SEL(setMaxMotionTransformCount_, + "setMaxMotionTransformCount:"); +_MTL_PRIVATE_DEF_SEL(setMaxObjectBufferBindCount_, + "setMaxObjectBufferBindCount:"); +_MTL_PRIVATE_DEF_SEL(setMaxObjectThreadgroupMemoryBindCount_, + "setMaxObjectThreadgroupMemoryBindCount:"); +_MTL_PRIVATE_DEF_SEL(setMaxSamplerStateBindCount_, + "setMaxSamplerStateBindCount:"); +_MTL_PRIVATE_DEF_SEL(setMaxTessellationFactor_, + "setMaxTessellationFactor:"); +_MTL_PRIVATE_DEF_SEL(setMaxTextureBindCount_, + "setMaxTextureBindCount:"); +_MTL_PRIVATE_DEF_SEL(setMaxTotalThreadgroupsPerMeshGrid_, + "setMaxTotalThreadgroupsPerMeshGrid:"); +_MTL_PRIVATE_DEF_SEL(setMaxTotalThreadsPerMeshThreadgroup_, + "setMaxTotalThreadsPerMeshThreadgroup:"); +_MTL_PRIVATE_DEF_SEL(setMaxTotalThreadsPerObjectThreadgroup_, + "setMaxTotalThreadsPerObjectThreadgroup:"); +_MTL_PRIVATE_DEF_SEL(setMaxTotalThreadsPerThreadgroup_, + "setMaxTotalThreadsPerThreadgroup:"); +_MTL_PRIVATE_DEF_SEL(setMaxVertexAmplificationCount_, + "setMaxVertexAmplificationCount:"); +_MTL_PRIVATE_DEF_SEL(setMaxVertexBufferBindCount_, + "setMaxVertexBufferBindCount:"); +_MTL_PRIVATE_DEF_SEL(setMaxVertexCallStackDepth_, + "setMaxVertexCallStackDepth:"); +_MTL_PRIVATE_DEF_SEL(setMeshAdditionalBinaryFunctions_, + "setMeshAdditionalBinaryFunctions:"); +_MTL_PRIVATE_DEF_SEL(setMeshBuffer_offset_atIndex_, + "setMeshBuffer:offset:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setMeshBufferOffset_atIndex_, + "setMeshBufferOffset:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setMeshBuffers_offsets_withRange_, + "setMeshBuffers:offsets:withRange:"); +_MTL_PRIVATE_DEF_SEL(setMeshBytes_length_atIndex_, + "setMeshBytes:length:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setMeshFunction_, + "setMeshFunction:"); +_MTL_PRIVATE_DEF_SEL(setMeshFunctionDescriptor_, + "setMeshFunctionDescriptor:"); +_MTL_PRIVATE_DEF_SEL(setMeshLinkedFunctions_, + "setMeshLinkedFunctions:"); +_MTL_PRIVATE_DEF_SEL(setMeshSamplerState_atIndex_, + "setMeshSamplerState:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setMeshSamplerState_lodMinClamp_lodMaxClamp_atIndex_, + "setMeshSamplerState:lodMinClamp:lodMaxClamp:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setMeshSamplerStates_lodMinClamps_lodMaxClamps_withRange_, + "setMeshSamplerStates:lodMinClamps:lodMaxClamps:withRange:"); +_MTL_PRIVATE_DEF_SEL(setMeshSamplerStates_withRange_, + "setMeshSamplerStates:withRange:"); +_MTL_PRIVATE_DEF_SEL(setMeshStaticLinkingDescriptor_, + "setMeshStaticLinkingDescriptor:"); +_MTL_PRIVATE_DEF_SEL(setMeshTexture_atIndex_, + "setMeshTexture:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setMeshTextures_withRange_, + "setMeshTextures:withRange:"); +_MTL_PRIVATE_DEF_SEL(setMeshThreadgroupSizeIsMultipleOfThreadExecutionWidth_, + "setMeshThreadgroupSizeIsMultipleOfThreadExecutionWidth:"); +_MTL_PRIVATE_DEF_SEL(setMinFilter_, + "setMinFilter:"); +_MTL_PRIVATE_DEF_SEL(setMipFilter_, + "setMipFilter:"); +_MTL_PRIVATE_DEF_SEL(setMipmapLevelCount_, + "setMipmapLevelCount:"); +_MTL_PRIVATE_DEF_SEL(setMotionEndBorderMode_, + "setMotionEndBorderMode:"); +_MTL_PRIVATE_DEF_SEL(setMotionEndTime_, + "setMotionEndTime:"); +_MTL_PRIVATE_DEF_SEL(setMotionKeyframeCount_, + "setMotionKeyframeCount:"); +_MTL_PRIVATE_DEF_SEL(setMotionStartBorderMode_, + "setMotionStartBorderMode:"); +_MTL_PRIVATE_DEF_SEL(setMotionStartTime_, + "setMotionStartTime:"); +_MTL_PRIVATE_DEF_SEL(setMotionTransformBuffer_, + "setMotionTransformBuffer:"); +_MTL_PRIVATE_DEF_SEL(setMotionTransformBufferOffset_, + "setMotionTransformBufferOffset:"); +_MTL_PRIVATE_DEF_SEL(setMotionTransformCount_, + "setMotionTransformCount:"); +_MTL_PRIVATE_DEF_SEL(setMotionTransformCountBuffer_, + "setMotionTransformCountBuffer:"); +_MTL_PRIVATE_DEF_SEL(setMotionTransformCountBufferOffset_, + "setMotionTransformCountBufferOffset:"); +_MTL_PRIVATE_DEF_SEL(setMotionTransformStride_, + "setMotionTransformStride:"); +_MTL_PRIVATE_DEF_SEL(setMotionTransformType_, + "setMotionTransformType:"); +_MTL_PRIVATE_DEF_SEL(setMutability_, + "setMutability:"); +_MTL_PRIVATE_DEF_SEL(setName_, + "setName:"); +_MTL_PRIVATE_DEF_SEL(setNodes_, + "setNodes:"); +_MTL_PRIVATE_DEF_SEL(setNormalizedCoordinates_, + "setNormalizedCoordinates:"); +_MTL_PRIVATE_DEF_SEL(setObject_atIndexedSubscript_, + "setObject:atIndexedSubscript:"); +_MTL_PRIVATE_DEF_SEL(setObjectAdditionalBinaryFunctions_, + "setObjectAdditionalBinaryFunctions:"); +_MTL_PRIVATE_DEF_SEL(setObjectBuffer_offset_atIndex_, + "setObjectBuffer:offset:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setObjectBufferOffset_atIndex_, + "setObjectBufferOffset:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setObjectBuffers_offsets_withRange_, + "setObjectBuffers:offsets:withRange:"); +_MTL_PRIVATE_DEF_SEL(setObjectBytes_length_atIndex_, + "setObjectBytes:length:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setObjectFunction_, + "setObjectFunction:"); +_MTL_PRIVATE_DEF_SEL(setObjectFunctionDescriptor_, + "setObjectFunctionDescriptor:"); +_MTL_PRIVATE_DEF_SEL(setObjectLinkedFunctions_, + "setObjectLinkedFunctions:"); +_MTL_PRIVATE_DEF_SEL(setObjectSamplerState_atIndex_, + "setObjectSamplerState:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setObjectSamplerState_lodMinClamp_lodMaxClamp_atIndex_, + "setObjectSamplerState:lodMinClamp:lodMaxClamp:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setObjectSamplerStates_lodMinClamps_lodMaxClamps_withRange_, + "setObjectSamplerStates:lodMinClamps:lodMaxClamps:withRange:"); +_MTL_PRIVATE_DEF_SEL(setObjectSamplerStates_withRange_, + "setObjectSamplerStates:withRange:"); +_MTL_PRIVATE_DEF_SEL(setObjectStaticLinkingDescriptor_, + "setObjectStaticLinkingDescriptor:"); +_MTL_PRIVATE_DEF_SEL(setObjectTexture_atIndex_, + "setObjectTexture:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setObjectTextures_withRange_, + "setObjectTextures:withRange:"); +_MTL_PRIVATE_DEF_SEL(setObjectThreadgroupMemoryLength_atIndex_, + "setObjectThreadgroupMemoryLength:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setObjectThreadgroupSizeIsMultipleOfThreadExecutionWidth_, + "setObjectThreadgroupSizeIsMultipleOfThreadExecutionWidth:"); +_MTL_PRIVATE_DEF_SEL(setOffset_, + "setOffset:"); +_MTL_PRIVATE_DEF_SEL(setOpaque_, + "setOpaque:"); +_MTL_PRIVATE_DEF_SEL(setOpaqueCurveIntersectionFunctionWithSignature_atIndex_, + "setOpaqueCurveIntersectionFunctionWithSignature:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setOpaqueCurveIntersectionFunctionWithSignature_withRange_, + "setOpaqueCurveIntersectionFunctionWithSignature:withRange:"); +_MTL_PRIVATE_DEF_SEL(setOpaqueTriangleIntersectionFunctionWithSignature_atIndex_, + "setOpaqueTriangleIntersectionFunctionWithSignature:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setOpaqueTriangleIntersectionFunctionWithSignature_withRange_, + "setOpaqueTriangleIntersectionFunctionWithSignature:withRange:"); +_MTL_PRIVATE_DEF_SEL(setOptimizationLevel_, + "setOptimizationLevel:"); +_MTL_PRIVATE_DEF_SEL(setOptions_, + "setOptions:"); +_MTL_PRIVATE_DEF_SEL(setOutputNode_, + "setOutputNode:"); +_MTL_PRIVATE_DEF_SEL(setOutputURL_, + "setOutputURL:"); +_MTL_PRIVATE_DEF_SEL(setOwnerWithIdentity_, + "setOwnerWithIdentity:"); +_MTL_PRIVATE_DEF_SEL(setPayloadMemoryLength_, + "setPayloadMemoryLength:"); +_MTL_PRIVATE_DEF_SEL(setPhysicalIndex_forLogicalIndex_, + "setPhysicalIndex:forLogicalIndex:"); +_MTL_PRIVATE_DEF_SEL(setPipelineDataSetSerializer_, + "setPipelineDataSetSerializer:"); +_MTL_PRIVATE_DEF_SEL(setPipelineState_, + "setPipelineState:"); +_MTL_PRIVATE_DEF_SEL(setPixelFormat_, + "setPixelFormat:"); +_MTL_PRIVATE_DEF_SEL(setPlacementSparsePageSize_, + "setPlacementSparsePageSize:"); +_MTL_PRIVATE_DEF_SEL(setPreloadedLibraries_, + "setPreloadedLibraries:"); +_MTL_PRIVATE_DEF_SEL(setPreprocessorMacros_, + "setPreprocessorMacros:"); +_MTL_PRIVATE_DEF_SEL(setPreserveInvariance_, + "setPreserveInvariance:"); +_MTL_PRIVATE_DEF_SEL(setPrimitiveDataBuffer_, + "setPrimitiveDataBuffer:"); +_MTL_PRIVATE_DEF_SEL(setPrimitiveDataBufferOffset_, + "setPrimitiveDataBufferOffset:"); +_MTL_PRIVATE_DEF_SEL(setPrimitiveDataElementSize_, + "setPrimitiveDataElementSize:"); +_MTL_PRIVATE_DEF_SEL(setPrimitiveDataStride_, + "setPrimitiveDataStride:"); +_MTL_PRIVATE_DEF_SEL(setPriority_, + "setPriority:"); +_MTL_PRIVATE_DEF_SEL(setPrivateFunctionDescriptors_, + "setPrivateFunctionDescriptors:"); +_MTL_PRIVATE_DEF_SEL(setPrivateFunctions_, + "setPrivateFunctions:"); +_MTL_PRIVATE_DEF_SEL(setPurgeableState_, + "setPurgeableState:"); +_MTL_PRIVATE_DEF_SEL(setRAddressMode_, + "setRAddressMode:"); +_MTL_PRIVATE_DEF_SEL(setRadiusBuffer_, + "setRadiusBuffer:"); +_MTL_PRIVATE_DEF_SEL(setRadiusBufferOffset_, + "setRadiusBufferOffset:"); +_MTL_PRIVATE_DEF_SEL(setRadiusBuffers_, + "setRadiusBuffers:"); +_MTL_PRIVATE_DEF_SEL(setRadiusFormat_, + "setRadiusFormat:"); +_MTL_PRIVATE_DEF_SEL(setRadiusStride_, + "setRadiusStride:"); +_MTL_PRIVATE_DEF_SEL(setRasterSampleCount_, + "setRasterSampleCount:"); +_MTL_PRIVATE_DEF_SEL(setRasterizationEnabled_, + "setRasterizationEnabled:"); +_MTL_PRIVATE_DEF_SEL(setRasterizationRateMap_, + "setRasterizationRateMap:"); +_MTL_PRIVATE_DEF_SEL(setReadMask_, + "setReadMask:"); +_MTL_PRIVATE_DEF_SEL(setReductionMode_, + "setReductionMode:"); +_MTL_PRIVATE_DEF_SEL(setRenderPipelineState_, + "setRenderPipelineState:"); +_MTL_PRIVATE_DEF_SEL(setRenderPipelineState_atIndex_, + "setRenderPipelineState:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setRenderPipelineStates_withRange_, + "setRenderPipelineStates:withRange:"); +_MTL_PRIVATE_DEF_SEL(setRenderTargetArrayLength_, + "setRenderTargetArrayLength:"); +_MTL_PRIVATE_DEF_SEL(setRenderTargetHeight_, + "setRenderTargetHeight:"); +_MTL_PRIVATE_DEF_SEL(setRenderTargetWidth_, + "setRenderTargetWidth:"); +_MTL_PRIVATE_DEF_SEL(setRequiredThreadsPerMeshThreadgroup_, + "setRequiredThreadsPerMeshThreadgroup:"); +_MTL_PRIVATE_DEF_SEL(setRequiredThreadsPerObjectThreadgroup_, + "setRequiredThreadsPerObjectThreadgroup:"); +_MTL_PRIVATE_DEF_SEL(setRequiredThreadsPerThreadgroup_, + "setRequiredThreadsPerThreadgroup:"); +_MTL_PRIVATE_DEF_SEL(setResolveDepthPlane_, + "setResolveDepthPlane:"); +_MTL_PRIVATE_DEF_SEL(setResolveLevel_, + "setResolveLevel:"); +_MTL_PRIVATE_DEF_SEL(setResolveSlice_, + "setResolveSlice:"); +_MTL_PRIVATE_DEF_SEL(setResolveTexture_, + "setResolveTexture:"); +_MTL_PRIVATE_DEF_SEL(setResource_atBufferIndex_, + "setResource:atBufferIndex:"); +_MTL_PRIVATE_DEF_SEL(setResourceOptions_, + "setResourceOptions:"); +_MTL_PRIVATE_DEF_SEL(setResourceViewCount_, + "setResourceViewCount:"); +_MTL_PRIVATE_DEF_SEL(setRetainedReferences_, + "setRetainedReferences:"); +_MTL_PRIVATE_DEF_SEL(setRgbBlendOperation_, + "setRgbBlendOperation:"); +_MTL_PRIVATE_DEF_SEL(setSAddressMode_, + "setSAddressMode:"); +_MTL_PRIVATE_DEF_SEL(setSampleBuffer_, + "setSampleBuffer:"); +_MTL_PRIVATE_DEF_SEL(setSampleCount_, + "setSampleCount:"); +_MTL_PRIVATE_DEF_SEL(setSamplePositions_count_, + "setSamplePositions:count:"); +_MTL_PRIVATE_DEF_SEL(setSamplerState_atIndex_, + "setSamplerState:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setSamplerState_lodMinClamp_lodMaxClamp_atIndex_, + "setSamplerState:lodMinClamp:lodMaxClamp:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setSamplerStates_lodMinClamps_lodMaxClamps_withRange_, + "setSamplerStates:lodMinClamps:lodMaxClamps:withRange:"); +_MTL_PRIVATE_DEF_SEL(setSamplerStates_withRange_, + "setSamplerStates:withRange:"); +_MTL_PRIVATE_DEF_SEL(setScissorRect_, + "setScissorRect:"); +_MTL_PRIVATE_DEF_SEL(setScissorRects_count_, + "setScissorRects:count:"); +_MTL_PRIVATE_DEF_SEL(setScratchBufferAllocator_, + "setScratchBufferAllocator:"); +_MTL_PRIVATE_DEF_SEL(setScreenSize_, + "setScreenSize:"); +_MTL_PRIVATE_DEF_SEL(setSegmentControlPointCount_, + "setSegmentControlPointCount:"); +_MTL_PRIVATE_DEF_SEL(setSegmentCount_, + "setSegmentCount:"); +_MTL_PRIVATE_DEF_SEL(setShaderReflection_, + "setShaderReflection:"); +_MTL_PRIVATE_DEF_SEL(setShaderValidation_, + "setShaderValidation:"); +_MTL_PRIVATE_DEF_SEL(setShouldMaximizeConcurrentCompilation_, + "setShouldMaximizeConcurrentCompilation:"); +_MTL_PRIVATE_DEF_SEL(setSignaledValue_, + "setSignaledValue:"); +_MTL_PRIVATE_DEF_SEL(setSize_, + "setSize:"); +_MTL_PRIVATE_DEF_SEL(setSlice_, + "setSlice:"); +_MTL_PRIVATE_DEF_SEL(setSliceRange_, + "setSliceRange:"); +_MTL_PRIVATE_DEF_SEL(setSource_, + "setSource:"); +_MTL_PRIVATE_DEF_SEL(setSourceAlphaBlendFactor_, + "setSourceAlphaBlendFactor:"); +_MTL_PRIVATE_DEF_SEL(setSourceRGBBlendFactor_, + "setSourceRGBBlendFactor:"); +_MTL_PRIVATE_DEF_SEL(setSparsePageSize_, + "setSparsePageSize:"); +_MTL_PRIVATE_DEF_SEL(setSpecializedName_, + "setSpecializedName:"); +_MTL_PRIVATE_DEF_SEL(setStageInRegion_, + "setStageInRegion:"); +_MTL_PRIVATE_DEF_SEL(setStageInRegionWithIndirectBuffer_indirectBufferOffset_, + "setStageInRegionWithIndirectBuffer:indirectBufferOffset:"); +_MTL_PRIVATE_DEF_SEL(setStageInputDescriptor_, + "setStageInputDescriptor:"); +_MTL_PRIVATE_DEF_SEL(setStartOfEncoderSampleIndex_, + "setStartOfEncoderSampleIndex:"); +_MTL_PRIVATE_DEF_SEL(setStartOfFragmentSampleIndex_, + "setStartOfFragmentSampleIndex:"); +_MTL_PRIVATE_DEF_SEL(setStartOfVertexSampleIndex_, + "setStartOfVertexSampleIndex:"); +_MTL_PRIVATE_DEF_SEL(setStaticLinkingDescriptor_, + "setStaticLinkingDescriptor:"); +_MTL_PRIVATE_DEF_SEL(setStencilAttachment_, + "setStencilAttachment:"); +_MTL_PRIVATE_DEF_SEL(setStencilAttachmentPixelFormat_, + "setStencilAttachmentPixelFormat:"); +_MTL_PRIVATE_DEF_SEL(setStencilCompareFunction_, + "setStencilCompareFunction:"); +_MTL_PRIVATE_DEF_SEL(setStencilFailureOperation_, + "setStencilFailureOperation:"); +_MTL_PRIVATE_DEF_SEL(setStencilFrontReferenceValue_backReferenceValue_, + "setStencilFrontReferenceValue:backReferenceValue:"); +_MTL_PRIVATE_DEF_SEL(setStencilReferenceValue_, + "setStencilReferenceValue:"); +_MTL_PRIVATE_DEF_SEL(setStencilResolveFilter_, + "setStencilResolveFilter:"); +_MTL_PRIVATE_DEF_SEL(setStencilStoreAction_, + "setStencilStoreAction:"); +_MTL_PRIVATE_DEF_SEL(setStencilStoreActionOptions_, + "setStencilStoreActionOptions:"); +_MTL_PRIVATE_DEF_SEL(setStepFunction_, + "setStepFunction:"); +_MTL_PRIVATE_DEF_SEL(setStepRate_, + "setStepRate:"); +_MTL_PRIVATE_DEF_SEL(setStorageMode_, + "setStorageMode:"); +_MTL_PRIVATE_DEF_SEL(setStoreAction_, + "setStoreAction:"); +_MTL_PRIVATE_DEF_SEL(setStoreActionOptions_, + "setStoreActionOptions:"); +_MTL_PRIVATE_DEF_SEL(setStride_, + "setStride:"); +_MTL_PRIVATE_DEF_SEL(setStrides_, + "setStrides:"); +_MTL_PRIVATE_DEF_SEL(setSupportAddingBinaryFunctions_, + "setSupportAddingBinaryFunctions:"); +_MTL_PRIVATE_DEF_SEL(setSupportAddingFragmentBinaryFunctions_, + "setSupportAddingFragmentBinaryFunctions:"); +_MTL_PRIVATE_DEF_SEL(setSupportAddingVertexBinaryFunctions_, + "setSupportAddingVertexBinaryFunctions:"); +_MTL_PRIVATE_DEF_SEL(setSupportArgumentBuffers_, + "setSupportArgumentBuffers:"); +_MTL_PRIVATE_DEF_SEL(setSupportAttributeStrides_, + "setSupportAttributeStrides:"); +_MTL_PRIVATE_DEF_SEL(setSupportBinaryLinking_, + "setSupportBinaryLinking:"); +_MTL_PRIVATE_DEF_SEL(setSupportColorAttachmentMapping_, + "setSupportColorAttachmentMapping:"); +_MTL_PRIVATE_DEF_SEL(setSupportDynamicAttributeStride_, + "setSupportDynamicAttributeStride:"); +_MTL_PRIVATE_DEF_SEL(setSupportFragmentBinaryLinking_, + "setSupportFragmentBinaryLinking:"); +_MTL_PRIVATE_DEF_SEL(setSupportIndirectCommandBuffers_, + "setSupportIndirectCommandBuffers:"); +_MTL_PRIVATE_DEF_SEL(setSupportMeshBinaryLinking_, + "setSupportMeshBinaryLinking:"); +_MTL_PRIVATE_DEF_SEL(setSupportObjectBinaryLinking_, + "setSupportObjectBinaryLinking:"); +_MTL_PRIVATE_DEF_SEL(setSupportRayTracing_, + "setSupportRayTracing:"); +_MTL_PRIVATE_DEF_SEL(setSupportVertexBinaryLinking_, + "setSupportVertexBinaryLinking:"); +_MTL_PRIVATE_DEF_SEL(setSwizzle_, + "setSwizzle:"); +_MTL_PRIVATE_DEF_SEL(setTAddressMode_, + "setTAddressMode:"); +_MTL_PRIVATE_DEF_SEL(setTessellationControlPointIndexType_, + "setTessellationControlPointIndexType:"); +_MTL_PRIVATE_DEF_SEL(setTessellationFactorBuffer_offset_instanceStride_, + "setTessellationFactorBuffer:offset:instanceStride:"); +_MTL_PRIVATE_DEF_SEL(setTessellationFactorFormat_, + "setTessellationFactorFormat:"); +_MTL_PRIVATE_DEF_SEL(setTessellationFactorScale_, + "setTessellationFactorScale:"); +_MTL_PRIVATE_DEF_SEL(setTessellationFactorScaleEnabled_, + "setTessellationFactorScaleEnabled:"); +_MTL_PRIVATE_DEF_SEL(setTessellationFactorStepFunction_, + "setTessellationFactorStepFunction:"); +_MTL_PRIVATE_DEF_SEL(setTessellationOutputWindingOrder_, + "setTessellationOutputWindingOrder:"); +_MTL_PRIVATE_DEF_SEL(setTessellationPartitionMode_, + "setTessellationPartitionMode:"); +_MTL_PRIVATE_DEF_SEL(setTexture_, + "setTexture:"); +_MTL_PRIVATE_DEF_SEL(setTexture_atIndex_, + "setTexture:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setTextureType_, + "setTextureType:"); +_MTL_PRIVATE_DEF_SEL(setTextureView_atIndex_, + "setTextureView:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setTextureView_descriptor_atIndex_, + "setTextureView:descriptor:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setTextureViewFromBuffer_descriptor_offset_bytesPerRow_atIndex_, + "setTextureViewFromBuffer:descriptor:offset:bytesPerRow:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setTextures_withRange_, + "setTextures:withRange:"); +_MTL_PRIVATE_DEF_SEL(setThreadGroupSizeIsMultipleOfThreadExecutionWidth_, + "setThreadGroupSizeIsMultipleOfThreadExecutionWidth:"); +_MTL_PRIVATE_DEF_SEL(setThreadgroupMemoryLength_, + "setThreadgroupMemoryLength:"); +_MTL_PRIVATE_DEF_SEL(setThreadgroupMemoryLength_atIndex_, + "setThreadgroupMemoryLength:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setThreadgroupMemoryLength_offset_atIndex_, + "setThreadgroupMemoryLength:offset:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setThreadgroupSizeMatchesTileSize_, + "setThreadgroupSizeMatchesTileSize:"); +_MTL_PRIVATE_DEF_SEL(setTileAccelerationStructure_atBufferIndex_, + "setTileAccelerationStructure:atBufferIndex:"); +_MTL_PRIVATE_DEF_SEL(setTileAdditionalBinaryFunctions_, + "setTileAdditionalBinaryFunctions:"); +_MTL_PRIVATE_DEF_SEL(setTileBuffer_offset_atIndex_, + "setTileBuffer:offset:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setTileBufferOffset_atIndex_, + "setTileBufferOffset:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setTileBuffers_offsets_withRange_, + "setTileBuffers:offsets:withRange:"); +_MTL_PRIVATE_DEF_SEL(setTileBytes_length_atIndex_, + "setTileBytes:length:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setTileFunction_, + "setTileFunction:"); +_MTL_PRIVATE_DEF_SEL(setTileFunctionDescriptor_, + "setTileFunctionDescriptor:"); +_MTL_PRIVATE_DEF_SEL(setTileHeight_, + "setTileHeight:"); +_MTL_PRIVATE_DEF_SEL(setTileIntersectionFunctionTable_atBufferIndex_, + "setTileIntersectionFunctionTable:atBufferIndex:"); +_MTL_PRIVATE_DEF_SEL(setTileIntersectionFunctionTables_withBufferRange_, + "setTileIntersectionFunctionTables:withBufferRange:"); +_MTL_PRIVATE_DEF_SEL(setTileSamplerState_atIndex_, + "setTileSamplerState:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setTileSamplerState_lodMinClamp_lodMaxClamp_atIndex_, + "setTileSamplerState:lodMinClamp:lodMaxClamp:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setTileSamplerStates_lodMinClamps_lodMaxClamps_withRange_, + "setTileSamplerStates:lodMinClamps:lodMaxClamps:withRange:"); +_MTL_PRIVATE_DEF_SEL(setTileSamplerStates_withRange_, + "setTileSamplerStates:withRange:"); +_MTL_PRIVATE_DEF_SEL(setTileTexture_atIndex_, + "setTileTexture:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setTileTextures_withRange_, + "setTileTextures:withRange:"); +_MTL_PRIVATE_DEF_SEL(setTileVisibleFunctionTable_atBufferIndex_, + "setTileVisibleFunctionTable:atBufferIndex:"); +_MTL_PRIVATE_DEF_SEL(setTileVisibleFunctionTables_withBufferRange_, + "setTileVisibleFunctionTables:withBufferRange:"); +_MTL_PRIVATE_DEF_SEL(setTileWidth_, + "setTileWidth:"); +_MTL_PRIVATE_DEF_SEL(setTransformationMatrixBuffer_, + "setTransformationMatrixBuffer:"); +_MTL_PRIVATE_DEF_SEL(setTransformationMatrixBufferOffset_, + "setTransformationMatrixBufferOffset:"); +_MTL_PRIVATE_DEF_SEL(setTransformationMatrixLayout_, + "setTransformationMatrixLayout:"); +_MTL_PRIVATE_DEF_SEL(setTriangleCount_, + "setTriangleCount:"); +_MTL_PRIVATE_DEF_SEL(setTriangleFillMode_, + "setTriangleFillMode:"); +_MTL_PRIVATE_DEF_SEL(setType_, + "setType:"); +_MTL_PRIVATE_DEF_SEL(setUrl_, + "setUrl:"); +_MTL_PRIVATE_DEF_SEL(setUsage_, + "setUsage:"); +_MTL_PRIVATE_DEF_SEL(setVertexAccelerationStructure_atBufferIndex_, + "setVertexAccelerationStructure:atBufferIndex:"); +_MTL_PRIVATE_DEF_SEL(setVertexAdditionalBinaryFunctions_, + "setVertexAdditionalBinaryFunctions:"); +_MTL_PRIVATE_DEF_SEL(setVertexAmplificationCount_viewMappings_, + "setVertexAmplificationCount:viewMappings:"); +_MTL_PRIVATE_DEF_SEL(setVertexBuffer_, + "setVertexBuffer:"); +_MTL_PRIVATE_DEF_SEL(setVertexBuffer_offset_atIndex_, + "setVertexBuffer:offset:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setVertexBuffer_offset_attributeStride_atIndex_, + "setVertexBuffer:offset:attributeStride:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setVertexBufferOffset_, + "setVertexBufferOffset:"); +_MTL_PRIVATE_DEF_SEL(setVertexBufferOffset_atIndex_, + "setVertexBufferOffset:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setVertexBufferOffset_attributeStride_atIndex_, + "setVertexBufferOffset:attributeStride:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setVertexBuffers_, + "setVertexBuffers:"); +_MTL_PRIVATE_DEF_SEL(setVertexBuffers_offsets_attributeStrides_withRange_, + "setVertexBuffers:offsets:attributeStrides:withRange:"); +_MTL_PRIVATE_DEF_SEL(setVertexBuffers_offsets_withRange_, + "setVertexBuffers:offsets:withRange:"); +_MTL_PRIVATE_DEF_SEL(setVertexBytes_length_atIndex_, + "setVertexBytes:length:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setVertexBytes_length_attributeStride_atIndex_, + "setVertexBytes:length:attributeStride:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setVertexDescriptor_, + "setVertexDescriptor:"); +_MTL_PRIVATE_DEF_SEL(setVertexFormat_, + "setVertexFormat:"); +_MTL_PRIVATE_DEF_SEL(setVertexFunction_, + "setVertexFunction:"); +_MTL_PRIVATE_DEF_SEL(setVertexFunctionDescriptor_, + "setVertexFunctionDescriptor:"); +_MTL_PRIVATE_DEF_SEL(setVertexIntersectionFunctionTable_atBufferIndex_, + "setVertexIntersectionFunctionTable:atBufferIndex:"); +_MTL_PRIVATE_DEF_SEL(setVertexIntersectionFunctionTables_withBufferRange_, + "setVertexIntersectionFunctionTables:withBufferRange:"); +_MTL_PRIVATE_DEF_SEL(setVertexLinkedFunctions_, + "setVertexLinkedFunctions:"); +_MTL_PRIVATE_DEF_SEL(setVertexPreloadedLibraries_, + "setVertexPreloadedLibraries:"); +_MTL_PRIVATE_DEF_SEL(setVertexSamplerState_atIndex_, + "setVertexSamplerState:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setVertexSamplerState_lodMinClamp_lodMaxClamp_atIndex_, + "setVertexSamplerState:lodMinClamp:lodMaxClamp:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setVertexSamplerStates_lodMinClamps_lodMaxClamps_withRange_, + "setVertexSamplerStates:lodMinClamps:lodMaxClamps:withRange:"); +_MTL_PRIVATE_DEF_SEL(setVertexSamplerStates_withRange_, + "setVertexSamplerStates:withRange:"); +_MTL_PRIVATE_DEF_SEL(setVertexStaticLinkingDescriptor_, + "setVertexStaticLinkingDescriptor:"); +_MTL_PRIVATE_DEF_SEL(setVertexStride_, + "setVertexStride:"); +_MTL_PRIVATE_DEF_SEL(setVertexTexture_atIndex_, + "setVertexTexture:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setVertexTextures_withRange_, + "setVertexTextures:withRange:"); +_MTL_PRIVATE_DEF_SEL(setVertexVisibleFunctionTable_atBufferIndex_, + "setVertexVisibleFunctionTable:atBufferIndex:"); +_MTL_PRIVATE_DEF_SEL(setVertexVisibleFunctionTables_withBufferRange_, + "setVertexVisibleFunctionTables:withBufferRange:"); +_MTL_PRIVATE_DEF_SEL(setViewport_, + "setViewport:"); +_MTL_PRIVATE_DEF_SEL(setViewports_count_, + "setViewports:count:"); +_MTL_PRIVATE_DEF_SEL(setVisibilityResultBuffer_, + "setVisibilityResultBuffer:"); +_MTL_PRIVATE_DEF_SEL(setVisibilityResultMode_offset_, + "setVisibilityResultMode:offset:"); +_MTL_PRIVATE_DEF_SEL(setVisibilityResultType_, + "setVisibilityResultType:"); +_MTL_PRIVATE_DEF_SEL(setVisibleFunctionTable_atBufferIndex_, + "setVisibleFunctionTable:atBufferIndex:"); +_MTL_PRIVATE_DEF_SEL(setVisibleFunctionTable_atIndex_, + "setVisibleFunctionTable:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setVisibleFunctionTables_withBufferRange_, + "setVisibleFunctionTables:withBufferRange:"); +_MTL_PRIVATE_DEF_SEL(setVisibleFunctionTables_withRange_, + "setVisibleFunctionTables:withRange:"); +_MTL_PRIVATE_DEF_SEL(setWidth_, + "setWidth:"); +_MTL_PRIVATE_DEF_SEL(setWriteMask_, + "setWriteMask:"); +_MTL_PRIVATE_DEF_SEL(shaderReflection, + "shaderReflection"); +_MTL_PRIVATE_DEF_SEL(shaderValidation, + "shaderValidation"); +_MTL_PRIVATE_DEF_SEL(sharedCaptureManager, + "sharedCaptureManager"); +_MTL_PRIVATE_DEF_SEL(sharedListener, + "sharedListener"); +_MTL_PRIVATE_DEF_SEL(shouldMaximizeConcurrentCompilation, + "shouldMaximizeConcurrentCompilation"); +_MTL_PRIVATE_DEF_SEL(signalDrawable_, + "signalDrawable:"); +_MTL_PRIVATE_DEF_SEL(signalEvent_value_, + "signalEvent:value:"); +_MTL_PRIVATE_DEF_SEL(signaledValue, + "signaledValue"); +_MTL_PRIVATE_DEF_SEL(size, + "size"); +_MTL_PRIVATE_DEF_SEL(sizeOfCounterHeapEntry_, + "sizeOfCounterHeapEntry:"); +_MTL_PRIVATE_DEF_SEL(slice, + "slice"); +_MTL_PRIVATE_DEF_SEL(sliceRange, + "sliceRange"); +_MTL_PRIVATE_DEF_SEL(source, + "source"); +_MTL_PRIVATE_DEF_SEL(sourceAlphaBlendFactor, + "sourceAlphaBlendFactor"); +_MTL_PRIVATE_DEF_SEL(sourceRGBBlendFactor, + "sourceRGBBlendFactor"); +_MTL_PRIVATE_DEF_SEL(sparseBufferTier, + "sparseBufferTier"); +_MTL_PRIVATE_DEF_SEL(sparsePageSize, + "sparsePageSize"); +_MTL_PRIVATE_DEF_SEL(sparseTextureTier, + "sparseTextureTier"); +_MTL_PRIVATE_DEF_SEL(sparseTileSizeInBytes, + "sparseTileSizeInBytes"); +_MTL_PRIVATE_DEF_SEL(sparseTileSizeInBytesForSparsePageSize_, + "sparseTileSizeInBytesForSparsePageSize:"); +_MTL_PRIVATE_DEF_SEL(sparseTileSizeWithTextureType_pixelFormat_sampleCount_, + "sparseTileSizeWithTextureType:pixelFormat:sampleCount:"); +_MTL_PRIVATE_DEF_SEL(sparseTileSizeWithTextureType_pixelFormat_sampleCount_sparsePageSize_, + "sparseTileSizeWithTextureType:pixelFormat:sampleCount:sparsePageSize:"); +_MTL_PRIVATE_DEF_SEL(specializedName, + "specializedName"); +_MTL_PRIVATE_DEF_SEL(stageInputAttributes, + "stageInputAttributes"); +_MTL_PRIVATE_DEF_SEL(stageInputDescriptor, + "stageInputDescriptor"); +_MTL_PRIVATE_DEF_SEL(stageInputOutputDescriptor, + "stageInputOutputDescriptor"); +_MTL_PRIVATE_DEF_SEL(stages, + "stages"); +_MTL_PRIVATE_DEF_SEL(startCaptureWithCommandQueue_, + "startCaptureWithCommandQueue:"); +_MTL_PRIVATE_DEF_SEL(startCaptureWithDescriptor_error_, + "startCaptureWithDescriptor:error:"); +_MTL_PRIVATE_DEF_SEL(startCaptureWithDevice_, + "startCaptureWithDevice:"); +_MTL_PRIVATE_DEF_SEL(startCaptureWithScope_, + "startCaptureWithScope:"); +_MTL_PRIVATE_DEF_SEL(startOfEncoderSampleIndex, + "startOfEncoderSampleIndex"); +_MTL_PRIVATE_DEF_SEL(startOfFragmentSampleIndex, + "startOfFragmentSampleIndex"); +_MTL_PRIVATE_DEF_SEL(startOfVertexSampleIndex, + "startOfVertexSampleIndex"); +_MTL_PRIVATE_DEF_SEL(staticLinkingDescriptor, + "staticLinkingDescriptor"); +_MTL_PRIVATE_DEF_SEL(staticThreadgroupMemoryLength, + "staticThreadgroupMemoryLength"); +_MTL_PRIVATE_DEF_SEL(status, + "status"); +_MTL_PRIVATE_DEF_SEL(stencilAttachment, + "stencilAttachment"); +_MTL_PRIVATE_DEF_SEL(stencilAttachmentPixelFormat, + "stencilAttachmentPixelFormat"); +_MTL_PRIVATE_DEF_SEL(stencilCompareFunction, + "stencilCompareFunction"); +_MTL_PRIVATE_DEF_SEL(stencilFailureOperation, + "stencilFailureOperation"); +_MTL_PRIVATE_DEF_SEL(stencilResolveFilter, + "stencilResolveFilter"); +_MTL_PRIVATE_DEF_SEL(stepFunction, + "stepFunction"); +_MTL_PRIVATE_DEF_SEL(stepRate, + "stepRate"); +_MTL_PRIVATE_DEF_SEL(stopCapture, + "stopCapture"); +_MTL_PRIVATE_DEF_SEL(storageMode, + "storageMode"); +_MTL_PRIVATE_DEF_SEL(storeAction, + "storeAction"); +_MTL_PRIVATE_DEF_SEL(storeActionOptions, + "storeActionOptions"); +_MTL_PRIVATE_DEF_SEL(stride, + "stride"); +_MTL_PRIVATE_DEF_SEL(strides, + "strides"); +_MTL_PRIVATE_DEF_SEL(structType, + "structType"); +_MTL_PRIVATE_DEF_SEL(supportAddingBinaryFunctions, + "supportAddingBinaryFunctions"); +_MTL_PRIVATE_DEF_SEL(supportAddingFragmentBinaryFunctions, + "supportAddingFragmentBinaryFunctions"); +_MTL_PRIVATE_DEF_SEL(supportAddingVertexBinaryFunctions, + "supportAddingVertexBinaryFunctions"); +_MTL_PRIVATE_DEF_SEL(supportArgumentBuffers, + "supportArgumentBuffers"); +_MTL_PRIVATE_DEF_SEL(supportAttributeStrides, + "supportAttributeStrides"); +_MTL_PRIVATE_DEF_SEL(supportBinaryLinking, + "supportBinaryLinking"); +_MTL_PRIVATE_DEF_SEL(supportColorAttachmentMapping, + "supportColorAttachmentMapping"); +_MTL_PRIVATE_DEF_SEL(supportDynamicAttributeStride, + "supportDynamicAttributeStride"); +_MTL_PRIVATE_DEF_SEL(supportFragmentBinaryLinking, + "supportFragmentBinaryLinking"); +_MTL_PRIVATE_DEF_SEL(supportIndirectCommandBuffers, + "supportIndirectCommandBuffers"); +_MTL_PRIVATE_DEF_SEL(supportMeshBinaryLinking, + "supportMeshBinaryLinking"); +_MTL_PRIVATE_DEF_SEL(supportObjectBinaryLinking, + "supportObjectBinaryLinking"); +_MTL_PRIVATE_DEF_SEL(supportRayTracing, + "supportRayTracing"); +_MTL_PRIVATE_DEF_SEL(supportVertexBinaryLinking, + "supportVertexBinaryLinking"); +_MTL_PRIVATE_DEF_SEL(supports32BitFloatFiltering, + "supports32BitFloatFiltering"); +_MTL_PRIVATE_DEF_SEL(supports32BitMSAA, + "supports32BitMSAA"); +_MTL_PRIVATE_DEF_SEL(supportsBCTextureCompression, + "supportsBCTextureCompression"); +_MTL_PRIVATE_DEF_SEL(supportsCounterSampling_, + "supportsCounterSampling:"); +_MTL_PRIVATE_DEF_SEL(supportsDestination_, + "supportsDestination:"); +_MTL_PRIVATE_DEF_SEL(supportsDynamicLibraries, + "supportsDynamicLibraries"); +_MTL_PRIVATE_DEF_SEL(supportsFamily_, + "supportsFamily:"); +_MTL_PRIVATE_DEF_SEL(supportsFeatureSet_, + "supportsFeatureSet:"); +_MTL_PRIVATE_DEF_SEL(supportsFunctionPointers, + "supportsFunctionPointers"); +_MTL_PRIVATE_DEF_SEL(supportsFunctionPointersFromRender, + "supportsFunctionPointersFromRender"); +_MTL_PRIVATE_DEF_SEL(supportsPrimitiveMotionBlur, + "supportsPrimitiveMotionBlur"); +_MTL_PRIVATE_DEF_SEL(supportsPullModelInterpolation, + "supportsPullModelInterpolation"); +_MTL_PRIVATE_DEF_SEL(supportsQueryTextureLOD, + "supportsQueryTextureLOD"); +_MTL_PRIVATE_DEF_SEL(supportsRasterizationRateMapWithLayerCount_, + "supportsRasterizationRateMapWithLayerCount:"); +_MTL_PRIVATE_DEF_SEL(supportsRaytracing, + "supportsRaytracing"); +_MTL_PRIVATE_DEF_SEL(supportsRaytracingFromRender, + "supportsRaytracingFromRender"); +_MTL_PRIVATE_DEF_SEL(supportsRenderDynamicLibraries, + "supportsRenderDynamicLibraries"); +_MTL_PRIVATE_DEF_SEL(supportsShaderBarycentricCoordinates, + "supportsShaderBarycentricCoordinates"); +_MTL_PRIVATE_DEF_SEL(supportsTextureSampleCount_, + "supportsTextureSampleCount:"); +_MTL_PRIVATE_DEF_SEL(supportsVertexAmplificationCount_, + "supportsVertexAmplificationCount:"); +_MTL_PRIVATE_DEF_SEL(swizzle, + "swizzle"); +_MTL_PRIVATE_DEF_SEL(synchronizeResource_, + "synchronizeResource:"); +_MTL_PRIVATE_DEF_SEL(synchronizeTexture_slice_level_, + "synchronizeTexture:slice:level:"); +_MTL_PRIVATE_DEF_SEL(tAddressMode, + "tAddressMode"); +_MTL_PRIVATE_DEF_SEL(tailSizeInBytes, + "tailSizeInBytes"); +_MTL_PRIVATE_DEF_SEL(tensorDataType, + "tensorDataType"); +_MTL_PRIVATE_DEF_SEL(tensorReferenceType, + "tensorReferenceType"); +_MTL_PRIVATE_DEF_SEL(tensorSizeAndAlignWithDescriptor_, + "tensorSizeAndAlignWithDescriptor:"); +_MTL_PRIVATE_DEF_SEL(tessellationControlPointIndexType, + "tessellationControlPointIndexType"); +_MTL_PRIVATE_DEF_SEL(tessellationFactorFormat, + "tessellationFactorFormat"); +_MTL_PRIVATE_DEF_SEL(tessellationFactorStepFunction, + "tessellationFactorStepFunction"); +_MTL_PRIVATE_DEF_SEL(tessellationOutputWindingOrder, + "tessellationOutputWindingOrder"); +_MTL_PRIVATE_DEF_SEL(tessellationPartitionMode, + "tessellationPartitionMode"); +_MTL_PRIVATE_DEF_SEL(texture, + "texture"); +_MTL_PRIVATE_DEF_SEL(texture2DDescriptorWithPixelFormat_width_height_mipmapped_, + "texture2DDescriptorWithPixelFormat:width:height:mipmapped:"); +_MTL_PRIVATE_DEF_SEL(textureBarrier, + "textureBarrier"); +_MTL_PRIVATE_DEF_SEL(textureBufferDescriptorWithPixelFormat_width_resourceOptions_usage_, + "textureBufferDescriptorWithPixelFormat:width:resourceOptions:usage:"); +_MTL_PRIVATE_DEF_SEL(textureCubeDescriptorWithPixelFormat_size_mipmapped_, + "textureCubeDescriptorWithPixelFormat:size:mipmapped:"); +_MTL_PRIVATE_DEF_SEL(textureDataType, + "textureDataType"); +_MTL_PRIVATE_DEF_SEL(textureReferenceType, + "textureReferenceType"); +_MTL_PRIVATE_DEF_SEL(textureType, + "textureType"); +_MTL_PRIVATE_DEF_SEL(threadExecutionWidth, + "threadExecutionWidth"); +_MTL_PRIVATE_DEF_SEL(threadGroupSizeIsMultipleOfThreadExecutionWidth, + "threadGroupSizeIsMultipleOfThreadExecutionWidth"); +_MTL_PRIVATE_DEF_SEL(threadgroupMemoryAlignment, + "threadgroupMemoryAlignment"); +_MTL_PRIVATE_DEF_SEL(threadgroupMemoryDataSize, + "threadgroupMemoryDataSize"); +_MTL_PRIVATE_DEF_SEL(threadgroupMemoryLength, + "threadgroupMemoryLength"); +_MTL_PRIVATE_DEF_SEL(threadgroupSizeMatchesTileSize, + "threadgroupSizeMatchesTileSize"); +_MTL_PRIVATE_DEF_SEL(tileAdditionalBinaryFunctions, + "tileAdditionalBinaryFunctions"); +_MTL_PRIVATE_DEF_SEL(tileArguments, + "tileArguments"); +_MTL_PRIVATE_DEF_SEL(tileBindings, + "tileBindings"); +_MTL_PRIVATE_DEF_SEL(tileBuffers, + "tileBuffers"); +_MTL_PRIVATE_DEF_SEL(tileFunction, + "tileFunction"); +_MTL_PRIVATE_DEF_SEL(tileFunctionDescriptor, + "tileFunctionDescriptor"); +_MTL_PRIVATE_DEF_SEL(tileHeight, + "tileHeight"); +_MTL_PRIVATE_DEF_SEL(tileLinkingDescriptor, + "tileLinkingDescriptor"); +_MTL_PRIVATE_DEF_SEL(tileWidth, + "tileWidth"); +_MTL_PRIVATE_DEF_SEL(transformationMatrixBuffer, + "transformationMatrixBuffer"); +_MTL_PRIVATE_DEF_SEL(transformationMatrixBufferOffset, + "transformationMatrixBufferOffset"); +_MTL_PRIVATE_DEF_SEL(transformationMatrixLayout, + "transformationMatrixLayout"); +_MTL_PRIVATE_DEF_SEL(triangleCount, + "triangleCount"); +_MTL_PRIVATE_DEF_SEL(tryCancel, + "tryCancel"); +_MTL_PRIVATE_DEF_SEL(type, + "type"); +_MTL_PRIVATE_DEF_SEL(updateBufferMappings_heap_operations_count_, + "updateBufferMappings:heap:operations:count:"); +_MTL_PRIVATE_DEF_SEL(updateFence_, + "updateFence:"); +_MTL_PRIVATE_DEF_SEL(updateFence_afterEncoderStages_, + "updateFence:afterEncoderStages:"); +_MTL_PRIVATE_DEF_SEL(updateFence_afterStages_, + "updateFence:afterStages:"); +_MTL_PRIVATE_DEF_SEL(updateTextureMapping_mode_indirectBuffer_indirectBufferOffset_, + "updateTextureMapping:mode:indirectBuffer:indirectBufferOffset:"); +_MTL_PRIVATE_DEF_SEL(updateTextureMapping_mode_region_mipLevel_slice_, + "updateTextureMapping:mode:region:mipLevel:slice:"); +_MTL_PRIVATE_DEF_SEL(updateTextureMappings_heap_operations_count_, + "updateTextureMappings:heap:operations:count:"); +_MTL_PRIVATE_DEF_SEL(updateTextureMappings_mode_regions_mipLevels_slices_numRegions_, + "updateTextureMappings:mode:regions:mipLevels:slices:numRegions:"); +_MTL_PRIVATE_DEF_SEL(url, + "url"); +_MTL_PRIVATE_DEF_SEL(usage, + "usage"); +_MTL_PRIVATE_DEF_SEL(useHeap_, + "useHeap:"); +_MTL_PRIVATE_DEF_SEL(useHeap_stages_, + "useHeap:stages:"); +_MTL_PRIVATE_DEF_SEL(useHeaps_count_, + "useHeaps:count:"); +_MTL_PRIVATE_DEF_SEL(useHeaps_count_stages_, + "useHeaps:count:stages:"); +_MTL_PRIVATE_DEF_SEL(useResidencySet_, + "useResidencySet:"); +_MTL_PRIVATE_DEF_SEL(useResidencySets_count_, + "useResidencySets:count:"); +_MTL_PRIVATE_DEF_SEL(useResource_usage_, + "useResource:usage:"); +_MTL_PRIVATE_DEF_SEL(useResource_usage_stages_, + "useResource:usage:stages:"); +_MTL_PRIVATE_DEF_SEL(useResources_count_usage_, + "useResources:count:usage:"); +_MTL_PRIVATE_DEF_SEL(useResources_count_usage_stages_, + "useResources:count:usage:stages:"); +_MTL_PRIVATE_DEF_SEL(usedSize, + "usedSize"); +_MTL_PRIVATE_DEF_SEL(vertexAdditionalBinaryFunctions, + "vertexAdditionalBinaryFunctions"); +_MTL_PRIVATE_DEF_SEL(vertexArguments, + "vertexArguments"); +_MTL_PRIVATE_DEF_SEL(vertexAttributes, + "vertexAttributes"); +_MTL_PRIVATE_DEF_SEL(vertexBindings, + "vertexBindings"); +_MTL_PRIVATE_DEF_SEL(vertexBuffer, + "vertexBuffer"); +_MTL_PRIVATE_DEF_SEL(vertexBufferOffset, + "vertexBufferOffset"); +_MTL_PRIVATE_DEF_SEL(vertexBuffers, + "vertexBuffers"); +_MTL_PRIVATE_DEF_SEL(vertexDescriptor, + "vertexDescriptor"); +_MTL_PRIVATE_DEF_SEL(vertexFormat, + "vertexFormat"); +_MTL_PRIVATE_DEF_SEL(vertexFunction, + "vertexFunction"); +_MTL_PRIVATE_DEF_SEL(vertexFunctionDescriptor, + "vertexFunctionDescriptor"); +_MTL_PRIVATE_DEF_SEL(vertexLinkedFunctions, + "vertexLinkedFunctions"); +_MTL_PRIVATE_DEF_SEL(vertexLinkingDescriptor, + "vertexLinkingDescriptor"); +_MTL_PRIVATE_DEF_SEL(vertexPreloadedLibraries, + "vertexPreloadedLibraries"); +_MTL_PRIVATE_DEF_SEL(vertexStaticLinkingDescriptor, + "vertexStaticLinkingDescriptor"); +_MTL_PRIVATE_DEF_SEL(vertexStride, + "vertexStride"); +_MTL_PRIVATE_DEF_SEL(vertical, + "vertical"); +_MTL_PRIVATE_DEF_SEL(verticalSampleStorage, + "verticalSampleStorage"); +_MTL_PRIVATE_DEF_SEL(visibilityResultBuffer, + "visibilityResultBuffer"); +_MTL_PRIVATE_DEF_SEL(visibilityResultType, + "visibilityResultType"); +_MTL_PRIVATE_DEF_SEL(visibleFunctionTableDescriptor, + "visibleFunctionTableDescriptor"); +_MTL_PRIVATE_DEF_SEL(waitForDrawable_, + "waitForDrawable:"); +_MTL_PRIVATE_DEF_SEL(waitForEvent_value_, + "waitForEvent:value:"); +_MTL_PRIVATE_DEF_SEL(waitForFence_, + "waitForFence:"); +_MTL_PRIVATE_DEF_SEL(waitForFence_beforeEncoderStages_, + "waitForFence:beforeEncoderStages:"); +_MTL_PRIVATE_DEF_SEL(waitForFence_beforeStages_, + "waitForFence:beforeStages:"); +_MTL_PRIVATE_DEF_SEL(waitUntilCompleted, + "waitUntilCompleted"); +_MTL_PRIVATE_DEF_SEL(waitUntilScheduled, + "waitUntilScheduled"); +_MTL_PRIVATE_DEF_SEL(waitUntilSignaledValue_timeoutMS_, + "waitUntilSignaledValue:timeoutMS:"); +_MTL_PRIVATE_DEF_SEL(width, + "width"); +_MTL_PRIVATE_DEF_SEL(writeCompactedAccelerationStructureSize_toBuffer_, + "writeCompactedAccelerationStructureSize:toBuffer:"); +_MTL_PRIVATE_DEF_SEL(writeCompactedAccelerationStructureSize_toBuffer_offset_, + "writeCompactedAccelerationStructureSize:toBuffer:offset:"); +_MTL_PRIVATE_DEF_SEL(writeCompactedAccelerationStructureSize_toBuffer_offset_sizeDataType_, + "writeCompactedAccelerationStructureSize:toBuffer:offset:sizeDataType:"); +_MTL_PRIVATE_DEF_SEL(writeMask, + "writeMask"); +_MTL_PRIVATE_DEF_SEL(writeTimestampIntoHeap_atIndex_, + "writeTimestampIntoHeap:atIndex:"); +_MTL_PRIVATE_DEF_SEL(writeTimestampWithGranularity_afterStage_intoHeap_atIndex_, + "writeTimestampWithGranularity:afterStage:intoHeap:atIndex:"); +_MTL_PRIVATE_DEF_SEL(writeTimestampWithGranularity_intoHeap_atIndex_, + "writeTimestampWithGranularity:intoHeap:atIndex:"); + +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLHeap.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLHeap.hpp new file mode 100644 index 00000000..251b284a --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLHeap.hpp @@ -0,0 +1,318 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLHeap.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLAllocation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLResource.hpp" + +namespace MTL +{ +class AccelerationStructure; +class AccelerationStructureDescriptor; +class Buffer; +class Device; +class HeapDescriptor; +class Texture; +class TextureDescriptor; +_MTL_ENUM(NS::Integer, HeapType) { + HeapTypeAutomatic = 0, + HeapTypePlacement = 1, + HeapTypeSparse = 2, +}; + +class HeapDescriptor : public NS::Copying +{ +public: + static HeapDescriptor* alloc(); + + CPUCacheMode cpuCacheMode() const; + + HazardTrackingMode hazardTrackingMode() const; + + HeapDescriptor* init(); + + SparsePageSize maxCompatiblePlacementSparsePageSize() const; + + ResourceOptions resourceOptions() const; + + void setCpuCacheMode(MTL::CPUCacheMode cpuCacheMode); + + void setHazardTrackingMode(MTL::HazardTrackingMode hazardTrackingMode); + + void setMaxCompatiblePlacementSparsePageSize(MTL::SparsePageSize maxCompatiblePlacementSparsePageSize); + + void setResourceOptions(MTL::ResourceOptions resourceOptions); + + void setSize(NS::UInteger size); + + void setSparsePageSize(MTL::SparsePageSize sparsePageSize); + + void setStorageMode(MTL::StorageMode storageMode); + + void setType(MTL::HeapType type); + + NS::UInteger size() const; + SparsePageSize sparsePageSize() const; + + StorageMode storageMode() const; + + HeapType type() const; +}; +class Heap : public NS::Referencing +{ +public: + CPUCacheMode cpuCacheMode() const; + + NS::UInteger currentAllocatedSize() const; + + Device* device() const; + + HazardTrackingMode hazardTrackingMode() const; + + NS::String* label() const; + + NS::UInteger maxAvailableSize(NS::UInteger alignment); + + AccelerationStructure* newAccelerationStructure(NS::UInteger size); + AccelerationStructure* newAccelerationStructure(const MTL::AccelerationStructureDescriptor* descriptor); + AccelerationStructure* newAccelerationStructure(NS::UInteger size, NS::UInteger offset); + AccelerationStructure* newAccelerationStructure(const MTL::AccelerationStructureDescriptor* descriptor, NS::UInteger offset); + + Buffer* newBuffer(NS::UInteger length, MTL::ResourceOptions options); + Buffer* newBuffer(NS::UInteger length, MTL::ResourceOptions options, NS::UInteger offset); + + Texture* newTexture(const MTL::TextureDescriptor* descriptor); + Texture* newTexture(const MTL::TextureDescriptor* descriptor, NS::UInteger offset); + + ResourceOptions resourceOptions() const; + + void setLabel(const NS::String* label); + + PurgeableState setPurgeableState(MTL::PurgeableState state); + + NS::UInteger size() const; + + StorageMode storageMode() const; + + HeapType type() const; + + NS::UInteger usedSize() const; +}; + +} +_MTL_INLINE MTL::HeapDescriptor* MTL::HeapDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLHeapDescriptor)); +} + +_MTL_INLINE MTL::CPUCacheMode MTL::HeapDescriptor::cpuCacheMode() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(cpuCacheMode)); +} + +_MTL_INLINE MTL::HazardTrackingMode MTL::HeapDescriptor::hazardTrackingMode() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(hazardTrackingMode)); +} + +_MTL_INLINE MTL::HeapDescriptor* MTL::HeapDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::SparsePageSize MTL::HeapDescriptor::maxCompatiblePlacementSparsePageSize() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxCompatiblePlacementSparsePageSize)); +} + +_MTL_INLINE MTL::ResourceOptions MTL::HeapDescriptor::resourceOptions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(resourceOptions)); +} + +_MTL_INLINE void MTL::HeapDescriptor::setCpuCacheMode(MTL::CPUCacheMode cpuCacheMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setCpuCacheMode_), cpuCacheMode); +} + +_MTL_INLINE void MTL::HeapDescriptor::setHazardTrackingMode(MTL::HazardTrackingMode hazardTrackingMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setHazardTrackingMode_), hazardTrackingMode); +} + +_MTL_INLINE void MTL::HeapDescriptor::setMaxCompatiblePlacementSparsePageSize(MTL::SparsePageSize maxCompatiblePlacementSparsePageSize) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxCompatiblePlacementSparsePageSize_), maxCompatiblePlacementSparsePageSize); +} + +_MTL_INLINE void MTL::HeapDescriptor::setResourceOptions(MTL::ResourceOptions resourceOptions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setResourceOptions_), resourceOptions); +} + +_MTL_INLINE void MTL::HeapDescriptor::setSize(NS::UInteger size) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSize_), size); +} + +_MTL_INLINE void MTL::HeapDescriptor::setSparsePageSize(MTL::SparsePageSize sparsePageSize) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSparsePageSize_), sparsePageSize); +} + +_MTL_INLINE void MTL::HeapDescriptor::setStorageMode(MTL::StorageMode storageMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStorageMode_), storageMode); +} + +_MTL_INLINE void MTL::HeapDescriptor::setType(MTL::HeapType type) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setType_), type); +} + +_MTL_INLINE NS::UInteger MTL::HeapDescriptor::size() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(size)); +} + +_MTL_INLINE MTL::SparsePageSize MTL::HeapDescriptor::sparsePageSize() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(sparsePageSize)); +} + +_MTL_INLINE MTL::StorageMode MTL::HeapDescriptor::storageMode() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(storageMode)); +} + +_MTL_INLINE MTL::HeapType MTL::HeapDescriptor::type() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(type)); +} + +_MTL_INLINE MTL::CPUCacheMode MTL::Heap::cpuCacheMode() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(cpuCacheMode)); +} + +_MTL_INLINE NS::UInteger MTL::Heap::currentAllocatedSize() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(currentAllocatedSize)); +} + +_MTL_INLINE MTL::Device* MTL::Heap::device() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(device)); +} + +_MTL_INLINE MTL::HazardTrackingMode MTL::Heap::hazardTrackingMode() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(hazardTrackingMode)); +} + +_MTL_INLINE NS::String* MTL::Heap::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE NS::UInteger MTL::Heap::maxAvailableSize(NS::UInteger alignment) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxAvailableSizeWithAlignment_), alignment); +} + +_MTL_INLINE MTL::AccelerationStructure* MTL::Heap::newAccelerationStructure(NS::UInteger size) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newAccelerationStructureWithSize_), size); +} + +_MTL_INLINE MTL::AccelerationStructure* MTL::Heap::newAccelerationStructure(const MTL::AccelerationStructureDescriptor* descriptor) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newAccelerationStructureWithDescriptor_), descriptor); +} + +_MTL_INLINE MTL::AccelerationStructure* MTL::Heap::newAccelerationStructure(NS::UInteger size, NS::UInteger offset) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newAccelerationStructureWithSize_offset_), size, offset); +} + +_MTL_INLINE MTL::AccelerationStructure* MTL::Heap::newAccelerationStructure(const MTL::AccelerationStructureDescriptor* descriptor, NS::UInteger offset) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newAccelerationStructureWithDescriptor_offset_), descriptor, offset); +} + +_MTL_INLINE MTL::Buffer* MTL::Heap::newBuffer(NS::UInteger length, MTL::ResourceOptions options) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newBufferWithLength_options_), length, options); +} + +_MTL_INLINE MTL::Buffer* MTL::Heap::newBuffer(NS::UInteger length, MTL::ResourceOptions options, NS::UInteger offset) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newBufferWithLength_options_offset_), length, options, offset); +} + +_MTL_INLINE MTL::Texture* MTL::Heap::newTexture(const MTL::TextureDescriptor* descriptor) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newTextureWithDescriptor_), descriptor); +} + +_MTL_INLINE MTL::Texture* MTL::Heap::newTexture(const MTL::TextureDescriptor* descriptor, NS::UInteger offset) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newTextureWithDescriptor_offset_), descriptor, offset); +} + +_MTL_INLINE MTL::ResourceOptions MTL::Heap::resourceOptions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(resourceOptions)); +} + +_MTL_INLINE void MTL::Heap::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +_MTL_INLINE MTL::PurgeableState MTL::Heap::setPurgeableState(MTL::PurgeableState state) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(setPurgeableState_), state); +} + +_MTL_INLINE NS::UInteger MTL::Heap::size() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(size)); +} + +_MTL_INLINE MTL::StorageMode MTL::Heap::storageMode() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(storageMode)); +} + +_MTL_INLINE MTL::HeapType MTL::Heap::type() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(type)); +} + +_MTL_INLINE NS::UInteger MTL::Heap::usedSize() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(usedSize)); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLIOCommandBuffer.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLIOCommandBuffer.hpp new file mode 100644 index 00000000..9402318d --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLIOCommandBuffer.hpp @@ -0,0 +1,182 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLIOCommandBuffer.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLTypes.hpp" +#include + +namespace MTL +{ +class Buffer; +class IOCommandBuffer; +class IOFileHandle; +class SharedEvent; +class Texture; +_MTL_ENUM(NS::Integer, IOStatus) { + IOStatusPending = 0, + IOStatusCancelled = 1, + IOStatusError = 2, + IOStatusComplete = 3, +}; + +using IOCommandBufferHandler = void (^)(MTL::IOCommandBuffer*); +using IOCommandBufferHandlerFunction = std::function; + +class IOCommandBuffer : public NS::Referencing +{ +public: + void addBarrier(); + + void addCompletedHandler(const MTL::IOCommandBufferHandler block); + void addCompletedHandler(const MTL::IOCommandBufferHandlerFunction& function); + + void commit(); + + void copyStatusToBuffer(const MTL::Buffer* buffer, NS::UInteger offset); + + void enqueue(); + + NS::Error* error() const; + + NS::String* label() const; + + void loadBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger size, const MTL::IOFileHandle* sourceHandle, NS::UInteger sourceHandleOffset); + + void loadBytes(const void* pointer, NS::UInteger size, const MTL::IOFileHandle* sourceHandle, NS::UInteger sourceHandleOffset); + + void loadTexture(const MTL::Texture* texture, NS::UInteger slice, NS::UInteger level, MTL::Size size, NS::UInteger sourceBytesPerRow, NS::UInteger sourceBytesPerImage, MTL::Origin destinationOrigin, const MTL::IOFileHandle* sourceHandle, NS::UInteger sourceHandleOffset); + + void popDebugGroup(); + + void pushDebugGroup(const NS::String* string); + + void setLabel(const NS::String* label); + + void signalEvent(const MTL::SharedEvent* event, uint64_t value); + + IOStatus status() const; + + void tryCancel(); + + void wait(const MTL::SharedEvent* event, uint64_t value); + void waitUntilCompleted(); +}; + +} +_MTL_INLINE void MTL::IOCommandBuffer::addBarrier() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(addBarrier)); +} + +_MTL_INLINE void MTL::IOCommandBuffer::addCompletedHandler(const MTL::IOCommandBufferHandler block) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(addCompletedHandler_), block); +} + +_MTL_INLINE void MTL::IOCommandBuffer::addCompletedHandler(const MTL::IOCommandBufferHandlerFunction& function) +{ + __block MTL::IOCommandBufferHandlerFunction blockFunction = function; + addCompletedHandler(^(MTL::IOCommandBuffer* pCommandBuffer) { blockFunction(pCommandBuffer); }); +} + +_MTL_INLINE void MTL::IOCommandBuffer::commit() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(commit)); +} + +_MTL_INLINE void MTL::IOCommandBuffer::copyStatusToBuffer(const MTL::Buffer* buffer, NS::UInteger offset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(copyStatusToBuffer_offset_), buffer, offset); +} + +_MTL_INLINE void MTL::IOCommandBuffer::enqueue() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(enqueue)); +} + +_MTL_INLINE NS::Error* MTL::IOCommandBuffer::error() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(error)); +} + +_MTL_INLINE NS::String* MTL::IOCommandBuffer::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE void MTL::IOCommandBuffer::loadBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger size, const MTL::IOFileHandle* sourceHandle, NS::UInteger sourceHandleOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(loadBuffer_offset_size_sourceHandle_sourceHandleOffset_), buffer, offset, size, sourceHandle, sourceHandleOffset); +} + +_MTL_INLINE void MTL::IOCommandBuffer::loadBytes(const void* pointer, NS::UInteger size, const MTL::IOFileHandle* sourceHandle, NS::UInteger sourceHandleOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(loadBytes_size_sourceHandle_sourceHandleOffset_), pointer, size, sourceHandle, sourceHandleOffset); +} + +_MTL_INLINE void MTL::IOCommandBuffer::loadTexture(const MTL::Texture* texture, NS::UInteger slice, NS::UInteger level, MTL::Size size, NS::UInteger sourceBytesPerRow, NS::UInteger sourceBytesPerImage, MTL::Origin destinationOrigin, const MTL::IOFileHandle* sourceHandle, NS::UInteger sourceHandleOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(loadTexture_slice_level_size_sourceBytesPerRow_sourceBytesPerImage_destinationOrigin_sourceHandle_sourceHandleOffset_), texture, slice, level, size, sourceBytesPerRow, sourceBytesPerImage, destinationOrigin, sourceHandle, sourceHandleOffset); +} + +_MTL_INLINE void MTL::IOCommandBuffer::popDebugGroup() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(popDebugGroup)); +} + +_MTL_INLINE void MTL::IOCommandBuffer::pushDebugGroup(const NS::String* string) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(pushDebugGroup_), string); +} + +_MTL_INLINE void MTL::IOCommandBuffer::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +_MTL_INLINE void MTL::IOCommandBuffer::signalEvent(const MTL::SharedEvent* event, uint64_t value) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(signalEvent_value_), event, value); +} + +_MTL_INLINE MTL::IOStatus MTL::IOCommandBuffer::status() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(status)); +} + +_MTL_INLINE void MTL::IOCommandBuffer::tryCancel() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(tryCancel)); +} + +_MTL_INLINE void MTL::IOCommandBuffer::wait(const MTL::SharedEvent* event, uint64_t value) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(waitForEvent_value_), event, value); +} + +_MTL_INLINE void MTL::IOCommandBuffer::waitUntilCompleted() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(waitUntilCompleted)); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLIOCommandQueue.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLIOCommandQueue.hpp new file mode 100644 index 00000000..78de5d82 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLIOCommandQueue.hpp @@ -0,0 +1,211 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLIOCommandQueue.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL +{ +class Buffer; +class IOCommandBuffer; +class IOCommandQueueDescriptor; +class IOScratchBuffer; +class IOScratchBufferAllocator; +_MTL_ENUM(NS::Integer, IOPriority) { + IOPriorityHigh = 0, + IOPriorityNormal = 1, + IOPriorityLow = 2, +}; + +_MTL_ENUM(NS::Integer, IOCommandQueueType) { + IOCommandQueueTypeConcurrent = 0, + IOCommandQueueTypeSerial = 1, +}; + +_MTL_ENUM(NS::Integer, IOError) { + IOErrorURLInvalid = 1, + IOErrorInternal = 2, +}; + +_MTL_CONST(NS::ErrorDomain, IOErrorDomain); +class IOCommandQueue : public NS::Referencing +{ +public: + IOCommandBuffer* commandBuffer(); + IOCommandBuffer* commandBufferWithUnretainedReferences(); + + void enqueueBarrier(); + + NS::String* label() const; + void setLabel(const NS::String* label); +}; +class IOScratchBuffer : public NS::Referencing +{ +public: + Buffer* buffer() const; +}; +class IOScratchBufferAllocator : public NS::Referencing +{ +public: + IOScratchBuffer* newScratchBuffer(NS::UInteger minimumSize); +}; +class IOCommandQueueDescriptor : public NS::Copying +{ +public: + static IOCommandQueueDescriptor* alloc(); + + IOCommandQueueDescriptor* init(); + + NS::UInteger maxCommandBufferCount() const; + + NS::UInteger maxCommandsInFlight() const; + + IOPriority priority() const; + + IOScratchBufferAllocator* scratchBufferAllocator() const; + + void setMaxCommandBufferCount(NS::UInteger maxCommandBufferCount); + + void setMaxCommandsInFlight(NS::UInteger maxCommandsInFlight); + + void setPriority(MTL::IOPriority priority); + + void setScratchBufferAllocator(const MTL::IOScratchBufferAllocator* scratchBufferAllocator); + + void setType(MTL::IOCommandQueueType type); + IOCommandQueueType type() const; +}; +class IOFileHandle : public NS::Referencing +{ +public: + NS::String* label() const; + void setLabel(const NS::String* label); +}; + +} +_MTL_PRIVATE_DEF_WEAK_CONST(NS::ErrorDomain, IOErrorDomain); +_MTL_INLINE MTL::IOCommandBuffer* MTL::IOCommandQueue::commandBuffer() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(commandBuffer)); +} + +_MTL_INLINE MTL::IOCommandBuffer* MTL::IOCommandQueue::commandBufferWithUnretainedReferences() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(commandBufferWithUnretainedReferences)); +} + +_MTL_INLINE void MTL::IOCommandQueue::enqueueBarrier() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(enqueueBarrier)); +} + +_MTL_INLINE NS::String* MTL::IOCommandQueue::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE void MTL::IOCommandQueue::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +_MTL_INLINE MTL::Buffer* MTL::IOScratchBuffer::buffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(buffer)); +} + +_MTL_INLINE MTL::IOScratchBuffer* MTL::IOScratchBufferAllocator::newScratchBuffer(NS::UInteger minimumSize) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newScratchBufferWithMinimumSize_), minimumSize); +} + +_MTL_INLINE MTL::IOCommandQueueDescriptor* MTL::IOCommandQueueDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLIOCommandQueueDescriptor)); +} + +_MTL_INLINE MTL::IOCommandQueueDescriptor* MTL::IOCommandQueueDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::UInteger MTL::IOCommandQueueDescriptor::maxCommandBufferCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxCommandBufferCount)); +} + +_MTL_INLINE NS::UInteger MTL::IOCommandQueueDescriptor::maxCommandsInFlight() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxCommandsInFlight)); +} + +_MTL_INLINE MTL::IOPriority MTL::IOCommandQueueDescriptor::priority() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(priority)); +} + +_MTL_INLINE MTL::IOScratchBufferAllocator* MTL::IOCommandQueueDescriptor::scratchBufferAllocator() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(scratchBufferAllocator)); +} + +_MTL_INLINE void MTL::IOCommandQueueDescriptor::setMaxCommandBufferCount(NS::UInteger maxCommandBufferCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxCommandBufferCount_), maxCommandBufferCount); +} + +_MTL_INLINE void MTL::IOCommandQueueDescriptor::setMaxCommandsInFlight(NS::UInteger maxCommandsInFlight) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxCommandsInFlight_), maxCommandsInFlight); +} + +_MTL_INLINE void MTL::IOCommandQueueDescriptor::setPriority(MTL::IOPriority priority) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setPriority_), priority); +} + +_MTL_INLINE void MTL::IOCommandQueueDescriptor::setScratchBufferAllocator(const MTL::IOScratchBufferAllocator* scratchBufferAllocator) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setScratchBufferAllocator_), scratchBufferAllocator); +} + +_MTL_INLINE void MTL::IOCommandQueueDescriptor::setType(MTL::IOCommandQueueType type) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setType_), type); +} + +_MTL_INLINE MTL::IOCommandQueueType MTL::IOCommandQueueDescriptor::type() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(type)); +} + +_MTL_INLINE NS::String* MTL::IOFileHandle::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE void MTL::IOFileHandle::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLIOCompressor.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLIOCompressor.hpp new file mode 100644 index 00000000..920fa611 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLIOCompressor.hpp @@ -0,0 +1,94 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLIOCompressor.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLDevice.hpp" + +#include "../Foundation/Foundation.hpp" + +namespace MTL +{ +using IOCompressionContext=void*; + +_MTL_ENUM(NS::Integer, IOCompressionStatus) { + IOCompressionStatusComplete = 0, + IOCompressionStatusError = 1, +}; + +size_t IOCompressionContextDefaultChunkSize(); + +IOCompressionContext IOCreateCompressionContext(const char* path, IOCompressionMethod type, size_t chunkSize); + +void IOCompressionContextAppendData(IOCompressionContext context, const void* data, size_t size); + +IOCompressionStatus IOFlushAndDestroyCompressionContext(IOCompressionContext context); + +} + +#if defined(MTL_PRIVATE_IMPLEMENTATION) + +namespace MTL::Private { + +MTL_DEF_FUNC(MTLIOCompressionContextDefaultChunkSize, size_t (*)(void)); + +MTL_DEF_FUNC( MTLIOCreateCompressionContext, void* (*)(const char*, MTL::IOCompressionMethod, size_t) ); + +MTL_DEF_FUNC( MTLIOCompressionContextAppendData, void (*)(void*, const void*, size_t) ); + +MTL_DEF_FUNC( MTLIOFlushAndDestroyCompressionContext, MTL::IOCompressionStatus (*)(void*) ); + +} + +_NS_EXPORT size_t MTL::IOCompressionContextDefaultChunkSize() +{ + return MTL::Private::MTLIOCompressionContextDefaultChunkSize(); +} + +_NS_EXPORT void* MTL::IOCreateCompressionContext(const char* path, IOCompressionMethod type, size_t chunkSize) +{ + if ( MTL::Private::MTLIOCreateCompressionContext ) + { + return MTL::Private::MTLIOCreateCompressionContext( path, type, chunkSize ); + } + return nullptr; +} + +_NS_EXPORT void MTL::IOCompressionContextAppendData(void* context, const void* data, size_t size) +{ + if ( MTL::Private::MTLIOCompressionContextAppendData ) + { + MTL::Private::MTLIOCompressionContextAppendData( context, data, size ); + } +} + +_NS_EXPORT MTL::IOCompressionStatus MTL::IOFlushAndDestroyCompressionContext(void* context) +{ + if ( MTL::Private::MTLIOFlushAndDestroyCompressionContext ) + { + return MTL::Private::MTLIOFlushAndDestroyCompressionContext( context ); + } + return MTL::IOCompressionStatusError; +} + +#endif diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLIndirectCommandBuffer.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLIndirectCommandBuffer.hpp new file mode 100644 index 00000000..6944d562 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLIndirectCommandBuffer.hpp @@ -0,0 +1,376 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLIndirectCommandBuffer.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLResource.hpp" +#include "MTLTypes.hpp" +#include + +namespace MTL +{ +class IndirectCommandBufferDescriptor; +class IndirectComputeCommand; +class IndirectRenderCommand; + +_MTL_OPTIONS(NS::UInteger, IndirectCommandType) { + IndirectCommandTypeDraw = 1, + IndirectCommandTypeDrawIndexed = 1 << 1, + IndirectCommandTypeDrawPatches = 1 << 2, + IndirectCommandTypeDrawIndexedPatches = 1 << 3, + IndirectCommandTypeConcurrentDispatch = 1 << 5, + IndirectCommandTypeConcurrentDispatchThreads = 1 << 6, + IndirectCommandTypeDrawMeshThreadgroups = 1 << 7, + IndirectCommandTypeDrawMeshThreads = 1 << 8, +}; + +struct IndirectCommandBufferExecutionRange +{ + uint32_t location; + uint32_t length; +} _MTL_PACKED; + +class IndirectCommandBufferDescriptor : public NS::Copying +{ +public: + static IndirectCommandBufferDescriptor* alloc(); + + IndirectCommandType commandTypes() const; + + bool inheritBuffers() const; + + bool inheritCullMode() const; + + bool inheritDepthBias() const; + + bool inheritDepthClipMode() const; + + bool inheritDepthStencilState() const; + + bool inheritFrontFacingWinding() const; + + bool inheritPipelineState() const; + + bool inheritTriangleFillMode() const; + + IndirectCommandBufferDescriptor* init(); + + NS::UInteger maxFragmentBufferBindCount() const; + + NS::UInteger maxKernelBufferBindCount() const; + + NS::UInteger maxKernelThreadgroupMemoryBindCount() const; + + NS::UInteger maxMeshBufferBindCount() const; + + NS::UInteger maxObjectBufferBindCount() const; + + NS::UInteger maxObjectThreadgroupMemoryBindCount() const; + + NS::UInteger maxVertexBufferBindCount() const; + + void setCommandTypes(MTL::IndirectCommandType commandTypes); + + void setInheritBuffers(bool inheritBuffers); + + void setInheritCullMode(bool inheritCullMode); + + void setInheritDepthBias(bool inheritDepthBias); + + void setInheritDepthClipMode(bool inheritDepthClipMode); + + void setInheritDepthStencilState(bool inheritDepthStencilState); + + void setInheritFrontFacingWinding(bool inheritFrontFacingWinding); + + void setInheritPipelineState(bool inheritPipelineState); + + void setInheritTriangleFillMode(bool inheritTriangleFillMode); + + void setMaxFragmentBufferBindCount(NS::UInteger maxFragmentBufferBindCount); + + void setMaxKernelBufferBindCount(NS::UInteger maxKernelBufferBindCount); + + void setMaxKernelThreadgroupMemoryBindCount(NS::UInteger maxKernelThreadgroupMemoryBindCount); + + void setMaxMeshBufferBindCount(NS::UInteger maxMeshBufferBindCount); + + void setMaxObjectBufferBindCount(NS::UInteger maxObjectBufferBindCount); + + void setMaxObjectThreadgroupMemoryBindCount(NS::UInteger maxObjectThreadgroupMemoryBindCount); + + void setMaxVertexBufferBindCount(NS::UInteger maxVertexBufferBindCount); + + void setSupportColorAttachmentMapping(bool supportColorAttachmentMapping); + + void setSupportDynamicAttributeStride(bool supportDynamicAttributeStride); + + void setSupportRayTracing(bool supportRayTracing); + + bool supportColorAttachmentMapping() const; + + bool supportDynamicAttributeStride() const; + + bool supportRayTracing() const; +}; +class IndirectCommandBuffer : public NS::Referencing +{ +public: + ResourceID gpuResourceID() const; + + IndirectComputeCommand* indirectComputeCommand(NS::UInteger commandIndex); + + IndirectRenderCommand* indirectRenderCommand(NS::UInteger commandIndex); + + void reset(NS::Range range); + + NS::UInteger size() const; +}; + +} + +_MTL_INLINE MTL::IndirectCommandBufferDescriptor* MTL::IndirectCommandBufferDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLIndirectCommandBufferDescriptor)); +} + +_MTL_INLINE MTL::IndirectCommandType MTL::IndirectCommandBufferDescriptor::commandTypes() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(commandTypes)); +} + +_MTL_INLINE bool MTL::IndirectCommandBufferDescriptor::inheritBuffers() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(inheritBuffers)); +} + +_MTL_INLINE bool MTL::IndirectCommandBufferDescriptor::inheritCullMode() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(inheritCullMode)); +} + +_MTL_INLINE bool MTL::IndirectCommandBufferDescriptor::inheritDepthBias() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(inheritDepthBias)); +} + +_MTL_INLINE bool MTL::IndirectCommandBufferDescriptor::inheritDepthClipMode() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(inheritDepthClipMode)); +} + +_MTL_INLINE bool MTL::IndirectCommandBufferDescriptor::inheritDepthStencilState() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(inheritDepthStencilState)); +} + +_MTL_INLINE bool MTL::IndirectCommandBufferDescriptor::inheritFrontFacingWinding() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(inheritFrontFacingWinding)); +} + +_MTL_INLINE bool MTL::IndirectCommandBufferDescriptor::inheritPipelineState() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(inheritPipelineState)); +} + +_MTL_INLINE bool MTL::IndirectCommandBufferDescriptor::inheritTriangleFillMode() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(inheritTriangleFillMode)); +} + +_MTL_INLINE MTL::IndirectCommandBufferDescriptor* MTL::IndirectCommandBufferDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::UInteger MTL::IndirectCommandBufferDescriptor::maxFragmentBufferBindCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxFragmentBufferBindCount)); +} + +_MTL_INLINE NS::UInteger MTL::IndirectCommandBufferDescriptor::maxKernelBufferBindCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxKernelBufferBindCount)); +} + +_MTL_INLINE NS::UInteger MTL::IndirectCommandBufferDescriptor::maxKernelThreadgroupMemoryBindCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxKernelThreadgroupMemoryBindCount)); +} + +_MTL_INLINE NS::UInteger MTL::IndirectCommandBufferDescriptor::maxMeshBufferBindCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxMeshBufferBindCount)); +} + +_MTL_INLINE NS::UInteger MTL::IndirectCommandBufferDescriptor::maxObjectBufferBindCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxObjectBufferBindCount)); +} + +_MTL_INLINE NS::UInteger MTL::IndirectCommandBufferDescriptor::maxObjectThreadgroupMemoryBindCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxObjectThreadgroupMemoryBindCount)); +} + +_MTL_INLINE NS::UInteger MTL::IndirectCommandBufferDescriptor::maxVertexBufferBindCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxVertexBufferBindCount)); +} + +_MTL_INLINE void MTL::IndirectCommandBufferDescriptor::setCommandTypes(MTL::IndirectCommandType commandTypes) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setCommandTypes_), commandTypes); +} + +_MTL_INLINE void MTL::IndirectCommandBufferDescriptor::setInheritBuffers(bool inheritBuffers) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInheritBuffers_), inheritBuffers); +} + +_MTL_INLINE void MTL::IndirectCommandBufferDescriptor::setInheritCullMode(bool inheritCullMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInheritCullMode_), inheritCullMode); +} + +_MTL_INLINE void MTL::IndirectCommandBufferDescriptor::setInheritDepthBias(bool inheritDepthBias) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInheritDepthBias_), inheritDepthBias); +} + +_MTL_INLINE void MTL::IndirectCommandBufferDescriptor::setInheritDepthClipMode(bool inheritDepthClipMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInheritDepthClipMode_), inheritDepthClipMode); +} + +_MTL_INLINE void MTL::IndirectCommandBufferDescriptor::setInheritDepthStencilState(bool inheritDepthStencilState) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInheritDepthStencilState_), inheritDepthStencilState); +} + +_MTL_INLINE void MTL::IndirectCommandBufferDescriptor::setInheritFrontFacingWinding(bool inheritFrontFacingWinding) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInheritFrontFacingWinding_), inheritFrontFacingWinding); +} + +_MTL_INLINE void MTL::IndirectCommandBufferDescriptor::setInheritPipelineState(bool inheritPipelineState) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInheritPipelineState_), inheritPipelineState); +} + +_MTL_INLINE void MTL::IndirectCommandBufferDescriptor::setInheritTriangleFillMode(bool inheritTriangleFillMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInheritTriangleFillMode_), inheritTriangleFillMode); +} + +_MTL_INLINE void MTL::IndirectCommandBufferDescriptor::setMaxFragmentBufferBindCount(NS::UInteger maxFragmentBufferBindCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxFragmentBufferBindCount_), maxFragmentBufferBindCount); +} + +_MTL_INLINE void MTL::IndirectCommandBufferDescriptor::setMaxKernelBufferBindCount(NS::UInteger maxKernelBufferBindCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxKernelBufferBindCount_), maxKernelBufferBindCount); +} + +_MTL_INLINE void MTL::IndirectCommandBufferDescriptor::setMaxKernelThreadgroupMemoryBindCount(NS::UInteger maxKernelThreadgroupMemoryBindCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxKernelThreadgroupMemoryBindCount_), maxKernelThreadgroupMemoryBindCount); +} + +_MTL_INLINE void MTL::IndirectCommandBufferDescriptor::setMaxMeshBufferBindCount(NS::UInteger maxMeshBufferBindCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxMeshBufferBindCount_), maxMeshBufferBindCount); +} + +_MTL_INLINE void MTL::IndirectCommandBufferDescriptor::setMaxObjectBufferBindCount(NS::UInteger maxObjectBufferBindCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxObjectBufferBindCount_), maxObjectBufferBindCount); +} + +_MTL_INLINE void MTL::IndirectCommandBufferDescriptor::setMaxObjectThreadgroupMemoryBindCount(NS::UInteger maxObjectThreadgroupMemoryBindCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxObjectThreadgroupMemoryBindCount_), maxObjectThreadgroupMemoryBindCount); +} + +_MTL_INLINE void MTL::IndirectCommandBufferDescriptor::setMaxVertexBufferBindCount(NS::UInteger maxVertexBufferBindCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxVertexBufferBindCount_), maxVertexBufferBindCount); +} + +_MTL_INLINE void MTL::IndirectCommandBufferDescriptor::setSupportColorAttachmentMapping(bool supportColorAttachmentMapping) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSupportColorAttachmentMapping_), supportColorAttachmentMapping); +} + +_MTL_INLINE void MTL::IndirectCommandBufferDescriptor::setSupportDynamicAttributeStride(bool supportDynamicAttributeStride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSupportDynamicAttributeStride_), supportDynamicAttributeStride); +} + +_MTL_INLINE void MTL::IndirectCommandBufferDescriptor::setSupportRayTracing(bool supportRayTracing) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSupportRayTracing_), supportRayTracing); +} + +_MTL_INLINE bool MTL::IndirectCommandBufferDescriptor::supportColorAttachmentMapping() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportColorAttachmentMapping)); +} + +_MTL_INLINE bool MTL::IndirectCommandBufferDescriptor::supportDynamicAttributeStride() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportDynamicAttributeStride)); +} + +_MTL_INLINE bool MTL::IndirectCommandBufferDescriptor::supportRayTracing() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportRayTracing)); +} + +_MTL_INLINE MTL::ResourceID MTL::IndirectCommandBuffer::gpuResourceID() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(gpuResourceID)); +} + +_MTL_INLINE MTL::IndirectComputeCommand* MTL::IndirectCommandBuffer::indirectComputeCommand(NS::UInteger commandIndex) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(indirectComputeCommandAtIndex_), commandIndex); +} + +_MTL_INLINE MTL::IndirectRenderCommand* MTL::IndirectCommandBuffer::indirectRenderCommand(NS::UInteger commandIndex) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(indirectRenderCommandAtIndex_), commandIndex); +} + +_MTL_INLINE void MTL::IndirectCommandBuffer::reset(NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(resetWithRange_), range); +} + +_MTL_INLINE NS::UInteger MTL::IndirectCommandBuffer::size() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(size)); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLIndirectCommandEncoder.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLIndirectCommandEncoder.hpp new file mode 100644 index 00000000..9e1a9400 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLIndirectCommandEncoder.hpp @@ -0,0 +1,272 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLIndirectCommandEncoder.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLArgument.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLRenderCommandEncoder.hpp" +#include "MTLTypes.hpp" + +namespace MTL +{ +class Buffer; +class ComputePipelineState; +class RenderPipelineState; + +class IndirectRenderCommand : public NS::Referencing +{ +public: + void clearBarrier(); + + void drawIndexedPatches(NS::UInteger numberOfPatchControlPoints, NS::UInteger patchStart, NS::UInteger patchCount, const MTL::Buffer* patchIndexBuffer, NS::UInteger patchIndexBufferOffset, const MTL::Buffer* controlPointIndexBuffer, NS::UInteger controlPointIndexBufferOffset, NS::UInteger instanceCount, NS::UInteger baseInstance, const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger instanceStride); + + void drawIndexedPrimitives(MTL::PrimitiveType primitiveType, NS::UInteger indexCount, MTL::IndexType indexType, const MTL::Buffer* indexBuffer, NS::UInteger indexBufferOffset, NS::UInteger instanceCount, NS::Integer baseVertex, NS::UInteger baseInstance); + + void drawMeshThreadgroups(MTL::Size threadgroupsPerGrid, MTL::Size threadsPerObjectThreadgroup, MTL::Size threadsPerMeshThreadgroup); + + void drawMeshThreads(MTL::Size threadsPerGrid, MTL::Size threadsPerObjectThreadgroup, MTL::Size threadsPerMeshThreadgroup); + + void drawPatches(NS::UInteger numberOfPatchControlPoints, NS::UInteger patchStart, NS::UInteger patchCount, const MTL::Buffer* patchIndexBuffer, NS::UInteger patchIndexBufferOffset, NS::UInteger instanceCount, NS::UInteger baseInstance, const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger instanceStride); + + void drawPrimitives(MTL::PrimitiveType primitiveType, NS::UInteger vertexStart, NS::UInteger vertexCount, NS::UInteger instanceCount, NS::UInteger baseInstance); + + void reset(); + + void setBarrier(); + + void setCullMode(MTL::CullMode cullMode); + + void setDepthBias(float depthBias, float slopeScale, float clamp); + + void setDepthClipMode(MTL::DepthClipMode depthClipMode); + + void setDepthStencilState(const MTL::DepthStencilState* depthStencilState); + + void setFragmentBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger index); + + void setFrontFacingWinding(MTL::Winding frontFacingWindning); + + void setMeshBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger index); + + void setObjectBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger index); + + void setObjectThreadgroupMemoryLength(NS::UInteger length, NS::UInteger index); + + void setRenderPipelineState(const MTL::RenderPipelineState* pipelineState); + + void setTriangleFillMode(MTL::TriangleFillMode fillMode); + + void setVertexBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger index); + void setVertexBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger stride, NS::UInteger index); +}; +class IndirectComputeCommand : public NS::Referencing +{ +public: + void clearBarrier(); + + void concurrentDispatchThreadgroups(MTL::Size threadgroupsPerGrid, MTL::Size threadsPerThreadgroup); + + void concurrentDispatchThreads(MTL::Size threadsPerGrid, MTL::Size threadsPerThreadgroup); + + void reset(); + + void setBarrier(); + + void setComputePipelineState(const MTL::ComputePipelineState* pipelineState); + + void setImageblockWidth(NS::UInteger width, NS::UInteger height); + + void setKernelBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger index); + void setKernelBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger stride, NS::UInteger index); + + void setStageInRegion(MTL::Region region); + + void setThreadgroupMemoryLength(NS::UInteger length, NS::UInteger index); +}; + +} +_MTL_INLINE void MTL::IndirectRenderCommand::clearBarrier() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(clearBarrier)); +} + +_MTL_INLINE void MTL::IndirectRenderCommand::drawIndexedPatches(NS::UInteger numberOfPatchControlPoints, NS::UInteger patchStart, NS::UInteger patchCount, const MTL::Buffer* patchIndexBuffer, NS::UInteger patchIndexBufferOffset, const MTL::Buffer* controlPointIndexBuffer, NS::UInteger controlPointIndexBufferOffset, NS::UInteger instanceCount, NS::UInteger baseInstance, const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger instanceStride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawIndexedPatches_patchStart_patchCount_patchIndexBuffer_patchIndexBufferOffset_controlPointIndexBuffer_controlPointIndexBufferOffset_instanceCount_baseInstance_tessellationFactorBuffer_tessellationFactorBufferOffset_tessellationFactorBufferInstanceStride_), numberOfPatchControlPoints, patchStart, patchCount, patchIndexBuffer, patchIndexBufferOffset, controlPointIndexBuffer, controlPointIndexBufferOffset, instanceCount, baseInstance, buffer, offset, instanceStride); +} + +_MTL_INLINE void MTL::IndirectRenderCommand::drawIndexedPrimitives(MTL::PrimitiveType primitiveType, NS::UInteger indexCount, MTL::IndexType indexType, const MTL::Buffer* indexBuffer, NS::UInteger indexBufferOffset, NS::UInteger instanceCount, NS::Integer baseVertex, NS::UInteger baseInstance) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawIndexedPrimitives_indexCount_indexType_indexBuffer_indexBufferOffset_instanceCount_baseVertex_baseInstance_), primitiveType, indexCount, indexType, indexBuffer, indexBufferOffset, instanceCount, baseVertex, baseInstance); +} + +_MTL_INLINE void MTL::IndirectRenderCommand::drawMeshThreadgroups(MTL::Size threadgroupsPerGrid, MTL::Size threadsPerObjectThreadgroup, MTL::Size threadsPerMeshThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawMeshThreadgroups_threadsPerObjectThreadgroup_threadsPerMeshThreadgroup_), threadgroupsPerGrid, threadsPerObjectThreadgroup, threadsPerMeshThreadgroup); +} + +_MTL_INLINE void MTL::IndirectRenderCommand::drawMeshThreads(MTL::Size threadsPerGrid, MTL::Size threadsPerObjectThreadgroup, MTL::Size threadsPerMeshThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawMeshThreads_threadsPerObjectThreadgroup_threadsPerMeshThreadgroup_), threadsPerGrid, threadsPerObjectThreadgroup, threadsPerMeshThreadgroup); +} + +_MTL_INLINE void MTL::IndirectRenderCommand::drawPatches(NS::UInteger numberOfPatchControlPoints, NS::UInteger patchStart, NS::UInteger patchCount, const MTL::Buffer* patchIndexBuffer, NS::UInteger patchIndexBufferOffset, NS::UInteger instanceCount, NS::UInteger baseInstance, const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger instanceStride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawPatches_patchStart_patchCount_patchIndexBuffer_patchIndexBufferOffset_instanceCount_baseInstance_tessellationFactorBuffer_tessellationFactorBufferOffset_tessellationFactorBufferInstanceStride_), numberOfPatchControlPoints, patchStart, patchCount, patchIndexBuffer, patchIndexBufferOffset, instanceCount, baseInstance, buffer, offset, instanceStride); +} + +_MTL_INLINE void MTL::IndirectRenderCommand::drawPrimitives(MTL::PrimitiveType primitiveType, NS::UInteger vertexStart, NS::UInteger vertexCount, NS::UInteger instanceCount, NS::UInteger baseInstance) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawPrimitives_vertexStart_vertexCount_instanceCount_baseInstance_), primitiveType, vertexStart, vertexCount, instanceCount, baseInstance); +} + +_MTL_INLINE void MTL::IndirectRenderCommand::reset() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(reset)); +} + +_MTL_INLINE void MTL::IndirectRenderCommand::setBarrier() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBarrier)); +} + +_MTL_INLINE void MTL::IndirectRenderCommand::setCullMode(MTL::CullMode cullMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setCullMode_), cullMode); +} + +_MTL_INLINE void MTL::IndirectRenderCommand::setDepthBias(float depthBias, float slopeScale, float clamp) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDepthBias_slopeScale_clamp_), depthBias, slopeScale, clamp); +} + +_MTL_INLINE void MTL::IndirectRenderCommand::setDepthClipMode(MTL::DepthClipMode depthClipMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDepthClipMode_), depthClipMode); +} + +_MTL_INLINE void MTL::IndirectRenderCommand::setDepthStencilState(const MTL::DepthStencilState* depthStencilState) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDepthStencilState_), depthStencilState); +} + +_MTL_INLINE void MTL::IndirectRenderCommand::setFragmentBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFragmentBuffer_offset_atIndex_), buffer, offset, index); +} + +_MTL_INLINE void MTL::IndirectRenderCommand::setFrontFacingWinding(MTL::Winding frontFacingWindning) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFrontFacingWinding_), frontFacingWindning); +} + +_MTL_INLINE void MTL::IndirectRenderCommand::setMeshBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMeshBuffer_offset_atIndex_), buffer, offset, index); +} + +_MTL_INLINE void MTL::IndirectRenderCommand::setObjectBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObjectBuffer_offset_atIndex_), buffer, offset, index); +} + +_MTL_INLINE void MTL::IndirectRenderCommand::setObjectThreadgroupMemoryLength(NS::UInteger length, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObjectThreadgroupMemoryLength_atIndex_), length, index); +} + +_MTL_INLINE void MTL::IndirectRenderCommand::setRenderPipelineState(const MTL::RenderPipelineState* pipelineState) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRenderPipelineState_), pipelineState); +} + +_MTL_INLINE void MTL::IndirectRenderCommand::setTriangleFillMode(MTL::TriangleFillMode fillMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTriangleFillMode_), fillMode); +} + +_MTL_INLINE void MTL::IndirectRenderCommand::setVertexBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexBuffer_offset_atIndex_), buffer, offset, index); +} + +_MTL_INLINE void MTL::IndirectRenderCommand::setVertexBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger stride, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexBuffer_offset_attributeStride_atIndex_), buffer, offset, stride, index); +} + +_MTL_INLINE void MTL::IndirectComputeCommand::clearBarrier() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(clearBarrier)); +} + +_MTL_INLINE void MTL::IndirectComputeCommand::concurrentDispatchThreadgroups(MTL::Size threadgroupsPerGrid, MTL::Size threadsPerThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(concurrentDispatchThreadgroups_threadsPerThreadgroup_), threadgroupsPerGrid, threadsPerThreadgroup); +} + +_MTL_INLINE void MTL::IndirectComputeCommand::concurrentDispatchThreads(MTL::Size threadsPerGrid, MTL::Size threadsPerThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(concurrentDispatchThreads_threadsPerThreadgroup_), threadsPerGrid, threadsPerThreadgroup); +} + +_MTL_INLINE void MTL::IndirectComputeCommand::reset() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(reset)); +} + +_MTL_INLINE void MTL::IndirectComputeCommand::setBarrier() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBarrier)); +} + +_MTL_INLINE void MTL::IndirectComputeCommand::setComputePipelineState(const MTL::ComputePipelineState* pipelineState) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setComputePipelineState_), pipelineState); +} + +_MTL_INLINE void MTL::IndirectComputeCommand::setImageblockWidth(NS::UInteger width, NS::UInteger height) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setImageblockWidth_height_), width, height); +} + +_MTL_INLINE void MTL::IndirectComputeCommand::setKernelBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setKernelBuffer_offset_atIndex_), buffer, offset, index); +} + +_MTL_INLINE void MTL::IndirectComputeCommand::setKernelBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger stride, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setKernelBuffer_offset_attributeStride_atIndex_), buffer, offset, stride, index); +} + +_MTL_INLINE void MTL::IndirectComputeCommand::setStageInRegion(MTL::Region region) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStageInRegion_), region); +} + +_MTL_INLINE void MTL::IndirectComputeCommand::setThreadgroupMemoryLength(NS::UInteger length, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setThreadgroupMemoryLength_atIndex_), length, index); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLIntersectionFunctionTable.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLIntersectionFunctionTable.hpp new file mode 100644 index 00000000..436653b2 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLIntersectionFunctionTable.hpp @@ -0,0 +1,173 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLIntersectionFunctionTable.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLResource.hpp" +#include "MTLTypes.hpp" +#include + +namespace MTL +{ +class Buffer; +class FunctionHandle; +class IntersectionFunctionTableDescriptor; +class VisibleFunctionTable; + +_MTL_OPTIONS(NS::UInteger, IntersectionFunctionSignature) { + IntersectionFunctionSignatureNone = 0, + IntersectionFunctionSignatureInstancing = 1, + IntersectionFunctionSignatureTriangleData = 1 << 1, + IntersectionFunctionSignatureWorldSpaceData = 1 << 2, + IntersectionFunctionSignatureInstanceMotion = 1 << 3, + IntersectionFunctionSignaturePrimitiveMotion = 1 << 4, + IntersectionFunctionSignatureExtendedLimits = 1 << 5, + IntersectionFunctionSignatureMaxLevels = 1 << 6, + IntersectionFunctionSignatureCurveData = 1 << 7, + IntersectionFunctionSignatureIntersectionFunctionBuffer = 1 << 8, + IntersectionFunctionSignatureUserData = 1 << 9, +}; + +struct IntersectionFunctionBufferArguments +{ + uint64_t intersectionFunctionBuffer; + uint64_t intersectionFunctionBufferSize; + uint64_t intersectionFunctionStride; +} _MTL_PACKED; + +class IntersectionFunctionTableDescriptor : public NS::Copying +{ +public: + static IntersectionFunctionTableDescriptor* alloc(); + + NS::UInteger functionCount() const; + + IntersectionFunctionTableDescriptor* init(); + + static IntersectionFunctionTableDescriptor* intersectionFunctionTableDescriptor(); + + void setFunctionCount(NS::UInteger functionCount); +}; +class IntersectionFunctionTable : public NS::Referencing +{ +public: + ResourceID gpuResourceID() const; + + void setBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger index); + void setBuffers(const MTL::Buffer* const buffers[], const NS::UInteger offsets[], NS::Range range); + + void setFunction(const MTL::FunctionHandle* function, NS::UInteger index); + void setFunctions(const MTL::FunctionHandle* const functions[], NS::Range range); + + void setOpaqueCurveIntersectionFunction(MTL::IntersectionFunctionSignature signature, NS::UInteger index); + void setOpaqueCurveIntersectionFunction(MTL::IntersectionFunctionSignature signature, NS::Range range); + + void setOpaqueTriangleIntersectionFunction(MTL::IntersectionFunctionSignature signature, NS::UInteger index); + void setOpaqueTriangleIntersectionFunction(MTL::IntersectionFunctionSignature signature, NS::Range range); + + void setVisibleFunctionTable(const MTL::VisibleFunctionTable* functionTable, NS::UInteger bufferIndex); + void setVisibleFunctionTables(const MTL::VisibleFunctionTable* const functionTables[], NS::Range bufferRange); +}; + +} + +_MTL_INLINE MTL::IntersectionFunctionTableDescriptor* MTL::IntersectionFunctionTableDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLIntersectionFunctionTableDescriptor)); +} + +_MTL_INLINE NS::UInteger MTL::IntersectionFunctionTableDescriptor::functionCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(functionCount)); +} + +_MTL_INLINE MTL::IntersectionFunctionTableDescriptor* MTL::IntersectionFunctionTableDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::IntersectionFunctionTableDescriptor* MTL::IntersectionFunctionTableDescriptor::intersectionFunctionTableDescriptor() +{ + return Object::sendMessage(_MTL_PRIVATE_CLS(MTLIntersectionFunctionTableDescriptor), _MTL_PRIVATE_SEL(intersectionFunctionTableDescriptor)); +} + +_MTL_INLINE void MTL::IntersectionFunctionTableDescriptor::setFunctionCount(NS::UInteger functionCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFunctionCount_), functionCount); +} + +_MTL_INLINE MTL::ResourceID MTL::IntersectionFunctionTable::gpuResourceID() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(gpuResourceID)); +} + +_MTL_INLINE void MTL::IntersectionFunctionTable::setBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBuffer_offset_atIndex_), buffer, offset, index); +} + +_MTL_INLINE void MTL::IntersectionFunctionTable::setBuffers(const MTL::Buffer* const buffers[], const NS::UInteger offsets[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBuffers_offsets_withRange_), buffers, offsets, range); +} + +_MTL_INLINE void MTL::IntersectionFunctionTable::setFunction(const MTL::FunctionHandle* function, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFunction_atIndex_), function, index); +} + +_MTL_INLINE void MTL::IntersectionFunctionTable::setFunctions(const MTL::FunctionHandle* const functions[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFunctions_withRange_), functions, range); +} + +_MTL_INLINE void MTL::IntersectionFunctionTable::setOpaqueCurveIntersectionFunction(MTL::IntersectionFunctionSignature signature, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setOpaqueCurveIntersectionFunctionWithSignature_atIndex_), signature, index); +} + +_MTL_INLINE void MTL::IntersectionFunctionTable::setOpaqueCurveIntersectionFunction(MTL::IntersectionFunctionSignature signature, NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setOpaqueCurveIntersectionFunctionWithSignature_withRange_), signature, range); +} + +_MTL_INLINE void MTL::IntersectionFunctionTable::setOpaqueTriangleIntersectionFunction(MTL::IntersectionFunctionSignature signature, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setOpaqueTriangleIntersectionFunctionWithSignature_atIndex_), signature, index); +} + +_MTL_INLINE void MTL::IntersectionFunctionTable::setOpaqueTriangleIntersectionFunction(MTL::IntersectionFunctionSignature signature, NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setOpaqueTriangleIntersectionFunctionWithSignature_withRange_), signature, range); +} + +_MTL_INLINE void MTL::IntersectionFunctionTable::setVisibleFunctionTable(const MTL::VisibleFunctionTable* functionTable, NS::UInteger bufferIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVisibleFunctionTable_atBufferIndex_), functionTable, bufferIndex); +} + +_MTL_INLINE void MTL::IntersectionFunctionTable::setVisibleFunctionTables(const MTL::VisibleFunctionTable* const functionTables[], NS::Range bufferRange) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVisibleFunctionTables_withBufferRange_), functionTables, bufferRange); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLLibrary.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLLibrary.hpp new file mode 100644 index 00000000..44aa3a7a --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLLibrary.hpp @@ -0,0 +1,786 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLLibrary.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDataType.hpp" +#include "MTLDefines.hpp" +#include "MTLFunctionDescriptor.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLTypes.hpp" + +namespace MTL +{ +class Argument; +class ArgumentEncoder; +class Attribute; +class CompileOptions; +class Device; +class Function; +class FunctionConstant; +class FunctionConstantValues; +class FunctionDescriptor; +class FunctionReflection; +class IntersectionFunctionDescriptor; +class VertexAttribute; +_MTL_ENUM(NS::UInteger, PatchType) { + PatchTypeNone = 0, + PatchTypeTriangle = 1, + PatchTypeQuad = 2, +}; + +_MTL_ENUM(NS::UInteger, FunctionType) { + FunctionTypeVertex = 1, + FunctionTypeFragment = 2, + FunctionTypeKernel = 3, + FunctionTypeVisible = 5, + FunctionTypeIntersection = 6, + FunctionTypeMesh = 7, + FunctionTypeObject = 8, +}; + +_MTL_ENUM(NS::UInteger, LanguageVersion) { + LanguageVersion1_0 = 65536, + LanguageVersion1_1 = 65537, + LanguageVersion1_2 = 65538, + LanguageVersion2_0 = 131072, + LanguageVersion2_1 = 131073, + LanguageVersion2_2 = 131074, + LanguageVersion2_3 = 131075, + LanguageVersion2_4 = 131076, + LanguageVersion3_0 = 196608, + LanguageVersion3_1 = 196609, + LanguageVersion3_2 = 196610, + LanguageVersion4_0 = 262144, +}; + +_MTL_ENUM(NS::Integer, LibraryType) { + LibraryTypeExecutable = 0, + LibraryTypeDynamic = 1, +}; + +_MTL_ENUM(NS::Integer, LibraryOptimizationLevel) { + LibraryOptimizationLevelDefault = 0, + LibraryOptimizationLevelSize = 1, +}; + +_MTL_ENUM(NS::Integer, CompileSymbolVisibility) { + CompileSymbolVisibilityDefault = 0, + CompileSymbolVisibilityHidden = 1, +}; + +_MTL_ENUM(NS::Integer, MathMode) { + MathModeSafe = 0, + MathModeRelaxed = 1, + MathModeFast = 2, +}; + +_MTL_ENUM(NS::Integer, MathFloatingPointFunctions) { + MathFloatingPointFunctionsFast = 0, + MathFloatingPointFunctionsPrecise = 1, +}; + +_MTL_ENUM(NS::UInteger, LibraryError) { + LibraryErrorUnsupported = 1, + LibraryErrorInternal = 2, + LibraryErrorCompileFailure = 3, + LibraryErrorCompileWarning = 4, + LibraryErrorFunctionNotFound = 5, + LibraryErrorFileNotFound = 6, +}; + +using AutoreleasedArgument = MTL::Argument*; +using FunctionCompletionHandlerFunction = std::function; + +class VertexAttribute : public NS::Referencing +{ +public: + [[deprecated("please use isActive instead")]] + bool active() const; + + static VertexAttribute* alloc(); + + NS::UInteger attributeIndex() const; + + DataType attributeType() const; + + VertexAttribute* init(); + + bool isActive() const; + + bool isPatchControlPointData() const; + + bool isPatchData() const; + + NS::String* name() const; + + [[deprecated("please use isPatchControlPointData instead")]] + bool patchControlPointData() const; + + [[deprecated("please use isPatchData instead")]] + bool patchData() const; +}; +class Attribute : public NS::Referencing +{ +public: + [[deprecated("please use isActive instead")]] + bool active() const; + + static Attribute* alloc(); + + NS::UInteger attributeIndex() const; + + DataType attributeType() const; + + Attribute* init(); + + bool isActive() const; + + bool isPatchControlPointData() const; + + bool isPatchData() const; + + NS::String* name() const; + + [[deprecated("please use isPatchControlPointData instead")]] + bool patchControlPointData() const; + + [[deprecated("please use isPatchData instead")]] + bool patchData() const; +}; +class FunctionConstant : public NS::Referencing +{ +public: + static FunctionConstant* alloc(); + + NS::UInteger index() const; + + FunctionConstant* init(); + + NS::String* name() const; + + bool required() const; + + DataType type() const; +}; +class Function : public NS::Referencing +{ +public: + Device* device() const; + + NS::Dictionary* functionConstantsDictionary() const; + + FunctionType functionType() const; + + NS::String* label() const; + + NS::String* name() const; + + ArgumentEncoder* newArgumentEncoder(NS::UInteger bufferIndex); + ArgumentEncoder* newArgumentEncoder(NS::UInteger bufferIndex, const MTL::AutoreleasedArgument* reflection); + + FunctionOptions options() const; + + NS::Integer patchControlPointCount() const; + + PatchType patchType() const; + + void setLabel(const NS::String* label); + + NS::Array* stageInputAttributes() const; + + NS::Array* vertexAttributes() const; +}; +class CompileOptions : public NS::Copying +{ +public: + static CompileOptions* alloc(); + + bool allowReferencingUndefinedSymbols() const; + + CompileSymbolVisibility compileSymbolVisibility() const; + + bool enableLogging() const; + + bool fastMathEnabled() const; + + CompileOptions* init(); + + NS::String* installName() const; + + LanguageVersion languageVersion() const; + + NS::Array* libraries() const; + + LibraryType libraryType() const; + + MathFloatingPointFunctions mathFloatingPointFunctions() const; + + MathMode mathMode() const; + + NS::UInteger maxTotalThreadsPerThreadgroup() const; + + LibraryOptimizationLevel optimizationLevel() const; + + NS::Dictionary* preprocessorMacros() const; + + bool preserveInvariance() const; + + Size requiredThreadsPerThreadgroup() const; + + void setAllowReferencingUndefinedSymbols(bool allowReferencingUndefinedSymbols); + + void setCompileSymbolVisibility(MTL::CompileSymbolVisibility compileSymbolVisibility); + + void setEnableLogging(bool enableLogging); + + void setFastMathEnabled(bool fastMathEnabled); + + void setInstallName(const NS::String* installName); + + void setLanguageVersion(MTL::LanguageVersion languageVersion); + + void setLibraries(const NS::Array* libraries); + + void setLibraryType(MTL::LibraryType libraryType); + + void setMathFloatingPointFunctions(MTL::MathFloatingPointFunctions mathFloatingPointFunctions); + + void setMathMode(MTL::MathMode mathMode); + + void setMaxTotalThreadsPerThreadgroup(NS::UInteger maxTotalThreadsPerThreadgroup); + + void setOptimizationLevel(MTL::LibraryOptimizationLevel optimizationLevel); + + void setPreprocessorMacros(const NS::Dictionary* preprocessorMacros); + + void setPreserveInvariance(bool preserveInvariance); + + void setRequiredThreadsPerThreadgroup(MTL::Size requiredThreadsPerThreadgroup); +}; +class FunctionReflection : public NS::Referencing +{ +public: + static FunctionReflection* alloc(); + + NS::Array* bindings() const; + + FunctionReflection* init(); +}; +class Library : public NS::Referencing +{ +public: + Device* device() const; + + NS::Array* functionNames() const; + + NS::String* installName() const; + + NS::String* label() const; + + Function* newFunction(const NS::String* functionName); + Function* newFunction(const NS::String* name, const MTL::FunctionConstantValues* constantValues, NS::Error** error); + void newFunction(const NS::String* name, const MTL::FunctionConstantValues* constantValues, void (^completionHandler)(MTL::Function*, NS::Error*)); + void newFunction(const MTL::FunctionDescriptor* descriptor, void (^completionHandler)(MTL::Function*, NS::Error*)); + Function* newFunction(const MTL::FunctionDescriptor* descriptor, NS::Error** error); + void newFunction(const NS::String* pFunctionName, const MTL::FunctionConstantValues* pConstantValues, const MTL::FunctionCompletionHandlerFunction& completionHandler); + void newFunction(const MTL::FunctionDescriptor* pDescriptor, const MTL::FunctionCompletionHandlerFunction& completionHandler); + + void newIntersectionFunction(const MTL::IntersectionFunctionDescriptor* descriptor, void (^completionHandler)(MTL::Function*, NS::Error*)); + Function* newIntersectionFunction(const MTL::IntersectionFunctionDescriptor* descriptor, NS::Error** error); + void newIntersectionFunction(const MTL::IntersectionFunctionDescriptor* pDescriptor, const MTL::FunctionCompletionHandlerFunction& completionHandler); + + FunctionReflection* reflectionForFunction(const NS::String* functionName); + + void setLabel(const NS::String* label); + + LibraryType type() const; +}; + +} +_MTL_INLINE bool MTL::VertexAttribute::active() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isActive)); +} + +_MTL_INLINE MTL::VertexAttribute* MTL::VertexAttribute::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLVertexAttribute)); +} + +_MTL_INLINE NS::UInteger MTL::VertexAttribute::attributeIndex() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(attributeIndex)); +} + +_MTL_INLINE MTL::DataType MTL::VertexAttribute::attributeType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(attributeType)); +} + +_MTL_INLINE MTL::VertexAttribute* MTL::VertexAttribute::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE bool MTL::VertexAttribute::isActive() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isActive)); +} + +_MTL_INLINE bool MTL::VertexAttribute::isPatchControlPointData() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isPatchControlPointData)); +} + +_MTL_INLINE bool MTL::VertexAttribute::isPatchData() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isPatchData)); +} + +_MTL_INLINE NS::String* MTL::VertexAttribute::name() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(name)); +} + +_MTL_INLINE bool MTL::VertexAttribute::patchControlPointData() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isPatchControlPointData)); +} + +_MTL_INLINE bool MTL::VertexAttribute::patchData() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isPatchData)); +} + +_MTL_INLINE bool MTL::Attribute::active() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isActive)); +} + +_MTL_INLINE MTL::Attribute* MTL::Attribute::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLAttribute)); +} + +_MTL_INLINE NS::UInteger MTL::Attribute::attributeIndex() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(attributeIndex)); +} + +_MTL_INLINE MTL::DataType MTL::Attribute::attributeType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(attributeType)); +} + +_MTL_INLINE MTL::Attribute* MTL::Attribute::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE bool MTL::Attribute::isActive() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isActive)); +} + +_MTL_INLINE bool MTL::Attribute::isPatchControlPointData() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isPatchControlPointData)); +} + +_MTL_INLINE bool MTL::Attribute::isPatchData() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isPatchData)); +} + +_MTL_INLINE NS::String* MTL::Attribute::name() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(name)); +} + +_MTL_INLINE bool MTL::Attribute::patchControlPointData() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isPatchControlPointData)); +} + +_MTL_INLINE bool MTL::Attribute::patchData() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isPatchData)); +} + +_MTL_INLINE MTL::FunctionConstant* MTL::FunctionConstant::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLFunctionConstant)); +} + +_MTL_INLINE NS::UInteger MTL::FunctionConstant::index() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(index)); +} + +_MTL_INLINE MTL::FunctionConstant* MTL::FunctionConstant::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::String* MTL::FunctionConstant::name() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(name)); +} + +_MTL_INLINE bool MTL::FunctionConstant::required() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(required)); +} + +_MTL_INLINE MTL::DataType MTL::FunctionConstant::type() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(type)); +} + +_MTL_INLINE MTL::Device* MTL::Function::device() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(device)); +} + +_MTL_INLINE NS::Dictionary* MTL::Function::functionConstantsDictionary() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(functionConstantsDictionary)); +} + +_MTL_INLINE MTL::FunctionType MTL::Function::functionType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(functionType)); +} + +_MTL_INLINE NS::String* MTL::Function::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE NS::String* MTL::Function::name() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(name)); +} + +_MTL_INLINE MTL::ArgumentEncoder* MTL::Function::newArgumentEncoder(NS::UInteger bufferIndex) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newArgumentEncoderWithBufferIndex_), bufferIndex); +} + +_MTL_INLINE MTL::ArgumentEncoder* MTL::Function::newArgumentEncoder(NS::UInteger bufferIndex, const MTL::AutoreleasedArgument* reflection) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newArgumentEncoderWithBufferIndex_reflection_), bufferIndex, reflection); +} + +_MTL_INLINE MTL::FunctionOptions MTL::Function::options() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(options)); +} + +_MTL_INLINE NS::Integer MTL::Function::patchControlPointCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(patchControlPointCount)); +} + +_MTL_INLINE MTL::PatchType MTL::Function::patchType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(patchType)); +} + +_MTL_INLINE void MTL::Function::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +_MTL_INLINE NS::Array* MTL::Function::stageInputAttributes() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(stageInputAttributes)); +} + +_MTL_INLINE NS::Array* MTL::Function::vertexAttributes() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(vertexAttributes)); +} + +_MTL_INLINE MTL::CompileOptions* MTL::CompileOptions::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLCompileOptions)); +} + +_MTL_INLINE bool MTL::CompileOptions::allowReferencingUndefinedSymbols() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(allowReferencingUndefinedSymbols)); +} + +_MTL_INLINE MTL::CompileSymbolVisibility MTL::CompileOptions::compileSymbolVisibility() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(compileSymbolVisibility)); +} + +_MTL_INLINE bool MTL::CompileOptions::enableLogging() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(enableLogging)); +} + +_MTL_INLINE bool MTL::CompileOptions::fastMathEnabled() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(fastMathEnabled)); +} + +_MTL_INLINE MTL::CompileOptions* MTL::CompileOptions::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::String* MTL::CompileOptions::installName() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(installName)); +} + +_MTL_INLINE MTL::LanguageVersion MTL::CompileOptions::languageVersion() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(languageVersion)); +} + +_MTL_INLINE NS::Array* MTL::CompileOptions::libraries() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(libraries)); +} + +_MTL_INLINE MTL::LibraryType MTL::CompileOptions::libraryType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(libraryType)); +} + +_MTL_INLINE MTL::MathFloatingPointFunctions MTL::CompileOptions::mathFloatingPointFunctions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(mathFloatingPointFunctions)); +} + +_MTL_INLINE MTL::MathMode MTL::CompileOptions::mathMode() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(mathMode)); +} + +_MTL_INLINE NS::UInteger MTL::CompileOptions::maxTotalThreadsPerThreadgroup() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxTotalThreadsPerThreadgroup)); +} + +_MTL_INLINE MTL::LibraryOptimizationLevel MTL::CompileOptions::optimizationLevel() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(optimizationLevel)); +} + +_MTL_INLINE NS::Dictionary* MTL::CompileOptions::preprocessorMacros() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(preprocessorMacros)); +} + +_MTL_INLINE bool MTL::CompileOptions::preserveInvariance() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(preserveInvariance)); +} + +_MTL_INLINE MTL::Size MTL::CompileOptions::requiredThreadsPerThreadgroup() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(requiredThreadsPerThreadgroup)); +} + +_MTL_INLINE void MTL::CompileOptions::setAllowReferencingUndefinedSymbols(bool allowReferencingUndefinedSymbols) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setAllowReferencingUndefinedSymbols_), allowReferencingUndefinedSymbols); +} + +_MTL_INLINE void MTL::CompileOptions::setCompileSymbolVisibility(MTL::CompileSymbolVisibility compileSymbolVisibility) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setCompileSymbolVisibility_), compileSymbolVisibility); +} + +_MTL_INLINE void MTL::CompileOptions::setEnableLogging(bool enableLogging) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setEnableLogging_), enableLogging); +} + +_MTL_INLINE void MTL::CompileOptions::setFastMathEnabled(bool fastMathEnabled) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFastMathEnabled_), fastMathEnabled); +} + +_MTL_INLINE void MTL::CompileOptions::setInstallName(const NS::String* installName) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInstallName_), installName); +} + +_MTL_INLINE void MTL::CompileOptions::setLanguageVersion(MTL::LanguageVersion languageVersion) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLanguageVersion_), languageVersion); +} + +_MTL_INLINE void MTL::CompileOptions::setLibraries(const NS::Array* libraries) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLibraries_), libraries); +} + +_MTL_INLINE void MTL::CompileOptions::setLibraryType(MTL::LibraryType libraryType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLibraryType_), libraryType); +} + +_MTL_INLINE void MTL::CompileOptions::setMathFloatingPointFunctions(MTL::MathFloatingPointFunctions mathFloatingPointFunctions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMathFloatingPointFunctions_), mathFloatingPointFunctions); +} + +_MTL_INLINE void MTL::CompileOptions::setMathMode(MTL::MathMode mathMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMathMode_), mathMode); +} + +_MTL_INLINE void MTL::CompileOptions::setMaxTotalThreadsPerThreadgroup(NS::UInteger maxTotalThreadsPerThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxTotalThreadsPerThreadgroup_), maxTotalThreadsPerThreadgroup); +} + +_MTL_INLINE void MTL::CompileOptions::setOptimizationLevel(MTL::LibraryOptimizationLevel optimizationLevel) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setOptimizationLevel_), optimizationLevel); +} + +_MTL_INLINE void MTL::CompileOptions::setPreprocessorMacros(const NS::Dictionary* preprocessorMacros) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setPreprocessorMacros_), preprocessorMacros); +} + +_MTL_INLINE void MTL::CompileOptions::setPreserveInvariance(bool preserveInvariance) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setPreserveInvariance_), preserveInvariance); +} + +_MTL_INLINE void MTL::CompileOptions::setRequiredThreadsPerThreadgroup(MTL::Size requiredThreadsPerThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRequiredThreadsPerThreadgroup_), requiredThreadsPerThreadgroup); +} + +_MTL_INLINE MTL::FunctionReflection* MTL::FunctionReflection::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLFunctionReflection)); +} + +_MTL_INLINE NS::Array* MTL::FunctionReflection::bindings() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(bindings)); +} + +_MTL_INLINE MTL::FunctionReflection* MTL::FunctionReflection::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::Device* MTL::Library::device() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(device)); +} + +_MTL_INLINE NS::Array* MTL::Library::functionNames() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(functionNames)); +} + +_MTL_INLINE NS::String* MTL::Library::installName() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(installName)); +} + +_MTL_INLINE NS::String* MTL::Library::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE MTL::Function* MTL::Library::newFunction(const NS::String* functionName) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newFunctionWithName_), functionName); +} + +_MTL_INLINE MTL::Function* MTL::Library::newFunction(const NS::String* name, const MTL::FunctionConstantValues* constantValues, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newFunctionWithName_constantValues_error_), name, constantValues, error); +} + +_MTL_INLINE void MTL::Library::newFunction(const NS::String* name, const MTL::FunctionConstantValues* constantValues, void (^completionHandler)(MTL::Function*, NS::Error*)) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(newFunctionWithName_constantValues_completionHandler_), name, constantValues, completionHandler); +} + +_MTL_INLINE void MTL::Library::newFunction(const MTL::FunctionDescriptor* descriptor, void (^completionHandler)(MTL::Function*, NS::Error*)) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(newFunctionWithDescriptor_completionHandler_), descriptor, completionHandler); +} + +_MTL_INLINE MTL::Function* MTL::Library::newFunction(const MTL::FunctionDescriptor* descriptor, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newFunctionWithDescriptor_error_), descriptor, error); +} + +_MTL_INLINE void MTL::Library::newFunction(const NS::String* pFunctionName, const MTL::FunctionConstantValues* pConstantValues, const MTL::FunctionCompletionHandlerFunction& completionHandler) +{ + __block MTL::FunctionCompletionHandlerFunction blockCompletionHandler = completionHandler; + newFunction(pFunctionName, pConstantValues, ^(MTL::Function* pFunction, NS::Error* pError) { blockCompletionHandler(pFunction, pError); }); +} + +_MTL_INLINE void MTL::Library::newFunction(const MTL::FunctionDescriptor* pDescriptor, const MTL::FunctionCompletionHandlerFunction& completionHandler) +{ + __block MTL::FunctionCompletionHandlerFunction blockCompletionHandler = completionHandler; + newFunction(pDescriptor, ^(MTL::Function* pFunction, NS::Error* pError) { blockCompletionHandler(pFunction, pError); }); +} + +_MTL_INLINE void MTL::Library::newIntersectionFunction(const MTL::IntersectionFunctionDescriptor* descriptor, void (^completionHandler)(MTL::Function*, NS::Error*)) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(newIntersectionFunctionWithDescriptor_completionHandler_), descriptor, completionHandler); +} + +_MTL_INLINE MTL::Function* MTL::Library::newIntersectionFunction(const MTL::IntersectionFunctionDescriptor* descriptor, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newIntersectionFunctionWithDescriptor_error_), descriptor, error); +} + +_MTL_INLINE void MTL::Library::newIntersectionFunction(const MTL::IntersectionFunctionDescriptor* pDescriptor, const MTL::FunctionCompletionHandlerFunction& completionHandler) +{ + __block MTL::FunctionCompletionHandlerFunction blockCompletionHandler = completionHandler; + newIntersectionFunction(pDescriptor, ^(MTL::Function* pFunction, NS::Error* pError) { blockCompletionHandler(pFunction, pError); }); +} + +_MTL_INLINE MTL::FunctionReflection* MTL::Library::reflectionForFunction(const NS::String* functionName) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(reflectionForFunctionWithName_), functionName); +} + +_MTL_INLINE void MTL::Library::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +_MTL_INLINE MTL::LibraryType MTL::Library::type() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(type)); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLLinkedFunctions.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLLinkedFunctions.hpp new file mode 100644 index 00000000..4b1bd953 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLLinkedFunctions.hpp @@ -0,0 +1,110 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLLinkedFunctions.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL +{ + +class LinkedFunctions : public NS::Copying +{ +public: + static LinkedFunctions* alloc(); + + NS::Array* binaryFunctions() const; + NS::Array* functions() const; + + NS::Dictionary* groups() const; + + LinkedFunctions* init(); + + static LinkedFunctions* linkedFunctions(); + + NS::Array* privateFunctions() const; + + void setBinaryFunctions(const NS::Array* binaryFunctions); + + void setFunctions(const NS::Array* functions); + + void setGroups(const NS::Dictionary* groups); + + void setPrivateFunctions(const NS::Array* privateFunctions); +}; + +} +_MTL_INLINE MTL::LinkedFunctions* MTL::LinkedFunctions::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLLinkedFunctions)); +} + +_MTL_INLINE NS::Array* MTL::LinkedFunctions::binaryFunctions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(binaryFunctions)); +} + +_MTL_INLINE NS::Array* MTL::LinkedFunctions::functions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(functions)); +} + +_MTL_INLINE NS::Dictionary* MTL::LinkedFunctions::groups() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(groups)); +} + +_MTL_INLINE MTL::LinkedFunctions* MTL::LinkedFunctions::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::LinkedFunctions* MTL::LinkedFunctions::linkedFunctions() +{ + return Object::sendMessage(_MTL_PRIVATE_CLS(MTLLinkedFunctions), _MTL_PRIVATE_SEL(linkedFunctions)); +} + +_MTL_INLINE NS::Array* MTL::LinkedFunctions::privateFunctions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(privateFunctions)); +} + +_MTL_INLINE void MTL::LinkedFunctions::setBinaryFunctions(const NS::Array* binaryFunctions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBinaryFunctions_), binaryFunctions); +} + +_MTL_INLINE void MTL::LinkedFunctions::setFunctions(const NS::Array* functions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFunctions_), functions); +} + +_MTL_INLINE void MTL::LinkedFunctions::setGroups(const NS::Dictionary* groups) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setGroups_), groups); +} + +_MTL_INLINE void MTL::LinkedFunctions::setPrivateFunctions(const NS::Array* privateFunctions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setPrivateFunctions_), privateFunctions); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLLogState.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLLogState.hpp new file mode 100644 index 00000000..b802adf3 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLLogState.hpp @@ -0,0 +1,111 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLLogState.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL +{ +class LogStateDescriptor; +_MTL_ENUM(NS::Integer, LogLevel) { + LogLevelUndefined = 0, + LogLevelDebug = 1, + LogLevelInfo = 2, + LogLevelNotice = 3, + LogLevelError = 4, + LogLevelFault = 5, +}; + +_MTL_ENUM(NS::UInteger, LogStateError) { + LogStateErrorInvalidSize = 1, + LogStateErrorInvalid = 2, +}; + +using LogHandlerFunction = std::function; + +_MTL_CONST(NS::ErrorDomain, LogStateErrorDomain); +class LogState : public NS::Referencing +{ +public: + void addLogHandler(void (^block)(NS::String*, NS::String*, MTL::LogLevel, NS::String*)); + void addLogHandler(const MTL::LogHandlerFunction& handler); +}; +class LogStateDescriptor : public NS::Copying +{ +public: + static LogStateDescriptor* alloc(); + + NS::Integer bufferSize() const; + + LogStateDescriptor* init(); + + LogLevel level() const; + + void setBufferSize(NS::Integer bufferSize); + + void setLevel(MTL::LogLevel level); +}; + +} +_MTL_PRIVATE_DEF_CONST(NS::ErrorDomain, LogStateErrorDomain); +_MTL_INLINE void MTL::LogState::addLogHandler(void (^block)(NS::String*, NS::String*, MTL::LogLevel, NS::String*)) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(addLogHandler_), block); +} + +_MTL_INLINE void MTL::LogState::addLogHandler(const MTL::LogHandlerFunction& handler) +{ + __block LogHandlerFunction function = handler; + addLogHandler(^void(NS::String* subsystem, NS::String* category, MTL::LogLevel logLevel, NS::String* message) { function(subsystem, category, logLevel, message); }); +} + +_MTL_INLINE MTL::LogStateDescriptor* MTL::LogStateDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLLogStateDescriptor)); +} + +_MTL_INLINE NS::Integer MTL::LogStateDescriptor::bufferSize() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(bufferSize)); +} + +_MTL_INLINE MTL::LogStateDescriptor* MTL::LogStateDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::LogLevel MTL::LogStateDescriptor::level() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(level)); +} + +_MTL_INLINE void MTL::LogStateDescriptor::setBufferSize(NS::Integer bufferSize) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBufferSize_), bufferSize); +} + +_MTL_INLINE void MTL::LogStateDescriptor::setLevel(MTL::LogLevel level) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLevel_), level); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLParallelRenderCommandEncoder.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLParallelRenderCommandEncoder.hpp new file mode 100644 index 00000000..8c345126 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLParallelRenderCommandEncoder.hpp @@ -0,0 +1,83 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLParallelRenderCommandEncoder.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLCommandEncoder.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLRenderPass.hpp" + +namespace MTL +{ +class RenderCommandEncoder; + +class ParallelRenderCommandEncoder : public NS::Referencing +{ +public: + RenderCommandEncoder* renderCommandEncoder(); + + void setColorStoreAction(MTL::StoreAction storeAction, NS::UInteger colorAttachmentIndex); + void setColorStoreActionOptions(MTL::StoreActionOptions storeActionOptions, NS::UInteger colorAttachmentIndex); + + void setDepthStoreAction(MTL::StoreAction storeAction); + void setDepthStoreActionOptions(MTL::StoreActionOptions storeActionOptions); + + void setStencilStoreAction(MTL::StoreAction storeAction); + void setStencilStoreActionOptions(MTL::StoreActionOptions storeActionOptions); +}; + +} +_MTL_INLINE MTL::RenderCommandEncoder* MTL::ParallelRenderCommandEncoder::renderCommandEncoder() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(renderCommandEncoder)); +} + +_MTL_INLINE void MTL::ParallelRenderCommandEncoder::setColorStoreAction(MTL::StoreAction storeAction, NS::UInteger colorAttachmentIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setColorStoreAction_atIndex_), storeAction, colorAttachmentIndex); +} + +_MTL_INLINE void MTL::ParallelRenderCommandEncoder::setColorStoreActionOptions(MTL::StoreActionOptions storeActionOptions, NS::UInteger colorAttachmentIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setColorStoreActionOptions_atIndex_), storeActionOptions, colorAttachmentIndex); +} + +_MTL_INLINE void MTL::ParallelRenderCommandEncoder::setDepthStoreAction(MTL::StoreAction storeAction) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDepthStoreAction_), storeAction); +} + +_MTL_INLINE void MTL::ParallelRenderCommandEncoder::setDepthStoreActionOptions(MTL::StoreActionOptions storeActionOptions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDepthStoreActionOptions_), storeActionOptions); +} + +_MTL_INLINE void MTL::ParallelRenderCommandEncoder::setStencilStoreAction(MTL::StoreAction storeAction) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStencilStoreAction_), storeAction); +} + +_MTL_INLINE void MTL::ParallelRenderCommandEncoder::setStencilStoreActionOptions(MTL::StoreActionOptions storeActionOptions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStencilStoreActionOptions_), storeActionOptions); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLPipeline.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLPipeline.hpp new file mode 100644 index 00000000..930bb7eb --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLPipeline.hpp @@ -0,0 +1,104 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLPipeline.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL +{ +class PipelineBufferDescriptor; +class PipelineBufferDescriptorArray; +_MTL_ENUM(NS::UInteger, Mutability) { + MutabilityDefault = 0, + MutabilityMutable = 1, + MutabilityImmutable = 2, +}; + +_MTL_ENUM(NS::Integer, ShaderValidation) { + ShaderValidationDefault = 0, + ShaderValidationEnabled = 1, + ShaderValidationDisabled = 2, +}; + +class PipelineBufferDescriptor : public NS::Copying +{ +public: + static PipelineBufferDescriptor* alloc(); + + PipelineBufferDescriptor* init(); + + Mutability mutability() const; + void setMutability(MTL::Mutability mutability); +}; +class PipelineBufferDescriptorArray : public NS::Referencing +{ +public: + static PipelineBufferDescriptorArray* alloc(); + + PipelineBufferDescriptorArray* init(); + + PipelineBufferDescriptor* object(NS::UInteger bufferIndex); + void setObject(const MTL::PipelineBufferDescriptor* buffer, NS::UInteger bufferIndex); +}; + +} +_MTL_INLINE MTL::PipelineBufferDescriptor* MTL::PipelineBufferDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLPipelineBufferDescriptor)); +} + +_MTL_INLINE MTL::PipelineBufferDescriptor* MTL::PipelineBufferDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::Mutability MTL::PipelineBufferDescriptor::mutability() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(mutability)); +} + +_MTL_INLINE void MTL::PipelineBufferDescriptor::setMutability(MTL::Mutability mutability) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMutability_), mutability); +} + +_MTL_INLINE MTL::PipelineBufferDescriptorArray* MTL::PipelineBufferDescriptorArray::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLPipelineBufferDescriptorArray)); +} + +_MTL_INLINE MTL::PipelineBufferDescriptorArray* MTL::PipelineBufferDescriptorArray::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::PipelineBufferDescriptor* MTL::PipelineBufferDescriptorArray::object(NS::UInteger bufferIndex) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(objectAtIndexedSubscript_), bufferIndex); +} + +_MTL_INLINE void MTL::PipelineBufferDescriptorArray::setObject(const MTL::PipelineBufferDescriptor* buffer, NS::UInteger bufferIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObject_atIndexedSubscript_), buffer, bufferIndex); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLPixelFormat.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLPixelFormat.hpp new file mode 100644 index 00000000..6d5d886c --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLPixelFormat.hpp @@ -0,0 +1,173 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLPixelFormat.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL +{ +_MTL_ENUM(NS::UInteger, PixelFormat) { + PixelFormatInvalid = 0, + PixelFormatA8Unorm = 1, + PixelFormatR8Unorm = 10, + PixelFormatR8Unorm_sRGB = 11, + PixelFormatR8Snorm = 12, + PixelFormatR8Uint = 13, + PixelFormatR8Sint = 14, + PixelFormatR16Unorm = 20, + PixelFormatR16Snorm = 22, + PixelFormatR16Uint = 23, + PixelFormatR16Sint = 24, + PixelFormatR16Float = 25, + PixelFormatRG8Unorm = 30, + PixelFormatRG8Unorm_sRGB = 31, + PixelFormatRG8Snorm = 32, + PixelFormatRG8Uint = 33, + PixelFormatRG8Sint = 34, + PixelFormatB5G6R5Unorm = 40, + PixelFormatA1BGR5Unorm = 41, + PixelFormatABGR4Unorm = 42, + PixelFormatBGR5A1Unorm = 43, + PixelFormatR32Uint = 53, + PixelFormatR32Sint = 54, + PixelFormatR32Float = 55, + PixelFormatRG16Unorm = 60, + PixelFormatRG16Snorm = 62, + PixelFormatRG16Uint = 63, + PixelFormatRG16Sint = 64, + PixelFormatRG16Float = 65, + PixelFormatRGBA8Unorm = 70, + PixelFormatRGBA8Unorm_sRGB = 71, + PixelFormatRGBA8Snorm = 72, + PixelFormatRGBA8Uint = 73, + PixelFormatRGBA8Sint = 74, + PixelFormatBGRA8Unorm = 80, + PixelFormatBGRA8Unorm_sRGB = 81, + PixelFormatRGB10A2Unorm = 90, + PixelFormatRGB10A2Uint = 91, + PixelFormatRG11B10Float = 92, + PixelFormatRGB9E5Float = 93, + PixelFormatBGR10A2Unorm = 94, + PixelFormatBGR10_XR = 554, + PixelFormatBGR10_XR_sRGB = 555, + PixelFormatRG32Uint = 103, + PixelFormatRG32Sint = 104, + PixelFormatRG32Float = 105, + PixelFormatRGBA16Unorm = 110, + PixelFormatRGBA16Snorm = 112, + PixelFormatRGBA16Uint = 113, + PixelFormatRGBA16Sint = 114, + PixelFormatRGBA16Float = 115, + PixelFormatBGRA10_XR = 552, + PixelFormatBGRA10_XR_sRGB = 553, + PixelFormatRGBA32Uint = 123, + PixelFormatRGBA32Sint = 124, + PixelFormatRGBA32Float = 125, + PixelFormatBC1_RGBA = 130, + PixelFormatBC1_RGBA_sRGB = 131, + PixelFormatBC2_RGBA = 132, + PixelFormatBC2_RGBA_sRGB = 133, + PixelFormatBC3_RGBA = 134, + PixelFormatBC3_RGBA_sRGB = 135, + PixelFormatBC4_RUnorm = 140, + PixelFormatBC4_RSnorm = 141, + PixelFormatBC5_RGUnorm = 142, + PixelFormatBC5_RGSnorm = 143, + PixelFormatBC6H_RGBFloat = 150, + PixelFormatBC6H_RGBUfloat = 151, + PixelFormatBC7_RGBAUnorm = 152, + PixelFormatBC7_RGBAUnorm_sRGB = 153, + PixelFormatPVRTC_RGB_2BPP = 160, + PixelFormatPVRTC_RGB_2BPP_sRGB = 161, + PixelFormatPVRTC_RGB_4BPP = 162, + PixelFormatPVRTC_RGB_4BPP_sRGB = 163, + PixelFormatPVRTC_RGBA_2BPP = 164, + PixelFormatPVRTC_RGBA_2BPP_sRGB = 165, + PixelFormatPVRTC_RGBA_4BPP = 166, + PixelFormatPVRTC_RGBA_4BPP_sRGB = 167, + PixelFormatEAC_R11Unorm = 170, + PixelFormatEAC_R11Snorm = 172, + PixelFormatEAC_RG11Unorm = 174, + PixelFormatEAC_RG11Snorm = 176, + PixelFormatEAC_RGBA8 = 178, + PixelFormatEAC_RGBA8_sRGB = 179, + PixelFormatETC2_RGB8 = 180, + PixelFormatETC2_RGB8_sRGB = 181, + PixelFormatETC2_RGB8A1 = 182, + PixelFormatETC2_RGB8A1_sRGB = 183, + PixelFormatASTC_4x4_sRGB = 186, + PixelFormatASTC_5x4_sRGB = 187, + PixelFormatASTC_5x5_sRGB = 188, + PixelFormatASTC_6x5_sRGB = 189, + PixelFormatASTC_6x6_sRGB = 190, + PixelFormatASTC_8x5_sRGB = 192, + PixelFormatASTC_8x6_sRGB = 193, + PixelFormatASTC_8x8_sRGB = 194, + PixelFormatASTC_10x5_sRGB = 195, + PixelFormatASTC_10x6_sRGB = 196, + PixelFormatASTC_10x8_sRGB = 197, + PixelFormatASTC_10x10_sRGB = 198, + PixelFormatASTC_12x10_sRGB = 199, + PixelFormatASTC_12x12_sRGB = 200, + PixelFormatASTC_4x4_LDR = 204, + PixelFormatASTC_5x4_LDR = 205, + PixelFormatASTC_5x5_LDR = 206, + PixelFormatASTC_6x5_LDR = 207, + PixelFormatASTC_6x6_LDR = 208, + PixelFormatASTC_8x5_LDR = 210, + PixelFormatASTC_8x6_LDR = 211, + PixelFormatASTC_8x8_LDR = 212, + PixelFormatASTC_10x5_LDR = 213, + PixelFormatASTC_10x6_LDR = 214, + PixelFormatASTC_10x8_LDR = 215, + PixelFormatASTC_10x10_LDR = 216, + PixelFormatASTC_12x10_LDR = 217, + PixelFormatASTC_12x12_LDR = 218, + PixelFormatASTC_4x4_HDR = 222, + PixelFormatASTC_5x4_HDR = 223, + PixelFormatASTC_5x5_HDR = 224, + PixelFormatASTC_6x5_HDR = 225, + PixelFormatASTC_6x6_HDR = 226, + PixelFormatASTC_8x5_HDR = 228, + PixelFormatASTC_8x6_HDR = 229, + PixelFormatASTC_8x8_HDR = 230, + PixelFormatASTC_10x5_HDR = 231, + PixelFormatASTC_10x6_HDR = 232, + PixelFormatASTC_10x8_HDR = 233, + PixelFormatASTC_10x10_HDR = 234, + PixelFormatASTC_12x10_HDR = 235, + PixelFormatASTC_12x12_HDR = 236, + PixelFormatGBGR422 = 240, + PixelFormatBGRG422 = 241, + PixelFormatDepth16Unorm = 250, + PixelFormatDepth32Float = 252, + PixelFormatStencil8 = 253, + PixelFormatDepth24Unorm_Stencil8 = 255, + PixelFormatDepth32Float_Stencil8 = 260, + PixelFormatX32_Stencil8 = 261, + PixelFormatX24_Stencil8 = 262, + PixelFormatUnspecialized = 263, +}; + +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLPrivate.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLPrivate.hpp new file mode 100644 index 00000000..41bcaa50 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLPrivate.hpp @@ -0,0 +1,156 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLPrivate.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "MTLDefines.hpp" + +#include + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#define _MTL_PRIVATE_CLS(symbol) (MTL::Private::Class::s_k##symbol) +#define _MTL_PRIVATE_SEL(accessor) (MTL::Private::Selector::s_k##accessor) + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#if defined(MTL_PRIVATE_IMPLEMENTATION) + +#ifdef METALCPP_SYMBOL_VISIBILITY_HIDDEN +#define _MTL_PRIVATE_VISIBILITY __attribute__((visibility("hidden"))) +#else +#define _MTL_PRIVATE_VISIBILITY __attribute__((visibility("default"))) +#endif // METALCPP_SYMBOL_VISIBILITY_HIDDEN + +#define _MTL_PRIVATE_IMPORT __attribute__((weak_import)) + +#ifdef __OBJC__ +#define _MTL_PRIVATE_OBJC_LOOKUP_CLASS(symbol) ((__bridge void*)objc_lookUpClass(#symbol)) +#define _MTL_PRIVATE_OBJC_GET_PROTOCOL(symbol) ((__bridge void*)objc_getProtocol(#symbol)) +#else +#define _MTL_PRIVATE_OBJC_LOOKUP_CLASS(symbol) objc_lookUpClass(#symbol) +#define _MTL_PRIVATE_OBJC_GET_PROTOCOL(symbol) objc_getProtocol(#symbol) +#endif // __OBJC__ + +#define _MTL_PRIVATE_DEF_CLS(symbol) void* s_k##symbol _MTL_PRIVATE_VISIBILITY = _MTL_PRIVATE_OBJC_LOOKUP_CLASS(symbol) +#define _MTL_PRIVATE_DEF_PRO(symbol) void* s_k##symbol _MTL_PRIVATE_VISIBILITY = _MTL_PRIVATE_OBJC_GET_PROTOCOL(symbol) +#define _MTL_PRIVATE_DEF_SEL(accessor, symbol) SEL s_k##accessor _MTL_PRIVATE_VISIBILITY = sel_registerName(symbol) + +#include +#define MTL_DEF_FUNC( name, signature ) \ + using Fn##name = signature; \ + Fn##name name = reinterpret_cast< Fn##name >( dlsym( RTLD_DEFAULT, #name ) ) + +namespace MTL::Private +{ + template + inline _Type const LoadSymbol(const char* pSymbol) + { + const _Type* pAddress = static_cast<_Type*>(dlsym(RTLD_DEFAULT, pSymbol)); + + return pAddress ? *pAddress : nullptr; + } +} // MTL::Private + +#if defined(__MAC_26_0) || defined(__IPHONE_26_0) || defined(__TVOS_26_0) + +#define _MTL_PRIVATE_DEF_STR(type, symbol) \ + _MTL_EXTERN type const MTL##symbol _MTL_PRIVATE_IMPORT; \ + type const MTL::symbol = (nullptr != &MTL##symbol) ? MTL##symbol : nullptr + +#define _MTL_PRIVATE_DEF_CONST(type, symbol) \ + _MTL_EXTERN type const MTL##symbol _MTL_PRIVATE_IMPORT; \ + type const MTL::symbol = (nullptr != &MTL##symbol) ? MTL##symbol : nullptr + +#define _MTL_PRIVATE_DEF_WEAK_CONST(type, symbol) \ + _MTL_EXTERN type const MTL##symbol; \ + type const MTL::symbol = MTL::Private::LoadSymbol("MTL" #symbol) + +#else + +#define _MTL_PRIVATE_DEF_STR(type, symbol) \ + _MTL_EXTERN type const MTL##symbol; \ + type const MTL::symbol = MTL::Private::LoadSymbol("MTL" #symbol) + +#define _MTL_PRIVATE_DEF_CONST(type, symbol) \ + _MTL_EXTERN type const MTL##symbol; \ + type const MTL::symbol = MTL::Private::LoadSymbol("MTL" #symbol) + +#define _MTL_PRIVATE_DEF_WEAK_CONST(type, symbol) _MTL_PRIVATE_DEF_CONST(type, symbol) + +#endif + +#else + +#define _MTL_PRIVATE_DEF_CLS(symbol) extern void* s_k##symbol +#define _MTL_PRIVATE_DEF_PRO(symbol) extern void* s_k##symbol +#define _MTL_PRIVATE_DEF_SEL(accessor, symbol) extern SEL s_k##accessor +#define _MTL_PRIVATE_DEF_STR(type, symbol) extern type const MTL::symbol +#define _MTL_PRIVATE_DEF_CONST(type, symbol) extern type const MTL::symbol +#define _MTL_PRIVATE_DEF_WEAK_CONST(type, symbol) extern type const MTL::symbol + +#endif // MTL_PRIVATE_IMPLEMENTATION + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace MTL +{ +namespace Private +{ + namespace Class + { + + } // Class +} // Private +} // MTL + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace MTL +{ +namespace Private +{ + namespace Protocol + { + + } // Protocol +} // Private +} // MTL + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace MTL +{ +namespace Private +{ + namespace Selector + { + + _MTL_PRIVATE_DEF_SEL(beginScope, + "beginScope"); + _MTL_PRIVATE_DEF_SEL(endScope, + "endScope"); + } // Class +} // Private +} // MTL + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLRasterizationRate.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLRasterizationRate.hpp new file mode 100644 index 00000000..b2804fa8 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLRasterizationRate.hpp @@ -0,0 +1,337 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLRasterizationRate.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLDevice.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLTypes.hpp" + +namespace MTL +{ +class Buffer; +class Device; +class RasterizationRateLayerArray; +class RasterizationRateLayerDescriptor; +class RasterizationRateMapDescriptor; +class RasterizationRateSampleArray; + +class RasterizationRateSampleArray : public NS::Referencing +{ +public: + static RasterizationRateSampleArray* alloc(); + + RasterizationRateSampleArray* init(); + + NS::Number* object(NS::UInteger index); + void setObject(const NS::Number* value, NS::UInteger index); +}; +class RasterizationRateLayerDescriptor : public NS::Copying +{ +public: + static RasterizationRateLayerDescriptor* alloc(); + + RasterizationRateSampleArray* horizontal() const; + float* horizontalSampleStorage() const; + + RasterizationRateLayerDescriptor* init(); + RasterizationRateLayerDescriptor* init(MTL::Size sampleCount); + RasterizationRateLayerDescriptor* init(MTL::Size sampleCount, const float* horizontal, const float* vertical); + + Size maxSampleCount() const; + Size sampleCount() const; + void setSampleCount(MTL::Size sampleCount); + + RasterizationRateSampleArray* vertical() const; + float* verticalSampleStorage() const; +}; +class RasterizationRateLayerArray : public NS::Referencing +{ +public: + static RasterizationRateLayerArray* alloc(); + + RasterizationRateLayerArray* init(); + + RasterizationRateLayerDescriptor* object(NS::UInteger layerIndex); + void setObject(const MTL::RasterizationRateLayerDescriptor* layer, NS::UInteger layerIndex); +}; +class RasterizationRateMapDescriptor : public NS::Copying +{ +public: + static RasterizationRateMapDescriptor* alloc(); + + RasterizationRateMapDescriptor* init(); + + NS::String* label() const; + + RasterizationRateLayerDescriptor* layer(NS::UInteger layerIndex); + NS::UInteger layerCount() const; + + RasterizationRateLayerArray* layers() const; + + static RasterizationRateMapDescriptor* rasterizationRateMapDescriptor(MTL::Size screenSize); + static RasterizationRateMapDescriptor* rasterizationRateMapDescriptor(MTL::Size screenSize, const MTL::RasterizationRateLayerDescriptor* layer); + static RasterizationRateMapDescriptor* rasterizationRateMapDescriptor(MTL::Size screenSize, NS::UInteger layerCount, const MTL::RasterizationRateLayerDescriptor* const* layers); + + Size screenSize() const; + + void setLabel(const NS::String* label); + + void setLayer(const MTL::RasterizationRateLayerDescriptor* layer, NS::UInteger layerIndex); + + void setScreenSize(MTL::Size screenSize); +}; +class RasterizationRateMap : public NS::Referencing +{ +public: + void copyParameterDataToBuffer(const MTL::Buffer* buffer, NS::UInteger offset); + + Device* device() const; + + NS::String* label() const; + + NS::UInteger layerCount() const; + + Coordinate2D mapPhysicalToScreenCoordinates(MTL::Coordinate2D physicalCoordinates, NS::UInteger layerIndex); + + Coordinate2D mapScreenToPhysicalCoordinates(MTL::Coordinate2D screenCoordinates, NS::UInteger layerIndex); + + SizeAndAlign parameterBufferSizeAndAlign() const; + + Size physicalGranularity() const; + + Size physicalSize(NS::UInteger layerIndex); + + Size screenSize() const; +}; + +} +_MTL_INLINE MTL::RasterizationRateSampleArray* MTL::RasterizationRateSampleArray::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLRasterizationRateSampleArray)); +} + +_MTL_INLINE MTL::RasterizationRateSampleArray* MTL::RasterizationRateSampleArray::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::Number* MTL::RasterizationRateSampleArray::object(NS::UInteger index) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(objectAtIndexedSubscript_), index); +} + +_MTL_INLINE void MTL::RasterizationRateSampleArray::setObject(const NS::Number* value, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObject_atIndexedSubscript_), value, index); +} + +_MTL_INLINE MTL::RasterizationRateLayerDescriptor* MTL::RasterizationRateLayerDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLRasterizationRateLayerDescriptor)); +} + +_MTL_INLINE MTL::RasterizationRateSampleArray* MTL::RasterizationRateLayerDescriptor::horizontal() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(horizontal)); +} + +_MTL_INLINE float* MTL::RasterizationRateLayerDescriptor::horizontalSampleStorage() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(horizontalSampleStorage)); +} + +_MTL_INLINE MTL::RasterizationRateLayerDescriptor* MTL::RasterizationRateLayerDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::RasterizationRateLayerDescriptor* MTL::RasterizationRateLayerDescriptor::init(MTL::Size sampleCount) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(initWithSampleCount_), sampleCount); +} + +_MTL_INLINE MTL::RasterizationRateLayerDescriptor* MTL::RasterizationRateLayerDescriptor::init(MTL::Size sampleCount, const float* horizontal, const float* vertical) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(initWithSampleCount_horizontal_vertical_), sampleCount, horizontal, vertical); +} + +_MTL_INLINE MTL::Size MTL::RasterizationRateLayerDescriptor::maxSampleCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxSampleCount)); +} + +_MTL_INLINE MTL::Size MTL::RasterizationRateLayerDescriptor::sampleCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(sampleCount)); +} + +_MTL_INLINE void MTL::RasterizationRateLayerDescriptor::setSampleCount(MTL::Size sampleCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSampleCount_), sampleCount); +} + +_MTL_INLINE MTL::RasterizationRateSampleArray* MTL::RasterizationRateLayerDescriptor::vertical() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(vertical)); +} + +_MTL_INLINE float* MTL::RasterizationRateLayerDescriptor::verticalSampleStorage() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(verticalSampleStorage)); +} + +_MTL_INLINE MTL::RasterizationRateLayerArray* MTL::RasterizationRateLayerArray::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLRasterizationRateLayerArray)); +} + +_MTL_INLINE MTL::RasterizationRateLayerArray* MTL::RasterizationRateLayerArray::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::RasterizationRateLayerDescriptor* MTL::RasterizationRateLayerArray::object(NS::UInteger layerIndex) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(objectAtIndexedSubscript_), layerIndex); +} + +_MTL_INLINE void MTL::RasterizationRateLayerArray::setObject(const MTL::RasterizationRateLayerDescriptor* layer, NS::UInteger layerIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObject_atIndexedSubscript_), layer, layerIndex); +} + +_MTL_INLINE MTL::RasterizationRateMapDescriptor* MTL::RasterizationRateMapDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLRasterizationRateMapDescriptor)); +} + +_MTL_INLINE MTL::RasterizationRateMapDescriptor* MTL::RasterizationRateMapDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::String* MTL::RasterizationRateMapDescriptor::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE MTL::RasterizationRateLayerDescriptor* MTL::RasterizationRateMapDescriptor::layer(NS::UInteger layerIndex) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(layerAtIndex_), layerIndex); +} + +_MTL_INLINE NS::UInteger MTL::RasterizationRateMapDescriptor::layerCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(layerCount)); +} + +_MTL_INLINE MTL::RasterizationRateLayerArray* MTL::RasterizationRateMapDescriptor::layers() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(layers)); +} + +_MTL_INLINE MTL::RasterizationRateMapDescriptor* MTL::RasterizationRateMapDescriptor::rasterizationRateMapDescriptor(MTL::Size screenSize) +{ + return Object::sendMessage(_MTL_PRIVATE_CLS(MTLRasterizationRateMapDescriptor), _MTL_PRIVATE_SEL(rasterizationRateMapDescriptorWithScreenSize_), screenSize); +} + +_MTL_INLINE MTL::RasterizationRateMapDescriptor* MTL::RasterizationRateMapDescriptor::rasterizationRateMapDescriptor(MTL::Size screenSize, const MTL::RasterizationRateLayerDescriptor* layer) +{ + return Object::sendMessage(_MTL_PRIVATE_CLS(MTLRasterizationRateMapDescriptor), _MTL_PRIVATE_SEL(rasterizationRateMapDescriptorWithScreenSize_layer_), screenSize, layer); +} + +_MTL_INLINE MTL::RasterizationRateMapDescriptor* MTL::RasterizationRateMapDescriptor::rasterizationRateMapDescriptor(MTL::Size screenSize, NS::UInteger layerCount, const MTL::RasterizationRateLayerDescriptor* const* layers) +{ + return Object::sendMessage(_MTL_PRIVATE_CLS(MTLRasterizationRateMapDescriptor), _MTL_PRIVATE_SEL(rasterizationRateMapDescriptorWithScreenSize_layerCount_layers_), screenSize, layerCount, layers); +} + +_MTL_INLINE MTL::Size MTL::RasterizationRateMapDescriptor::screenSize() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(screenSize)); +} + +_MTL_INLINE void MTL::RasterizationRateMapDescriptor::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +_MTL_INLINE void MTL::RasterizationRateMapDescriptor::setLayer(const MTL::RasterizationRateLayerDescriptor* layer, NS::UInteger layerIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLayer_atIndex_), layer, layerIndex); +} + +_MTL_INLINE void MTL::RasterizationRateMapDescriptor::setScreenSize(MTL::Size screenSize) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setScreenSize_), screenSize); +} + +_MTL_INLINE void MTL::RasterizationRateMap::copyParameterDataToBuffer(const MTL::Buffer* buffer, NS::UInteger offset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(copyParameterDataToBuffer_offset_), buffer, offset); +} + +_MTL_INLINE MTL::Device* MTL::RasterizationRateMap::device() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(device)); +} + +_MTL_INLINE NS::String* MTL::RasterizationRateMap::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE NS::UInteger MTL::RasterizationRateMap::layerCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(layerCount)); +} + +_MTL_INLINE MTL::Coordinate2D MTL::RasterizationRateMap::mapPhysicalToScreenCoordinates(MTL::Coordinate2D physicalCoordinates, NS::UInteger layerIndex) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(mapPhysicalToScreenCoordinates_forLayer_), physicalCoordinates, layerIndex); +} + +_MTL_INLINE MTL::Coordinate2D MTL::RasterizationRateMap::mapScreenToPhysicalCoordinates(MTL::Coordinate2D screenCoordinates, NS::UInteger layerIndex) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(mapScreenToPhysicalCoordinates_forLayer_), screenCoordinates, layerIndex); +} + +_MTL_INLINE MTL::SizeAndAlign MTL::RasterizationRateMap::parameterBufferSizeAndAlign() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(parameterBufferSizeAndAlign)); +} + +_MTL_INLINE MTL::Size MTL::RasterizationRateMap::physicalGranularity() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(physicalGranularity)); +} + +_MTL_INLINE MTL::Size MTL::RasterizationRateMap::physicalSize(NS::UInteger layerIndex) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(physicalSizeForLayer_), layerIndex); +} + +_MTL_INLINE MTL::Size MTL::RasterizationRateMap::screenSize() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(screenSize)); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLRenderCommandEncoder.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLRenderCommandEncoder.hpp new file mode 100644 index 00000000..b2667b77 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLRenderCommandEncoder.hpp @@ -0,0 +1,1019 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLResourceStatePass.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLArgument.hpp" +#include "MTLCommandEncoder.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLRenderPass.hpp" +#include "MTLTypes.hpp" +#include + +namespace MTL +{ +class AccelerationStructure; +class Buffer; +class CounterSampleBuffer; +class DepthStencilState; +class Fence; +class Heap; +class IndirectCommandBuffer; +class IntersectionFunctionTable; +class LogicalToPhysicalColorAttachmentMap; +class RenderPipelineState; +class Resource; +class SamplerState; +struct ScissorRect; +class Texture; +struct VertexAmplificationViewMapping; +struct Viewport; +class VisibleFunctionTable; +_MTL_ENUM(NS::UInteger, PrimitiveType) { + PrimitiveTypePoint = 0, + PrimitiveTypeLine = 1, + PrimitiveTypeLineStrip = 2, + PrimitiveTypeTriangle = 3, + PrimitiveTypeTriangleStrip = 4, +}; + +_MTL_ENUM(NS::UInteger, VisibilityResultMode) { + VisibilityResultModeDisabled = 0, + VisibilityResultModeBoolean = 1, + VisibilityResultModeCounting = 2, +}; + +_MTL_ENUM(NS::UInteger, CullMode) { + CullModeNone = 0, + CullModeFront = 1, + CullModeBack = 2, +}; + +_MTL_ENUM(NS::UInteger, Winding) { + WindingClockwise = 0, + WindingCounterClockwise = 1, +}; + +_MTL_ENUM(NS::UInteger, DepthClipMode) { + DepthClipModeClip = 0, + DepthClipModeClamp = 1, +}; + +_MTL_ENUM(NS::UInteger, TriangleFillMode) { + TriangleFillModeFill = 0, + TriangleFillModeLines = 1, +}; + +_MTL_OPTIONS(NS::UInteger, RenderStages) { + RenderStageVertex = 1, + RenderStageFragment = 1 << 1, + RenderStageTile = 1 << 2, + RenderStageObject = 1 << 3, + RenderStageMesh = 1 << 4, +}; + +struct ScissorRect +{ + NS::UInteger x; + NS::UInteger y; + NS::UInteger width; + NS::UInteger height; +} _MTL_PACKED; + +struct Viewport +{ + double originX; + double originY; + double width; + double height; + double znear; + double zfar; +} _MTL_PACKED; + +struct DrawPrimitivesIndirectArguments +{ + uint32_t vertexCount; + uint32_t instanceCount; + uint32_t vertexStart; + uint32_t baseInstance; +} _MTL_PACKED; + +struct DrawIndexedPrimitivesIndirectArguments +{ + uint32_t indexCount; + uint32_t instanceCount; + uint32_t indexStart; + int32_t baseVertex; + uint32_t baseInstance; +} _MTL_PACKED; + +struct VertexAmplificationViewMapping +{ + uint32_t viewportArrayIndexOffset; + uint32_t renderTargetArrayIndexOffset; +} _MTL_PACKED; + +struct DrawPatchIndirectArguments +{ + uint32_t patchCount; + uint32_t instanceCount; + uint32_t patchStart; + uint32_t baseInstance; +} _MTL_PACKED; + +struct QuadTessellationFactorsHalf +{ + uint16_t edgeTessellationFactor[4]; + uint16_t insideTessellationFactor[2]; +} _MTL_PACKED; + +struct TriangleTessellationFactorsHalf +{ + uint16_t edgeTessellationFactor[3]; + uint16_t insideTessellationFactor; +} _MTL_PACKED; + +class RenderCommandEncoder : public NS::Referencing +{ +public: + void dispatchThreadsPerTile(MTL::Size threadsPerTile); + + void drawIndexedPatches(NS::UInteger numberOfPatchControlPoints, NS::UInteger patchStart, NS::UInteger patchCount, const MTL::Buffer* patchIndexBuffer, NS::UInteger patchIndexBufferOffset, const MTL::Buffer* controlPointIndexBuffer, NS::UInteger controlPointIndexBufferOffset, NS::UInteger instanceCount, NS::UInteger baseInstance); + void drawIndexedPatches(NS::UInteger numberOfPatchControlPoints, const MTL::Buffer* patchIndexBuffer, NS::UInteger patchIndexBufferOffset, const MTL::Buffer* controlPointIndexBuffer, NS::UInteger controlPointIndexBufferOffset, const MTL::Buffer* indirectBuffer, NS::UInteger indirectBufferOffset); + + void drawIndexedPrimitives(MTL::PrimitiveType primitiveType, NS::UInteger indexCount, MTL::IndexType indexType, const MTL::Buffer* indexBuffer, NS::UInteger indexBufferOffset, NS::UInteger instanceCount); + void drawIndexedPrimitives(MTL::PrimitiveType primitiveType, NS::UInteger indexCount, MTL::IndexType indexType, const MTL::Buffer* indexBuffer, NS::UInteger indexBufferOffset); + void drawIndexedPrimitives(MTL::PrimitiveType primitiveType, NS::UInteger indexCount, MTL::IndexType indexType, const MTL::Buffer* indexBuffer, NS::UInteger indexBufferOffset, NS::UInteger instanceCount, NS::Integer baseVertex, NS::UInteger baseInstance); + void drawIndexedPrimitives(MTL::PrimitiveType primitiveType, MTL::IndexType indexType, const MTL::Buffer* indexBuffer, NS::UInteger indexBufferOffset, const MTL::Buffer* indirectBuffer, NS::UInteger indirectBufferOffset); + + void drawMeshThreadgroups(MTL::Size threadgroupsPerGrid, MTL::Size threadsPerObjectThreadgroup, MTL::Size threadsPerMeshThreadgroup); + void drawMeshThreadgroups(const MTL::Buffer* indirectBuffer, NS::UInteger indirectBufferOffset, MTL::Size threadsPerObjectThreadgroup, MTL::Size threadsPerMeshThreadgroup); + + void drawMeshThreads(MTL::Size threadsPerGrid, MTL::Size threadsPerObjectThreadgroup, MTL::Size threadsPerMeshThreadgroup); + + void drawPatches(NS::UInteger numberOfPatchControlPoints, NS::UInteger patchStart, NS::UInteger patchCount, const MTL::Buffer* patchIndexBuffer, NS::UInteger patchIndexBufferOffset, NS::UInteger instanceCount, NS::UInteger baseInstance); + void drawPatches(NS::UInteger numberOfPatchControlPoints, const MTL::Buffer* patchIndexBuffer, NS::UInteger patchIndexBufferOffset, const MTL::Buffer* indirectBuffer, NS::UInteger indirectBufferOffset); + + void drawPrimitives(MTL::PrimitiveType primitiveType, NS::UInteger vertexStart, NS::UInteger vertexCount, NS::UInteger instanceCount); + void drawPrimitives(MTL::PrimitiveType primitiveType, NS::UInteger vertexStart, NS::UInteger vertexCount); + void drawPrimitives(MTL::PrimitiveType primitiveType, NS::UInteger vertexStart, NS::UInteger vertexCount, NS::UInteger instanceCount, NS::UInteger baseInstance); + void drawPrimitives(MTL::PrimitiveType primitiveType, const MTL::Buffer* indirectBuffer, NS::UInteger indirectBufferOffset); + + void executeCommandsInBuffer(const MTL::IndirectCommandBuffer* indirectCommandBuffer, NS::Range executionRange); + void executeCommandsInBuffer(const MTL::IndirectCommandBuffer* indirectCommandbuffer, const MTL::Buffer* indirectRangeBuffer, NS::UInteger indirectBufferOffset); + + void memoryBarrier(MTL::BarrierScope scope, MTL::RenderStages after, MTL::RenderStages before); + void memoryBarrier(const MTL::Resource* const resources[], NS::UInteger count, MTL::RenderStages after, MTL::RenderStages before); + + void sampleCountersInBuffer(const MTL::CounterSampleBuffer* sampleBuffer, NS::UInteger sampleIndex, bool barrier); + + void setBlendColor(float red, float green, float blue, float alpha); + + void setColorAttachmentMap(const MTL::LogicalToPhysicalColorAttachmentMap* mapping); + + void setColorStoreAction(MTL::StoreAction storeAction, NS::UInteger colorAttachmentIndex); + void setColorStoreActionOptions(MTL::StoreActionOptions storeActionOptions, NS::UInteger colorAttachmentIndex); + + void setCullMode(MTL::CullMode cullMode); + + void setDepthBias(float depthBias, float slopeScale, float clamp); + + void setDepthClipMode(MTL::DepthClipMode depthClipMode); + + void setDepthStencilState(const MTL::DepthStencilState* depthStencilState); + + void setDepthStoreAction(MTL::StoreAction storeAction); + void setDepthStoreActionOptions(MTL::StoreActionOptions storeActionOptions); + + void setDepthTestBounds(float minBound, float maxBound); + + void setFragmentAccelerationStructure(const MTL::AccelerationStructure* accelerationStructure, NS::UInteger bufferIndex); + + void setFragmentBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger index); + void setFragmentBufferOffset(NS::UInteger offset, NS::UInteger index); + + void setFragmentBuffers(const MTL::Buffer* const buffers[], const NS::UInteger offsets[], NS::Range range); + + void setFragmentBytes(const void* bytes, NS::UInteger length, NS::UInteger index); + + void setFragmentIntersectionFunctionTable(const MTL::IntersectionFunctionTable* intersectionFunctionTable, NS::UInteger bufferIndex); + void setFragmentIntersectionFunctionTables(const MTL::IntersectionFunctionTable* const intersectionFunctionTables[], NS::Range range); + + void setFragmentSamplerState(const MTL::SamplerState* sampler, NS::UInteger index); + void setFragmentSamplerState(const MTL::SamplerState* sampler, float lodMinClamp, float lodMaxClamp, NS::UInteger index); + void setFragmentSamplerStates(const MTL::SamplerState* const samplers[], NS::Range range); + void setFragmentSamplerStates(const MTL::SamplerState* const samplers[], const float lodMinClamps[], const float lodMaxClamps[], NS::Range range); + + void setFragmentTexture(const MTL::Texture* texture, NS::UInteger index); + void setFragmentTextures(const MTL::Texture* const textures[], NS::Range range); + + void setFragmentVisibleFunctionTable(const MTL::VisibleFunctionTable* functionTable, NS::UInteger bufferIndex); + void setFragmentVisibleFunctionTables(const MTL::VisibleFunctionTable* const functionTables[], NS::Range range); + + void setFrontFacingWinding(MTL::Winding frontFacingWinding); + + void setMeshBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger index); + void setMeshBufferOffset(NS::UInteger offset, NS::UInteger index); + + void setMeshBuffers(const MTL::Buffer* const buffers[], const NS::UInteger* offsets, NS::Range range); + + void setMeshBytes(const void* bytes, NS::UInteger length, NS::UInteger index); + + void setMeshSamplerState(const MTL::SamplerState* sampler, NS::UInteger index); + void setMeshSamplerState(const MTL::SamplerState* sampler, float lodMinClamp, float lodMaxClamp, NS::UInteger index); + void setMeshSamplerStates(const MTL::SamplerState* const samplers[], NS::Range range); + void setMeshSamplerStates(const MTL::SamplerState* const samplers[], const float* lodMinClamps, const float* lodMaxClamps, NS::Range range); + + void setMeshTexture(const MTL::Texture* texture, NS::UInteger index); + void setMeshTextures(const MTL::Texture* const textures[], NS::Range range); + + void setObjectBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger index); + void setObjectBufferOffset(NS::UInteger offset, NS::UInteger index); + + void setObjectBuffers(const MTL::Buffer* const buffers[], const NS::UInteger* offsets, NS::Range range); + + void setObjectBytes(const void* bytes, NS::UInteger length, NS::UInteger index); + + void setObjectSamplerState(const MTL::SamplerState* sampler, NS::UInteger index); + void setObjectSamplerState(const MTL::SamplerState* sampler, float lodMinClamp, float lodMaxClamp, NS::UInteger index); + void setObjectSamplerStates(const MTL::SamplerState* const samplers[], NS::Range range); + void setObjectSamplerStates(const MTL::SamplerState* const samplers[], const float* lodMinClamps, const float* lodMaxClamps, NS::Range range); + + void setObjectTexture(const MTL::Texture* texture, NS::UInteger index); + void setObjectTextures(const MTL::Texture* const textures[], NS::Range range); + + void setObjectThreadgroupMemoryLength(NS::UInteger length, NS::UInteger index); + + void setRenderPipelineState(const MTL::RenderPipelineState* pipelineState); + + void setScissorRect(MTL::ScissorRect rect); + void setScissorRects(const MTL::ScissorRect* scissorRects, NS::UInteger count); + + void setStencilReferenceValue(uint32_t referenceValue); + void setStencilReferenceValues(uint32_t frontReferenceValue, uint32_t backReferenceValue); + + void setStencilStoreAction(MTL::StoreAction storeAction); + void setStencilStoreActionOptions(MTL::StoreActionOptions storeActionOptions); + + void setTessellationFactorBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger instanceStride); + + void setTessellationFactorScale(float scale); + + void setThreadgroupMemoryLength(NS::UInteger length, NS::UInteger offset, NS::UInteger index); + + void setTileAccelerationStructure(const MTL::AccelerationStructure* accelerationStructure, NS::UInteger bufferIndex); + + void setTileBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger index); + void setTileBufferOffset(NS::UInteger offset, NS::UInteger index); + + void setTileBuffers(const MTL::Buffer* const buffers[], const NS::UInteger* offsets, NS::Range range); + + void setTileBytes(const void* bytes, NS::UInteger length, NS::UInteger index); + + void setTileIntersectionFunctionTable(const MTL::IntersectionFunctionTable* intersectionFunctionTable, NS::UInteger bufferIndex); + void setTileIntersectionFunctionTables(const MTL::IntersectionFunctionTable* const intersectionFunctionTables[], NS::Range range); + + void setTileSamplerState(const MTL::SamplerState* sampler, NS::UInteger index); + void setTileSamplerState(const MTL::SamplerState* sampler, float lodMinClamp, float lodMaxClamp, NS::UInteger index); + void setTileSamplerStates(const MTL::SamplerState* const samplers[], NS::Range range); + void setTileSamplerStates(const MTL::SamplerState* const samplers[], const float lodMinClamps[], const float lodMaxClamps[], NS::Range range); + + void setTileTexture(const MTL::Texture* texture, NS::UInteger index); + void setTileTextures(const MTL::Texture* const textures[], NS::Range range); + + void setTileVisibleFunctionTable(const MTL::VisibleFunctionTable* functionTable, NS::UInteger bufferIndex); + void setTileVisibleFunctionTables(const MTL::VisibleFunctionTable* const functionTables[], NS::Range range); + + void setTriangleFillMode(MTL::TriangleFillMode fillMode); + + void setVertexAccelerationStructure(const MTL::AccelerationStructure* accelerationStructure, NS::UInteger bufferIndex); + + void setVertexAmplificationCount(NS::UInteger count, const MTL::VertexAmplificationViewMapping* viewMappings); + + void setVertexBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger index); + void setVertexBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger stride, NS::UInteger index); + void setVertexBufferOffset(NS::UInteger offset, NS::UInteger index); + void setVertexBufferOffset(NS::UInteger offset, NS::UInteger stride, NS::UInteger index); + + void setVertexBuffers(const MTL::Buffer* const buffers[], const NS::UInteger offsets[], NS::Range range); + void setVertexBuffers(const MTL::Buffer* const buffers[], const NS::UInteger* offsets, const NS::UInteger* strides, NS::Range range); + + void setVertexBytes(const void* bytes, NS::UInteger length, NS::UInteger index); + void setVertexBytes(const void* bytes, NS::UInteger length, NS::UInteger stride, NS::UInteger index); + + void setVertexIntersectionFunctionTable(const MTL::IntersectionFunctionTable* intersectionFunctionTable, NS::UInteger bufferIndex); + void setVertexIntersectionFunctionTables(const MTL::IntersectionFunctionTable* const intersectionFunctionTables[], NS::Range range); + + void setVertexSamplerState(const MTL::SamplerState* sampler, NS::UInteger index); + void setVertexSamplerState(const MTL::SamplerState* sampler, float lodMinClamp, float lodMaxClamp, NS::UInteger index); + void setVertexSamplerStates(const MTL::SamplerState* const samplers[], NS::Range range); + void setVertexSamplerStates(const MTL::SamplerState* const samplers[], const float lodMinClamps[], const float lodMaxClamps[], NS::Range range); + + void setVertexTexture(const MTL::Texture* texture, NS::UInteger index); + void setVertexTextures(const MTL::Texture* const textures[], NS::Range range); + + void setVertexVisibleFunctionTable(const MTL::VisibleFunctionTable* functionTable, NS::UInteger bufferIndex); + void setVertexVisibleFunctionTables(const MTL::VisibleFunctionTable* const functionTables[], NS::Range range); + + void setViewport(MTL::Viewport viewport); + void setViewports(const MTL::Viewport* viewports, NS::UInteger count); + + void setVisibilityResultMode(MTL::VisibilityResultMode mode, NS::UInteger offset); + + void textureBarrier(); + + NS::UInteger tileHeight() const; + + NS::UInteger tileWidth() const; + + void updateFence(const MTL::Fence* fence, MTL::RenderStages stages); + + void useHeap(const MTL::Heap* heap); + void useHeap(const MTL::Heap* heap, MTL::RenderStages stages); + void useHeaps(const MTL::Heap* const heaps[], NS::UInteger count); + void useHeaps(const MTL::Heap* const heaps[], NS::UInteger count, MTL::RenderStages stages); + + void useResource(const MTL::Resource* resource, MTL::ResourceUsage usage); + void useResource(const MTL::Resource* resource, MTL::ResourceUsage usage, MTL::RenderStages stages); + void useResources(const MTL::Resource* const resources[], NS::UInteger count, MTL::ResourceUsage usage); + void useResources(const MTL::Resource* const resources[], NS::UInteger count, MTL::ResourceUsage usage, MTL::RenderStages stages); + + void waitForFence(const MTL::Fence* fence, MTL::RenderStages stages); +}; + +} + +_MTL_INLINE void MTL::RenderCommandEncoder::dispatchThreadsPerTile(MTL::Size threadsPerTile) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(dispatchThreadsPerTile_), threadsPerTile); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::drawIndexedPatches(NS::UInteger numberOfPatchControlPoints, NS::UInteger patchStart, NS::UInteger patchCount, const MTL::Buffer* patchIndexBuffer, NS::UInteger patchIndexBufferOffset, const MTL::Buffer* controlPointIndexBuffer, NS::UInteger controlPointIndexBufferOffset, NS::UInteger instanceCount, NS::UInteger baseInstance) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawIndexedPatches_patchStart_patchCount_patchIndexBuffer_patchIndexBufferOffset_controlPointIndexBuffer_controlPointIndexBufferOffset_instanceCount_baseInstance_), numberOfPatchControlPoints, patchStart, patchCount, patchIndexBuffer, patchIndexBufferOffset, controlPointIndexBuffer, controlPointIndexBufferOffset, instanceCount, baseInstance); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::drawIndexedPatches(NS::UInteger numberOfPatchControlPoints, const MTL::Buffer* patchIndexBuffer, NS::UInteger patchIndexBufferOffset, const MTL::Buffer* controlPointIndexBuffer, NS::UInteger controlPointIndexBufferOffset, const MTL::Buffer* indirectBuffer, NS::UInteger indirectBufferOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawIndexedPatches_patchIndexBuffer_patchIndexBufferOffset_controlPointIndexBuffer_controlPointIndexBufferOffset_indirectBuffer_indirectBufferOffset_), numberOfPatchControlPoints, patchIndexBuffer, patchIndexBufferOffset, controlPointIndexBuffer, controlPointIndexBufferOffset, indirectBuffer, indirectBufferOffset); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::drawIndexedPrimitives(MTL::PrimitiveType primitiveType, NS::UInteger indexCount, MTL::IndexType indexType, const MTL::Buffer* indexBuffer, NS::UInteger indexBufferOffset, NS::UInteger instanceCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawIndexedPrimitives_indexCount_indexType_indexBuffer_indexBufferOffset_instanceCount_), primitiveType, indexCount, indexType, indexBuffer, indexBufferOffset, instanceCount); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::drawIndexedPrimitives(MTL::PrimitiveType primitiveType, NS::UInteger indexCount, MTL::IndexType indexType, const MTL::Buffer* indexBuffer, NS::UInteger indexBufferOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawIndexedPrimitives_indexCount_indexType_indexBuffer_indexBufferOffset_), primitiveType, indexCount, indexType, indexBuffer, indexBufferOffset); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::drawIndexedPrimitives(MTL::PrimitiveType primitiveType, NS::UInteger indexCount, MTL::IndexType indexType, const MTL::Buffer* indexBuffer, NS::UInteger indexBufferOffset, NS::UInteger instanceCount, NS::Integer baseVertex, NS::UInteger baseInstance) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawIndexedPrimitives_indexCount_indexType_indexBuffer_indexBufferOffset_instanceCount_baseVertex_baseInstance_), primitiveType, indexCount, indexType, indexBuffer, indexBufferOffset, instanceCount, baseVertex, baseInstance); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::drawIndexedPrimitives(MTL::PrimitiveType primitiveType, MTL::IndexType indexType, const MTL::Buffer* indexBuffer, NS::UInteger indexBufferOffset, const MTL::Buffer* indirectBuffer, NS::UInteger indirectBufferOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawIndexedPrimitives_indexType_indexBuffer_indexBufferOffset_indirectBuffer_indirectBufferOffset_), primitiveType, indexType, indexBuffer, indexBufferOffset, indirectBuffer, indirectBufferOffset); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::drawMeshThreadgroups(MTL::Size threadgroupsPerGrid, MTL::Size threadsPerObjectThreadgroup, MTL::Size threadsPerMeshThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawMeshThreadgroups_threadsPerObjectThreadgroup_threadsPerMeshThreadgroup_), threadgroupsPerGrid, threadsPerObjectThreadgroup, threadsPerMeshThreadgroup); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::drawMeshThreadgroups(const MTL::Buffer* indirectBuffer, NS::UInteger indirectBufferOffset, MTL::Size threadsPerObjectThreadgroup, MTL::Size threadsPerMeshThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawMeshThreadgroupsWithIndirectBuffer_indirectBufferOffset_threadsPerObjectThreadgroup_threadsPerMeshThreadgroup_), indirectBuffer, indirectBufferOffset, threadsPerObjectThreadgroup, threadsPerMeshThreadgroup); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::drawMeshThreads(MTL::Size threadsPerGrid, MTL::Size threadsPerObjectThreadgroup, MTL::Size threadsPerMeshThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawMeshThreads_threadsPerObjectThreadgroup_threadsPerMeshThreadgroup_), threadsPerGrid, threadsPerObjectThreadgroup, threadsPerMeshThreadgroup); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::drawPatches(NS::UInteger numberOfPatchControlPoints, NS::UInteger patchStart, NS::UInteger patchCount, const MTL::Buffer* patchIndexBuffer, NS::UInteger patchIndexBufferOffset, NS::UInteger instanceCount, NS::UInteger baseInstance) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawPatches_patchStart_patchCount_patchIndexBuffer_patchIndexBufferOffset_instanceCount_baseInstance_), numberOfPatchControlPoints, patchStart, patchCount, patchIndexBuffer, patchIndexBufferOffset, instanceCount, baseInstance); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::drawPatches(NS::UInteger numberOfPatchControlPoints, const MTL::Buffer* patchIndexBuffer, NS::UInteger patchIndexBufferOffset, const MTL::Buffer* indirectBuffer, NS::UInteger indirectBufferOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawPatches_patchIndexBuffer_patchIndexBufferOffset_indirectBuffer_indirectBufferOffset_), numberOfPatchControlPoints, patchIndexBuffer, patchIndexBufferOffset, indirectBuffer, indirectBufferOffset); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::drawPrimitives(MTL::PrimitiveType primitiveType, NS::UInteger vertexStart, NS::UInteger vertexCount, NS::UInteger instanceCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawPrimitives_vertexStart_vertexCount_instanceCount_), primitiveType, vertexStart, vertexCount, instanceCount); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::drawPrimitives(MTL::PrimitiveType primitiveType, NS::UInteger vertexStart, NS::UInteger vertexCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawPrimitives_vertexStart_vertexCount_), primitiveType, vertexStart, vertexCount); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::drawPrimitives(MTL::PrimitiveType primitiveType, NS::UInteger vertexStart, NS::UInteger vertexCount, NS::UInteger instanceCount, NS::UInteger baseInstance) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawPrimitives_vertexStart_vertexCount_instanceCount_baseInstance_), primitiveType, vertexStart, vertexCount, instanceCount, baseInstance); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::drawPrimitives(MTL::PrimitiveType primitiveType, const MTL::Buffer* indirectBuffer, NS::UInteger indirectBufferOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawPrimitives_indirectBuffer_indirectBufferOffset_), primitiveType, indirectBuffer, indirectBufferOffset); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::executeCommandsInBuffer(const MTL::IndirectCommandBuffer* indirectCommandBuffer, NS::Range executionRange) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(executeCommandsInBuffer_withRange_), indirectCommandBuffer, executionRange); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::executeCommandsInBuffer(const MTL::IndirectCommandBuffer* indirectCommandbuffer, const MTL::Buffer* indirectRangeBuffer, NS::UInteger indirectBufferOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(executeCommandsInBuffer_indirectBuffer_indirectBufferOffset_), indirectCommandbuffer, indirectRangeBuffer, indirectBufferOffset); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::memoryBarrier(MTL::BarrierScope scope, MTL::RenderStages after, MTL::RenderStages before) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(memoryBarrierWithScope_afterStages_beforeStages_), scope, after, before); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::memoryBarrier(const MTL::Resource* const resources[], NS::UInteger count, MTL::RenderStages after, MTL::RenderStages before) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(memoryBarrierWithResources_count_afterStages_beforeStages_), resources, count, after, before); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::sampleCountersInBuffer(const MTL::CounterSampleBuffer* sampleBuffer, NS::UInteger sampleIndex, bool barrier) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(sampleCountersInBuffer_atSampleIndex_withBarrier_), sampleBuffer, sampleIndex, barrier); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setBlendColor(float red, float green, float blue, float alpha) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBlendColorRed_green_blue_alpha_), red, green, blue, alpha); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setColorAttachmentMap(const MTL::LogicalToPhysicalColorAttachmentMap* mapping) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setColorAttachmentMap_), mapping); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setColorStoreAction(MTL::StoreAction storeAction, NS::UInteger colorAttachmentIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setColorStoreAction_atIndex_), storeAction, colorAttachmentIndex); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setColorStoreActionOptions(MTL::StoreActionOptions storeActionOptions, NS::UInteger colorAttachmentIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setColorStoreActionOptions_atIndex_), storeActionOptions, colorAttachmentIndex); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setCullMode(MTL::CullMode cullMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setCullMode_), cullMode); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setDepthBias(float depthBias, float slopeScale, float clamp) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDepthBias_slopeScale_clamp_), depthBias, slopeScale, clamp); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setDepthClipMode(MTL::DepthClipMode depthClipMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDepthClipMode_), depthClipMode); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setDepthStencilState(const MTL::DepthStencilState* depthStencilState) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDepthStencilState_), depthStencilState); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setDepthStoreAction(MTL::StoreAction storeAction) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDepthStoreAction_), storeAction); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setDepthStoreActionOptions(MTL::StoreActionOptions storeActionOptions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDepthStoreActionOptions_), storeActionOptions); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setDepthTestBounds(float minBound, float maxBound) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDepthTestMinBound_maxBound_), minBound, maxBound); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setFragmentAccelerationStructure(const MTL::AccelerationStructure* accelerationStructure, NS::UInteger bufferIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFragmentAccelerationStructure_atBufferIndex_), accelerationStructure, bufferIndex); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setFragmentBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFragmentBuffer_offset_atIndex_), buffer, offset, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setFragmentBufferOffset(NS::UInteger offset, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFragmentBufferOffset_atIndex_), offset, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setFragmentBuffers(const MTL::Buffer* const buffers[], const NS::UInteger offsets[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFragmentBuffers_offsets_withRange_), buffers, offsets, range); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setFragmentBytes(const void* bytes, NS::UInteger length, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFragmentBytes_length_atIndex_), bytes, length, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setFragmentIntersectionFunctionTable(const MTL::IntersectionFunctionTable* intersectionFunctionTable, NS::UInteger bufferIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFragmentIntersectionFunctionTable_atBufferIndex_), intersectionFunctionTable, bufferIndex); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setFragmentIntersectionFunctionTables(const MTL::IntersectionFunctionTable* const intersectionFunctionTables[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFragmentIntersectionFunctionTables_withBufferRange_), intersectionFunctionTables, range); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setFragmentSamplerState(const MTL::SamplerState* sampler, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFragmentSamplerState_atIndex_), sampler, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setFragmentSamplerState(const MTL::SamplerState* sampler, float lodMinClamp, float lodMaxClamp, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFragmentSamplerState_lodMinClamp_lodMaxClamp_atIndex_), sampler, lodMinClamp, lodMaxClamp, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setFragmentSamplerStates(const MTL::SamplerState* const samplers[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFragmentSamplerStates_withRange_), samplers, range); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setFragmentSamplerStates(const MTL::SamplerState* const samplers[], const float lodMinClamps[], const float lodMaxClamps[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFragmentSamplerStates_lodMinClamps_lodMaxClamps_withRange_), samplers, lodMinClamps, lodMaxClamps, range); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setFragmentTexture(const MTL::Texture* texture, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFragmentTexture_atIndex_), texture, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setFragmentTextures(const MTL::Texture* const textures[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFragmentTextures_withRange_), textures, range); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setFragmentVisibleFunctionTable(const MTL::VisibleFunctionTable* functionTable, NS::UInteger bufferIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFragmentVisibleFunctionTable_atBufferIndex_), functionTable, bufferIndex); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setFragmentVisibleFunctionTables(const MTL::VisibleFunctionTable* const functionTables[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFragmentVisibleFunctionTables_withBufferRange_), functionTables, range); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setFrontFacingWinding(MTL::Winding frontFacingWinding) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFrontFacingWinding_), frontFacingWinding); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setMeshBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMeshBuffer_offset_atIndex_), buffer, offset, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setMeshBufferOffset(NS::UInteger offset, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMeshBufferOffset_atIndex_), offset, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setMeshBuffers(const MTL::Buffer* const buffers[], const NS::UInteger* offsets, NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMeshBuffers_offsets_withRange_), buffers, offsets, range); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setMeshBytes(const void* bytes, NS::UInteger length, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMeshBytes_length_atIndex_), bytes, length, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setMeshSamplerState(const MTL::SamplerState* sampler, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMeshSamplerState_atIndex_), sampler, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setMeshSamplerState(const MTL::SamplerState* sampler, float lodMinClamp, float lodMaxClamp, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMeshSamplerState_lodMinClamp_lodMaxClamp_atIndex_), sampler, lodMinClamp, lodMaxClamp, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setMeshSamplerStates(const MTL::SamplerState* const samplers[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMeshSamplerStates_withRange_), samplers, range); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setMeshSamplerStates(const MTL::SamplerState* const samplers[], const float* lodMinClamps, const float* lodMaxClamps, NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMeshSamplerStates_lodMinClamps_lodMaxClamps_withRange_), samplers, lodMinClamps, lodMaxClamps, range); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setMeshTexture(const MTL::Texture* texture, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMeshTexture_atIndex_), texture, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setMeshTextures(const MTL::Texture* const textures[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMeshTextures_withRange_), textures, range); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setObjectBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObjectBuffer_offset_atIndex_), buffer, offset, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setObjectBufferOffset(NS::UInteger offset, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObjectBufferOffset_atIndex_), offset, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setObjectBuffers(const MTL::Buffer* const buffers[], const NS::UInteger* offsets, NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObjectBuffers_offsets_withRange_), buffers, offsets, range); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setObjectBytes(const void* bytes, NS::UInteger length, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObjectBytes_length_atIndex_), bytes, length, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setObjectSamplerState(const MTL::SamplerState* sampler, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObjectSamplerState_atIndex_), sampler, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setObjectSamplerState(const MTL::SamplerState* sampler, float lodMinClamp, float lodMaxClamp, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObjectSamplerState_lodMinClamp_lodMaxClamp_atIndex_), sampler, lodMinClamp, lodMaxClamp, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setObjectSamplerStates(const MTL::SamplerState* const samplers[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObjectSamplerStates_withRange_), samplers, range); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setObjectSamplerStates(const MTL::SamplerState* const samplers[], const float* lodMinClamps, const float* lodMaxClamps, NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObjectSamplerStates_lodMinClamps_lodMaxClamps_withRange_), samplers, lodMinClamps, lodMaxClamps, range); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setObjectTexture(const MTL::Texture* texture, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObjectTexture_atIndex_), texture, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setObjectTextures(const MTL::Texture* const textures[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObjectTextures_withRange_), textures, range); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setObjectThreadgroupMemoryLength(NS::UInteger length, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObjectThreadgroupMemoryLength_atIndex_), length, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setRenderPipelineState(const MTL::RenderPipelineState* pipelineState) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRenderPipelineState_), pipelineState); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setScissorRect(MTL::ScissorRect rect) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setScissorRect_), rect); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setScissorRects(const MTL::ScissorRect* scissorRects, NS::UInteger count) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setScissorRects_count_), scissorRects, count); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setStencilReferenceValue(uint32_t referenceValue) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStencilReferenceValue_), referenceValue); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setStencilReferenceValues(uint32_t frontReferenceValue, uint32_t backReferenceValue) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStencilFrontReferenceValue_backReferenceValue_), frontReferenceValue, backReferenceValue); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setStencilStoreAction(MTL::StoreAction storeAction) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStencilStoreAction_), storeAction); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setStencilStoreActionOptions(MTL::StoreActionOptions storeActionOptions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStencilStoreActionOptions_), storeActionOptions); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setTessellationFactorBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger instanceStride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTessellationFactorBuffer_offset_instanceStride_), buffer, offset, instanceStride); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setTessellationFactorScale(float scale) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTessellationFactorScale_), scale); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setThreadgroupMemoryLength(NS::UInteger length, NS::UInteger offset, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setThreadgroupMemoryLength_offset_atIndex_), length, offset, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setTileAccelerationStructure(const MTL::AccelerationStructure* accelerationStructure, NS::UInteger bufferIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTileAccelerationStructure_atBufferIndex_), accelerationStructure, bufferIndex); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setTileBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTileBuffer_offset_atIndex_), buffer, offset, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setTileBufferOffset(NS::UInteger offset, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTileBufferOffset_atIndex_), offset, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setTileBuffers(const MTL::Buffer* const buffers[], const NS::UInteger* offsets, NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTileBuffers_offsets_withRange_), buffers, offsets, range); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setTileBytes(const void* bytes, NS::UInteger length, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTileBytes_length_atIndex_), bytes, length, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setTileIntersectionFunctionTable(const MTL::IntersectionFunctionTable* intersectionFunctionTable, NS::UInteger bufferIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTileIntersectionFunctionTable_atBufferIndex_), intersectionFunctionTable, bufferIndex); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setTileIntersectionFunctionTables(const MTL::IntersectionFunctionTable* const intersectionFunctionTables[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTileIntersectionFunctionTables_withBufferRange_), intersectionFunctionTables, range); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setTileSamplerState(const MTL::SamplerState* sampler, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTileSamplerState_atIndex_), sampler, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setTileSamplerState(const MTL::SamplerState* sampler, float lodMinClamp, float lodMaxClamp, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTileSamplerState_lodMinClamp_lodMaxClamp_atIndex_), sampler, lodMinClamp, lodMaxClamp, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setTileSamplerStates(const MTL::SamplerState* const samplers[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTileSamplerStates_withRange_), samplers, range); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setTileSamplerStates(const MTL::SamplerState* const samplers[], const float lodMinClamps[], const float lodMaxClamps[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTileSamplerStates_lodMinClamps_lodMaxClamps_withRange_), samplers, lodMinClamps, lodMaxClamps, range); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setTileTexture(const MTL::Texture* texture, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTileTexture_atIndex_), texture, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setTileTextures(const MTL::Texture* const textures[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTileTextures_withRange_), textures, range); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setTileVisibleFunctionTable(const MTL::VisibleFunctionTable* functionTable, NS::UInteger bufferIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTileVisibleFunctionTable_atBufferIndex_), functionTable, bufferIndex); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setTileVisibleFunctionTables(const MTL::VisibleFunctionTable* const functionTables[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTileVisibleFunctionTables_withBufferRange_), functionTables, range); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setTriangleFillMode(MTL::TriangleFillMode fillMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTriangleFillMode_), fillMode); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setVertexAccelerationStructure(const MTL::AccelerationStructure* accelerationStructure, NS::UInteger bufferIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexAccelerationStructure_atBufferIndex_), accelerationStructure, bufferIndex); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setVertexAmplificationCount(NS::UInteger count, const MTL::VertexAmplificationViewMapping* viewMappings) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexAmplificationCount_viewMappings_), count, viewMappings); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setVertexBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexBuffer_offset_atIndex_), buffer, offset, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setVertexBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger stride, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexBuffer_offset_attributeStride_atIndex_), buffer, offset, stride, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setVertexBufferOffset(NS::UInteger offset, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexBufferOffset_atIndex_), offset, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setVertexBufferOffset(NS::UInteger offset, NS::UInteger stride, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexBufferOffset_attributeStride_atIndex_), offset, stride, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setVertexBuffers(const MTL::Buffer* const buffers[], const NS::UInteger offsets[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexBuffers_offsets_withRange_), buffers, offsets, range); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setVertexBuffers(const MTL::Buffer* const buffers[], const NS::UInteger* offsets, const NS::UInteger* strides, NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexBuffers_offsets_attributeStrides_withRange_), buffers, offsets, strides, range); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setVertexBytes(const void* bytes, NS::UInteger length, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexBytes_length_atIndex_), bytes, length, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setVertexBytes(const void* bytes, NS::UInteger length, NS::UInteger stride, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexBytes_length_attributeStride_atIndex_), bytes, length, stride, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setVertexIntersectionFunctionTable(const MTL::IntersectionFunctionTable* intersectionFunctionTable, NS::UInteger bufferIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexIntersectionFunctionTable_atBufferIndex_), intersectionFunctionTable, bufferIndex); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setVertexIntersectionFunctionTables(const MTL::IntersectionFunctionTable* const intersectionFunctionTables[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexIntersectionFunctionTables_withBufferRange_), intersectionFunctionTables, range); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setVertexSamplerState(const MTL::SamplerState* sampler, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexSamplerState_atIndex_), sampler, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setVertexSamplerState(const MTL::SamplerState* sampler, float lodMinClamp, float lodMaxClamp, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexSamplerState_lodMinClamp_lodMaxClamp_atIndex_), sampler, lodMinClamp, lodMaxClamp, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setVertexSamplerStates(const MTL::SamplerState* const samplers[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexSamplerStates_withRange_), samplers, range); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setVertexSamplerStates(const MTL::SamplerState* const samplers[], const float lodMinClamps[], const float lodMaxClamps[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexSamplerStates_lodMinClamps_lodMaxClamps_withRange_), samplers, lodMinClamps, lodMaxClamps, range); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setVertexTexture(const MTL::Texture* texture, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexTexture_atIndex_), texture, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setVertexTextures(const MTL::Texture* const textures[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexTextures_withRange_), textures, range); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setVertexVisibleFunctionTable(const MTL::VisibleFunctionTable* functionTable, NS::UInteger bufferIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexVisibleFunctionTable_atBufferIndex_), functionTable, bufferIndex); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setVertexVisibleFunctionTables(const MTL::VisibleFunctionTable* const functionTables[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexVisibleFunctionTables_withBufferRange_), functionTables, range); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setViewport(MTL::Viewport viewport) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setViewport_), viewport); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setViewports(const MTL::Viewport* viewports, NS::UInteger count) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setViewports_count_), viewports, count); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setVisibilityResultMode(MTL::VisibilityResultMode mode, NS::UInteger offset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVisibilityResultMode_offset_), mode, offset); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::textureBarrier() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(textureBarrier)); +} + +_MTL_INLINE NS::UInteger MTL::RenderCommandEncoder::tileHeight() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(tileHeight)); +} + +_MTL_INLINE NS::UInteger MTL::RenderCommandEncoder::tileWidth() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(tileWidth)); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::updateFence(const MTL::Fence* fence, MTL::RenderStages stages) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(updateFence_afterStages_), fence, stages); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::useHeap(const MTL::Heap* heap) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(useHeap_), heap); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::useHeap(const MTL::Heap* heap, MTL::RenderStages stages) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(useHeap_stages_), heap, stages); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::useHeaps(const MTL::Heap* const heaps[], NS::UInteger count) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(useHeaps_count_), heaps, count); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::useHeaps(const MTL::Heap* const heaps[], NS::UInteger count, MTL::RenderStages stages) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(useHeaps_count_stages_), heaps, count, stages); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::useResource(const MTL::Resource* resource, MTL::ResourceUsage usage) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(useResource_usage_), resource, usage); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::useResource(const MTL::Resource* resource, MTL::ResourceUsage usage, MTL::RenderStages stages) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(useResource_usage_stages_), resource, usage, stages); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::useResources(const MTL::Resource* const resources[], NS::UInteger count, MTL::ResourceUsage usage) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(useResources_count_usage_), resources, count, usage); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::useResources(const MTL::Resource* const resources[], NS::UInteger count, MTL::ResourceUsage usage, MTL::RenderStages stages) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(useResources_count_usage_stages_), resources, count, usage, stages); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::waitForFence(const MTL::Fence* fence, MTL::RenderStages stages) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(waitForFence_beforeStages_), fence, stages); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLRenderPass.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLRenderPass.hpp new file mode 100644 index 00000000..ed2172d7 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLRenderPass.hpp @@ -0,0 +1,792 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLRenderPass.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include + +namespace MTL +{ +class Buffer; +class CounterSampleBuffer; +class RasterizationRateMap; +class RenderPassAttachmentDescriptor; +class RenderPassColorAttachmentDescriptor; +class RenderPassColorAttachmentDescriptorArray; +class RenderPassDepthAttachmentDescriptor; +class RenderPassDescriptor; +class RenderPassSampleBufferAttachmentDescriptor; +class RenderPassSampleBufferAttachmentDescriptorArray; +class RenderPassStencilAttachmentDescriptor; +struct SamplePosition; +class Texture; +_MTL_ENUM(NS::UInteger, LoadAction) { + LoadActionDontCare = 0, + LoadActionLoad = 1, + LoadActionClear = 2, +}; + +_MTL_ENUM(NS::UInteger, StoreAction) { + StoreActionDontCare = 0, + StoreActionStore = 1, + StoreActionMultisampleResolve = 2, + StoreActionStoreAndMultisampleResolve = 3, + StoreActionUnknown = 4, + StoreActionCustomSampleDepthStore = 5, +}; + +_MTL_ENUM(NS::Integer, VisibilityResultType) { + VisibilityResultTypeReset = 0, + VisibilityResultTypeAccumulate = 1, +}; + +_MTL_ENUM(NS::UInteger, MultisampleDepthResolveFilter) { + MultisampleDepthResolveFilterSample0 = 0, + MultisampleDepthResolveFilterMin = 1, + MultisampleDepthResolveFilterMax = 2, +}; + +_MTL_ENUM(NS::UInteger, MultisampleStencilResolveFilter) { + MultisampleStencilResolveFilterSample0 = 0, + MultisampleStencilResolveFilterDepthResolvedSample = 1, +}; + +_MTL_OPTIONS(NS::UInteger, StoreActionOptions) { + StoreActionOptionNone = 0, + StoreActionOptionCustomSamplePositions = 1, + StoreActionOptionValidMask = 1, +}; + +struct ClearColor +{ + ClearColor() = default; + + ClearColor(double red, double green, double blue, double alpha); + + static ClearColor Make(double red, double green, double blue, double alpha); + + double red; + double green; + double blue; + double alpha; +} _MTL_PACKED; + +class RenderPassAttachmentDescriptor : public NS::Copying +{ +public: + static RenderPassAttachmentDescriptor* alloc(); + + NS::UInteger depthPlane() const; + + RenderPassAttachmentDescriptor* init(); + + NS::UInteger level() const; + + LoadAction loadAction() const; + + NS::UInteger resolveDepthPlane() const; + + NS::UInteger resolveLevel() const; + + NS::UInteger resolveSlice() const; + + Texture* resolveTexture() const; + + void setDepthPlane(NS::UInteger depthPlane); + + void setLevel(NS::UInteger level); + + void setLoadAction(MTL::LoadAction loadAction); + + void setResolveDepthPlane(NS::UInteger resolveDepthPlane); + + void setResolveLevel(NS::UInteger resolveLevel); + + void setResolveSlice(NS::UInteger resolveSlice); + + void setResolveTexture(const MTL::Texture* resolveTexture); + + void setSlice(NS::UInteger slice); + + void setStoreAction(MTL::StoreAction storeAction); + void setStoreActionOptions(MTL::StoreActionOptions storeActionOptions); + + void setTexture(const MTL::Texture* texture); + + NS::UInteger slice() const; + + StoreAction storeAction() const; + StoreActionOptions storeActionOptions() const; + + Texture* texture() const; +}; +class RenderPassColorAttachmentDescriptor : public NS::Copying +{ +public: + static RenderPassColorAttachmentDescriptor* alloc(); + + ClearColor clearColor() const; + + RenderPassColorAttachmentDescriptor* init(); + + void setClearColor(MTL::ClearColor clearColor); +}; +class RenderPassDepthAttachmentDescriptor : public NS::Copying +{ +public: + static RenderPassDepthAttachmentDescriptor* alloc(); + + double clearDepth() const; + + MultisampleDepthResolveFilter depthResolveFilter() const; + + RenderPassDepthAttachmentDescriptor* init(); + + void setClearDepth(double clearDepth); + + void setDepthResolveFilter(MTL::MultisampleDepthResolveFilter depthResolveFilter); +}; +class RenderPassStencilAttachmentDescriptor : public NS::Copying +{ +public: + static RenderPassStencilAttachmentDescriptor* alloc(); + + uint32_t clearStencil() const; + + RenderPassStencilAttachmentDescriptor* init(); + + void setClearStencil(uint32_t clearStencil); + + void setStencilResolveFilter(MTL::MultisampleStencilResolveFilter stencilResolveFilter); + MultisampleStencilResolveFilter stencilResolveFilter() const; +}; +class RenderPassColorAttachmentDescriptorArray : public NS::Referencing +{ +public: + static RenderPassColorAttachmentDescriptorArray* alloc(); + + RenderPassColorAttachmentDescriptorArray* init(); + + RenderPassColorAttachmentDescriptor* object(NS::UInteger attachmentIndex); + void setObject(const MTL::RenderPassColorAttachmentDescriptor* attachment, NS::UInteger attachmentIndex); +}; +class RenderPassSampleBufferAttachmentDescriptor : public NS::Copying +{ +public: + static RenderPassSampleBufferAttachmentDescriptor* alloc(); + + NS::UInteger endOfFragmentSampleIndex() const; + + NS::UInteger endOfVertexSampleIndex() const; + + RenderPassSampleBufferAttachmentDescriptor* init(); + + CounterSampleBuffer* sampleBuffer() const; + + void setEndOfFragmentSampleIndex(NS::UInteger endOfFragmentSampleIndex); + + void setEndOfVertexSampleIndex(NS::UInteger endOfVertexSampleIndex); + + void setSampleBuffer(const MTL::CounterSampleBuffer* sampleBuffer); + + void setStartOfFragmentSampleIndex(NS::UInteger startOfFragmentSampleIndex); + + void setStartOfVertexSampleIndex(NS::UInteger startOfVertexSampleIndex); + + NS::UInteger startOfFragmentSampleIndex() const; + + NS::UInteger startOfVertexSampleIndex() const; +}; +class RenderPassSampleBufferAttachmentDescriptorArray : public NS::Referencing +{ +public: + static RenderPassSampleBufferAttachmentDescriptorArray* alloc(); + + RenderPassSampleBufferAttachmentDescriptorArray* init(); + + RenderPassSampleBufferAttachmentDescriptor* object(NS::UInteger attachmentIndex); + void setObject(const MTL::RenderPassSampleBufferAttachmentDescriptor* attachment, NS::UInteger attachmentIndex); +}; +class RenderPassDescriptor : public NS::Copying +{ +public: + static RenderPassDescriptor* alloc(); + + RenderPassColorAttachmentDescriptorArray* colorAttachments() const; + + NS::UInteger defaultRasterSampleCount() const; + + RenderPassDepthAttachmentDescriptor* depthAttachment() const; + + NS::UInteger getSamplePositions(MTL::SamplePosition* positions, NS::UInteger count); + + NS::UInteger imageblockSampleLength() const; + + RenderPassDescriptor* init(); + + RasterizationRateMap* rasterizationRateMap() const; + + static RenderPassDescriptor* renderPassDescriptor(); + + NS::UInteger renderTargetArrayLength() const; + + NS::UInteger renderTargetHeight() const; + + NS::UInteger renderTargetWidth() const; + + RenderPassSampleBufferAttachmentDescriptorArray* sampleBufferAttachments() const; + + void setDefaultRasterSampleCount(NS::UInteger defaultRasterSampleCount); + + void setDepthAttachment(const MTL::RenderPassDepthAttachmentDescriptor* depthAttachment); + + void setImageblockSampleLength(NS::UInteger imageblockSampleLength); + + void setRasterizationRateMap(const MTL::RasterizationRateMap* rasterizationRateMap); + + void setRenderTargetArrayLength(NS::UInteger renderTargetArrayLength); + + void setRenderTargetHeight(NS::UInteger renderTargetHeight); + + void setRenderTargetWidth(NS::UInteger renderTargetWidth); + + void setSamplePositions(const MTL::SamplePosition* positions, NS::UInteger count); + + void setStencilAttachment(const MTL::RenderPassStencilAttachmentDescriptor* stencilAttachment); + + void setSupportColorAttachmentMapping(bool supportColorAttachmentMapping); + + void setThreadgroupMemoryLength(NS::UInteger threadgroupMemoryLength); + + void setTileHeight(NS::UInteger tileHeight); + + void setTileWidth(NS::UInteger tileWidth); + + void setVisibilityResultBuffer(const MTL::Buffer* visibilityResultBuffer); + + void setVisibilityResultType(MTL::VisibilityResultType visibilityResultType); + + RenderPassStencilAttachmentDescriptor* stencilAttachment() const; + + bool supportColorAttachmentMapping() const; + + NS::UInteger threadgroupMemoryLength() const; + + NS::UInteger tileHeight() const; + + NS::UInteger tileWidth() const; + + Buffer* visibilityResultBuffer() const; + + VisibilityResultType visibilityResultType() const; +}; + +} +_MTL_INLINE MTL::ClearColor::ClearColor(double red, double green, double blue, double alpha) + : red(red) + , green(green) + , blue(blue) + , alpha(alpha) +{ +} + +_MTL_INLINE MTL::ClearColor MTL::ClearColor::Make(double red, double green, double blue, double alpha) +{ + return ClearColor(red, green, blue, alpha); +} + +_MTL_INLINE MTL::RenderPassAttachmentDescriptor* MTL::RenderPassAttachmentDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLRenderPassAttachmentDescriptor)); +} + +_MTL_INLINE NS::UInteger MTL::RenderPassAttachmentDescriptor::depthPlane() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(depthPlane)); +} + +_MTL_INLINE MTL::RenderPassAttachmentDescriptor* MTL::RenderPassAttachmentDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::UInteger MTL::RenderPassAttachmentDescriptor::level() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(level)); +} + +_MTL_INLINE MTL::LoadAction MTL::RenderPassAttachmentDescriptor::loadAction() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(loadAction)); +} + +_MTL_INLINE NS::UInteger MTL::RenderPassAttachmentDescriptor::resolveDepthPlane() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(resolveDepthPlane)); +} + +_MTL_INLINE NS::UInteger MTL::RenderPassAttachmentDescriptor::resolveLevel() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(resolveLevel)); +} + +_MTL_INLINE NS::UInteger MTL::RenderPassAttachmentDescriptor::resolveSlice() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(resolveSlice)); +} + +_MTL_INLINE MTL::Texture* MTL::RenderPassAttachmentDescriptor::resolveTexture() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(resolveTexture)); +} + +_MTL_INLINE void MTL::RenderPassAttachmentDescriptor::setDepthPlane(NS::UInteger depthPlane) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDepthPlane_), depthPlane); +} + +_MTL_INLINE void MTL::RenderPassAttachmentDescriptor::setLevel(NS::UInteger level) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLevel_), level); +} + +_MTL_INLINE void MTL::RenderPassAttachmentDescriptor::setLoadAction(MTL::LoadAction loadAction) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLoadAction_), loadAction); +} + +_MTL_INLINE void MTL::RenderPassAttachmentDescriptor::setResolveDepthPlane(NS::UInteger resolveDepthPlane) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setResolveDepthPlane_), resolveDepthPlane); +} + +_MTL_INLINE void MTL::RenderPassAttachmentDescriptor::setResolveLevel(NS::UInteger resolveLevel) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setResolveLevel_), resolveLevel); +} + +_MTL_INLINE void MTL::RenderPassAttachmentDescriptor::setResolveSlice(NS::UInteger resolveSlice) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setResolveSlice_), resolveSlice); +} + +_MTL_INLINE void MTL::RenderPassAttachmentDescriptor::setResolveTexture(const MTL::Texture* resolveTexture) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setResolveTexture_), resolveTexture); +} + +_MTL_INLINE void MTL::RenderPassAttachmentDescriptor::setSlice(NS::UInteger slice) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSlice_), slice); +} + +_MTL_INLINE void MTL::RenderPassAttachmentDescriptor::setStoreAction(MTL::StoreAction storeAction) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStoreAction_), storeAction); +} + +_MTL_INLINE void MTL::RenderPassAttachmentDescriptor::setStoreActionOptions(MTL::StoreActionOptions storeActionOptions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStoreActionOptions_), storeActionOptions); +} + +_MTL_INLINE void MTL::RenderPassAttachmentDescriptor::setTexture(const MTL::Texture* texture) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTexture_), texture); +} + +_MTL_INLINE NS::UInteger MTL::RenderPassAttachmentDescriptor::slice() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(slice)); +} + +_MTL_INLINE MTL::StoreAction MTL::RenderPassAttachmentDescriptor::storeAction() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(storeAction)); +} + +_MTL_INLINE MTL::StoreActionOptions MTL::RenderPassAttachmentDescriptor::storeActionOptions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(storeActionOptions)); +} + +_MTL_INLINE MTL::Texture* MTL::RenderPassAttachmentDescriptor::texture() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(texture)); +} + +_MTL_INLINE MTL::RenderPassColorAttachmentDescriptor* MTL::RenderPassColorAttachmentDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLRenderPassColorAttachmentDescriptor)); +} + +_MTL_INLINE MTL::ClearColor MTL::RenderPassColorAttachmentDescriptor::clearColor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(clearColor)); +} + +_MTL_INLINE MTL::RenderPassColorAttachmentDescriptor* MTL::RenderPassColorAttachmentDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE void MTL::RenderPassColorAttachmentDescriptor::setClearColor(MTL::ClearColor clearColor) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setClearColor_), clearColor); +} + +_MTL_INLINE MTL::RenderPassDepthAttachmentDescriptor* MTL::RenderPassDepthAttachmentDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLRenderPassDepthAttachmentDescriptor)); +} + +_MTL_INLINE double MTL::RenderPassDepthAttachmentDescriptor::clearDepth() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(clearDepth)); +} + +_MTL_INLINE MTL::MultisampleDepthResolveFilter MTL::RenderPassDepthAttachmentDescriptor::depthResolveFilter() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(depthResolveFilter)); +} + +_MTL_INLINE MTL::RenderPassDepthAttachmentDescriptor* MTL::RenderPassDepthAttachmentDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE void MTL::RenderPassDepthAttachmentDescriptor::setClearDepth(double clearDepth) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setClearDepth_), clearDepth); +} + +_MTL_INLINE void MTL::RenderPassDepthAttachmentDescriptor::setDepthResolveFilter(MTL::MultisampleDepthResolveFilter depthResolveFilter) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDepthResolveFilter_), depthResolveFilter); +} + +_MTL_INLINE MTL::RenderPassStencilAttachmentDescriptor* MTL::RenderPassStencilAttachmentDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLRenderPassStencilAttachmentDescriptor)); +} + +_MTL_INLINE uint32_t MTL::RenderPassStencilAttachmentDescriptor::clearStencil() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(clearStencil)); +} + +_MTL_INLINE MTL::RenderPassStencilAttachmentDescriptor* MTL::RenderPassStencilAttachmentDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE void MTL::RenderPassStencilAttachmentDescriptor::setClearStencil(uint32_t clearStencil) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setClearStencil_), clearStencil); +} + +_MTL_INLINE void MTL::RenderPassStencilAttachmentDescriptor::setStencilResolveFilter(MTL::MultisampleStencilResolveFilter stencilResolveFilter) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStencilResolveFilter_), stencilResolveFilter); +} + +_MTL_INLINE MTL::MultisampleStencilResolveFilter MTL::RenderPassStencilAttachmentDescriptor::stencilResolveFilter() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(stencilResolveFilter)); +} + +_MTL_INLINE MTL::RenderPassColorAttachmentDescriptorArray* MTL::RenderPassColorAttachmentDescriptorArray::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLRenderPassColorAttachmentDescriptorArray)); +} + +_MTL_INLINE MTL::RenderPassColorAttachmentDescriptorArray* MTL::RenderPassColorAttachmentDescriptorArray::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::RenderPassColorAttachmentDescriptor* MTL::RenderPassColorAttachmentDescriptorArray::object(NS::UInteger attachmentIndex) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(objectAtIndexedSubscript_), attachmentIndex); +} + +_MTL_INLINE void MTL::RenderPassColorAttachmentDescriptorArray::setObject(const MTL::RenderPassColorAttachmentDescriptor* attachment, NS::UInteger attachmentIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObject_atIndexedSubscript_), attachment, attachmentIndex); +} + +_MTL_INLINE MTL::RenderPassSampleBufferAttachmentDescriptor* MTL::RenderPassSampleBufferAttachmentDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLRenderPassSampleBufferAttachmentDescriptor)); +} + +_MTL_INLINE NS::UInteger MTL::RenderPassSampleBufferAttachmentDescriptor::endOfFragmentSampleIndex() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(endOfFragmentSampleIndex)); +} + +_MTL_INLINE NS::UInteger MTL::RenderPassSampleBufferAttachmentDescriptor::endOfVertexSampleIndex() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(endOfVertexSampleIndex)); +} + +_MTL_INLINE MTL::RenderPassSampleBufferAttachmentDescriptor* MTL::RenderPassSampleBufferAttachmentDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::CounterSampleBuffer* MTL::RenderPassSampleBufferAttachmentDescriptor::sampleBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(sampleBuffer)); +} + +_MTL_INLINE void MTL::RenderPassSampleBufferAttachmentDescriptor::setEndOfFragmentSampleIndex(NS::UInteger endOfFragmentSampleIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setEndOfFragmentSampleIndex_), endOfFragmentSampleIndex); +} + +_MTL_INLINE void MTL::RenderPassSampleBufferAttachmentDescriptor::setEndOfVertexSampleIndex(NS::UInteger endOfVertexSampleIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setEndOfVertexSampleIndex_), endOfVertexSampleIndex); +} + +_MTL_INLINE void MTL::RenderPassSampleBufferAttachmentDescriptor::setSampleBuffer(const MTL::CounterSampleBuffer* sampleBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSampleBuffer_), sampleBuffer); +} + +_MTL_INLINE void MTL::RenderPassSampleBufferAttachmentDescriptor::setStartOfFragmentSampleIndex(NS::UInteger startOfFragmentSampleIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStartOfFragmentSampleIndex_), startOfFragmentSampleIndex); +} + +_MTL_INLINE void MTL::RenderPassSampleBufferAttachmentDescriptor::setStartOfVertexSampleIndex(NS::UInteger startOfVertexSampleIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStartOfVertexSampleIndex_), startOfVertexSampleIndex); +} + +_MTL_INLINE NS::UInteger MTL::RenderPassSampleBufferAttachmentDescriptor::startOfFragmentSampleIndex() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(startOfFragmentSampleIndex)); +} + +_MTL_INLINE NS::UInteger MTL::RenderPassSampleBufferAttachmentDescriptor::startOfVertexSampleIndex() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(startOfVertexSampleIndex)); +} + +_MTL_INLINE MTL::RenderPassSampleBufferAttachmentDescriptorArray* MTL::RenderPassSampleBufferAttachmentDescriptorArray::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLRenderPassSampleBufferAttachmentDescriptorArray)); +} + +_MTL_INLINE MTL::RenderPassSampleBufferAttachmentDescriptorArray* MTL::RenderPassSampleBufferAttachmentDescriptorArray::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::RenderPassSampleBufferAttachmentDescriptor* MTL::RenderPassSampleBufferAttachmentDescriptorArray::object(NS::UInteger attachmentIndex) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(objectAtIndexedSubscript_), attachmentIndex); +} + +_MTL_INLINE void MTL::RenderPassSampleBufferAttachmentDescriptorArray::setObject(const MTL::RenderPassSampleBufferAttachmentDescriptor* attachment, NS::UInteger attachmentIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObject_atIndexedSubscript_), attachment, attachmentIndex); +} + +_MTL_INLINE MTL::RenderPassDescriptor* MTL::RenderPassDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLRenderPassDescriptor)); +} + +_MTL_INLINE MTL::RenderPassColorAttachmentDescriptorArray* MTL::RenderPassDescriptor::colorAttachments() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(colorAttachments)); +} + +_MTL_INLINE NS::UInteger MTL::RenderPassDescriptor::defaultRasterSampleCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(defaultRasterSampleCount)); +} + +_MTL_INLINE MTL::RenderPassDepthAttachmentDescriptor* MTL::RenderPassDescriptor::depthAttachment() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(depthAttachment)); +} + +_MTL_INLINE NS::UInteger MTL::RenderPassDescriptor::getSamplePositions(MTL::SamplePosition* positions, NS::UInteger count) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(getSamplePositions_count_), positions, count); +} + +_MTL_INLINE NS::UInteger MTL::RenderPassDescriptor::imageblockSampleLength() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(imageblockSampleLength)); +} + +_MTL_INLINE MTL::RenderPassDescriptor* MTL::RenderPassDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::RasterizationRateMap* MTL::RenderPassDescriptor::rasterizationRateMap() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(rasterizationRateMap)); +} + +_MTL_INLINE MTL::RenderPassDescriptor* MTL::RenderPassDescriptor::renderPassDescriptor() +{ + return Object::sendMessage(_MTL_PRIVATE_CLS(MTLRenderPassDescriptor), _MTL_PRIVATE_SEL(renderPassDescriptor)); +} + +_MTL_INLINE NS::UInteger MTL::RenderPassDescriptor::renderTargetArrayLength() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(renderTargetArrayLength)); +} + +_MTL_INLINE NS::UInteger MTL::RenderPassDescriptor::renderTargetHeight() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(renderTargetHeight)); +} + +_MTL_INLINE NS::UInteger MTL::RenderPassDescriptor::renderTargetWidth() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(renderTargetWidth)); +} + +_MTL_INLINE MTL::RenderPassSampleBufferAttachmentDescriptorArray* MTL::RenderPassDescriptor::sampleBufferAttachments() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(sampleBufferAttachments)); +} + +_MTL_INLINE void MTL::RenderPassDescriptor::setDefaultRasterSampleCount(NS::UInteger defaultRasterSampleCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDefaultRasterSampleCount_), defaultRasterSampleCount); +} + +_MTL_INLINE void MTL::RenderPassDescriptor::setDepthAttachment(const MTL::RenderPassDepthAttachmentDescriptor* depthAttachment) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDepthAttachment_), depthAttachment); +} + +_MTL_INLINE void MTL::RenderPassDescriptor::setImageblockSampleLength(NS::UInteger imageblockSampleLength) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setImageblockSampleLength_), imageblockSampleLength); +} + +_MTL_INLINE void MTL::RenderPassDescriptor::setRasterizationRateMap(const MTL::RasterizationRateMap* rasterizationRateMap) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRasterizationRateMap_), rasterizationRateMap); +} + +_MTL_INLINE void MTL::RenderPassDescriptor::setRenderTargetArrayLength(NS::UInteger renderTargetArrayLength) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRenderTargetArrayLength_), renderTargetArrayLength); +} + +_MTL_INLINE void MTL::RenderPassDescriptor::setRenderTargetHeight(NS::UInteger renderTargetHeight) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRenderTargetHeight_), renderTargetHeight); +} + +_MTL_INLINE void MTL::RenderPassDescriptor::setRenderTargetWidth(NS::UInteger renderTargetWidth) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRenderTargetWidth_), renderTargetWidth); +} + +_MTL_INLINE void MTL::RenderPassDescriptor::setSamplePositions(const MTL::SamplePosition* positions, NS::UInteger count) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSamplePositions_count_), positions, count); +} + +_MTL_INLINE void MTL::RenderPassDescriptor::setStencilAttachment(const MTL::RenderPassStencilAttachmentDescriptor* stencilAttachment) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStencilAttachment_), stencilAttachment); +} + +_MTL_INLINE void MTL::RenderPassDescriptor::setSupportColorAttachmentMapping(bool supportColorAttachmentMapping) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSupportColorAttachmentMapping_), supportColorAttachmentMapping); +} + +_MTL_INLINE void MTL::RenderPassDescriptor::setThreadgroupMemoryLength(NS::UInteger threadgroupMemoryLength) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setThreadgroupMemoryLength_), threadgroupMemoryLength); +} + +_MTL_INLINE void MTL::RenderPassDescriptor::setTileHeight(NS::UInteger tileHeight) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTileHeight_), tileHeight); +} + +_MTL_INLINE void MTL::RenderPassDescriptor::setTileWidth(NS::UInteger tileWidth) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTileWidth_), tileWidth); +} + +_MTL_INLINE void MTL::RenderPassDescriptor::setVisibilityResultBuffer(const MTL::Buffer* visibilityResultBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVisibilityResultBuffer_), visibilityResultBuffer); +} + +_MTL_INLINE void MTL::RenderPassDescriptor::setVisibilityResultType(MTL::VisibilityResultType visibilityResultType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVisibilityResultType_), visibilityResultType); +} + +_MTL_INLINE MTL::RenderPassStencilAttachmentDescriptor* MTL::RenderPassDescriptor::stencilAttachment() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(stencilAttachment)); +} + +_MTL_INLINE bool MTL::RenderPassDescriptor::supportColorAttachmentMapping() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportColorAttachmentMapping)); +} + +_MTL_INLINE NS::UInteger MTL::RenderPassDescriptor::threadgroupMemoryLength() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(threadgroupMemoryLength)); +} + +_MTL_INLINE NS::UInteger MTL::RenderPassDescriptor::tileHeight() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(tileHeight)); +} + +_MTL_INLINE NS::UInteger MTL::RenderPassDescriptor::tileWidth() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(tileWidth)); +} + +_MTL_INLINE MTL::Buffer* MTL::RenderPassDescriptor::visibilityResultBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(visibilityResultBuffer)); +} + +_MTL_INLINE MTL::VisibilityResultType MTL::RenderPassDescriptor::visibilityResultType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(visibilityResultType)); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLRenderPipeline.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLRenderPipeline.hpp new file mode 100644 index 00000000..aaa9cdad --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLRenderPipeline.hpp @@ -0,0 +1,1876 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLRenderPipeline.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLAllocation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPipeline.hpp" +#include "MTLPixelFormat.hpp" +#include "MTLPrivate.hpp" +#include "MTLRenderCommandEncoder.hpp" +#include "MTLTypes.hpp" + +namespace MTL +{ +class Device; +class Function; +class FunctionHandle; +class IntersectionFunctionTable; +class IntersectionFunctionTableDescriptor; +class LinkedFunctions; +class LogicalToPhysicalColorAttachmentMap; +class MeshRenderPipelineDescriptor; +class PipelineBufferDescriptorArray; +class RenderPipelineColorAttachmentDescriptor; +class RenderPipelineColorAttachmentDescriptorArray; +class RenderPipelineDescriptor; +class RenderPipelineFunctionsDescriptor; +class RenderPipelineReflection; +class RenderPipelineState; +class TileRenderPipelineColorAttachmentDescriptor; +class TileRenderPipelineColorAttachmentDescriptorArray; +class TileRenderPipelineDescriptor; +class VertexDescriptor; +class VisibleFunctionTable; +class VisibleFunctionTableDescriptor; + +} +namespace MTL4 +{ +class BinaryFunction; +class PipelineDescriptor; +class RenderPipelineBinaryFunctionsDescriptor; + +} +namespace MTL +{ +_MTL_ENUM(NS::UInteger, BlendFactor) { + BlendFactorZero = 0, + BlendFactorOne = 1, + BlendFactorSourceColor = 2, + BlendFactorOneMinusSourceColor = 3, + BlendFactorSourceAlpha = 4, + BlendFactorOneMinusSourceAlpha = 5, + BlendFactorDestinationColor = 6, + BlendFactorOneMinusDestinationColor = 7, + BlendFactorDestinationAlpha = 8, + BlendFactorOneMinusDestinationAlpha = 9, + BlendFactorSourceAlphaSaturated = 10, + BlendFactorBlendColor = 11, + BlendFactorOneMinusBlendColor = 12, + BlendFactorBlendAlpha = 13, + BlendFactorOneMinusBlendAlpha = 14, + BlendFactorSource1Color = 15, + BlendFactorOneMinusSource1Color = 16, + BlendFactorSource1Alpha = 17, + BlendFactorOneMinusSource1Alpha = 18, + BlendFactorUnspecialized = 19, +}; + +_MTL_ENUM(NS::UInteger, BlendOperation) { + BlendOperationAdd = 0, + BlendOperationSubtract = 1, + BlendOperationReverseSubtract = 2, + BlendOperationMin = 3, + BlendOperationMax = 4, + BlendOperationUnspecialized = 5, +}; + +_MTL_ENUM(NS::UInteger, PrimitiveTopologyClass) { + PrimitiveTopologyClassUnspecified = 0, + PrimitiveTopologyClassPoint = 1, + PrimitiveTopologyClassLine = 2, + PrimitiveTopologyClassTriangle = 3, +}; + +_MTL_ENUM(NS::UInteger, TessellationPartitionMode) { + TessellationPartitionModePow2 = 0, + TessellationPartitionModeInteger = 1, + TessellationPartitionModeFractionalOdd = 2, + TessellationPartitionModeFractionalEven = 3, +}; + +_MTL_ENUM(NS::UInteger, TessellationFactorStepFunction) { + TessellationFactorStepFunctionConstant = 0, + TessellationFactorStepFunctionPerPatch = 1, + TessellationFactorStepFunctionPerInstance = 2, + TessellationFactorStepFunctionPerPatchAndPerInstance = 3, +}; + +_MTL_ENUM(NS::UInteger, TessellationFactorFormat) { + TessellationFactorFormatHalf = 0, +}; + +_MTL_ENUM(NS::UInteger, TessellationControlPointIndexType) { + TessellationControlPointIndexTypeNone = 0, + TessellationControlPointIndexTypeUInt16 = 1, + TessellationControlPointIndexTypeUInt32 = 2, +}; + +_MTL_OPTIONS(NS::UInteger, ColorWriteMask) { + ColorWriteMaskNone = 0, + ColorWriteMaskRed = 1 << 3, + ColorWriteMaskGreen = 1 << 2, + ColorWriteMaskBlue = 1 << 1, + ColorWriteMaskAlpha = 1, + ColorWriteMaskAll = 15, + ColorWriteMaskUnspecialized = 1 << 4, +}; + +class RenderPipelineColorAttachmentDescriptor : public NS::Copying +{ +public: + static RenderPipelineColorAttachmentDescriptor* alloc(); + + BlendOperation alphaBlendOperation() const; + + [[deprecated("please use isBlendingEnabled instead")]] + bool blendingEnabled() const; + + BlendFactor destinationAlphaBlendFactor() const; + + BlendFactor destinationRGBBlendFactor() const; + + RenderPipelineColorAttachmentDescriptor* init(); + + bool isBlendingEnabled() const; + + PixelFormat pixelFormat() const; + + BlendOperation rgbBlendOperation() const; + + void setAlphaBlendOperation(MTL::BlendOperation alphaBlendOperation); + + void setBlendingEnabled(bool blendingEnabled); + + void setDestinationAlphaBlendFactor(MTL::BlendFactor destinationAlphaBlendFactor); + + void setDestinationRGBBlendFactor(MTL::BlendFactor destinationRGBBlendFactor); + + void setPixelFormat(MTL::PixelFormat pixelFormat); + + void setRgbBlendOperation(MTL::BlendOperation rgbBlendOperation); + + void setSourceAlphaBlendFactor(MTL::BlendFactor sourceAlphaBlendFactor); + + void setSourceRGBBlendFactor(MTL::BlendFactor sourceRGBBlendFactor); + + void setWriteMask(MTL::ColorWriteMask writeMask); + + BlendFactor sourceAlphaBlendFactor() const; + + BlendFactor sourceRGBBlendFactor() const; + + ColorWriteMask writeMask() const; +}; +class LogicalToPhysicalColorAttachmentMap : public NS::Copying +{ +public: + static LogicalToPhysicalColorAttachmentMap* alloc(); + + NS::UInteger getPhysicalIndex(NS::UInteger logicalIndex); + + LogicalToPhysicalColorAttachmentMap* init(); + + void reset(); + + void setPhysicalIndex(NS::UInteger physicalIndex, NS::UInteger logicalIndex); +}; +class RenderPipelineReflection : public NS::Referencing +{ +public: + static RenderPipelineReflection* alloc(); + + NS::Array* fragmentArguments() const; + + NS::Array* fragmentBindings() const; + + RenderPipelineReflection* init(); + + NS::Array* meshBindings() const; + + NS::Array* objectBindings() const; + + NS::Array* tileArguments() const; + + NS::Array* tileBindings() const; + + NS::Array* vertexArguments() const; + + NS::Array* vertexBindings() const; +}; +class RenderPipelineDescriptor : public NS::Copying +{ +public: + static RenderPipelineDescriptor* alloc(); + + [[deprecated("please use isAlphaToCoverageEnabled instead")]] + bool alphaToCoverageEnabled() const; + + [[deprecated("please use isAlphaToOneEnabled instead")]] + bool alphaToOneEnabled() const; + + NS::Array* binaryArchives() const; + + RenderPipelineColorAttachmentDescriptorArray* colorAttachments() const; + + PixelFormat depthAttachmentPixelFormat() const; + + PipelineBufferDescriptorArray* fragmentBuffers() const; + + Function* fragmentFunction() const; + + LinkedFunctions* fragmentLinkedFunctions() const; + + NS::Array* fragmentPreloadedLibraries() const; + + RenderPipelineDescriptor* init(); + + PrimitiveTopologyClass inputPrimitiveTopology() const; + + bool isAlphaToCoverageEnabled() const; + + bool isAlphaToOneEnabled() const; + + bool isRasterizationEnabled() const; + + bool isTessellationFactorScaleEnabled() const; + + NS::String* label() const; + + NS::UInteger maxFragmentCallStackDepth() const; + + NS::UInteger maxTessellationFactor() const; + + NS::UInteger maxVertexAmplificationCount() const; + + NS::UInteger maxVertexCallStackDepth() const; + + NS::UInteger rasterSampleCount() const; + + [[deprecated("please use isRasterizationEnabled instead")]] + bool rasterizationEnabled() const; + + void reset(); + + NS::UInteger sampleCount() const; + + void setAlphaToCoverageEnabled(bool alphaToCoverageEnabled); + + void setAlphaToOneEnabled(bool alphaToOneEnabled); + + void setBinaryArchives(const NS::Array* binaryArchives); + + void setDepthAttachmentPixelFormat(MTL::PixelFormat depthAttachmentPixelFormat); + + void setFragmentFunction(const MTL::Function* fragmentFunction); + + void setFragmentLinkedFunctions(const MTL::LinkedFunctions* fragmentLinkedFunctions); + + void setFragmentPreloadedLibraries(const NS::Array* fragmentPreloadedLibraries); + + void setInputPrimitiveTopology(MTL::PrimitiveTopologyClass inputPrimitiveTopology); + + void setLabel(const NS::String* label); + + void setMaxFragmentCallStackDepth(NS::UInteger maxFragmentCallStackDepth); + + void setMaxTessellationFactor(NS::UInteger maxTessellationFactor); + + void setMaxVertexAmplificationCount(NS::UInteger maxVertexAmplificationCount); + + void setMaxVertexCallStackDepth(NS::UInteger maxVertexCallStackDepth); + + void setRasterSampleCount(NS::UInteger rasterSampleCount); + + void setRasterizationEnabled(bool rasterizationEnabled); + + void setSampleCount(NS::UInteger sampleCount); + + void setShaderValidation(MTL::ShaderValidation shaderValidation); + + void setStencilAttachmentPixelFormat(MTL::PixelFormat stencilAttachmentPixelFormat); + + void setSupportAddingFragmentBinaryFunctions(bool supportAddingFragmentBinaryFunctions); + + void setSupportAddingVertexBinaryFunctions(bool supportAddingVertexBinaryFunctions); + + void setSupportIndirectCommandBuffers(bool supportIndirectCommandBuffers); + + void setTessellationControlPointIndexType(MTL::TessellationControlPointIndexType tessellationControlPointIndexType); + + void setTessellationFactorFormat(MTL::TessellationFactorFormat tessellationFactorFormat); + + void setTessellationFactorScaleEnabled(bool tessellationFactorScaleEnabled); + + void setTessellationFactorStepFunction(MTL::TessellationFactorStepFunction tessellationFactorStepFunction); + + void setTessellationOutputWindingOrder(MTL::Winding tessellationOutputWindingOrder); + + void setTessellationPartitionMode(MTL::TessellationPartitionMode tessellationPartitionMode); + + void setVertexDescriptor(const MTL::VertexDescriptor* vertexDescriptor); + + void setVertexFunction(const MTL::Function* vertexFunction); + + void setVertexLinkedFunctions(const MTL::LinkedFunctions* vertexLinkedFunctions); + + void setVertexPreloadedLibraries(const NS::Array* vertexPreloadedLibraries); + + ShaderValidation shaderValidation() const; + + PixelFormat stencilAttachmentPixelFormat() const; + + bool supportAddingFragmentBinaryFunctions() const; + + bool supportAddingVertexBinaryFunctions() const; + + bool supportIndirectCommandBuffers() const; + + TessellationControlPointIndexType tessellationControlPointIndexType() const; + + TessellationFactorFormat tessellationFactorFormat() const; + + [[deprecated("please use isTessellationFactorScaleEnabled instead")]] + bool tessellationFactorScaleEnabled() const; + + TessellationFactorStepFunction tessellationFactorStepFunction() const; + + Winding tessellationOutputWindingOrder() const; + + TessellationPartitionMode tessellationPartitionMode() const; + + PipelineBufferDescriptorArray* vertexBuffers() const; + + VertexDescriptor* vertexDescriptor() const; + + Function* vertexFunction() const; + + LinkedFunctions* vertexLinkedFunctions() const; + + NS::Array* vertexPreloadedLibraries() const; +}; +class RenderPipelineFunctionsDescriptor : public NS::Copying +{ +public: + static RenderPipelineFunctionsDescriptor* alloc(); + + NS::Array* fragmentAdditionalBinaryFunctions() const; + + RenderPipelineFunctionsDescriptor* init(); + + void setFragmentAdditionalBinaryFunctions(const NS::Array* fragmentAdditionalBinaryFunctions); + + void setTileAdditionalBinaryFunctions(const NS::Array* tileAdditionalBinaryFunctions); + + void setVertexAdditionalBinaryFunctions(const NS::Array* vertexAdditionalBinaryFunctions); + + NS::Array* tileAdditionalBinaryFunctions() const; + + NS::Array* vertexAdditionalBinaryFunctions() const; +}; +class RenderPipelineState : public NS::Referencing +{ +public: + Device* device() const; + + FunctionHandle* functionHandle(const NS::String* name, MTL::RenderStages stage); + FunctionHandle* functionHandle(const MTL4::BinaryFunction* function, MTL::RenderStages stage); + FunctionHandle* functionHandle(const MTL::Function* function, MTL::RenderStages stage); + + ResourceID gpuResourceID() const; + + NS::UInteger imageblockMemoryLength(MTL::Size imageblockDimensions); + + NS::UInteger imageblockSampleLength() const; + + NS::String* label() const; + + NS::UInteger maxTotalThreadgroupsPerMeshGrid() const; + + NS::UInteger maxTotalThreadsPerMeshThreadgroup() const; + + NS::UInteger maxTotalThreadsPerObjectThreadgroup() const; + + NS::UInteger maxTotalThreadsPerThreadgroup() const; + + NS::UInteger meshThreadExecutionWidth() const; + + IntersectionFunctionTable* newIntersectionFunctionTable(const MTL::IntersectionFunctionTableDescriptor* descriptor, MTL::RenderStages stage); + + MTL4::PipelineDescriptor* newRenderPipelineDescriptor(); + + RenderPipelineState* newRenderPipelineState(const MTL4::RenderPipelineBinaryFunctionsDescriptor* binaryFunctionsDescriptor, NS::Error** error); + RenderPipelineState* newRenderPipelineState(const MTL::RenderPipelineFunctionsDescriptor* additionalBinaryFunctions, NS::Error** error); + + VisibleFunctionTable* newVisibleFunctionTable(const MTL::VisibleFunctionTableDescriptor* descriptor, MTL::RenderStages stage); + + NS::UInteger objectThreadExecutionWidth() const; + + RenderPipelineReflection* reflection() const; + + Size requiredThreadsPerMeshThreadgroup() const; + + Size requiredThreadsPerObjectThreadgroup() const; + + Size requiredThreadsPerTileThreadgroup() const; + + ShaderValidation shaderValidation() const; + + bool supportIndirectCommandBuffers() const; + + bool threadgroupSizeMatchesTileSize() const; +}; +class RenderPipelineColorAttachmentDescriptorArray : public NS::Referencing +{ +public: + static RenderPipelineColorAttachmentDescriptorArray* alloc(); + + RenderPipelineColorAttachmentDescriptorArray* init(); + + RenderPipelineColorAttachmentDescriptor* object(NS::UInteger attachmentIndex); + void setObject(const MTL::RenderPipelineColorAttachmentDescriptor* attachment, NS::UInteger attachmentIndex); +}; +class TileRenderPipelineColorAttachmentDescriptor : public NS::Copying +{ +public: + static TileRenderPipelineColorAttachmentDescriptor* alloc(); + + TileRenderPipelineColorAttachmentDescriptor* init(); + + PixelFormat pixelFormat() const; + void setPixelFormat(MTL::PixelFormat pixelFormat); +}; +class TileRenderPipelineColorAttachmentDescriptorArray : public NS::Referencing +{ +public: + static TileRenderPipelineColorAttachmentDescriptorArray* alloc(); + + TileRenderPipelineColorAttachmentDescriptorArray* init(); + + TileRenderPipelineColorAttachmentDescriptor* object(NS::UInteger attachmentIndex); + void setObject(const MTL::TileRenderPipelineColorAttachmentDescriptor* attachment, NS::UInteger attachmentIndex); +}; +class TileRenderPipelineDescriptor : public NS::Copying +{ +public: + static TileRenderPipelineDescriptor* alloc(); + + NS::Array* binaryArchives() const; + + TileRenderPipelineColorAttachmentDescriptorArray* colorAttachments() const; + + TileRenderPipelineDescriptor* init(); + + NS::String* label() const; + + LinkedFunctions* linkedFunctions() const; + + NS::UInteger maxCallStackDepth() const; + + NS::UInteger maxTotalThreadsPerThreadgroup() const; + + NS::Array* preloadedLibraries() const; + + NS::UInteger rasterSampleCount() const; + + Size requiredThreadsPerThreadgroup() const; + + void reset(); + + void setBinaryArchives(const NS::Array* binaryArchives); + + void setLabel(const NS::String* label); + + void setLinkedFunctions(const MTL::LinkedFunctions* linkedFunctions); + + void setMaxCallStackDepth(NS::UInteger maxCallStackDepth); + + void setMaxTotalThreadsPerThreadgroup(NS::UInteger maxTotalThreadsPerThreadgroup); + + void setPreloadedLibraries(const NS::Array* preloadedLibraries); + + void setRasterSampleCount(NS::UInteger rasterSampleCount); + + void setRequiredThreadsPerThreadgroup(MTL::Size requiredThreadsPerThreadgroup); + + void setShaderValidation(MTL::ShaderValidation shaderValidation); + + void setSupportAddingBinaryFunctions(bool supportAddingBinaryFunctions); + + void setThreadgroupSizeMatchesTileSize(bool threadgroupSizeMatchesTileSize); + + void setTileFunction(const MTL::Function* tileFunction); + + ShaderValidation shaderValidation() const; + + bool supportAddingBinaryFunctions() const; + + bool threadgroupSizeMatchesTileSize() const; + + PipelineBufferDescriptorArray* tileBuffers() const; + + Function* tileFunction() const; +}; +class MeshRenderPipelineDescriptor : public NS::Copying +{ +public: + static MeshRenderPipelineDescriptor* alloc(); + + [[deprecated("please use isAlphaToCoverageEnabled instead")]] + bool alphaToCoverageEnabled() const; + + [[deprecated("please use isAlphaToOneEnabled instead")]] + bool alphaToOneEnabled() const; + + NS::Array* binaryArchives() const; + + RenderPipelineColorAttachmentDescriptorArray* colorAttachments() const; + + PixelFormat depthAttachmentPixelFormat() const; + + PipelineBufferDescriptorArray* fragmentBuffers() const; + + Function* fragmentFunction() const; + + LinkedFunctions* fragmentLinkedFunctions() const; + + MeshRenderPipelineDescriptor* init(); + + bool isAlphaToCoverageEnabled() const; + + bool isAlphaToOneEnabled() const; + + bool isRasterizationEnabled() const; + + NS::String* label() const; + + NS::UInteger maxTotalThreadgroupsPerMeshGrid() const; + + NS::UInteger maxTotalThreadsPerMeshThreadgroup() const; + + NS::UInteger maxTotalThreadsPerObjectThreadgroup() const; + + NS::UInteger maxVertexAmplificationCount() const; + + PipelineBufferDescriptorArray* meshBuffers() const; + + Function* meshFunction() const; + + LinkedFunctions* meshLinkedFunctions() const; + + bool meshThreadgroupSizeIsMultipleOfThreadExecutionWidth() const; + + PipelineBufferDescriptorArray* objectBuffers() const; + + Function* objectFunction() const; + + LinkedFunctions* objectLinkedFunctions() const; + + bool objectThreadgroupSizeIsMultipleOfThreadExecutionWidth() const; + + NS::UInteger payloadMemoryLength() const; + + NS::UInteger rasterSampleCount() const; + + [[deprecated("please use isRasterizationEnabled instead")]] + bool rasterizationEnabled() const; + + Size requiredThreadsPerMeshThreadgroup() const; + + Size requiredThreadsPerObjectThreadgroup() const; + + void reset(); + + void setAlphaToCoverageEnabled(bool alphaToCoverageEnabled); + + void setAlphaToOneEnabled(bool alphaToOneEnabled); + + void setBinaryArchives(const NS::Array* binaryArchives); + + void setDepthAttachmentPixelFormat(MTL::PixelFormat depthAttachmentPixelFormat); + + void setFragmentFunction(const MTL::Function* fragmentFunction); + + void setFragmentLinkedFunctions(const MTL::LinkedFunctions* fragmentLinkedFunctions); + + void setLabel(const NS::String* label); + + void setMaxTotalThreadgroupsPerMeshGrid(NS::UInteger maxTotalThreadgroupsPerMeshGrid); + + void setMaxTotalThreadsPerMeshThreadgroup(NS::UInteger maxTotalThreadsPerMeshThreadgroup); + + void setMaxTotalThreadsPerObjectThreadgroup(NS::UInteger maxTotalThreadsPerObjectThreadgroup); + + void setMaxVertexAmplificationCount(NS::UInteger maxVertexAmplificationCount); + + void setMeshFunction(const MTL::Function* meshFunction); + + void setMeshLinkedFunctions(const MTL::LinkedFunctions* meshLinkedFunctions); + + void setMeshThreadgroupSizeIsMultipleOfThreadExecutionWidth(bool meshThreadgroupSizeIsMultipleOfThreadExecutionWidth); + + void setObjectFunction(const MTL::Function* objectFunction); + + void setObjectLinkedFunctions(const MTL::LinkedFunctions* objectLinkedFunctions); + + void setObjectThreadgroupSizeIsMultipleOfThreadExecutionWidth(bool objectThreadgroupSizeIsMultipleOfThreadExecutionWidth); + + void setPayloadMemoryLength(NS::UInteger payloadMemoryLength); + + void setRasterSampleCount(NS::UInteger rasterSampleCount); + + void setRasterizationEnabled(bool rasterizationEnabled); + + void setRequiredThreadsPerMeshThreadgroup(MTL::Size requiredThreadsPerMeshThreadgroup); + + void setRequiredThreadsPerObjectThreadgroup(MTL::Size requiredThreadsPerObjectThreadgroup); + + void setShaderValidation(MTL::ShaderValidation shaderValidation); + + void setStencilAttachmentPixelFormat(MTL::PixelFormat stencilAttachmentPixelFormat); + + void setSupportIndirectCommandBuffers(bool supportIndirectCommandBuffers); + + ShaderValidation shaderValidation() const; + + PixelFormat stencilAttachmentPixelFormat() const; + + bool supportIndirectCommandBuffers() const; +}; + +} +_MTL_INLINE MTL::RenderPipelineColorAttachmentDescriptor* MTL::RenderPipelineColorAttachmentDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLRenderPipelineColorAttachmentDescriptor)); +} + +_MTL_INLINE MTL::BlendOperation MTL::RenderPipelineColorAttachmentDescriptor::alphaBlendOperation() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(alphaBlendOperation)); +} + +_MTL_INLINE bool MTL::RenderPipelineColorAttachmentDescriptor::blendingEnabled() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isBlendingEnabled)); +} + +_MTL_INLINE MTL::BlendFactor MTL::RenderPipelineColorAttachmentDescriptor::destinationAlphaBlendFactor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(destinationAlphaBlendFactor)); +} + +_MTL_INLINE MTL::BlendFactor MTL::RenderPipelineColorAttachmentDescriptor::destinationRGBBlendFactor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(destinationRGBBlendFactor)); +} + +_MTL_INLINE MTL::RenderPipelineColorAttachmentDescriptor* MTL::RenderPipelineColorAttachmentDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE bool MTL::RenderPipelineColorAttachmentDescriptor::isBlendingEnabled() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isBlendingEnabled)); +} + +_MTL_INLINE MTL::PixelFormat MTL::RenderPipelineColorAttachmentDescriptor::pixelFormat() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(pixelFormat)); +} + +_MTL_INLINE MTL::BlendOperation MTL::RenderPipelineColorAttachmentDescriptor::rgbBlendOperation() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(rgbBlendOperation)); +} + +_MTL_INLINE void MTL::RenderPipelineColorAttachmentDescriptor::setAlphaBlendOperation(MTL::BlendOperation alphaBlendOperation) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setAlphaBlendOperation_), alphaBlendOperation); +} + +_MTL_INLINE void MTL::RenderPipelineColorAttachmentDescriptor::setBlendingEnabled(bool blendingEnabled) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBlendingEnabled_), blendingEnabled); +} + +_MTL_INLINE void MTL::RenderPipelineColorAttachmentDescriptor::setDestinationAlphaBlendFactor(MTL::BlendFactor destinationAlphaBlendFactor) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDestinationAlphaBlendFactor_), destinationAlphaBlendFactor); +} + +_MTL_INLINE void MTL::RenderPipelineColorAttachmentDescriptor::setDestinationRGBBlendFactor(MTL::BlendFactor destinationRGBBlendFactor) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDestinationRGBBlendFactor_), destinationRGBBlendFactor); +} + +_MTL_INLINE void MTL::RenderPipelineColorAttachmentDescriptor::setPixelFormat(MTL::PixelFormat pixelFormat) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setPixelFormat_), pixelFormat); +} + +_MTL_INLINE void MTL::RenderPipelineColorAttachmentDescriptor::setRgbBlendOperation(MTL::BlendOperation rgbBlendOperation) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRgbBlendOperation_), rgbBlendOperation); +} + +_MTL_INLINE void MTL::RenderPipelineColorAttachmentDescriptor::setSourceAlphaBlendFactor(MTL::BlendFactor sourceAlphaBlendFactor) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSourceAlphaBlendFactor_), sourceAlphaBlendFactor); +} + +_MTL_INLINE void MTL::RenderPipelineColorAttachmentDescriptor::setSourceRGBBlendFactor(MTL::BlendFactor sourceRGBBlendFactor) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSourceRGBBlendFactor_), sourceRGBBlendFactor); +} + +_MTL_INLINE void MTL::RenderPipelineColorAttachmentDescriptor::setWriteMask(MTL::ColorWriteMask writeMask) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setWriteMask_), writeMask); +} + +_MTL_INLINE MTL::BlendFactor MTL::RenderPipelineColorAttachmentDescriptor::sourceAlphaBlendFactor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(sourceAlphaBlendFactor)); +} + +_MTL_INLINE MTL::BlendFactor MTL::RenderPipelineColorAttachmentDescriptor::sourceRGBBlendFactor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(sourceRGBBlendFactor)); +} + +_MTL_INLINE MTL::ColorWriteMask MTL::RenderPipelineColorAttachmentDescriptor::writeMask() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(writeMask)); +} + +_MTL_INLINE MTL::LogicalToPhysicalColorAttachmentMap* MTL::LogicalToPhysicalColorAttachmentMap::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLLogicalToPhysicalColorAttachmentMap)); +} + +_MTL_INLINE NS::UInteger MTL::LogicalToPhysicalColorAttachmentMap::getPhysicalIndex(NS::UInteger logicalIndex) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(getPhysicalIndexForLogicalIndex_), logicalIndex); +} + +_MTL_INLINE MTL::LogicalToPhysicalColorAttachmentMap* MTL::LogicalToPhysicalColorAttachmentMap::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE void MTL::LogicalToPhysicalColorAttachmentMap::reset() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(reset)); +} + +_MTL_INLINE void MTL::LogicalToPhysicalColorAttachmentMap::setPhysicalIndex(NS::UInteger physicalIndex, NS::UInteger logicalIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setPhysicalIndex_forLogicalIndex_), physicalIndex, logicalIndex); +} + +_MTL_INLINE MTL::RenderPipelineReflection* MTL::RenderPipelineReflection::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLRenderPipelineReflection)); +} + +_MTL_INLINE NS::Array* MTL::RenderPipelineReflection::fragmentArguments() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(fragmentArguments)); +} + +_MTL_INLINE NS::Array* MTL::RenderPipelineReflection::fragmentBindings() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(fragmentBindings)); +} + +_MTL_INLINE MTL::RenderPipelineReflection* MTL::RenderPipelineReflection::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::Array* MTL::RenderPipelineReflection::meshBindings() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(meshBindings)); +} + +_MTL_INLINE NS::Array* MTL::RenderPipelineReflection::objectBindings() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(objectBindings)); +} + +_MTL_INLINE NS::Array* MTL::RenderPipelineReflection::tileArguments() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(tileArguments)); +} + +_MTL_INLINE NS::Array* MTL::RenderPipelineReflection::tileBindings() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(tileBindings)); +} + +_MTL_INLINE NS::Array* MTL::RenderPipelineReflection::vertexArguments() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(vertexArguments)); +} + +_MTL_INLINE NS::Array* MTL::RenderPipelineReflection::vertexBindings() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(vertexBindings)); +} + +_MTL_INLINE MTL::RenderPipelineDescriptor* MTL::RenderPipelineDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLRenderPipelineDescriptor)); +} + +_MTL_INLINE bool MTL::RenderPipelineDescriptor::alphaToCoverageEnabled() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isAlphaToCoverageEnabled)); +} + +_MTL_INLINE bool MTL::RenderPipelineDescriptor::alphaToOneEnabled() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isAlphaToOneEnabled)); +} + +_MTL_INLINE NS::Array* MTL::RenderPipelineDescriptor::binaryArchives() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(binaryArchives)); +} + +_MTL_INLINE MTL::RenderPipelineColorAttachmentDescriptorArray* MTL::RenderPipelineDescriptor::colorAttachments() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(colorAttachments)); +} + +_MTL_INLINE MTL::PixelFormat MTL::RenderPipelineDescriptor::depthAttachmentPixelFormat() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(depthAttachmentPixelFormat)); +} + +_MTL_INLINE MTL::PipelineBufferDescriptorArray* MTL::RenderPipelineDescriptor::fragmentBuffers() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(fragmentBuffers)); +} + +_MTL_INLINE MTL::Function* MTL::RenderPipelineDescriptor::fragmentFunction() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(fragmentFunction)); +} + +_MTL_INLINE MTL::LinkedFunctions* MTL::RenderPipelineDescriptor::fragmentLinkedFunctions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(fragmentLinkedFunctions)); +} + +_MTL_INLINE NS::Array* MTL::RenderPipelineDescriptor::fragmentPreloadedLibraries() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(fragmentPreloadedLibraries)); +} + +_MTL_INLINE MTL::RenderPipelineDescriptor* MTL::RenderPipelineDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::PrimitiveTopologyClass MTL::RenderPipelineDescriptor::inputPrimitiveTopology() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(inputPrimitiveTopology)); +} + +_MTL_INLINE bool MTL::RenderPipelineDescriptor::isAlphaToCoverageEnabled() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isAlphaToCoverageEnabled)); +} + +_MTL_INLINE bool MTL::RenderPipelineDescriptor::isAlphaToOneEnabled() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isAlphaToOneEnabled)); +} + +_MTL_INLINE bool MTL::RenderPipelineDescriptor::isRasterizationEnabled() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isRasterizationEnabled)); +} + +_MTL_INLINE bool MTL::RenderPipelineDescriptor::isTessellationFactorScaleEnabled() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isTessellationFactorScaleEnabled)); +} + +_MTL_INLINE NS::String* MTL::RenderPipelineDescriptor::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE NS::UInteger MTL::RenderPipelineDescriptor::maxFragmentCallStackDepth() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxFragmentCallStackDepth)); +} + +_MTL_INLINE NS::UInteger MTL::RenderPipelineDescriptor::maxTessellationFactor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxTessellationFactor)); +} + +_MTL_INLINE NS::UInteger MTL::RenderPipelineDescriptor::maxVertexAmplificationCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxVertexAmplificationCount)); +} + +_MTL_INLINE NS::UInteger MTL::RenderPipelineDescriptor::maxVertexCallStackDepth() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxVertexCallStackDepth)); +} + +_MTL_INLINE NS::UInteger MTL::RenderPipelineDescriptor::rasterSampleCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(rasterSampleCount)); +} + +_MTL_INLINE bool MTL::RenderPipelineDescriptor::rasterizationEnabled() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isRasterizationEnabled)); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::reset() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(reset)); +} + +_MTL_INLINE NS::UInteger MTL::RenderPipelineDescriptor::sampleCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(sampleCount)); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setAlphaToCoverageEnabled(bool alphaToCoverageEnabled) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setAlphaToCoverageEnabled_), alphaToCoverageEnabled); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setAlphaToOneEnabled(bool alphaToOneEnabled) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setAlphaToOneEnabled_), alphaToOneEnabled); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setBinaryArchives(const NS::Array* binaryArchives) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBinaryArchives_), binaryArchives); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setDepthAttachmentPixelFormat(MTL::PixelFormat depthAttachmentPixelFormat) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDepthAttachmentPixelFormat_), depthAttachmentPixelFormat); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setFragmentFunction(const MTL::Function* fragmentFunction) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFragmentFunction_), fragmentFunction); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setFragmentLinkedFunctions(const MTL::LinkedFunctions* fragmentLinkedFunctions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFragmentLinkedFunctions_), fragmentLinkedFunctions); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setFragmentPreloadedLibraries(const NS::Array* fragmentPreloadedLibraries) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFragmentPreloadedLibraries_), fragmentPreloadedLibraries); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setInputPrimitiveTopology(MTL::PrimitiveTopologyClass inputPrimitiveTopology) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInputPrimitiveTopology_), inputPrimitiveTopology); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setMaxFragmentCallStackDepth(NS::UInteger maxFragmentCallStackDepth) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxFragmentCallStackDepth_), maxFragmentCallStackDepth); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setMaxTessellationFactor(NS::UInteger maxTessellationFactor) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxTessellationFactor_), maxTessellationFactor); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setMaxVertexAmplificationCount(NS::UInteger maxVertexAmplificationCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxVertexAmplificationCount_), maxVertexAmplificationCount); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setMaxVertexCallStackDepth(NS::UInteger maxVertexCallStackDepth) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxVertexCallStackDepth_), maxVertexCallStackDepth); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setRasterSampleCount(NS::UInteger rasterSampleCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRasterSampleCount_), rasterSampleCount); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setRasterizationEnabled(bool rasterizationEnabled) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRasterizationEnabled_), rasterizationEnabled); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setSampleCount(NS::UInteger sampleCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSampleCount_), sampleCount); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setShaderValidation(MTL::ShaderValidation shaderValidation) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setShaderValidation_), shaderValidation); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setStencilAttachmentPixelFormat(MTL::PixelFormat stencilAttachmentPixelFormat) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStencilAttachmentPixelFormat_), stencilAttachmentPixelFormat); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setSupportAddingFragmentBinaryFunctions(bool supportAddingFragmentBinaryFunctions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSupportAddingFragmentBinaryFunctions_), supportAddingFragmentBinaryFunctions); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setSupportAddingVertexBinaryFunctions(bool supportAddingVertexBinaryFunctions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSupportAddingVertexBinaryFunctions_), supportAddingVertexBinaryFunctions); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setSupportIndirectCommandBuffers(bool supportIndirectCommandBuffers) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSupportIndirectCommandBuffers_), supportIndirectCommandBuffers); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setTessellationControlPointIndexType(MTL::TessellationControlPointIndexType tessellationControlPointIndexType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTessellationControlPointIndexType_), tessellationControlPointIndexType); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setTessellationFactorFormat(MTL::TessellationFactorFormat tessellationFactorFormat) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTessellationFactorFormat_), tessellationFactorFormat); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setTessellationFactorScaleEnabled(bool tessellationFactorScaleEnabled) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTessellationFactorScaleEnabled_), tessellationFactorScaleEnabled); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setTessellationFactorStepFunction(MTL::TessellationFactorStepFunction tessellationFactorStepFunction) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTessellationFactorStepFunction_), tessellationFactorStepFunction); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setTessellationOutputWindingOrder(MTL::Winding tessellationOutputWindingOrder) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTessellationOutputWindingOrder_), tessellationOutputWindingOrder); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setTessellationPartitionMode(MTL::TessellationPartitionMode tessellationPartitionMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTessellationPartitionMode_), tessellationPartitionMode); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setVertexDescriptor(const MTL::VertexDescriptor* vertexDescriptor) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexDescriptor_), vertexDescriptor); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setVertexFunction(const MTL::Function* vertexFunction) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexFunction_), vertexFunction); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setVertexLinkedFunctions(const MTL::LinkedFunctions* vertexLinkedFunctions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexLinkedFunctions_), vertexLinkedFunctions); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setVertexPreloadedLibraries(const NS::Array* vertexPreloadedLibraries) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexPreloadedLibraries_), vertexPreloadedLibraries); +} + +_MTL_INLINE MTL::ShaderValidation MTL::RenderPipelineDescriptor::shaderValidation() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(shaderValidation)); +} + +_MTL_INLINE MTL::PixelFormat MTL::RenderPipelineDescriptor::stencilAttachmentPixelFormat() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(stencilAttachmentPixelFormat)); +} + +_MTL_INLINE bool MTL::RenderPipelineDescriptor::supportAddingFragmentBinaryFunctions() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportAddingFragmentBinaryFunctions)); +} + +_MTL_INLINE bool MTL::RenderPipelineDescriptor::supportAddingVertexBinaryFunctions() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportAddingVertexBinaryFunctions)); +} + +_MTL_INLINE bool MTL::RenderPipelineDescriptor::supportIndirectCommandBuffers() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportIndirectCommandBuffers)); +} + +_MTL_INLINE MTL::TessellationControlPointIndexType MTL::RenderPipelineDescriptor::tessellationControlPointIndexType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(tessellationControlPointIndexType)); +} + +_MTL_INLINE MTL::TessellationFactorFormat MTL::RenderPipelineDescriptor::tessellationFactorFormat() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(tessellationFactorFormat)); +} + +_MTL_INLINE bool MTL::RenderPipelineDescriptor::tessellationFactorScaleEnabled() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isTessellationFactorScaleEnabled)); +} + +_MTL_INLINE MTL::TessellationFactorStepFunction MTL::RenderPipelineDescriptor::tessellationFactorStepFunction() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(tessellationFactorStepFunction)); +} + +_MTL_INLINE MTL::Winding MTL::RenderPipelineDescriptor::tessellationOutputWindingOrder() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(tessellationOutputWindingOrder)); +} + +_MTL_INLINE MTL::TessellationPartitionMode MTL::RenderPipelineDescriptor::tessellationPartitionMode() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(tessellationPartitionMode)); +} + +_MTL_INLINE MTL::PipelineBufferDescriptorArray* MTL::RenderPipelineDescriptor::vertexBuffers() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(vertexBuffers)); +} + +_MTL_INLINE MTL::VertexDescriptor* MTL::RenderPipelineDescriptor::vertexDescriptor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(vertexDescriptor)); +} + +_MTL_INLINE MTL::Function* MTL::RenderPipelineDescriptor::vertexFunction() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(vertexFunction)); +} + +_MTL_INLINE MTL::LinkedFunctions* MTL::RenderPipelineDescriptor::vertexLinkedFunctions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(vertexLinkedFunctions)); +} + +_MTL_INLINE NS::Array* MTL::RenderPipelineDescriptor::vertexPreloadedLibraries() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(vertexPreloadedLibraries)); +} + +_MTL_INLINE MTL::RenderPipelineFunctionsDescriptor* MTL::RenderPipelineFunctionsDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLRenderPipelineFunctionsDescriptor)); +} + +_MTL_INLINE NS::Array* MTL::RenderPipelineFunctionsDescriptor::fragmentAdditionalBinaryFunctions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(fragmentAdditionalBinaryFunctions)); +} + +_MTL_INLINE MTL::RenderPipelineFunctionsDescriptor* MTL::RenderPipelineFunctionsDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE void MTL::RenderPipelineFunctionsDescriptor::setFragmentAdditionalBinaryFunctions(const NS::Array* fragmentAdditionalBinaryFunctions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFragmentAdditionalBinaryFunctions_), fragmentAdditionalBinaryFunctions); +} + +_MTL_INLINE void MTL::RenderPipelineFunctionsDescriptor::setTileAdditionalBinaryFunctions(const NS::Array* tileAdditionalBinaryFunctions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTileAdditionalBinaryFunctions_), tileAdditionalBinaryFunctions); +} + +_MTL_INLINE void MTL::RenderPipelineFunctionsDescriptor::setVertexAdditionalBinaryFunctions(const NS::Array* vertexAdditionalBinaryFunctions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexAdditionalBinaryFunctions_), vertexAdditionalBinaryFunctions); +} + +_MTL_INLINE NS::Array* MTL::RenderPipelineFunctionsDescriptor::tileAdditionalBinaryFunctions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(tileAdditionalBinaryFunctions)); +} + +_MTL_INLINE NS::Array* MTL::RenderPipelineFunctionsDescriptor::vertexAdditionalBinaryFunctions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(vertexAdditionalBinaryFunctions)); +} + +_MTL_INLINE MTL::Device* MTL::RenderPipelineState::device() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(device)); +} + +_MTL_INLINE MTL::FunctionHandle* MTL::RenderPipelineState::functionHandle(const NS::String* name, MTL::RenderStages stage) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(functionHandleWithName_stage_), name, stage); +} + +_MTL_INLINE MTL::FunctionHandle* MTL::RenderPipelineState::functionHandle(const MTL4::BinaryFunction* function, MTL::RenderStages stage) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(functionHandleWithBinaryFunction_stage_), function, stage); +} + +_MTL_INLINE MTL::FunctionHandle* MTL::RenderPipelineState::functionHandle(const MTL::Function* function, MTL::RenderStages stage) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(functionHandleWithFunction_stage_), function, stage); +} + +_MTL_INLINE MTL::ResourceID MTL::RenderPipelineState::gpuResourceID() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(gpuResourceID)); +} + +_MTL_INLINE NS::UInteger MTL::RenderPipelineState::imageblockMemoryLength(MTL::Size imageblockDimensions) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(imageblockMemoryLengthForDimensions_), imageblockDimensions); +} + +_MTL_INLINE NS::UInteger MTL::RenderPipelineState::imageblockSampleLength() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(imageblockSampleLength)); +} + +_MTL_INLINE NS::String* MTL::RenderPipelineState::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE NS::UInteger MTL::RenderPipelineState::maxTotalThreadgroupsPerMeshGrid() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxTotalThreadgroupsPerMeshGrid)); +} + +_MTL_INLINE NS::UInteger MTL::RenderPipelineState::maxTotalThreadsPerMeshThreadgroup() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxTotalThreadsPerMeshThreadgroup)); +} + +_MTL_INLINE NS::UInteger MTL::RenderPipelineState::maxTotalThreadsPerObjectThreadgroup() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxTotalThreadsPerObjectThreadgroup)); +} + +_MTL_INLINE NS::UInteger MTL::RenderPipelineState::maxTotalThreadsPerThreadgroup() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxTotalThreadsPerThreadgroup)); +} + +_MTL_INLINE NS::UInteger MTL::RenderPipelineState::meshThreadExecutionWidth() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(meshThreadExecutionWidth)); +} + +_MTL_INLINE MTL::IntersectionFunctionTable* MTL::RenderPipelineState::newIntersectionFunctionTable(const MTL::IntersectionFunctionTableDescriptor* descriptor, MTL::RenderStages stage) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newIntersectionFunctionTableWithDescriptor_stage_), descriptor, stage); +} + +_MTL_INLINE MTL4::PipelineDescriptor* MTL::RenderPipelineState::newRenderPipelineDescriptor() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newRenderPipelineDescriptorForSpecialization)); +} + +_MTL_INLINE MTL::RenderPipelineState* MTL::RenderPipelineState::newRenderPipelineState(const MTL4::RenderPipelineBinaryFunctionsDescriptor* binaryFunctionsDescriptor, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newRenderPipelineStateWithBinaryFunctions_error_), binaryFunctionsDescriptor, error); +} + +_MTL_INLINE MTL::RenderPipelineState* MTL::RenderPipelineState::newRenderPipelineState(const MTL::RenderPipelineFunctionsDescriptor* additionalBinaryFunctions, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newRenderPipelineStateWithAdditionalBinaryFunctions_error_), additionalBinaryFunctions, error); +} + +_MTL_INLINE MTL::VisibleFunctionTable* MTL::RenderPipelineState::newVisibleFunctionTable(const MTL::VisibleFunctionTableDescriptor* descriptor, MTL::RenderStages stage) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newVisibleFunctionTableWithDescriptor_stage_), descriptor, stage); +} + +_MTL_INLINE NS::UInteger MTL::RenderPipelineState::objectThreadExecutionWidth() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(objectThreadExecutionWidth)); +} + +_MTL_INLINE MTL::RenderPipelineReflection* MTL::RenderPipelineState::reflection() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(reflection)); +} + +_MTL_INLINE MTL::Size MTL::RenderPipelineState::requiredThreadsPerMeshThreadgroup() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(requiredThreadsPerMeshThreadgroup)); +} + +_MTL_INLINE MTL::Size MTL::RenderPipelineState::requiredThreadsPerObjectThreadgroup() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(requiredThreadsPerObjectThreadgroup)); +} + +_MTL_INLINE MTL::Size MTL::RenderPipelineState::requiredThreadsPerTileThreadgroup() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(requiredThreadsPerTileThreadgroup)); +} + +_MTL_INLINE MTL::ShaderValidation MTL::RenderPipelineState::shaderValidation() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(shaderValidation)); +} + +_MTL_INLINE bool MTL::RenderPipelineState::supportIndirectCommandBuffers() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportIndirectCommandBuffers)); +} + +_MTL_INLINE bool MTL::RenderPipelineState::threadgroupSizeMatchesTileSize() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(threadgroupSizeMatchesTileSize)); +} + +_MTL_INLINE MTL::RenderPipelineColorAttachmentDescriptorArray* MTL::RenderPipelineColorAttachmentDescriptorArray::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLRenderPipelineColorAttachmentDescriptorArray)); +} + +_MTL_INLINE MTL::RenderPipelineColorAttachmentDescriptorArray* MTL::RenderPipelineColorAttachmentDescriptorArray::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::RenderPipelineColorAttachmentDescriptor* MTL::RenderPipelineColorAttachmentDescriptorArray::object(NS::UInteger attachmentIndex) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(objectAtIndexedSubscript_), attachmentIndex); +} + +_MTL_INLINE void MTL::RenderPipelineColorAttachmentDescriptorArray::setObject(const MTL::RenderPipelineColorAttachmentDescriptor* attachment, NS::UInteger attachmentIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObject_atIndexedSubscript_), attachment, attachmentIndex); +} + +_MTL_INLINE MTL::TileRenderPipelineColorAttachmentDescriptor* MTL::TileRenderPipelineColorAttachmentDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLTileRenderPipelineColorAttachmentDescriptor)); +} + +_MTL_INLINE MTL::TileRenderPipelineColorAttachmentDescriptor* MTL::TileRenderPipelineColorAttachmentDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::PixelFormat MTL::TileRenderPipelineColorAttachmentDescriptor::pixelFormat() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(pixelFormat)); +} + +_MTL_INLINE void MTL::TileRenderPipelineColorAttachmentDescriptor::setPixelFormat(MTL::PixelFormat pixelFormat) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setPixelFormat_), pixelFormat); +} + +_MTL_INLINE MTL::TileRenderPipelineColorAttachmentDescriptorArray* MTL::TileRenderPipelineColorAttachmentDescriptorArray::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLTileRenderPipelineColorAttachmentDescriptorArray)); +} + +_MTL_INLINE MTL::TileRenderPipelineColorAttachmentDescriptorArray* MTL::TileRenderPipelineColorAttachmentDescriptorArray::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::TileRenderPipelineColorAttachmentDescriptor* MTL::TileRenderPipelineColorAttachmentDescriptorArray::object(NS::UInteger attachmentIndex) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(objectAtIndexedSubscript_), attachmentIndex); +} + +_MTL_INLINE void MTL::TileRenderPipelineColorAttachmentDescriptorArray::setObject(const MTL::TileRenderPipelineColorAttachmentDescriptor* attachment, NS::UInteger attachmentIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObject_atIndexedSubscript_), attachment, attachmentIndex); +} + +_MTL_INLINE MTL::TileRenderPipelineDescriptor* MTL::TileRenderPipelineDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLTileRenderPipelineDescriptor)); +} + +_MTL_INLINE NS::Array* MTL::TileRenderPipelineDescriptor::binaryArchives() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(binaryArchives)); +} + +_MTL_INLINE MTL::TileRenderPipelineColorAttachmentDescriptorArray* MTL::TileRenderPipelineDescriptor::colorAttachments() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(colorAttachments)); +} + +_MTL_INLINE MTL::TileRenderPipelineDescriptor* MTL::TileRenderPipelineDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::String* MTL::TileRenderPipelineDescriptor::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE MTL::LinkedFunctions* MTL::TileRenderPipelineDescriptor::linkedFunctions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(linkedFunctions)); +} + +_MTL_INLINE NS::UInteger MTL::TileRenderPipelineDescriptor::maxCallStackDepth() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxCallStackDepth)); +} + +_MTL_INLINE NS::UInteger MTL::TileRenderPipelineDescriptor::maxTotalThreadsPerThreadgroup() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxTotalThreadsPerThreadgroup)); +} + +_MTL_INLINE NS::Array* MTL::TileRenderPipelineDescriptor::preloadedLibraries() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(preloadedLibraries)); +} + +_MTL_INLINE NS::UInteger MTL::TileRenderPipelineDescriptor::rasterSampleCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(rasterSampleCount)); +} + +_MTL_INLINE MTL::Size MTL::TileRenderPipelineDescriptor::requiredThreadsPerThreadgroup() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(requiredThreadsPerThreadgroup)); +} + +_MTL_INLINE void MTL::TileRenderPipelineDescriptor::reset() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(reset)); +} + +_MTL_INLINE void MTL::TileRenderPipelineDescriptor::setBinaryArchives(const NS::Array* binaryArchives) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBinaryArchives_), binaryArchives); +} + +_MTL_INLINE void MTL::TileRenderPipelineDescriptor::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +_MTL_INLINE void MTL::TileRenderPipelineDescriptor::setLinkedFunctions(const MTL::LinkedFunctions* linkedFunctions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLinkedFunctions_), linkedFunctions); +} + +_MTL_INLINE void MTL::TileRenderPipelineDescriptor::setMaxCallStackDepth(NS::UInteger maxCallStackDepth) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxCallStackDepth_), maxCallStackDepth); +} + +_MTL_INLINE void MTL::TileRenderPipelineDescriptor::setMaxTotalThreadsPerThreadgroup(NS::UInteger maxTotalThreadsPerThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxTotalThreadsPerThreadgroup_), maxTotalThreadsPerThreadgroup); +} + +_MTL_INLINE void MTL::TileRenderPipelineDescriptor::setPreloadedLibraries(const NS::Array* preloadedLibraries) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setPreloadedLibraries_), preloadedLibraries); +} + +_MTL_INLINE void MTL::TileRenderPipelineDescriptor::setRasterSampleCount(NS::UInteger rasterSampleCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRasterSampleCount_), rasterSampleCount); +} + +_MTL_INLINE void MTL::TileRenderPipelineDescriptor::setRequiredThreadsPerThreadgroup(MTL::Size requiredThreadsPerThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRequiredThreadsPerThreadgroup_), requiredThreadsPerThreadgroup); +} + +_MTL_INLINE void MTL::TileRenderPipelineDescriptor::setShaderValidation(MTL::ShaderValidation shaderValidation) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setShaderValidation_), shaderValidation); +} + +_MTL_INLINE void MTL::TileRenderPipelineDescriptor::setSupportAddingBinaryFunctions(bool supportAddingBinaryFunctions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSupportAddingBinaryFunctions_), supportAddingBinaryFunctions); +} + +_MTL_INLINE void MTL::TileRenderPipelineDescriptor::setThreadgroupSizeMatchesTileSize(bool threadgroupSizeMatchesTileSize) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setThreadgroupSizeMatchesTileSize_), threadgroupSizeMatchesTileSize); +} + +_MTL_INLINE void MTL::TileRenderPipelineDescriptor::setTileFunction(const MTL::Function* tileFunction) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTileFunction_), tileFunction); +} + +_MTL_INLINE MTL::ShaderValidation MTL::TileRenderPipelineDescriptor::shaderValidation() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(shaderValidation)); +} + +_MTL_INLINE bool MTL::TileRenderPipelineDescriptor::supportAddingBinaryFunctions() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportAddingBinaryFunctions)); +} + +_MTL_INLINE bool MTL::TileRenderPipelineDescriptor::threadgroupSizeMatchesTileSize() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(threadgroupSizeMatchesTileSize)); +} + +_MTL_INLINE MTL::PipelineBufferDescriptorArray* MTL::TileRenderPipelineDescriptor::tileBuffers() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(tileBuffers)); +} + +_MTL_INLINE MTL::Function* MTL::TileRenderPipelineDescriptor::tileFunction() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(tileFunction)); +} + +_MTL_INLINE MTL::MeshRenderPipelineDescriptor* MTL::MeshRenderPipelineDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLMeshRenderPipelineDescriptor)); +} + +_MTL_INLINE bool MTL::MeshRenderPipelineDescriptor::alphaToCoverageEnabled() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isAlphaToCoverageEnabled)); +} + +_MTL_INLINE bool MTL::MeshRenderPipelineDescriptor::alphaToOneEnabled() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isAlphaToOneEnabled)); +} + +_MTL_INLINE NS::Array* MTL::MeshRenderPipelineDescriptor::binaryArchives() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(binaryArchives)); +} + +_MTL_INLINE MTL::RenderPipelineColorAttachmentDescriptorArray* MTL::MeshRenderPipelineDescriptor::colorAttachments() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(colorAttachments)); +} + +_MTL_INLINE MTL::PixelFormat MTL::MeshRenderPipelineDescriptor::depthAttachmentPixelFormat() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(depthAttachmentPixelFormat)); +} + +_MTL_INLINE MTL::PipelineBufferDescriptorArray* MTL::MeshRenderPipelineDescriptor::fragmentBuffers() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(fragmentBuffers)); +} + +_MTL_INLINE MTL::Function* MTL::MeshRenderPipelineDescriptor::fragmentFunction() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(fragmentFunction)); +} + +_MTL_INLINE MTL::LinkedFunctions* MTL::MeshRenderPipelineDescriptor::fragmentLinkedFunctions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(fragmentLinkedFunctions)); +} + +_MTL_INLINE MTL::MeshRenderPipelineDescriptor* MTL::MeshRenderPipelineDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE bool MTL::MeshRenderPipelineDescriptor::isAlphaToCoverageEnabled() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isAlphaToCoverageEnabled)); +} + +_MTL_INLINE bool MTL::MeshRenderPipelineDescriptor::isAlphaToOneEnabled() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isAlphaToOneEnabled)); +} + +_MTL_INLINE bool MTL::MeshRenderPipelineDescriptor::isRasterizationEnabled() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isRasterizationEnabled)); +} + +_MTL_INLINE NS::String* MTL::MeshRenderPipelineDescriptor::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE NS::UInteger MTL::MeshRenderPipelineDescriptor::maxTotalThreadgroupsPerMeshGrid() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxTotalThreadgroupsPerMeshGrid)); +} + +_MTL_INLINE NS::UInteger MTL::MeshRenderPipelineDescriptor::maxTotalThreadsPerMeshThreadgroup() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxTotalThreadsPerMeshThreadgroup)); +} + +_MTL_INLINE NS::UInteger MTL::MeshRenderPipelineDescriptor::maxTotalThreadsPerObjectThreadgroup() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxTotalThreadsPerObjectThreadgroup)); +} + +_MTL_INLINE NS::UInteger MTL::MeshRenderPipelineDescriptor::maxVertexAmplificationCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxVertexAmplificationCount)); +} + +_MTL_INLINE MTL::PipelineBufferDescriptorArray* MTL::MeshRenderPipelineDescriptor::meshBuffers() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(meshBuffers)); +} + +_MTL_INLINE MTL::Function* MTL::MeshRenderPipelineDescriptor::meshFunction() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(meshFunction)); +} + +_MTL_INLINE MTL::LinkedFunctions* MTL::MeshRenderPipelineDescriptor::meshLinkedFunctions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(meshLinkedFunctions)); +} + +_MTL_INLINE bool MTL::MeshRenderPipelineDescriptor::meshThreadgroupSizeIsMultipleOfThreadExecutionWidth() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(meshThreadgroupSizeIsMultipleOfThreadExecutionWidth)); +} + +_MTL_INLINE MTL::PipelineBufferDescriptorArray* MTL::MeshRenderPipelineDescriptor::objectBuffers() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(objectBuffers)); +} + +_MTL_INLINE MTL::Function* MTL::MeshRenderPipelineDescriptor::objectFunction() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(objectFunction)); +} + +_MTL_INLINE MTL::LinkedFunctions* MTL::MeshRenderPipelineDescriptor::objectLinkedFunctions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(objectLinkedFunctions)); +} + +_MTL_INLINE bool MTL::MeshRenderPipelineDescriptor::objectThreadgroupSizeIsMultipleOfThreadExecutionWidth() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(objectThreadgroupSizeIsMultipleOfThreadExecutionWidth)); +} + +_MTL_INLINE NS::UInteger MTL::MeshRenderPipelineDescriptor::payloadMemoryLength() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(payloadMemoryLength)); +} + +_MTL_INLINE NS::UInteger MTL::MeshRenderPipelineDescriptor::rasterSampleCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(rasterSampleCount)); +} + +_MTL_INLINE bool MTL::MeshRenderPipelineDescriptor::rasterizationEnabled() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isRasterizationEnabled)); +} + +_MTL_INLINE MTL::Size MTL::MeshRenderPipelineDescriptor::requiredThreadsPerMeshThreadgroup() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(requiredThreadsPerMeshThreadgroup)); +} + +_MTL_INLINE MTL::Size MTL::MeshRenderPipelineDescriptor::requiredThreadsPerObjectThreadgroup() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(requiredThreadsPerObjectThreadgroup)); +} + +_MTL_INLINE void MTL::MeshRenderPipelineDescriptor::reset() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(reset)); +} + +_MTL_INLINE void MTL::MeshRenderPipelineDescriptor::setAlphaToCoverageEnabled(bool alphaToCoverageEnabled) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setAlphaToCoverageEnabled_), alphaToCoverageEnabled); +} + +_MTL_INLINE void MTL::MeshRenderPipelineDescriptor::setAlphaToOneEnabled(bool alphaToOneEnabled) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setAlphaToOneEnabled_), alphaToOneEnabled); +} + +_MTL_INLINE void MTL::MeshRenderPipelineDescriptor::setBinaryArchives(const NS::Array* binaryArchives) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBinaryArchives_), binaryArchives); +} + +_MTL_INLINE void MTL::MeshRenderPipelineDescriptor::setDepthAttachmentPixelFormat(MTL::PixelFormat depthAttachmentPixelFormat) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDepthAttachmentPixelFormat_), depthAttachmentPixelFormat); +} + +_MTL_INLINE void MTL::MeshRenderPipelineDescriptor::setFragmentFunction(const MTL::Function* fragmentFunction) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFragmentFunction_), fragmentFunction); +} + +_MTL_INLINE void MTL::MeshRenderPipelineDescriptor::setFragmentLinkedFunctions(const MTL::LinkedFunctions* fragmentLinkedFunctions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFragmentLinkedFunctions_), fragmentLinkedFunctions); +} + +_MTL_INLINE void MTL::MeshRenderPipelineDescriptor::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +_MTL_INLINE void MTL::MeshRenderPipelineDescriptor::setMaxTotalThreadgroupsPerMeshGrid(NS::UInteger maxTotalThreadgroupsPerMeshGrid) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxTotalThreadgroupsPerMeshGrid_), maxTotalThreadgroupsPerMeshGrid); +} + +_MTL_INLINE void MTL::MeshRenderPipelineDescriptor::setMaxTotalThreadsPerMeshThreadgroup(NS::UInteger maxTotalThreadsPerMeshThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxTotalThreadsPerMeshThreadgroup_), maxTotalThreadsPerMeshThreadgroup); +} + +_MTL_INLINE void MTL::MeshRenderPipelineDescriptor::setMaxTotalThreadsPerObjectThreadgroup(NS::UInteger maxTotalThreadsPerObjectThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxTotalThreadsPerObjectThreadgroup_), maxTotalThreadsPerObjectThreadgroup); +} + +_MTL_INLINE void MTL::MeshRenderPipelineDescriptor::setMaxVertexAmplificationCount(NS::UInteger maxVertexAmplificationCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxVertexAmplificationCount_), maxVertexAmplificationCount); +} + +_MTL_INLINE void MTL::MeshRenderPipelineDescriptor::setMeshFunction(const MTL::Function* meshFunction) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMeshFunction_), meshFunction); +} + +_MTL_INLINE void MTL::MeshRenderPipelineDescriptor::setMeshLinkedFunctions(const MTL::LinkedFunctions* meshLinkedFunctions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMeshLinkedFunctions_), meshLinkedFunctions); +} + +_MTL_INLINE void MTL::MeshRenderPipelineDescriptor::setMeshThreadgroupSizeIsMultipleOfThreadExecutionWidth(bool meshThreadgroupSizeIsMultipleOfThreadExecutionWidth) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMeshThreadgroupSizeIsMultipleOfThreadExecutionWidth_), meshThreadgroupSizeIsMultipleOfThreadExecutionWidth); +} + +_MTL_INLINE void MTL::MeshRenderPipelineDescriptor::setObjectFunction(const MTL::Function* objectFunction) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObjectFunction_), objectFunction); +} + +_MTL_INLINE void MTL::MeshRenderPipelineDescriptor::setObjectLinkedFunctions(const MTL::LinkedFunctions* objectLinkedFunctions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObjectLinkedFunctions_), objectLinkedFunctions); +} + +_MTL_INLINE void MTL::MeshRenderPipelineDescriptor::setObjectThreadgroupSizeIsMultipleOfThreadExecutionWidth(bool objectThreadgroupSizeIsMultipleOfThreadExecutionWidth) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObjectThreadgroupSizeIsMultipleOfThreadExecutionWidth_), objectThreadgroupSizeIsMultipleOfThreadExecutionWidth); +} + +_MTL_INLINE void MTL::MeshRenderPipelineDescriptor::setPayloadMemoryLength(NS::UInteger payloadMemoryLength) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setPayloadMemoryLength_), payloadMemoryLength); +} + +_MTL_INLINE void MTL::MeshRenderPipelineDescriptor::setRasterSampleCount(NS::UInteger rasterSampleCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRasterSampleCount_), rasterSampleCount); +} + +_MTL_INLINE void MTL::MeshRenderPipelineDescriptor::setRasterizationEnabled(bool rasterizationEnabled) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRasterizationEnabled_), rasterizationEnabled); +} + +_MTL_INLINE void MTL::MeshRenderPipelineDescriptor::setRequiredThreadsPerMeshThreadgroup(MTL::Size requiredThreadsPerMeshThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRequiredThreadsPerMeshThreadgroup_), requiredThreadsPerMeshThreadgroup); +} + +_MTL_INLINE void MTL::MeshRenderPipelineDescriptor::setRequiredThreadsPerObjectThreadgroup(MTL::Size requiredThreadsPerObjectThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRequiredThreadsPerObjectThreadgroup_), requiredThreadsPerObjectThreadgroup); +} + +_MTL_INLINE void MTL::MeshRenderPipelineDescriptor::setShaderValidation(MTL::ShaderValidation shaderValidation) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setShaderValidation_), shaderValidation); +} + +_MTL_INLINE void MTL::MeshRenderPipelineDescriptor::setStencilAttachmentPixelFormat(MTL::PixelFormat stencilAttachmentPixelFormat) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStencilAttachmentPixelFormat_), stencilAttachmentPixelFormat); +} + +_MTL_INLINE void MTL::MeshRenderPipelineDescriptor::setSupportIndirectCommandBuffers(bool supportIndirectCommandBuffers) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSupportIndirectCommandBuffers_), supportIndirectCommandBuffers); +} + +_MTL_INLINE MTL::ShaderValidation MTL::MeshRenderPipelineDescriptor::shaderValidation() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(shaderValidation)); +} + +_MTL_INLINE MTL::PixelFormat MTL::MeshRenderPipelineDescriptor::stencilAttachmentPixelFormat() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(stencilAttachmentPixelFormat)); +} + +_MTL_INLINE bool MTL::MeshRenderPipelineDescriptor::supportIndirectCommandBuffers() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportIndirectCommandBuffers)); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLResidencySet.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLResidencySet.hpp new file mode 100644 index 00000000..d073972d --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLResidencySet.hpp @@ -0,0 +1,178 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLResidencySet.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include + +namespace MTL +{ +class Allocation; +class Device; +class ResidencySetDescriptor; + +class ResidencySetDescriptor : public NS::Copying +{ +public: + static ResidencySetDescriptor* alloc(); + + ResidencySetDescriptor* init(); + NS::UInteger initialCapacity() const; + + NS::String* label() const; + + void setInitialCapacity(NS::UInteger initialCapacity); + + void setLabel(const NS::String* label); +}; +class ResidencySet : public NS::Referencing +{ +public: + void addAllocation(const MTL::Allocation* allocation); + void addAllocations(const MTL::Allocation* const allocations[], NS::UInteger count); + + NS::Array* allAllocations() const; + + uint64_t allocatedSize() const; + + NS::UInteger allocationCount() const; + + void commit(); + + bool containsAllocation(const MTL::Allocation* anAllocation); + + Device* device() const; + + void endResidency(); + + NS::String* label() const; + + void removeAllAllocations(); + + void removeAllocation(const MTL::Allocation* allocation); + void removeAllocations(const MTL::Allocation* const allocations[], NS::UInteger count); + + void requestResidency(); +}; + +} +_MTL_INLINE MTL::ResidencySetDescriptor* MTL::ResidencySetDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLResidencySetDescriptor)); +} + +_MTL_INLINE MTL::ResidencySetDescriptor* MTL::ResidencySetDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::UInteger MTL::ResidencySetDescriptor::initialCapacity() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(initialCapacity)); +} + +_MTL_INLINE NS::String* MTL::ResidencySetDescriptor::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE void MTL::ResidencySetDescriptor::setInitialCapacity(NS::UInteger initialCapacity) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInitialCapacity_), initialCapacity); +} + +_MTL_INLINE void MTL::ResidencySetDescriptor::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +_MTL_INLINE void MTL::ResidencySet::addAllocation(const MTL::Allocation* allocation) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(addAllocation_), allocation); +} + +_MTL_INLINE void MTL::ResidencySet::addAllocations(const MTL::Allocation* const allocations[], NS::UInteger count) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(addAllocations_count_), allocations, count); +} + +_MTL_INLINE NS::Array* MTL::ResidencySet::allAllocations() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(allAllocations)); +} + +_MTL_INLINE uint64_t MTL::ResidencySet::allocatedSize() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(allocatedSize)); +} + +_MTL_INLINE NS::UInteger MTL::ResidencySet::allocationCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(allocationCount)); +} + +_MTL_INLINE void MTL::ResidencySet::commit() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(commit)); +} + +_MTL_INLINE bool MTL::ResidencySet::containsAllocation(const MTL::Allocation* anAllocation) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(containsAllocation_), anAllocation); +} + +_MTL_INLINE MTL::Device* MTL::ResidencySet::device() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(device)); +} + +_MTL_INLINE void MTL::ResidencySet::endResidency() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(endResidency)); +} + +_MTL_INLINE NS::String* MTL::ResidencySet::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE void MTL::ResidencySet::removeAllAllocations() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(removeAllAllocations)); +} + +_MTL_INLINE void MTL::ResidencySet::removeAllocation(const MTL::Allocation* allocation) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(removeAllocation_), allocation); +} + +_MTL_INLINE void MTL::ResidencySet::removeAllocations(const MTL::Allocation* const allocations[], NS::UInteger count) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(removeAllocations_count_), allocations, count); +} + +_MTL_INLINE void MTL::ResidencySet::requestResidency() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(requestResidency)); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLResource.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLResource.hpp new file mode 100644 index 00000000..21e49bb9 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLResource.hpp @@ -0,0 +1,190 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLResource.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLAllocation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include + +namespace MTL +{ +class Device; +class Heap; +_MTL_ENUM(NS::UInteger, PurgeableState) { + PurgeableStateKeepCurrent = 1, + PurgeableStateNonVolatile = 2, + PurgeableStateVolatile = 3, + PurgeableStateEmpty = 4, +}; + +_MTL_ENUM(NS::UInteger, CPUCacheMode) { + CPUCacheModeDefaultCache = 0, + CPUCacheModeWriteCombined = 1, +}; + +_MTL_ENUM(NS::UInteger, StorageMode) { + StorageModeShared = 0, + StorageModeManaged = 1, + StorageModePrivate = 2, + StorageModeMemoryless = 3, +}; + +_MTL_ENUM(NS::UInteger, HazardTrackingMode) { + HazardTrackingModeDefault = 0, + HazardTrackingModeUntracked = 1, + HazardTrackingModeTracked = 2, +}; + +_MTL_ENUM(NS::Integer, SparsePageSize) { + SparsePageSize16 = 101, + SparsePageSize64 = 102, + SparsePageSize256 = 103, +}; + +_MTL_ENUM(NS::Integer, BufferSparseTier) { + BufferSparseTierNone = 0, + BufferSparseTier1 = 1, +}; + +_MTL_ENUM(NS::Integer, TextureSparseTier) { + TextureSparseTierNone = 0, + TextureSparseTier1 = 1, + TextureSparseTier2 = 2, +}; + +_MTL_OPTIONS(NS::UInteger, ResourceOptions) { + ResourceCPUCacheModeDefaultCache = 0, + ResourceCPUCacheModeWriteCombined = 1, + ResourceStorageModeShared = 0, + ResourceStorageModeManaged = 1 << 4, + ResourceStorageModePrivate = 1 << 5, + ResourceStorageModeMemoryless = 1 << 5, + ResourceHazardTrackingModeDefault = 0, + ResourceHazardTrackingModeUntracked = 1 << 8, + ResourceHazardTrackingModeTracked = 1 << 9, + ResourceOptionCPUCacheModeDefault = 0, + ResourceOptionCPUCacheModeWriteCombined = 1, +}; + +class Resource : public NS::Referencing +{ +public: + NS::UInteger allocatedSize() const; + + CPUCacheMode cpuCacheMode() const; + + Device* device() const; + + HazardTrackingMode hazardTrackingMode() const; + + Heap* heap() const; + NS::UInteger heapOffset() const; + + bool isAliasable(); + + NS::String* label() const; + + void makeAliasable(); + + ResourceOptions resourceOptions() const; + + void setLabel(const NS::String* label); + + kern_return_t setOwner(task_id_token_t task_id_token); + + PurgeableState setPurgeableState(MTL::PurgeableState state); + + StorageMode storageMode() const; +}; + +} +_MTL_INLINE NS::UInteger MTL::Resource::allocatedSize() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(allocatedSize)); +} + +_MTL_INLINE MTL::CPUCacheMode MTL::Resource::cpuCacheMode() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(cpuCacheMode)); +} + +_MTL_INLINE MTL::Device* MTL::Resource::device() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(device)); +} + +_MTL_INLINE MTL::HazardTrackingMode MTL::Resource::hazardTrackingMode() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(hazardTrackingMode)); +} + +_MTL_INLINE MTL::Heap* MTL::Resource::heap() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(heap)); +} + +_MTL_INLINE NS::UInteger MTL::Resource::heapOffset() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(heapOffset)); +} + +_MTL_INLINE bool MTL::Resource::isAliasable() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isAliasable)); +} + +_MTL_INLINE NS::String* MTL::Resource::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE void MTL::Resource::makeAliasable() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(makeAliasable)); +} + +_MTL_INLINE MTL::ResourceOptions MTL::Resource::resourceOptions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(resourceOptions)); +} + +_MTL_INLINE void MTL::Resource::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +_MTL_INLINE kern_return_t MTL::Resource::setOwner(task_id_token_t task_id_token) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(setOwnerWithIdentity_), task_id_token); +} + +_MTL_INLINE MTL::PurgeableState MTL::Resource::setPurgeableState(MTL::PurgeableState state) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(setPurgeableState_), state); +} + +_MTL_INLINE MTL::StorageMode MTL::Resource::storageMode() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(storageMode)); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLResourceStateCommandEncoder.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLResourceStateCommandEncoder.hpp new file mode 100644 index 00000000..3f565c30 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLResourceStateCommandEncoder.hpp @@ -0,0 +1,98 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLResourceStateCommandEncoder.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLCommandEncoder.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLTypes.hpp" +#include + +namespace MTL +{ +class Buffer; +class Fence; +struct Region; +class Texture; +_MTL_ENUM(NS::UInteger, SparseTextureMappingMode) { + SparseTextureMappingModeMap = 0, + SparseTextureMappingModeUnmap = 1, +}; + +struct MapIndirectArguments +{ + uint32_t regionOriginX; + uint32_t regionOriginY; + uint32_t regionOriginZ; + uint32_t regionSizeWidth; + uint32_t regionSizeHeight; + uint32_t regionSizeDepth; + uint32_t mipMapLevel; + uint32_t sliceId; +} _MTL_PACKED; + +class ResourceStateCommandEncoder : public NS::Referencing +{ +public: + void moveTextureMappingsFromTexture(const MTL::Texture* sourceTexture, NS::UInteger sourceSlice, NS::UInteger sourceLevel, MTL::Origin sourceOrigin, MTL::Size sourceSize, const MTL::Texture* destinationTexture, NS::UInteger destinationSlice, NS::UInteger destinationLevel, MTL::Origin destinationOrigin); + + void updateFence(const MTL::Fence* fence); + + void updateTextureMapping(const MTL::Texture* texture, const MTL::SparseTextureMappingMode mode, const MTL::Region region, const NS::UInteger mipLevel, const NS::UInteger slice); + void updateTextureMapping(const MTL::Texture* texture, const MTL::SparseTextureMappingMode mode, const MTL::Buffer* indirectBuffer, NS::UInteger indirectBufferOffset); + void updateTextureMappings(const MTL::Texture* texture, const MTL::SparseTextureMappingMode mode, const MTL::Region* regions, const NS::UInteger* mipLevels, const NS::UInteger* slices, NS::UInteger numRegions); + + void waitForFence(const MTL::Fence* fence); +}; + +} + +_MTL_INLINE void MTL::ResourceStateCommandEncoder::moveTextureMappingsFromTexture(const MTL::Texture* sourceTexture, NS::UInteger sourceSlice, NS::UInteger sourceLevel, MTL::Origin sourceOrigin, MTL::Size sourceSize, const MTL::Texture* destinationTexture, NS::UInteger destinationSlice, NS::UInteger destinationLevel, MTL::Origin destinationOrigin) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(moveTextureMappingsFromTexture_sourceSlice_sourceLevel_sourceOrigin_sourceSize_toTexture_destinationSlice_destinationLevel_destinationOrigin_), sourceTexture, sourceSlice, sourceLevel, sourceOrigin, sourceSize, destinationTexture, destinationSlice, destinationLevel, destinationOrigin); +} + +_MTL_INLINE void MTL::ResourceStateCommandEncoder::updateFence(const MTL::Fence* fence) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(updateFence_), fence); +} + +_MTL_INLINE void MTL::ResourceStateCommandEncoder::updateTextureMapping(const MTL::Texture* texture, const MTL::SparseTextureMappingMode mode, const MTL::Region region, const NS::UInteger mipLevel, const NS::UInteger slice) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(updateTextureMapping_mode_region_mipLevel_slice_), texture, mode, region, mipLevel, slice); +} + +_MTL_INLINE void MTL::ResourceStateCommandEncoder::updateTextureMapping(const MTL::Texture* texture, const MTL::SparseTextureMappingMode mode, const MTL::Buffer* indirectBuffer, NS::UInteger indirectBufferOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(updateTextureMapping_mode_indirectBuffer_indirectBufferOffset_), texture, mode, indirectBuffer, indirectBufferOffset); +} + +_MTL_INLINE void MTL::ResourceStateCommandEncoder::updateTextureMappings(const MTL::Texture* texture, const MTL::SparseTextureMappingMode mode, const MTL::Region* regions, const NS::UInteger* mipLevels, const NS::UInteger* slices, NS::UInteger numRegions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(updateTextureMappings_mode_regions_mipLevels_slices_numRegions_), texture, mode, regions, mipLevels, slices, numRegions); +} + +_MTL_INLINE void MTL::ResourceStateCommandEncoder::waitForFence(const MTL::Fence* fence) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(waitForFence_), fence); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLResourceStatePass.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLResourceStatePass.hpp new file mode 100644 index 00000000..f3689012 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLResourceStatePass.hpp @@ -0,0 +1,154 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLResourceStatePass.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL +{ +class CounterSampleBuffer; +class ResourceStatePassDescriptor; +class ResourceStatePassSampleBufferAttachmentDescriptor; +class ResourceStatePassSampleBufferAttachmentDescriptorArray; + +class ResourceStatePassSampleBufferAttachmentDescriptor : public NS::Copying +{ +public: + static ResourceStatePassSampleBufferAttachmentDescriptor* alloc(); + + NS::UInteger endOfEncoderSampleIndex() const; + + ResourceStatePassSampleBufferAttachmentDescriptor* init(); + + CounterSampleBuffer* sampleBuffer() const; + + void setEndOfEncoderSampleIndex(NS::UInteger endOfEncoderSampleIndex); + + void setSampleBuffer(const MTL::CounterSampleBuffer* sampleBuffer); + + void setStartOfEncoderSampleIndex(NS::UInteger startOfEncoderSampleIndex); + NS::UInteger startOfEncoderSampleIndex() const; +}; +class ResourceStatePassSampleBufferAttachmentDescriptorArray : public NS::Referencing +{ +public: + static ResourceStatePassSampleBufferAttachmentDescriptorArray* alloc(); + + ResourceStatePassSampleBufferAttachmentDescriptorArray* init(); + + ResourceStatePassSampleBufferAttachmentDescriptor* object(NS::UInteger attachmentIndex); + void setObject(const MTL::ResourceStatePassSampleBufferAttachmentDescriptor* attachment, NS::UInteger attachmentIndex); +}; +class ResourceStatePassDescriptor : public NS::Copying +{ +public: + static ResourceStatePassDescriptor* alloc(); + + ResourceStatePassDescriptor* init(); + + static ResourceStatePassDescriptor* resourceStatePassDescriptor(); + + ResourceStatePassSampleBufferAttachmentDescriptorArray* sampleBufferAttachments() const; +}; + +} +_MTL_INLINE MTL::ResourceStatePassSampleBufferAttachmentDescriptor* MTL::ResourceStatePassSampleBufferAttachmentDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLResourceStatePassSampleBufferAttachmentDescriptor)); +} + +_MTL_INLINE NS::UInteger MTL::ResourceStatePassSampleBufferAttachmentDescriptor::endOfEncoderSampleIndex() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(endOfEncoderSampleIndex)); +} + +_MTL_INLINE MTL::ResourceStatePassSampleBufferAttachmentDescriptor* MTL::ResourceStatePassSampleBufferAttachmentDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::CounterSampleBuffer* MTL::ResourceStatePassSampleBufferAttachmentDescriptor::sampleBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(sampleBuffer)); +} + +_MTL_INLINE void MTL::ResourceStatePassSampleBufferAttachmentDescriptor::setEndOfEncoderSampleIndex(NS::UInteger endOfEncoderSampleIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setEndOfEncoderSampleIndex_), endOfEncoderSampleIndex); +} + +_MTL_INLINE void MTL::ResourceStatePassSampleBufferAttachmentDescriptor::setSampleBuffer(const MTL::CounterSampleBuffer* sampleBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSampleBuffer_), sampleBuffer); +} + +_MTL_INLINE void MTL::ResourceStatePassSampleBufferAttachmentDescriptor::setStartOfEncoderSampleIndex(NS::UInteger startOfEncoderSampleIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStartOfEncoderSampleIndex_), startOfEncoderSampleIndex); +} + +_MTL_INLINE NS::UInteger MTL::ResourceStatePassSampleBufferAttachmentDescriptor::startOfEncoderSampleIndex() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(startOfEncoderSampleIndex)); +} + +_MTL_INLINE MTL::ResourceStatePassSampleBufferAttachmentDescriptorArray* MTL::ResourceStatePassSampleBufferAttachmentDescriptorArray::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLResourceStatePassSampleBufferAttachmentDescriptorArray)); +} + +_MTL_INLINE MTL::ResourceStatePassSampleBufferAttachmentDescriptorArray* MTL::ResourceStatePassSampleBufferAttachmentDescriptorArray::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::ResourceStatePassSampleBufferAttachmentDescriptor* MTL::ResourceStatePassSampleBufferAttachmentDescriptorArray::object(NS::UInteger attachmentIndex) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(objectAtIndexedSubscript_), attachmentIndex); +} + +_MTL_INLINE void MTL::ResourceStatePassSampleBufferAttachmentDescriptorArray::setObject(const MTL::ResourceStatePassSampleBufferAttachmentDescriptor* attachment, NS::UInteger attachmentIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObject_atIndexedSubscript_), attachment, attachmentIndex); +} + +_MTL_INLINE MTL::ResourceStatePassDescriptor* MTL::ResourceStatePassDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLResourceStatePassDescriptor)); +} + +_MTL_INLINE MTL::ResourceStatePassDescriptor* MTL::ResourceStatePassDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::ResourceStatePassDescriptor* MTL::ResourceStatePassDescriptor::resourceStatePassDescriptor() +{ + return Object::sendMessage(_MTL_PRIVATE_CLS(MTLResourceStatePassDescriptor), _MTL_PRIVATE_SEL(resourceStatePassDescriptor)); +} + +_MTL_INLINE MTL::ResourceStatePassSampleBufferAttachmentDescriptorArray* MTL::ResourceStatePassDescriptor::sampleBufferAttachments() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(sampleBufferAttachments)); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLResourceViewPool.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLResourceViewPool.hpp new file mode 100644 index 00000000..aa8bfda3 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLResourceViewPool.hpp @@ -0,0 +1,118 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLResourceViewPool.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLTypes.hpp" + +namespace MTL +{ +class Device; +class ResourceViewPool; +class ResourceViewPoolDescriptor; + +class ResourceViewPoolDescriptor : public NS::Copying +{ +public: + static ResourceViewPoolDescriptor* alloc(); + + ResourceViewPoolDescriptor* init(); + + NS::String* label() const; + + NS::UInteger resourceViewCount() const; + + void setLabel(const NS::String* label); + + void setResourceViewCount(NS::UInteger resourceViewCount); +}; +class ResourceViewPool : public NS::Referencing +{ +public: + ResourceID baseResourceID() const; + + ResourceID copyResourceViewsFromPool(const MTL::ResourceViewPool* sourcePool, NS::Range sourceRange, NS::UInteger destinationIndex); + + Device* device() const; + + NS::String* label() const; + + NS::UInteger resourceViewCount() const; +}; + +} +_MTL_INLINE MTL::ResourceViewPoolDescriptor* MTL::ResourceViewPoolDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLResourceViewPoolDescriptor)); +} + +_MTL_INLINE MTL::ResourceViewPoolDescriptor* MTL::ResourceViewPoolDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::String* MTL::ResourceViewPoolDescriptor::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE NS::UInteger MTL::ResourceViewPoolDescriptor::resourceViewCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(resourceViewCount)); +} + +_MTL_INLINE void MTL::ResourceViewPoolDescriptor::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +_MTL_INLINE void MTL::ResourceViewPoolDescriptor::setResourceViewCount(NS::UInteger resourceViewCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setResourceViewCount_), resourceViewCount); +} + +_MTL_INLINE MTL::ResourceID MTL::ResourceViewPool::baseResourceID() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(baseResourceID)); +} + +_MTL_INLINE MTL::ResourceID MTL::ResourceViewPool::copyResourceViewsFromPool(const MTL::ResourceViewPool* sourcePool, NS::Range sourceRange, NS::UInteger destinationIndex) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(copyResourceViewsFromPool_sourceRange_destinationIndex_), sourcePool, sourceRange, destinationIndex); +} + +_MTL_INLINE MTL::Device* MTL::ResourceViewPool::device() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(device)); +} + +_MTL_INLINE NS::String* MTL::ResourceViewPool::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE NS::UInteger MTL::ResourceViewPool::resourceViewCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(resourceViewCount)); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLSampler.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLSampler.hpp new file mode 100644 index 00000000..f2286656 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLSampler.hpp @@ -0,0 +1,345 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLSampler.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLDepthStencil.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLTypes.hpp" + +namespace MTL +{ +class Device; +class SamplerDescriptor; +_MTL_ENUM(NS::UInteger, SamplerMinMagFilter) { + SamplerMinMagFilterNearest = 0, + SamplerMinMagFilterLinear = 1, +}; + +_MTL_ENUM(NS::UInteger, SamplerMipFilter) { + SamplerMipFilterNotMipmapped = 0, + SamplerMipFilterNearest = 1, + SamplerMipFilterLinear = 2, +}; + +_MTL_ENUM(NS::UInteger, SamplerAddressMode) { + SamplerAddressModeClampToEdge = 0, + SamplerAddressModeMirrorClampToEdge = 1, + SamplerAddressModeRepeat = 2, + SamplerAddressModeMirrorRepeat = 3, + SamplerAddressModeClampToZero = 4, + SamplerAddressModeClampToBorderColor = 5, +}; + +_MTL_ENUM(NS::UInteger, SamplerBorderColor) { + SamplerBorderColorTransparentBlack = 0, + SamplerBorderColorOpaqueBlack = 1, + SamplerBorderColorOpaqueWhite = 2, +}; + +_MTL_ENUM(NS::UInteger, SamplerReductionMode) { + SamplerReductionModeWeightedAverage = 0, + SamplerReductionModeMinimum = 1, + SamplerReductionModeMaximum = 2, +}; + +class SamplerDescriptor : public NS::Copying +{ +public: + static SamplerDescriptor* alloc(); + + SamplerBorderColor borderColor() const; + + CompareFunction compareFunction() const; + + SamplerDescriptor* init(); + + NS::String* label() const; + + bool lodAverage() const; + + float lodBias() const; + + float lodMaxClamp() const; + + float lodMinClamp() const; + + SamplerMinMagFilter magFilter() const; + + NS::UInteger maxAnisotropy() const; + + SamplerMinMagFilter minFilter() const; + + SamplerMipFilter mipFilter() const; + + bool normalizedCoordinates() const; + + SamplerAddressMode rAddressMode() const; + + SamplerReductionMode reductionMode() const; + + SamplerAddressMode sAddressMode() const; + + void setBorderColor(MTL::SamplerBorderColor borderColor); + + void setCompareFunction(MTL::CompareFunction compareFunction); + + void setLabel(const NS::String* label); + + void setLodAverage(bool lodAverage); + + void setLodBias(float lodBias); + + void setLodMaxClamp(float lodMaxClamp); + + void setLodMinClamp(float lodMinClamp); + + void setMagFilter(MTL::SamplerMinMagFilter magFilter); + + void setMaxAnisotropy(NS::UInteger maxAnisotropy); + + void setMinFilter(MTL::SamplerMinMagFilter minFilter); + + void setMipFilter(MTL::SamplerMipFilter mipFilter); + + void setNormalizedCoordinates(bool normalizedCoordinates); + + void setRAddressMode(MTL::SamplerAddressMode rAddressMode); + + void setReductionMode(MTL::SamplerReductionMode reductionMode); + + void setSAddressMode(MTL::SamplerAddressMode sAddressMode); + + void setSupportArgumentBuffers(bool supportArgumentBuffers); + + void setTAddressMode(MTL::SamplerAddressMode tAddressMode); + + bool supportArgumentBuffers() const; + + SamplerAddressMode tAddressMode() const; +}; +class SamplerState : public NS::Referencing +{ +public: + Device* device() const; + + ResourceID gpuResourceID() const; + + NS::String* label() const; +}; + +} +_MTL_INLINE MTL::SamplerDescriptor* MTL::SamplerDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLSamplerDescriptor)); +} + +_MTL_INLINE MTL::SamplerBorderColor MTL::SamplerDescriptor::borderColor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(borderColor)); +} + +_MTL_INLINE MTL::CompareFunction MTL::SamplerDescriptor::compareFunction() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(compareFunction)); +} + +_MTL_INLINE MTL::SamplerDescriptor* MTL::SamplerDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::String* MTL::SamplerDescriptor::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE bool MTL::SamplerDescriptor::lodAverage() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(lodAverage)); +} + +_MTL_INLINE float MTL::SamplerDescriptor::lodBias() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(lodBias)); +} + +_MTL_INLINE float MTL::SamplerDescriptor::lodMaxClamp() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(lodMaxClamp)); +} + +_MTL_INLINE float MTL::SamplerDescriptor::lodMinClamp() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(lodMinClamp)); +} + +_MTL_INLINE MTL::SamplerMinMagFilter MTL::SamplerDescriptor::magFilter() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(magFilter)); +} + +_MTL_INLINE NS::UInteger MTL::SamplerDescriptor::maxAnisotropy() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxAnisotropy)); +} + +_MTL_INLINE MTL::SamplerMinMagFilter MTL::SamplerDescriptor::minFilter() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(minFilter)); +} + +_MTL_INLINE MTL::SamplerMipFilter MTL::SamplerDescriptor::mipFilter() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(mipFilter)); +} + +_MTL_INLINE bool MTL::SamplerDescriptor::normalizedCoordinates() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(normalizedCoordinates)); +} + +_MTL_INLINE MTL::SamplerAddressMode MTL::SamplerDescriptor::rAddressMode() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(rAddressMode)); +} + +_MTL_INLINE MTL::SamplerReductionMode MTL::SamplerDescriptor::reductionMode() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(reductionMode)); +} + +_MTL_INLINE MTL::SamplerAddressMode MTL::SamplerDescriptor::sAddressMode() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(sAddressMode)); +} + +_MTL_INLINE void MTL::SamplerDescriptor::setBorderColor(MTL::SamplerBorderColor borderColor) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBorderColor_), borderColor); +} + +_MTL_INLINE void MTL::SamplerDescriptor::setCompareFunction(MTL::CompareFunction compareFunction) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setCompareFunction_), compareFunction); +} + +_MTL_INLINE void MTL::SamplerDescriptor::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +_MTL_INLINE void MTL::SamplerDescriptor::setLodAverage(bool lodAverage) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLodAverage_), lodAverage); +} + +_MTL_INLINE void MTL::SamplerDescriptor::setLodBias(float lodBias) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLodBias_), lodBias); +} + +_MTL_INLINE void MTL::SamplerDescriptor::setLodMaxClamp(float lodMaxClamp) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLodMaxClamp_), lodMaxClamp); +} + +_MTL_INLINE void MTL::SamplerDescriptor::setLodMinClamp(float lodMinClamp) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLodMinClamp_), lodMinClamp); +} + +_MTL_INLINE void MTL::SamplerDescriptor::setMagFilter(MTL::SamplerMinMagFilter magFilter) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMagFilter_), magFilter); +} + +_MTL_INLINE void MTL::SamplerDescriptor::setMaxAnisotropy(NS::UInteger maxAnisotropy) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxAnisotropy_), maxAnisotropy); +} + +_MTL_INLINE void MTL::SamplerDescriptor::setMinFilter(MTL::SamplerMinMagFilter minFilter) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMinFilter_), minFilter); +} + +_MTL_INLINE void MTL::SamplerDescriptor::setMipFilter(MTL::SamplerMipFilter mipFilter) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMipFilter_), mipFilter); +} + +_MTL_INLINE void MTL::SamplerDescriptor::setNormalizedCoordinates(bool normalizedCoordinates) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setNormalizedCoordinates_), normalizedCoordinates); +} + +_MTL_INLINE void MTL::SamplerDescriptor::setRAddressMode(MTL::SamplerAddressMode rAddressMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRAddressMode_), rAddressMode); +} + +_MTL_INLINE void MTL::SamplerDescriptor::setReductionMode(MTL::SamplerReductionMode reductionMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setReductionMode_), reductionMode); +} + +_MTL_INLINE void MTL::SamplerDescriptor::setSAddressMode(MTL::SamplerAddressMode sAddressMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSAddressMode_), sAddressMode); +} + +_MTL_INLINE void MTL::SamplerDescriptor::setSupportArgumentBuffers(bool supportArgumentBuffers) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSupportArgumentBuffers_), supportArgumentBuffers); +} + +_MTL_INLINE void MTL::SamplerDescriptor::setTAddressMode(MTL::SamplerAddressMode tAddressMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTAddressMode_), tAddressMode); +} + +_MTL_INLINE bool MTL::SamplerDescriptor::supportArgumentBuffers() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportArgumentBuffers)); +} + +_MTL_INLINE MTL::SamplerAddressMode MTL::SamplerDescriptor::tAddressMode() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(tAddressMode)); +} + +_MTL_INLINE MTL::Device* MTL::SamplerState::device() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(device)); +} + +_MTL_INLINE MTL::ResourceID MTL::SamplerState::gpuResourceID() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(gpuResourceID)); +} + +_MTL_INLINE NS::String* MTL::SamplerState::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLStageInputOutputDescriptor.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLStageInputOutputDescriptor.hpp new file mode 100644 index 00000000..b9a7a483 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLStageInputOutputDescriptor.hpp @@ -0,0 +1,356 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLStageInputOutputDescriptor.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLArgument.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL +{ +class AttributeDescriptor; +class AttributeDescriptorArray; +class BufferLayoutDescriptor; +class BufferLayoutDescriptorArray; +class StageInputOutputDescriptor; +_MTL_ENUM(NS::UInteger, AttributeFormat) { + AttributeFormatInvalid = 0, + AttributeFormatUChar2 = 1, + AttributeFormatUChar3 = 2, + AttributeFormatUChar4 = 3, + AttributeFormatChar2 = 4, + AttributeFormatChar3 = 5, + AttributeFormatChar4 = 6, + AttributeFormatUChar2Normalized = 7, + AttributeFormatUChar3Normalized = 8, + AttributeFormatUChar4Normalized = 9, + AttributeFormatChar2Normalized = 10, + AttributeFormatChar3Normalized = 11, + AttributeFormatChar4Normalized = 12, + AttributeFormatUShort2 = 13, + AttributeFormatUShort3 = 14, + AttributeFormatUShort4 = 15, + AttributeFormatShort2 = 16, + AttributeFormatShort3 = 17, + AttributeFormatShort4 = 18, + AttributeFormatUShort2Normalized = 19, + AttributeFormatUShort3Normalized = 20, + AttributeFormatUShort4Normalized = 21, + AttributeFormatShort2Normalized = 22, + AttributeFormatShort3Normalized = 23, + AttributeFormatShort4Normalized = 24, + AttributeFormatHalf2 = 25, + AttributeFormatHalf3 = 26, + AttributeFormatHalf4 = 27, + AttributeFormatFloat = 28, + AttributeFormatFloat2 = 29, + AttributeFormatFloat3 = 30, + AttributeFormatFloat4 = 31, + AttributeFormatInt = 32, + AttributeFormatInt2 = 33, + AttributeFormatInt3 = 34, + AttributeFormatInt4 = 35, + AttributeFormatUInt = 36, + AttributeFormatUInt2 = 37, + AttributeFormatUInt3 = 38, + AttributeFormatUInt4 = 39, + AttributeFormatInt1010102Normalized = 40, + AttributeFormatUInt1010102Normalized = 41, + AttributeFormatUChar4Normalized_BGRA = 42, + AttributeFormatUChar = 45, + AttributeFormatChar = 46, + AttributeFormatUCharNormalized = 47, + AttributeFormatCharNormalized = 48, + AttributeFormatUShort = 49, + AttributeFormatShort = 50, + AttributeFormatUShortNormalized = 51, + AttributeFormatShortNormalized = 52, + AttributeFormatHalf = 53, + AttributeFormatFloatRG11B10 = 54, + AttributeFormatFloatRGB9E5 = 55, +}; + +_MTL_ENUM(NS::UInteger, StepFunction) { + StepFunctionConstant = 0, + StepFunctionPerVertex = 1, + StepFunctionPerInstance = 2, + StepFunctionPerPatch = 3, + StepFunctionPerPatchControlPoint = 4, + StepFunctionThreadPositionInGridX = 5, + StepFunctionThreadPositionInGridY = 6, + StepFunctionThreadPositionInGridXIndexed = 7, + StepFunctionThreadPositionInGridYIndexed = 8, +}; + +class BufferLayoutDescriptor : public NS::Copying +{ +public: + static BufferLayoutDescriptor* alloc(); + + BufferLayoutDescriptor* init(); + + void setStepFunction(MTL::StepFunction stepFunction); + + void setStepRate(NS::UInteger stepRate); + + void setStride(NS::UInteger stride); + + StepFunction stepFunction() const; + + NS::UInteger stepRate() const; + + NS::UInteger stride() const; +}; +class BufferLayoutDescriptorArray : public NS::Referencing +{ +public: + static BufferLayoutDescriptorArray* alloc(); + + BufferLayoutDescriptorArray* init(); + + BufferLayoutDescriptor* object(NS::UInteger index); + void setObject(const MTL::BufferLayoutDescriptor* bufferDesc, NS::UInteger index); +}; +class AttributeDescriptor : public NS::Copying +{ +public: + static AttributeDescriptor* alloc(); + + NS::UInteger bufferIndex() const; + + AttributeFormat format() const; + + AttributeDescriptor* init(); + + NS::UInteger offset() const; + + void setBufferIndex(NS::UInteger bufferIndex); + + void setFormat(MTL::AttributeFormat format); + + void setOffset(NS::UInteger offset); +}; +class AttributeDescriptorArray : public NS::Referencing +{ +public: + static AttributeDescriptorArray* alloc(); + + AttributeDescriptorArray* init(); + + AttributeDescriptor* object(NS::UInteger index); + void setObject(const MTL::AttributeDescriptor* attributeDesc, NS::UInteger index); +}; +class StageInputOutputDescriptor : public NS::Copying +{ +public: + static StageInputOutputDescriptor* alloc(); + + AttributeDescriptorArray* attributes() const; + + NS::UInteger indexBufferIndex() const; + + IndexType indexType() const; + + StageInputOutputDescriptor* init(); + + BufferLayoutDescriptorArray* layouts() const; + + void reset(); + + void setIndexBufferIndex(NS::UInteger indexBufferIndex); + + void setIndexType(MTL::IndexType indexType); + + static StageInputOutputDescriptor* stageInputOutputDescriptor(); +}; + +} +_MTL_INLINE MTL::BufferLayoutDescriptor* MTL::BufferLayoutDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLBufferLayoutDescriptor)); +} + +_MTL_INLINE MTL::BufferLayoutDescriptor* MTL::BufferLayoutDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE void MTL::BufferLayoutDescriptor::setStepFunction(MTL::StepFunction stepFunction) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStepFunction_), stepFunction); +} + +_MTL_INLINE void MTL::BufferLayoutDescriptor::setStepRate(NS::UInteger stepRate) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStepRate_), stepRate); +} + +_MTL_INLINE void MTL::BufferLayoutDescriptor::setStride(NS::UInteger stride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStride_), stride); +} + +_MTL_INLINE MTL::StepFunction MTL::BufferLayoutDescriptor::stepFunction() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(stepFunction)); +} + +_MTL_INLINE NS::UInteger MTL::BufferLayoutDescriptor::stepRate() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(stepRate)); +} + +_MTL_INLINE NS::UInteger MTL::BufferLayoutDescriptor::stride() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(stride)); +} + +_MTL_INLINE MTL::BufferLayoutDescriptorArray* MTL::BufferLayoutDescriptorArray::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLBufferLayoutDescriptorArray)); +} + +_MTL_INLINE MTL::BufferLayoutDescriptorArray* MTL::BufferLayoutDescriptorArray::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::BufferLayoutDescriptor* MTL::BufferLayoutDescriptorArray::object(NS::UInteger index) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(objectAtIndexedSubscript_), index); +} + +_MTL_INLINE void MTL::BufferLayoutDescriptorArray::setObject(const MTL::BufferLayoutDescriptor* bufferDesc, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObject_atIndexedSubscript_), bufferDesc, index); +} + +_MTL_INLINE MTL::AttributeDescriptor* MTL::AttributeDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLAttributeDescriptor)); +} + +_MTL_INLINE NS::UInteger MTL::AttributeDescriptor::bufferIndex() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(bufferIndex)); +} + +_MTL_INLINE MTL::AttributeFormat MTL::AttributeDescriptor::format() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(format)); +} + +_MTL_INLINE MTL::AttributeDescriptor* MTL::AttributeDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::UInteger MTL::AttributeDescriptor::offset() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(offset)); +} + +_MTL_INLINE void MTL::AttributeDescriptor::setBufferIndex(NS::UInteger bufferIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBufferIndex_), bufferIndex); +} + +_MTL_INLINE void MTL::AttributeDescriptor::setFormat(MTL::AttributeFormat format) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFormat_), format); +} + +_MTL_INLINE void MTL::AttributeDescriptor::setOffset(NS::UInteger offset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setOffset_), offset); +} + +_MTL_INLINE MTL::AttributeDescriptorArray* MTL::AttributeDescriptorArray::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLAttributeDescriptorArray)); +} + +_MTL_INLINE MTL::AttributeDescriptorArray* MTL::AttributeDescriptorArray::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::AttributeDescriptor* MTL::AttributeDescriptorArray::object(NS::UInteger index) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(objectAtIndexedSubscript_), index); +} + +_MTL_INLINE void MTL::AttributeDescriptorArray::setObject(const MTL::AttributeDescriptor* attributeDesc, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObject_atIndexedSubscript_), attributeDesc, index); +} + +_MTL_INLINE MTL::StageInputOutputDescriptor* MTL::StageInputOutputDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLStageInputOutputDescriptor)); +} + +_MTL_INLINE MTL::AttributeDescriptorArray* MTL::StageInputOutputDescriptor::attributes() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(attributes)); +} + +_MTL_INLINE NS::UInteger MTL::StageInputOutputDescriptor::indexBufferIndex() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(indexBufferIndex)); +} + +_MTL_INLINE MTL::IndexType MTL::StageInputOutputDescriptor::indexType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(indexType)); +} + +_MTL_INLINE MTL::StageInputOutputDescriptor* MTL::StageInputOutputDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::BufferLayoutDescriptorArray* MTL::StageInputOutputDescriptor::layouts() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(layouts)); +} + +_MTL_INLINE void MTL::StageInputOutputDescriptor::reset() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(reset)); +} + +_MTL_INLINE void MTL::StageInputOutputDescriptor::setIndexBufferIndex(NS::UInteger indexBufferIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setIndexBufferIndex_), indexBufferIndex); +} + +_MTL_INLINE void MTL::StageInputOutputDescriptor::setIndexType(MTL::IndexType indexType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setIndexType_), indexType); +} + +_MTL_INLINE MTL::StageInputOutputDescriptor* MTL::StageInputOutputDescriptor::stageInputOutputDescriptor() +{ + return Object::sendMessage(_MTL_PRIVATE_CLS(MTLStageInputOutputDescriptor), _MTL_PRIVATE_SEL(stageInputOutputDescriptor)); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLTensor.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLTensor.hpp new file mode 100644 index 00000000..5f8b04eb --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLTensor.hpp @@ -0,0 +1,297 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLTensor.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLResource.hpp" +#include "MTLTypes.hpp" + +namespace MTL +{ +class Buffer; +class TensorDescriptor; +class TensorExtents; + +_MTL_CONST(NS::ErrorDomain, TensorDomain); + +_MTL_ENUM(NS::Integer, TensorDataType) { + TensorDataTypeNone = 0, + TensorDataTypeFloat32 = 3, + TensorDataTypeFloat16 = 16, + TensorDataTypeBFloat16 = 121, + TensorDataTypeInt8 = 45, + TensorDataTypeUInt8 = 49, + TensorDataTypeInt16 = 37, + TensorDataTypeUInt16 = 41, + TensorDataTypeInt32 = 29, + TensorDataTypeUInt32 = 33, +}; + +_MTL_ENUM(NS::Integer, TensorError) { + TensorErrorNone = 0, + TensorErrorInternalError = 1, + TensorErrorInvalidDescriptor = 2, +}; + +_MTL_OPTIONS(NS::UInteger, TensorUsage) { + TensorUsageCompute = 1, + TensorUsageRender = 1 << 1, + TensorUsageMachineLearning = 1 << 2, +}; + +class TensorExtents : public NS::Referencing +{ +public: + static TensorExtents* alloc(); + + NS::Integer extentAtDimensionIndex(NS::UInteger dimensionIndex); + + TensorExtents* init(); + TensorExtents* init(NS::UInteger rank, const NS::Integer* values); + + NS::UInteger rank() const; +}; +class TensorDescriptor : public NS::Copying +{ +public: + static TensorDescriptor* alloc(); + + CPUCacheMode cpuCacheMode() const; + + TensorDataType dataType() const; + + TensorExtents* dimensions() const; + + HazardTrackingMode hazardTrackingMode() const; + + TensorDescriptor* init(); + + ResourceOptions resourceOptions() const; + + void setCpuCacheMode(MTL::CPUCacheMode cpuCacheMode); + + void setDataType(MTL::TensorDataType dataType); + + void setDimensions(const MTL::TensorExtents* dimensions); + + void setHazardTrackingMode(MTL::HazardTrackingMode hazardTrackingMode); + + void setResourceOptions(MTL::ResourceOptions resourceOptions); + + void setStorageMode(MTL::StorageMode storageMode); + + void setStrides(const MTL::TensorExtents* strides); + + void setUsage(MTL::TensorUsage usage); + + StorageMode storageMode() const; + + TensorExtents* strides() const; + + TensorUsage usage() const; +}; +class Tensor : public NS::Referencing +{ +public: + Buffer* buffer() const; + NS::UInteger bufferOffset() const; + + TensorDataType dataType() const; + + TensorExtents* dimensions() const; + + void getBytes(void* bytes, const MTL::TensorExtents* strides, const MTL::TensorExtents* sliceOrigin, const MTL::TensorExtents* sliceDimensions); + + ResourceID gpuResourceID() const; + + void replaceSliceOrigin(const MTL::TensorExtents* sliceOrigin, const MTL::TensorExtents* sliceDimensions, const void* bytes, const MTL::TensorExtents* strides); + + TensorExtents* strides() const; + + TensorUsage usage() const; +}; + +} + +_MTL_PRIVATE_DEF_WEAK_CONST(NS::ErrorDomain, TensorDomain); + +_MTL_INLINE MTL::TensorExtents* MTL::TensorExtents::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLTensorExtents)); +} + +_MTL_INLINE NS::Integer MTL::TensorExtents::extentAtDimensionIndex(NS::UInteger dimensionIndex) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(extentAtDimensionIndex_), dimensionIndex); +} + +_MTL_INLINE MTL::TensorExtents* MTL::TensorExtents::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::TensorExtents* MTL::TensorExtents::init(NS::UInteger rank, const NS::Integer* values) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(initWithRank_values_), rank, values); +} + +_MTL_INLINE NS::UInteger MTL::TensorExtents::rank() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(rank)); +} + +_MTL_INLINE MTL::TensorDescriptor* MTL::TensorDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLTensorDescriptor)); +} + +_MTL_INLINE MTL::CPUCacheMode MTL::TensorDescriptor::cpuCacheMode() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(cpuCacheMode)); +} + +_MTL_INLINE MTL::TensorDataType MTL::TensorDescriptor::dataType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(dataType)); +} + +_MTL_INLINE MTL::TensorExtents* MTL::TensorDescriptor::dimensions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(dimensions)); +} + +_MTL_INLINE MTL::HazardTrackingMode MTL::TensorDescriptor::hazardTrackingMode() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(hazardTrackingMode)); +} + +_MTL_INLINE MTL::TensorDescriptor* MTL::TensorDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::ResourceOptions MTL::TensorDescriptor::resourceOptions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(resourceOptions)); +} + +_MTL_INLINE void MTL::TensorDescriptor::setCpuCacheMode(MTL::CPUCacheMode cpuCacheMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setCpuCacheMode_), cpuCacheMode); +} + +_MTL_INLINE void MTL::TensorDescriptor::setDataType(MTL::TensorDataType dataType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDataType_), dataType); +} + +_MTL_INLINE void MTL::TensorDescriptor::setDimensions(const MTL::TensorExtents* dimensions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDimensions_), dimensions); +} + +_MTL_INLINE void MTL::TensorDescriptor::setHazardTrackingMode(MTL::HazardTrackingMode hazardTrackingMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setHazardTrackingMode_), hazardTrackingMode); +} + +_MTL_INLINE void MTL::TensorDescriptor::setResourceOptions(MTL::ResourceOptions resourceOptions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setResourceOptions_), resourceOptions); +} + +_MTL_INLINE void MTL::TensorDescriptor::setStorageMode(MTL::StorageMode storageMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStorageMode_), storageMode); +} + +_MTL_INLINE void MTL::TensorDescriptor::setStrides(const MTL::TensorExtents* strides) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStrides_), strides); +} + +_MTL_INLINE void MTL::TensorDescriptor::setUsage(MTL::TensorUsage usage) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setUsage_), usage); +} + +_MTL_INLINE MTL::StorageMode MTL::TensorDescriptor::storageMode() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(storageMode)); +} + +_MTL_INLINE MTL::TensorExtents* MTL::TensorDescriptor::strides() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(strides)); +} + +_MTL_INLINE MTL::TensorUsage MTL::TensorDescriptor::usage() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(usage)); +} + +_MTL_INLINE MTL::Buffer* MTL::Tensor::buffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(buffer)); +} + +_MTL_INLINE NS::UInteger MTL::Tensor::bufferOffset() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(bufferOffset)); +} + +_MTL_INLINE MTL::TensorDataType MTL::Tensor::dataType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(dataType)); +} + +_MTL_INLINE MTL::TensorExtents* MTL::Tensor::dimensions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(dimensions)); +} + +_MTL_INLINE void MTL::Tensor::getBytes(void* bytes, const MTL::TensorExtents* strides, const MTL::TensorExtents* sliceOrigin, const MTL::TensorExtents* sliceDimensions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(getBytes_strides_fromSliceOrigin_sliceDimensions_), bytes, strides, sliceOrigin, sliceDimensions); +} + +_MTL_INLINE MTL::ResourceID MTL::Tensor::gpuResourceID() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(gpuResourceID)); +} + +_MTL_INLINE void MTL::Tensor::replaceSliceOrigin(const MTL::TensorExtents* sliceOrigin, const MTL::TensorExtents* sliceDimensions, const void* bytes, const MTL::TensorExtents* strides) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(replaceSliceOrigin_sliceDimensions_withBytes_strides_), sliceOrigin, sliceDimensions, bytes, strides); +} + +_MTL_INLINE MTL::TensorExtents* MTL::Tensor::strides() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(strides)); +} + +_MTL_INLINE MTL::TensorUsage MTL::Tensor::usage() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(usage)); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLTexture.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLTexture.hpp new file mode 100644 index 00000000..631d9202 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLTexture.hpp @@ -0,0 +1,803 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLTexture.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPixelFormat.hpp" +#include "MTLPrivate.hpp" +#include "MTLResource.hpp" +#include "MTLTypes.hpp" +#include + +namespace MTL +{ +class Buffer; +class Device; +class Resource; +class SharedTextureHandle; +class Texture; +class TextureDescriptor; +class TextureViewDescriptor; +} + +namespace MTL +{ +_MTL_ENUM(NS::UInteger, TextureType) { + TextureType1D = 0, + TextureType1DArray = 1, + TextureType2D = 2, + TextureType2DArray = 3, + TextureType2DMultisample = 4, + TextureTypeCube = 5, + TextureTypeCubeArray = 6, + TextureType3D = 7, + TextureType2DMultisampleArray = 8, + TextureTypeTextureBuffer = 9, +}; + +_MTL_ENUM(uint8_t, TextureSwizzle) { + TextureSwizzleZero = 0, + TextureSwizzleOne = 1, + TextureSwizzleRed = 2, + TextureSwizzleGreen = 3, + TextureSwizzleBlue = 4, + TextureSwizzleAlpha = 5, +}; + +_MTL_ENUM(NS::Integer, TextureCompressionType) { + TextureCompressionTypeLossless = 0, + TextureCompressionTypeLossy = 1, +}; + +_MTL_OPTIONS(NS::UInteger, TextureUsage) { + TextureUsageUnknown = 0, + TextureUsageShaderRead = 1, + TextureUsageShaderWrite = 1 << 1, + TextureUsageRenderTarget = 1 << 2, + TextureUsagePixelFormatView = 1 << 4, + TextureUsageShaderAtomic = 1 << 5, +}; + +struct TextureSwizzleChannels +{ + + TextureSwizzleChannels(MTL::TextureSwizzle r, MTL::TextureSwizzle g, MTL::TextureSwizzle b, MTL::TextureSwizzle a); + + TextureSwizzleChannels(); + + static TextureSwizzleChannels Default(); + + static TextureSwizzleChannels Make(MTL::TextureSwizzle r, MTL::TextureSwizzle g, MTL::TextureSwizzle b, MTL::TextureSwizzle a); + + MTL::TextureSwizzle red; + MTL::TextureSwizzle green; + MTL::TextureSwizzle blue; + MTL::TextureSwizzle alpha; +} _MTL_PACKED; + +class SharedTextureHandle : public NS::SecureCoding +{ +public: + static SharedTextureHandle* alloc(); + + Device* device() const; + + SharedTextureHandle* init(); + + NS::String* label() const; +}; +class TextureDescriptor : public NS::Copying +{ +public: + static TextureDescriptor* alloc(); + + bool allowGPUOptimizedContents() const; + + NS::UInteger arrayLength() const; + + TextureCompressionType compressionType() const; + + CPUCacheMode cpuCacheMode() const; + + NS::UInteger depth() const; + + HazardTrackingMode hazardTrackingMode() const; + + NS::UInteger height() const; + + TextureDescriptor* init(); + + NS::UInteger mipmapLevelCount() const; + + PixelFormat pixelFormat() const; + + SparsePageSize placementSparsePageSize() const; + + ResourceOptions resourceOptions() const; + + NS::UInteger sampleCount() const; + + void setAllowGPUOptimizedContents(bool allowGPUOptimizedContents); + + void setArrayLength(NS::UInteger arrayLength); + + void setCompressionType(MTL::TextureCompressionType compressionType); + + void setCpuCacheMode(MTL::CPUCacheMode cpuCacheMode); + + void setDepth(NS::UInteger depth); + + void setHazardTrackingMode(MTL::HazardTrackingMode hazardTrackingMode); + + void setHeight(NS::UInteger height); + + void setMipmapLevelCount(NS::UInteger mipmapLevelCount); + + void setPixelFormat(MTL::PixelFormat pixelFormat); + + void setPlacementSparsePageSize(MTL::SparsePageSize placementSparsePageSize); + + void setResourceOptions(MTL::ResourceOptions resourceOptions); + + void setSampleCount(NS::UInteger sampleCount); + + void setStorageMode(MTL::StorageMode storageMode); + + void setSwizzle(MTL::TextureSwizzleChannels swizzle); + + void setTextureType(MTL::TextureType textureType); + + void setUsage(MTL::TextureUsage usage); + + void setWidth(NS::UInteger width); + + StorageMode storageMode() const; + + TextureSwizzleChannels swizzle() const; + + static TextureDescriptor* texture2DDescriptor(MTL::PixelFormat pixelFormat, NS::UInteger width, NS::UInteger height, bool mipmapped); + + static TextureDescriptor* textureBufferDescriptor(MTL::PixelFormat pixelFormat, NS::UInteger width, MTL::ResourceOptions resourceOptions, MTL::TextureUsage usage); + + static TextureDescriptor* textureCubeDescriptor(MTL::PixelFormat pixelFormat, NS::UInteger size, bool mipmapped); + + TextureType textureType() const; + + TextureUsage usage() const; + + NS::UInteger width() const; +}; +class TextureViewDescriptor : public NS::Copying +{ +public: + static TextureViewDescriptor* alloc(); + + TextureViewDescriptor* init(); + + NS::Range levelRange() const; + + PixelFormat pixelFormat() const; + + void setLevelRange(NS::Range levelRange); + + void setPixelFormat(MTL::PixelFormat pixelFormat); + + void setSliceRange(NS::Range sliceRange); + + void setSwizzle(MTL::TextureSwizzleChannels swizzle); + + void setTextureType(MTL::TextureType textureType); + + NS::Range sliceRange() const; + + TextureSwizzleChannels swizzle() const; + + TextureType textureType() const; +}; +class Texture : public NS::Referencing +{ +public: + bool allowGPUOptimizedContents() const; + + NS::UInteger arrayLength() const; + + Buffer* buffer() const; + NS::UInteger bufferBytesPerRow() const; + + NS::UInteger bufferOffset() const; + + TextureCompressionType compressionType() const; + + NS::UInteger depth() const; + + NS::UInteger firstMipmapInTail() const; + + [[deprecated("please use isFramebufferOnly instead")]] + bool framebufferOnly() const; + + void getBytes(void* pixelBytes, NS::UInteger bytesPerRow, NS::UInteger bytesPerImage, MTL::Region region, NS::UInteger level, NS::UInteger slice); + void getBytes(void* pixelBytes, NS::UInteger bytesPerRow, MTL::Region region, NS::UInteger level); + + ResourceID gpuResourceID() const; + + NS::UInteger height() const; + + IOSurfaceRef iosurface() const; + NS::UInteger iosurfacePlane() const; + + bool isFramebufferOnly() const; + + bool isShareable() const; + + bool isSparse() const; + + NS::UInteger mipmapLevelCount() const; + + Texture* newRemoteTextureViewForDevice(const MTL::Device* device); + + SharedTextureHandle* newSharedTextureHandle(); + + Texture* newTextureView(MTL::PixelFormat pixelFormat); + Texture* newTextureView(MTL::PixelFormat pixelFormat, MTL::TextureType textureType, NS::Range levelRange, NS::Range sliceRange); + Texture* newTextureView(const MTL::TextureViewDescriptor* descriptor); + Texture* newTextureView(MTL::PixelFormat pixelFormat, MTL::TextureType textureType, NS::Range levelRange, NS::Range sliceRange, MTL::TextureSwizzleChannels swizzle); + + NS::UInteger parentRelativeLevel() const; + + NS::UInteger parentRelativeSlice() const; + + Texture* parentTexture() const; + + PixelFormat pixelFormat() const; + + Texture* remoteStorageTexture() const; + + void replaceRegion(MTL::Region region, NS::UInteger level, NS::UInteger slice, const void* pixelBytes, NS::UInteger bytesPerRow, NS::UInteger bytesPerImage); + void replaceRegion(MTL::Region region, NS::UInteger level, const void* pixelBytes, NS::UInteger bytesPerRow); + + Resource* rootResource() const; + + NS::UInteger sampleCount() const; + + [[deprecated("please use isShareable instead")]] + bool shareable() const; + + TextureSparseTier sparseTextureTier() const; + + TextureSwizzleChannels swizzle() const; + + NS::UInteger tailSizeInBytes() const; + + TextureType textureType() const; + + TextureUsage usage() const; + + NS::UInteger width() const; +}; + +} +_MTL_INLINE MTL::TextureSwizzleChannels::TextureSwizzleChannels(MTL::TextureSwizzle r, MTL::TextureSwizzle g, MTL::TextureSwizzle b, MTL::TextureSwizzle a) + : red(r) + , green(g) + , blue(b) + , alpha(a) +{ +} + +_MTL_INLINE MTL::TextureSwizzleChannels::TextureSwizzleChannels() + : red(MTL::TextureSwizzleRed) + , green(MTL::TextureSwizzleGreen) + , blue(MTL::TextureSwizzleBlue) + , alpha(MTL::TextureSwizzleAlpha) +{ +} + +_MTL_INLINE MTL::TextureSwizzleChannels MTL::TextureSwizzleChannels::Default() +{ + return MTL::TextureSwizzleChannels(); +} + +_MTL_INLINE MTL::TextureSwizzleChannels MTL::TextureSwizzleChannels::Make(MTL::TextureSwizzle r, MTL::TextureSwizzle g, MTL::TextureSwizzle b, MTL::TextureSwizzle a) +{ + return TextureSwizzleChannels(r, g, b, a); +} + +_MTL_INLINE MTL::SharedTextureHandle* MTL::SharedTextureHandle::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLSharedTextureHandle)); +} + +_MTL_INLINE MTL::Device* MTL::SharedTextureHandle::device() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(device)); +} + +_MTL_INLINE MTL::SharedTextureHandle* MTL::SharedTextureHandle::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::String* MTL::SharedTextureHandle::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE MTL::TextureDescriptor* MTL::TextureDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLTextureDescriptor)); +} + +_MTL_INLINE bool MTL::TextureDescriptor::allowGPUOptimizedContents() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(allowGPUOptimizedContents)); +} + +_MTL_INLINE NS::UInteger MTL::TextureDescriptor::arrayLength() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(arrayLength)); +} + +_MTL_INLINE MTL::TextureCompressionType MTL::TextureDescriptor::compressionType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(compressionType)); +} + +_MTL_INLINE MTL::CPUCacheMode MTL::TextureDescriptor::cpuCacheMode() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(cpuCacheMode)); +} + +_MTL_INLINE NS::UInteger MTL::TextureDescriptor::depth() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(depth)); +} + +_MTL_INLINE MTL::HazardTrackingMode MTL::TextureDescriptor::hazardTrackingMode() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(hazardTrackingMode)); +} + +_MTL_INLINE NS::UInteger MTL::TextureDescriptor::height() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(height)); +} + +_MTL_INLINE MTL::TextureDescriptor* MTL::TextureDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::UInteger MTL::TextureDescriptor::mipmapLevelCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(mipmapLevelCount)); +} + +_MTL_INLINE MTL::PixelFormat MTL::TextureDescriptor::pixelFormat() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(pixelFormat)); +} + +_MTL_INLINE MTL::SparsePageSize MTL::TextureDescriptor::placementSparsePageSize() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(placementSparsePageSize)); +} + +_MTL_INLINE MTL::ResourceOptions MTL::TextureDescriptor::resourceOptions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(resourceOptions)); +} + +_MTL_INLINE NS::UInteger MTL::TextureDescriptor::sampleCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(sampleCount)); +} + +_MTL_INLINE void MTL::TextureDescriptor::setAllowGPUOptimizedContents(bool allowGPUOptimizedContents) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setAllowGPUOptimizedContents_), allowGPUOptimizedContents); +} + +_MTL_INLINE void MTL::TextureDescriptor::setArrayLength(NS::UInteger arrayLength) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setArrayLength_), arrayLength); +} + +_MTL_INLINE void MTL::TextureDescriptor::setCompressionType(MTL::TextureCompressionType compressionType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setCompressionType_), compressionType); +} + +_MTL_INLINE void MTL::TextureDescriptor::setCpuCacheMode(MTL::CPUCacheMode cpuCacheMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setCpuCacheMode_), cpuCacheMode); +} + +_MTL_INLINE void MTL::TextureDescriptor::setDepth(NS::UInteger depth) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDepth_), depth); +} + +_MTL_INLINE void MTL::TextureDescriptor::setHazardTrackingMode(MTL::HazardTrackingMode hazardTrackingMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setHazardTrackingMode_), hazardTrackingMode); +} + +_MTL_INLINE void MTL::TextureDescriptor::setHeight(NS::UInteger height) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setHeight_), height); +} + +_MTL_INLINE void MTL::TextureDescriptor::setMipmapLevelCount(NS::UInteger mipmapLevelCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMipmapLevelCount_), mipmapLevelCount); +} + +_MTL_INLINE void MTL::TextureDescriptor::setPixelFormat(MTL::PixelFormat pixelFormat) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setPixelFormat_), pixelFormat); +} + +_MTL_INLINE void MTL::TextureDescriptor::setPlacementSparsePageSize(MTL::SparsePageSize placementSparsePageSize) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setPlacementSparsePageSize_), placementSparsePageSize); +} + +_MTL_INLINE void MTL::TextureDescriptor::setResourceOptions(MTL::ResourceOptions resourceOptions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setResourceOptions_), resourceOptions); +} + +_MTL_INLINE void MTL::TextureDescriptor::setSampleCount(NS::UInteger sampleCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSampleCount_), sampleCount); +} + +_MTL_INLINE void MTL::TextureDescriptor::setStorageMode(MTL::StorageMode storageMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStorageMode_), storageMode); +} + +_MTL_INLINE void MTL::TextureDescriptor::setSwizzle(MTL::TextureSwizzleChannels swizzle) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSwizzle_), swizzle); +} + +_MTL_INLINE void MTL::TextureDescriptor::setTextureType(MTL::TextureType textureType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTextureType_), textureType); +} + +_MTL_INLINE void MTL::TextureDescriptor::setUsage(MTL::TextureUsage usage) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setUsage_), usage); +} + +_MTL_INLINE void MTL::TextureDescriptor::setWidth(NS::UInteger width) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setWidth_), width); +} + +_MTL_INLINE MTL::StorageMode MTL::TextureDescriptor::storageMode() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(storageMode)); +} + +_MTL_INLINE MTL::TextureSwizzleChannels MTL::TextureDescriptor::swizzle() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(swizzle)); +} + +_MTL_INLINE MTL::TextureDescriptor* MTL::TextureDescriptor::texture2DDescriptor(MTL::PixelFormat pixelFormat, NS::UInteger width, NS::UInteger height, bool mipmapped) +{ + return Object::sendMessage(_MTL_PRIVATE_CLS(MTLTextureDescriptor), _MTL_PRIVATE_SEL(texture2DDescriptorWithPixelFormat_width_height_mipmapped_), pixelFormat, width, height, mipmapped); +} + +_MTL_INLINE MTL::TextureDescriptor* MTL::TextureDescriptor::textureBufferDescriptor(MTL::PixelFormat pixelFormat, NS::UInteger width, MTL::ResourceOptions resourceOptions, MTL::TextureUsage usage) +{ + return Object::sendMessage(_MTL_PRIVATE_CLS(MTLTextureDescriptor), _MTL_PRIVATE_SEL(textureBufferDescriptorWithPixelFormat_width_resourceOptions_usage_), pixelFormat, width, resourceOptions, usage); +} + +_MTL_INLINE MTL::TextureDescriptor* MTL::TextureDescriptor::textureCubeDescriptor(MTL::PixelFormat pixelFormat, NS::UInteger size, bool mipmapped) +{ + return Object::sendMessage(_MTL_PRIVATE_CLS(MTLTextureDescriptor), _MTL_PRIVATE_SEL(textureCubeDescriptorWithPixelFormat_size_mipmapped_), pixelFormat, size, mipmapped); +} + +_MTL_INLINE MTL::TextureType MTL::TextureDescriptor::textureType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(textureType)); +} + +_MTL_INLINE MTL::TextureUsage MTL::TextureDescriptor::usage() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(usage)); +} + +_MTL_INLINE NS::UInteger MTL::TextureDescriptor::width() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(width)); +} + +_MTL_INLINE MTL::TextureViewDescriptor* MTL::TextureViewDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLTextureViewDescriptor)); +} + +_MTL_INLINE MTL::TextureViewDescriptor* MTL::TextureViewDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::Range MTL::TextureViewDescriptor::levelRange() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(levelRange)); +} + +_MTL_INLINE MTL::PixelFormat MTL::TextureViewDescriptor::pixelFormat() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(pixelFormat)); +} + +_MTL_INLINE void MTL::TextureViewDescriptor::setLevelRange(NS::Range levelRange) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLevelRange_), levelRange); +} + +_MTL_INLINE void MTL::TextureViewDescriptor::setPixelFormat(MTL::PixelFormat pixelFormat) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setPixelFormat_), pixelFormat); +} + +_MTL_INLINE void MTL::TextureViewDescriptor::setSliceRange(NS::Range sliceRange) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSliceRange_), sliceRange); +} + +_MTL_INLINE void MTL::TextureViewDescriptor::setSwizzle(MTL::TextureSwizzleChannels swizzle) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSwizzle_), swizzle); +} + +_MTL_INLINE void MTL::TextureViewDescriptor::setTextureType(MTL::TextureType textureType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTextureType_), textureType); +} + +_MTL_INLINE NS::Range MTL::TextureViewDescriptor::sliceRange() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(sliceRange)); +} + +_MTL_INLINE MTL::TextureSwizzleChannels MTL::TextureViewDescriptor::swizzle() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(swizzle)); +} + +_MTL_INLINE MTL::TextureType MTL::TextureViewDescriptor::textureType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(textureType)); +} + +_MTL_INLINE bool MTL::Texture::allowGPUOptimizedContents() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(allowGPUOptimizedContents)); +} + +_MTL_INLINE NS::UInteger MTL::Texture::arrayLength() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(arrayLength)); +} + +_MTL_INLINE MTL::Buffer* MTL::Texture::buffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(buffer)); +} + +_MTL_INLINE NS::UInteger MTL::Texture::bufferBytesPerRow() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(bufferBytesPerRow)); +} + +_MTL_INLINE NS::UInteger MTL::Texture::bufferOffset() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(bufferOffset)); +} + +_MTL_INLINE MTL::TextureCompressionType MTL::Texture::compressionType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(compressionType)); +} + +_MTL_INLINE NS::UInteger MTL::Texture::depth() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(depth)); +} + +_MTL_INLINE NS::UInteger MTL::Texture::firstMipmapInTail() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(firstMipmapInTail)); +} + +_MTL_INLINE bool MTL::Texture::framebufferOnly() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isFramebufferOnly)); +} + +_MTL_INLINE void MTL::Texture::getBytes(void* pixelBytes, NS::UInteger bytesPerRow, NS::UInteger bytesPerImage, MTL::Region region, NS::UInteger level, NS::UInteger slice) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(getBytes_bytesPerRow_bytesPerImage_fromRegion_mipmapLevel_slice_), pixelBytes, bytesPerRow, bytesPerImage, region, level, slice); +} + +_MTL_INLINE void MTL::Texture::getBytes(void* pixelBytes, NS::UInteger bytesPerRow, MTL::Region region, NS::UInteger level) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(getBytes_bytesPerRow_fromRegion_mipmapLevel_), pixelBytes, bytesPerRow, region, level); +} + +_MTL_INLINE MTL::ResourceID MTL::Texture::gpuResourceID() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(gpuResourceID)); +} + +_MTL_INLINE NS::UInteger MTL::Texture::height() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(height)); +} + +_MTL_INLINE IOSurfaceRef MTL::Texture::iosurface() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(iosurface)); +} + +_MTL_INLINE NS::UInteger MTL::Texture::iosurfacePlane() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(iosurfacePlane)); +} + +_MTL_INLINE bool MTL::Texture::isFramebufferOnly() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isFramebufferOnly)); +} + +_MTL_INLINE bool MTL::Texture::isShareable() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isShareable)); +} + +_MTL_INLINE bool MTL::Texture::isSparse() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isSparse)); +} + +_MTL_INLINE NS::UInteger MTL::Texture::mipmapLevelCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(mipmapLevelCount)); +} + +_MTL_INLINE MTL::Texture* MTL::Texture::newRemoteTextureViewForDevice(const MTL::Device* device) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newRemoteTextureViewForDevice_), device); +} + +_MTL_INLINE MTL::SharedTextureHandle* MTL::Texture::newSharedTextureHandle() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newSharedTextureHandle)); +} + +_MTL_INLINE MTL::Texture* MTL::Texture::newTextureView(MTL::PixelFormat pixelFormat) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newTextureViewWithPixelFormat_), pixelFormat); +} + +_MTL_INLINE MTL::Texture* MTL::Texture::newTextureView(MTL::PixelFormat pixelFormat, MTL::TextureType textureType, NS::Range levelRange, NS::Range sliceRange) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newTextureViewWithPixelFormat_textureType_levels_slices_), pixelFormat, textureType, levelRange, sliceRange); +} + +_MTL_INLINE MTL::Texture* MTL::Texture::newTextureView(const MTL::TextureViewDescriptor* descriptor) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newTextureViewWithDescriptor_), descriptor); +} + +_MTL_INLINE MTL::Texture* MTL::Texture::newTextureView(MTL::PixelFormat pixelFormat, MTL::TextureType textureType, NS::Range levelRange, NS::Range sliceRange, MTL::TextureSwizzleChannels swizzle) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newTextureViewWithPixelFormat_textureType_levels_slices_swizzle_), pixelFormat, textureType, levelRange, sliceRange, swizzle); +} + +_MTL_INLINE NS::UInteger MTL::Texture::parentRelativeLevel() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(parentRelativeLevel)); +} + +_MTL_INLINE NS::UInteger MTL::Texture::parentRelativeSlice() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(parentRelativeSlice)); +} + +_MTL_INLINE MTL::Texture* MTL::Texture::parentTexture() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(parentTexture)); +} + +_MTL_INLINE MTL::PixelFormat MTL::Texture::pixelFormat() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(pixelFormat)); +} + +_MTL_INLINE MTL::Texture* MTL::Texture::remoteStorageTexture() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(remoteStorageTexture)); +} + +_MTL_INLINE void MTL::Texture::replaceRegion(MTL::Region region, NS::UInteger level, NS::UInteger slice, const void* pixelBytes, NS::UInteger bytesPerRow, NS::UInteger bytesPerImage) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(replaceRegion_mipmapLevel_slice_withBytes_bytesPerRow_bytesPerImage_), region, level, slice, pixelBytes, bytesPerRow, bytesPerImage); +} + +_MTL_INLINE void MTL::Texture::replaceRegion(MTL::Region region, NS::UInteger level, const void* pixelBytes, NS::UInteger bytesPerRow) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(replaceRegion_mipmapLevel_withBytes_bytesPerRow_), region, level, pixelBytes, bytesPerRow); +} + +_MTL_INLINE MTL::Resource* MTL::Texture::rootResource() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(rootResource)); +} + +_MTL_INLINE NS::UInteger MTL::Texture::sampleCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(sampleCount)); +} + +_MTL_INLINE bool MTL::Texture::shareable() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isShareable)); +} + +_MTL_INLINE MTL::TextureSparseTier MTL::Texture::sparseTextureTier() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(sparseTextureTier)); +} + +_MTL_INLINE MTL::TextureSwizzleChannels MTL::Texture::swizzle() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(swizzle)); +} + +_MTL_INLINE NS::UInteger MTL::Texture::tailSizeInBytes() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(tailSizeInBytes)); +} + +_MTL_INLINE MTL::TextureType MTL::Texture::textureType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(textureType)); +} + +_MTL_INLINE MTL::TextureUsage MTL::Texture::usage() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(usage)); +} + +_MTL_INLINE NS::UInteger MTL::Texture::width() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(width)); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLTextureViewPool.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLTextureViewPool.hpp new file mode 100644 index 00000000..cb7556f5 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLTextureViewPool.hpp @@ -0,0 +1,59 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLTextureViewPool.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLResourceViewPool.hpp" +#include "MTLTypes.hpp" + +namespace MTL +{ +class Buffer; +class Texture; +class TextureDescriptor; +class TextureViewDescriptor; + +class TextureViewPool : public NS::Referencing +{ +public: + ResourceID setTextureView(const MTL::Texture* texture, NS::UInteger index); + ResourceID setTextureView(const MTL::Texture* texture, const MTL::TextureViewDescriptor* descriptor, NS::UInteger index); + ResourceID setTextureViewFromBuffer(const MTL::Buffer* buffer, const MTL::TextureDescriptor* descriptor, NS::UInteger offset, NS::UInteger bytesPerRow, NS::UInteger index); +}; + +} +_MTL_INLINE MTL::ResourceID MTL::TextureViewPool::setTextureView(const MTL::Texture* texture, NS::UInteger index) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(setTextureView_atIndex_), texture, index); +} + +_MTL_INLINE MTL::ResourceID MTL::TextureViewPool::setTextureView(const MTL::Texture* texture, const MTL::TextureViewDescriptor* descriptor, NS::UInteger index) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(setTextureView_descriptor_atIndex_), texture, descriptor, index); +} + +_MTL_INLINE MTL::ResourceID MTL::TextureViewPool::setTextureViewFromBuffer(const MTL::Buffer* buffer, const MTL::TextureDescriptor* descriptor, NS::UInteger offset, NS::UInteger bytesPerRow, NS::UInteger index) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(setTextureViewFromBuffer_descriptor_offset_bytesPerRow_atIndex_), buffer, descriptor, offset, bytesPerRow, index); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLTypes.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLTypes.hpp new file mode 100644 index 00000000..c6bbc031 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLTypes.hpp @@ -0,0 +1,164 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLTypes.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL +{ +struct SamplePosition; + +using Coordinate2D = MTL::SamplePosition; + +struct Origin +{ + Origin() = default; + + Origin(NS::UInteger x, NS::UInteger y, NS::UInteger z); + + static Origin Make(NS::UInteger x, NS::UInteger y, NS::UInteger z); + + NS::UInteger x; + NS::UInteger y; + NS::UInteger z; +} _MTL_PACKED; + +struct Size +{ + Size() = default; + + Size(NS::UInteger width, NS::UInteger height, NS::UInteger depth); + + static Size Make(NS::UInteger width, NS::UInteger height, NS::UInteger depth); + + NS::UInteger width; + NS::UInteger height; + NS::UInteger depth; +} _MTL_PACKED; + +struct Region +{ + Region() = default; + + Region(NS::UInteger x, NS::UInteger width); + + Region(NS::UInteger x, NS::UInteger y, NS::UInteger width, NS::UInteger height); + + Region(NS::UInteger x, NS::UInteger y, NS::UInteger z, NS::UInteger width, NS::UInteger height, NS::UInteger depth); + + static Region Make1D(NS::UInteger x, NS::UInteger width); + + static Region Make2D(NS::UInteger x, NS::UInteger y, NS::UInteger width, NS::UInteger height); + + static Region Make3D(NS::UInteger x, NS::UInteger y, NS::UInteger z, NS::UInteger width, NS::UInteger height, NS::UInteger depth); + + MTL::Origin origin; + MTL::Size size; +} _MTL_PACKED; + +struct SamplePosition +{ + SamplePosition() = default; + + SamplePosition(float x, float y); + + static SamplePosition Make(float x, float y); + + float x; + float y; +} _MTL_PACKED; + +struct ResourceID +{ + uint64_t _impl; +} _MTL_PACKED; + +} +_MTL_INLINE MTL::Origin::Origin(NS::UInteger x, NS::UInteger y, NS::UInteger z) + : x(x) + , y(y) + , z(z) +{ +} + +_MTL_INLINE MTL::Origin MTL::Origin::Make(NS::UInteger x, NS::UInteger y, NS::UInteger z) +{ + return Origin(x, y, z); +} + +_MTL_INLINE MTL::Size::Size(NS::UInteger width, NS::UInteger height, NS::UInteger depth) + : width(width) + , height(height) + , depth(depth) +{ +} + +_MTL_INLINE MTL::Size MTL::Size::Make(NS::UInteger width, NS::UInteger height, NS::UInteger depth) +{ + return Size(width, height, depth); +} + +_MTL_INLINE MTL::Region::Region(NS::UInteger x, NS::UInteger width) + : origin(x, 0, 0) + , size(width, 1, 1) +{ +} + +_MTL_INLINE MTL::Region::Region(NS::UInteger x, NS::UInteger y, NS::UInteger width, NS::UInteger height) + : origin(x, y, 0) + , size(width, height, 1) +{ +} + +_MTL_INLINE MTL::Region::Region(NS::UInteger x, NS::UInteger y, NS::UInteger z, NS::UInteger width, NS::UInteger height, NS::UInteger depth) + : origin(x, y, z) + , size(width, height, depth) +{ +} + +_MTL_INLINE MTL::Region MTL::Region::Make1D(NS::UInteger x, NS::UInteger width) +{ + return Region(x, width); +} + +_MTL_INLINE MTL::Region MTL::Region::Make2D(NS::UInteger x, NS::UInteger y, NS::UInteger width, NS::UInteger height) +{ + return Region(x, y, width, height); +} + +_MTL_INLINE MTL::Region MTL::Region::Make3D(NS::UInteger x, NS::UInteger y, NS::UInteger z, NS::UInteger width, NS::UInteger height, NS::UInteger depth) +{ + return Region(x, y, z, width, height, depth); +} + +_MTL_INLINE MTL::SamplePosition::SamplePosition(float x, float y) + : x(x) + , y(y) +{ +} + +_MTL_INLINE MTL::SamplePosition MTL::SamplePosition::Make(float x, float y) +{ + return SamplePosition(x, y); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLVersion.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLVersion.hpp new file mode 100644 index 00000000..d3503972 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLVersion.hpp @@ -0,0 +1,32 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLVersion.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#define METALCPP_VERSION_MAJOR 370 +#define METALCPP_VERSION_MINOR 63 +#define METALCPP_VERSION_PATCH 1 + +#define METALCPP_SUPPORTS_VERSION(major, minor, patch) \ + ((major < METALCPP_VERSION_MAJOR) || \ + (major == METALCPP_VERSION_MAJOR && minor < METALCPP_VERSION_MINOR) || \ + (major == METALCPP_VERSION_MAJOR && minor == METALCPP_VERSION_MINOR && patch <= METALCPP_VERSION_PATCH)) diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLVertexDescriptor.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLVertexDescriptor.hpp new file mode 100644 index 00000000..4a38f3bc --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLVertexDescriptor.hpp @@ -0,0 +1,326 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLVertexDescriptor.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL +{ +class VertexAttributeDescriptor; +class VertexAttributeDescriptorArray; +class VertexBufferLayoutDescriptor; +class VertexBufferLayoutDescriptorArray; +class VertexDescriptor; +_MTL_ENUM(NS::UInteger, VertexFormat) { + VertexFormatInvalid = 0, + VertexFormatUChar2 = 1, + VertexFormatUChar3 = 2, + VertexFormatUChar4 = 3, + VertexFormatChar2 = 4, + VertexFormatChar3 = 5, + VertexFormatChar4 = 6, + VertexFormatUChar2Normalized = 7, + VertexFormatUChar3Normalized = 8, + VertexFormatUChar4Normalized = 9, + VertexFormatChar2Normalized = 10, + VertexFormatChar3Normalized = 11, + VertexFormatChar4Normalized = 12, + VertexFormatUShort2 = 13, + VertexFormatUShort3 = 14, + VertexFormatUShort4 = 15, + VertexFormatShort2 = 16, + VertexFormatShort3 = 17, + VertexFormatShort4 = 18, + VertexFormatUShort2Normalized = 19, + VertexFormatUShort3Normalized = 20, + VertexFormatUShort4Normalized = 21, + VertexFormatShort2Normalized = 22, + VertexFormatShort3Normalized = 23, + VertexFormatShort4Normalized = 24, + VertexFormatHalf2 = 25, + VertexFormatHalf3 = 26, + VertexFormatHalf4 = 27, + VertexFormatFloat = 28, + VertexFormatFloat2 = 29, + VertexFormatFloat3 = 30, + VertexFormatFloat4 = 31, + VertexFormatInt = 32, + VertexFormatInt2 = 33, + VertexFormatInt3 = 34, + VertexFormatInt4 = 35, + VertexFormatUInt = 36, + VertexFormatUInt2 = 37, + VertexFormatUInt3 = 38, + VertexFormatUInt4 = 39, + VertexFormatInt1010102Normalized = 40, + VertexFormatUInt1010102Normalized = 41, + VertexFormatUChar4Normalized_BGRA = 42, + VertexFormatUChar = 45, + VertexFormatChar = 46, + VertexFormatUCharNormalized = 47, + VertexFormatCharNormalized = 48, + VertexFormatUShort = 49, + VertexFormatShort = 50, + VertexFormatUShortNormalized = 51, + VertexFormatShortNormalized = 52, + VertexFormatHalf = 53, + VertexFormatFloatRG11B10 = 54, + VertexFormatFloatRGB9E5 = 55, +}; + +_MTL_ENUM(NS::UInteger, VertexStepFunction) { + VertexStepFunctionConstant = 0, + VertexStepFunctionPerVertex = 1, + VertexStepFunctionPerInstance = 2, + VertexStepFunctionPerPatch = 3, + VertexStepFunctionPerPatchControlPoint = 4, +}; + +static const NS::UInteger BufferLayoutStrideDynamic = NS::UIntegerMax; + +class VertexBufferLayoutDescriptor : public NS::Copying +{ +public: + static VertexBufferLayoutDescriptor* alloc(); + + VertexBufferLayoutDescriptor* init(); + + void setStepFunction(MTL::VertexStepFunction stepFunction); + + void setStepRate(NS::UInteger stepRate); + + void setStride(NS::UInteger stride); + + VertexStepFunction stepFunction() const; + + NS::UInteger stepRate() const; + + NS::UInteger stride() const; +}; +class VertexBufferLayoutDescriptorArray : public NS::Referencing +{ +public: + static VertexBufferLayoutDescriptorArray* alloc(); + + VertexBufferLayoutDescriptorArray* init(); + + VertexBufferLayoutDescriptor* object(NS::UInteger index); + void setObject(const MTL::VertexBufferLayoutDescriptor* bufferDesc, NS::UInteger index); +}; +class VertexAttributeDescriptor : public NS::Copying +{ +public: + static VertexAttributeDescriptor* alloc(); + + NS::UInteger bufferIndex() const; + + VertexFormat format() const; + + VertexAttributeDescriptor* init(); + + NS::UInteger offset() const; + + void setBufferIndex(NS::UInteger bufferIndex); + + void setFormat(MTL::VertexFormat format); + + void setOffset(NS::UInteger offset); +}; +class VertexAttributeDescriptorArray : public NS::Referencing +{ +public: + static VertexAttributeDescriptorArray* alloc(); + + VertexAttributeDescriptorArray* init(); + + VertexAttributeDescriptor* object(NS::UInteger index); + void setObject(const MTL::VertexAttributeDescriptor* attributeDesc, NS::UInteger index); +}; +class VertexDescriptor : public NS::Copying +{ +public: + static VertexDescriptor* alloc(); + + VertexAttributeDescriptorArray* attributes() const; + + VertexDescriptor* init(); + + VertexBufferLayoutDescriptorArray* layouts() const; + + void reset(); + + static VertexDescriptor* vertexDescriptor(); +}; + +} + +_MTL_INLINE MTL::VertexBufferLayoutDescriptor* MTL::VertexBufferLayoutDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLVertexBufferLayoutDescriptor)); +} + +_MTL_INLINE MTL::VertexBufferLayoutDescriptor* MTL::VertexBufferLayoutDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE void MTL::VertexBufferLayoutDescriptor::setStepFunction(MTL::VertexStepFunction stepFunction) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStepFunction_), stepFunction); +} + +_MTL_INLINE void MTL::VertexBufferLayoutDescriptor::setStepRate(NS::UInteger stepRate) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStepRate_), stepRate); +} + +_MTL_INLINE void MTL::VertexBufferLayoutDescriptor::setStride(NS::UInteger stride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStride_), stride); +} + +_MTL_INLINE MTL::VertexStepFunction MTL::VertexBufferLayoutDescriptor::stepFunction() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(stepFunction)); +} + +_MTL_INLINE NS::UInteger MTL::VertexBufferLayoutDescriptor::stepRate() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(stepRate)); +} + +_MTL_INLINE NS::UInteger MTL::VertexBufferLayoutDescriptor::stride() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(stride)); +} + +_MTL_INLINE MTL::VertexBufferLayoutDescriptorArray* MTL::VertexBufferLayoutDescriptorArray::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLVertexBufferLayoutDescriptorArray)); +} + +_MTL_INLINE MTL::VertexBufferLayoutDescriptorArray* MTL::VertexBufferLayoutDescriptorArray::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::VertexBufferLayoutDescriptor* MTL::VertexBufferLayoutDescriptorArray::object(NS::UInteger index) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(objectAtIndexedSubscript_), index); +} + +_MTL_INLINE void MTL::VertexBufferLayoutDescriptorArray::setObject(const MTL::VertexBufferLayoutDescriptor* bufferDesc, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObject_atIndexedSubscript_), bufferDesc, index); +} + +_MTL_INLINE MTL::VertexAttributeDescriptor* MTL::VertexAttributeDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLVertexAttributeDescriptor)); +} + +_MTL_INLINE NS::UInteger MTL::VertexAttributeDescriptor::bufferIndex() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(bufferIndex)); +} + +_MTL_INLINE MTL::VertexFormat MTL::VertexAttributeDescriptor::format() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(format)); +} + +_MTL_INLINE MTL::VertexAttributeDescriptor* MTL::VertexAttributeDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::UInteger MTL::VertexAttributeDescriptor::offset() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(offset)); +} + +_MTL_INLINE void MTL::VertexAttributeDescriptor::setBufferIndex(NS::UInteger bufferIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBufferIndex_), bufferIndex); +} + +_MTL_INLINE void MTL::VertexAttributeDescriptor::setFormat(MTL::VertexFormat format) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFormat_), format); +} + +_MTL_INLINE void MTL::VertexAttributeDescriptor::setOffset(NS::UInteger offset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setOffset_), offset); +} + +_MTL_INLINE MTL::VertexAttributeDescriptorArray* MTL::VertexAttributeDescriptorArray::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLVertexAttributeDescriptorArray)); +} + +_MTL_INLINE MTL::VertexAttributeDescriptorArray* MTL::VertexAttributeDescriptorArray::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::VertexAttributeDescriptor* MTL::VertexAttributeDescriptorArray::object(NS::UInteger index) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(objectAtIndexedSubscript_), index); +} + +_MTL_INLINE void MTL::VertexAttributeDescriptorArray::setObject(const MTL::VertexAttributeDescriptor* attributeDesc, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObject_atIndexedSubscript_), attributeDesc, index); +} + +_MTL_INLINE MTL::VertexDescriptor* MTL::VertexDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLVertexDescriptor)); +} + +_MTL_INLINE MTL::VertexAttributeDescriptorArray* MTL::VertexDescriptor::attributes() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(attributes)); +} + +_MTL_INLINE MTL::VertexDescriptor* MTL::VertexDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::VertexBufferLayoutDescriptorArray* MTL::VertexDescriptor::layouts() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(layouts)); +} + +_MTL_INLINE void MTL::VertexDescriptor::reset() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(reset)); +} + +_MTL_INLINE MTL::VertexDescriptor* MTL::VertexDescriptor::vertexDescriptor() +{ + return Object::sendMessage(_MTL_PRIVATE_CLS(MTLVertexDescriptor), _MTL_PRIVATE_SEL(vertexDescriptor)); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/MTLVisibleFunctionTable.hpp b/Source/Cxxmlx/metal-cpp/Metal/MTLVisibleFunctionTable.hpp new file mode 100644 index 00000000..de144ea2 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/MTLVisibleFunctionTable.hpp @@ -0,0 +1,96 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLVisibleFunctionTable.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLResource.hpp" +#include "MTLTypes.hpp" + +namespace MTL +{ +class FunctionHandle; +class VisibleFunctionTableDescriptor; + +class VisibleFunctionTableDescriptor : public NS::Copying +{ +public: + static VisibleFunctionTableDescriptor* alloc(); + + NS::UInteger functionCount() const; + + VisibleFunctionTableDescriptor* init(); + + void setFunctionCount(NS::UInteger functionCount); + + static VisibleFunctionTableDescriptor* visibleFunctionTableDescriptor(); +}; +class VisibleFunctionTable : public NS::Referencing +{ +public: + ResourceID gpuResourceID() const; + + void setFunction(const MTL::FunctionHandle* function, NS::UInteger index); + void setFunctions(const MTL::FunctionHandle* const functions[], NS::Range range); +}; + +} +_MTL_INLINE MTL::VisibleFunctionTableDescriptor* MTL::VisibleFunctionTableDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLVisibleFunctionTableDescriptor)); +} + +_MTL_INLINE NS::UInteger MTL::VisibleFunctionTableDescriptor::functionCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(functionCount)); +} + +_MTL_INLINE MTL::VisibleFunctionTableDescriptor* MTL::VisibleFunctionTableDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE void MTL::VisibleFunctionTableDescriptor::setFunctionCount(NS::UInteger functionCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFunctionCount_), functionCount); +} + +_MTL_INLINE MTL::VisibleFunctionTableDescriptor* MTL::VisibleFunctionTableDescriptor::visibleFunctionTableDescriptor() +{ + return Object::sendMessage(_MTL_PRIVATE_CLS(MTLVisibleFunctionTableDescriptor), _MTL_PRIVATE_SEL(visibleFunctionTableDescriptor)); +} + +_MTL_INLINE MTL::ResourceID MTL::VisibleFunctionTable::gpuResourceID() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(gpuResourceID)); +} + +_MTL_INLINE void MTL::VisibleFunctionTable::setFunction(const MTL::FunctionHandle* function, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFunction_atIndex_), function, index); +} + +_MTL_INLINE void MTL::VisibleFunctionTable::setFunctions(const MTL::FunctionHandle* const functions[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFunctions_withRange_), functions, range); +} diff --git a/Source/Cxxmlx/metal-cpp/Metal/Metal.hpp b/Source/Cxxmlx/metal-cpp/Metal/Metal.hpp new file mode 100644 index 00000000..0d89cc04 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/Metal/Metal.hpp @@ -0,0 +1,120 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/Metal.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "MTLAccelerationStructure.hpp" +#include "MTLAccelerationStructureCommandEncoder.hpp" +#include "MTLAccelerationStructureTypes.hpp" +#include "MTLAllocation.hpp" +#include "MTLArgument.hpp" +#include "MTLArgumentEncoder.hpp" +#include "MTLBinaryArchive.hpp" +#include "MTLBlitCommandEncoder.hpp" +#include "MTLBlitPass.hpp" +#include "MTLBuffer.hpp" +#include "MTLCaptureManager.hpp" +#include "MTLCaptureScope.hpp" +#include "MTLCommandBuffer.hpp" +#include "MTLCommandEncoder.hpp" +#include "MTLCommandQueue.hpp" +#include "MTLComputeCommandEncoder.hpp" +#include "MTLComputePass.hpp" +#include "MTLComputePipeline.hpp" +#include "MTLCounters.hpp" +#include "MTLDefines.hpp" +#include "MTLDepthStencil.hpp" +#include "MTLDevice.hpp" +#include "MTLDrawable.hpp" +#include "MTLDynamicLibrary.hpp" +#include "MTLEvent.hpp" +#include "MTLFence.hpp" +#include "MTLFunctionConstantValues.hpp" +#include "MTLFunctionDescriptor.hpp" +#include "MTLFunctionHandle.hpp" +#include "MTLFunctionLog.hpp" +#include "MTLFunctionStitching.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLHeap.hpp" +#include "MTLIndirectCommandBuffer.hpp" +#include "MTLIndirectCommandEncoder.hpp" +#include "MTLIntersectionFunctionTable.hpp" +#include "MTLIOCommandBuffer.hpp" +#include "MTLIOCommandQueue.hpp" +#include "MTLIOCompressor.hpp" +#include "MTLLibrary.hpp" +#include "MTLLinkedFunctions.hpp" +#include "MTLLogState.hpp" +#include "MTLParallelRenderCommandEncoder.hpp" +#include "MTLPipeline.hpp" +#include "MTLPixelFormat.hpp" +#include "MTLPrivate.hpp" +#include "MTLRasterizationRate.hpp" +#include "MTLRenderCommandEncoder.hpp" +#include "MTLRenderPass.hpp" +#include "MTLRenderPipeline.hpp" +#include "MTLResidencySet.hpp" +#include "MTLResource.hpp" +#include "MTLResourceStateCommandEncoder.hpp" +#include "MTLResourceStatePass.hpp" +#include "MTLSampler.hpp" +#include "MTLStageInputOutputDescriptor.hpp" +#include "MTLTexture.hpp" +#include "MTLTypes.hpp" +#include "MTLVertexDescriptor.hpp" +#include "MTLVisibleFunctionTable.hpp" +#include "MTLVersion.hpp" +#include "MTLTensor.hpp" +#include "MTLResourceViewPool.hpp" +#include "MTLTextureViewPool.hpp" +#include "MTLDataType.hpp" +#include "MTL4ArgumentTable.hpp" +#include "MTL4BinaryFunction.hpp" +#include "MTL4CommandAllocator.hpp" +#include "MTL4CommandBuffer.hpp" +#include "MTL4CommandEncoder.hpp" +#include "MTL4CommandQueue.hpp" +#include "MTL4Counters.hpp" +#include "MTL4RenderPass.hpp" +#include "MTL4RenderCommandEncoder.hpp" +#include "MTL4ComputeCommandEncoder.hpp" +#include "MTL4MachineLearningCommandEncoder.hpp" +#include "MTL4Compiler.hpp" +#include "MTL4CompilerTask.hpp" +#include "MTL4LibraryDescriptor.hpp" +#include "MTL4FunctionDescriptor.hpp" +#include "MTL4LibraryFunctionDescriptor.hpp" +#include "MTL4SpecializedFunctionDescriptor.hpp" +#include "MTL4StitchedFunctionDescriptor.hpp" +#include "MTL4PipelineState.hpp" +#include "MTL4ComputePipeline.hpp" +#include "MTL4RenderPipeline.hpp" +#include "MTL4MachineLearningPipeline.hpp" +#include "MTL4TileRenderPipeline.hpp" +#include "MTL4MeshRenderPipeline.hpp" +#include "MTL4PipelineDataSetSerializer.hpp" +#include "MTL4Archive.hpp" +#include "MTL4CommitFeedback.hpp" +#include "MTL4BinaryFunctionDescriptor.hpp" +#include "MTL4LinkingDescriptor.hpp" + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/Source/Cxxmlx/metal-cpp/MetalFX/MTL4FXFrameInterpolator.hpp b/Source/Cxxmlx/metal-cpp/MetalFX/MTL4FXFrameInterpolator.hpp new file mode 100644 index 00000000..1c50ec9e --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/MetalFX/MTL4FXFrameInterpolator.hpp @@ -0,0 +1,47 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// MetalFX/MTL4FXFrameInterpolator.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "MTLFXDefines.hpp" +#include "MTLFXPrivate.hpp" + +#include "MTLFXFrameInterpolator.hpp" +#include "../Metal/Metal.hpp" + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace MTL4FX +{ + class FrameInterpolator : public NS::Referencing + { + public: + void encodeToCommandBuffer(MTL4::CommandBuffer* commandBuffer); + }; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTL4FX::FrameInterpolator::encodeToCommandBuffer(MTL4::CommandBuffer* commandBuffer) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( encodeToCommandBuffer_ ), commandBuffer ); +} diff --git a/Source/Cxxmlx/metal-cpp/MetalFX/MTL4FXSpatialScaler.hpp b/Source/Cxxmlx/metal-cpp/MetalFX/MTL4FXSpatialScaler.hpp new file mode 100644 index 00000000..8ea8dfdd --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/MetalFX/MTL4FXSpatialScaler.hpp @@ -0,0 +1,49 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// MetalFX/MTL4FXSpatialScaler.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "MTLFXDefines.hpp" +#include "MTLFXPrivate.hpp" + +#include "MTLFXSpatialScaler.hpp" +#include "../Metal/Metal.hpp" + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace MTL4FX +{ + class SpatialScaler : public NS::Referencing< SpatialScaler, MTLFX::SpatialScalerBase > + { + public: + void encodeToCommandBuffer( MTL4::CommandBuffer* pCommandBuffer ); + }; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTL4FX::SpatialScaler::encodeToCommandBuffer( MTL4::CommandBuffer* pCommandBuffer ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( encodeToCommandBuffer_ ), pCommandBuffer ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/Source/Cxxmlx/metal-cpp/MetalFX/MTL4FXTemporalDenoisedScaler.hpp b/Source/Cxxmlx/metal-cpp/MetalFX/MTL4FXTemporalDenoisedScaler.hpp new file mode 100644 index 00000000..73014bbc --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/MetalFX/MTL4FXTemporalDenoisedScaler.hpp @@ -0,0 +1,49 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// MetalFX/MTLFXTemporalDenoisedScaler.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "MTLFXDefines.hpp" +#include "MTLFXPrivate.hpp" + +#include "MTLFXTemporalDenoisedScaler.hpp" +#include "../Metal/Metal.hpp" + +#include + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace MTL4FX +{ + class TemporalDenoisedScaler : public NS::Referencing< TemporalDenoisedScaler, MTLFX::TemporalDenoisedScalerBase > + { + public: + void encodeToCommandBuffer(MTL4::CommandBuffer* commandBuffer); + }; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTL4FX::TemporalDenoisedScaler::encodeToCommandBuffer( MTL4::CommandBuffer* commandBuffer ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( encodeToCommandBuffer_ ), commandBuffer ); +} diff --git a/Source/Cxxmlx/metal-cpp/MetalFX/MTL4FXTemporalScaler.hpp b/Source/Cxxmlx/metal-cpp/MetalFX/MTL4FXTemporalScaler.hpp new file mode 100644 index 00000000..3bda5dca --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/MetalFX/MTL4FXTemporalScaler.hpp @@ -0,0 +1,49 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// MetalFX/MTL4FXTemporalScaler.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "MTLFXDefines.hpp" +#include "MTLFXPrivate.hpp" + +#include "MTLFXTemporalScaler.hpp" +#include "../Metal/Metal.hpp" + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace MTL4FX +{ + class TemporalScaler : public NS::Referencing< TemporalScaler, MTLFX::TemporalScalerBase > + { + public: + void encodeToCommandBuffer( MTL4::CommandBuffer* pCommandBuffer ); + }; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTL4FX::TemporalScaler::encodeToCommandBuffer( MTL4::CommandBuffer* pCommandBuffer ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( encodeToCommandBuffer_ ), pCommandBuffer ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/Source/Cxxmlx/metal-cpp/MetalFX/MTLFXDefines.hpp b/Source/Cxxmlx/metal-cpp/MetalFX/MTLFXDefines.hpp new file mode 100644 index 00000000..320e0aa8 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/MetalFX/MTLFXDefines.hpp @@ -0,0 +1,41 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// MetalFX/MTLFXDefines.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "../Foundation/NSDefines.hpp" + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#define _MTLFX_EXPORT _NS_EXPORT +#define _MTLFX_EXTERN _NS_EXTERN +#define _MTLFX_INLINE _NS_INLINE +#define _MTLFX_PACKED _NS_PACKED + +#define _MTLFX_CONST( type, name ) _NS_CONST( type, name ) +#define _MTLFX_ENUM( type, name ) _NS_ENUM( type, name ) +#define _MTLFX_OPTIONS( type, name ) _NS_OPTIONS( type, name ) + +#define _MTLFX_VALIDATE_SIZE( mtlfx, name ) _NS_VALIDATE_SIZE( mtlfx, name ) +#define _MTLFX_VALIDATE_ENUM( mtlfx, name ) _NS_VALIDATE_ENUM( mtlfx, name ) + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/Source/Cxxmlx/metal-cpp/MetalFX/MTLFXFrameInterpolator.hpp b/Source/Cxxmlx/metal-cpp/MetalFX/MTLFXFrameInterpolator.hpp new file mode 100644 index 00000000..10ff69cb --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/MetalFX/MTLFXFrameInterpolator.hpp @@ -0,0 +1,719 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// MetalFX/MTLFXFrameInterpolator.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "MTLFXDefines.hpp" +#include "MTLFXPrivate.hpp" +#include "MTLFXTemporalScaler.hpp" + +#include "../Metal/Metal.hpp" + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace MTL4FX +{ + class TemporalScaler; + class TemporalDenoisedScaler; + class FrameInterpolator; +} + +namespace MTLFX +{ + class FrameInterpolatorDescriptor : public NS::Copying< FrameInterpolatorDescriptor > + { + public: + static FrameInterpolatorDescriptor* alloc(); + FrameInterpolatorDescriptor* init(); + + MTL::PixelFormat colorTextureFormat() const; + void setColorTextureFormat(MTL::PixelFormat colorTextureFormat); + + MTL::PixelFormat outputTextureFormat() const; + void setOutputTextureFormat(MTL::PixelFormat outputTextureFormat); + + MTL::PixelFormat depthTextureFormat() const; + void setDepthTextureFormat(MTL::PixelFormat depthTextureFormat); + + MTL::PixelFormat motionTextureFormat() const; + void setMotionTextureFormat(MTL::PixelFormat motionTextureFormat); + + MTL::PixelFormat uiTextureFormat() const; + void setUITextureFormat(MTL::PixelFormat uiTextureFormat); + + MTLFX::FrameInterpolatableScaler* scaler() const; + void setScaler(MTLFX::FrameInterpolatableScaler* scaler); + + NS::UInteger inputWidth() const; + void setInputWidth( NS::UInteger inputWidth ); + + NS::UInteger inputHeight() const; + void setInputHeight( NS::UInteger inputHeight ); + + NS::UInteger outputWidth() const; + void setOutputWidth( NS::UInteger outputWidth ); + + NS::UInteger outputHeight() const; + void setOutputHeight( NS::UInteger outputHeight ); + + class FrameInterpolator* newFrameInterpolator( const MTL::Device* pDevice) const; + MTL4FX::FrameInterpolator* newFrameInterpolator( const MTL::Device* pDevice, const MTL4::Compiler* pCompiler) const; + + static bool supportsMetal4FX(MTL::Device* device); + static bool supportsDevice(MTL::Device* device); + }; + + class FrameInterpolatorBase : public NS::Referencing + { + public: + MTL::TextureUsage colorTextureUsage() const; + MTL::TextureUsage outputTextureUsage() const; + MTL::TextureUsage depthTextureUsage() const; + MTL::TextureUsage motionTextureUsage() const; + MTL::TextureUsage uiTextureUsage() const; + + MTL::PixelFormat colorTextureFormat() const; + MTL::PixelFormat depthTextureFormat() const; + MTL::PixelFormat motionTextureFormat() const; + MTL::PixelFormat outputTextureFormat() const; + + NS::UInteger inputWidth() const; + NS::UInteger inputHeight() const; + NS::UInteger outputWidth() const; + NS::UInteger outputHeight() const; + MTL::PixelFormat uiTextureFormat() const; + + MTL::Texture* colorTexture() const; + void setColorTexture(MTL::Texture* colorTexture); + + MTL::Texture* prevColorTexture() const; + void setPrevColorTexture(MTL::Texture* prevColorTexture); + + MTL::Texture* depthTexture() const; + void setDepthTexture(MTL::Texture* depthTexture); + + MTL::Texture* motionTexture() const; + void setMotionTexture(MTL::Texture* motionTexture); + + float motionVectorScaleX() const; + void setMotionVectorScaleX(float scaleX); + + float motionVectorScaleY() const; + void setMotionVectorScaleY(float scaleY); + + float deltaTime() const; + void setDeltaTime( float deltaTime ); + + float nearPlane() const; + void setNearPlane( float nearPlane ); + + float farPlane() const; + void setFarPlane( float farPlane ); + + float fieldOfView() const; + void setFieldOfView( float fieldOfView ); + + float aspectRatio() const; + void setAspectRatio( float aspectRatio ); + + MTL::Texture* uiTexture() const; + void setUITexture(MTL::Texture* uiTexture); + + float jitterOffsetX() const; + void setJitterOffsetX( float jitterOffsetX ); + + float jitterOffsetY() const; + void setJitterOffsetY( float jitterOffsetY ); + + bool isUITextureComposited() const; + void setIsUITextureComposited( bool uiTextureComposited ); + + bool shouldResetHistory() const; + void setShouldResetHistory( bool shouldResetHistory ); + + MTL::Texture* outputTexture() const; + void setOutputTexture( MTL::Texture* outputTexture ); + + MTL::Fence* fence() const; + void setFence( MTL::Fence* fence ); + + bool isDepthReversed() const; + void setDepthReversed( bool depthReversed ); + }; + + class FrameInterpolator : public NS::Referencing + { + public: + void encodeToCommandBuffer(MTL::CommandBuffer* commandBuffer); + }; + +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTLFX::FrameInterpolatorDescriptor* MTLFX::FrameInterpolatorDescriptor::alloc() +{ + return NS::Object::alloc< FrameInterpolatorDescriptor >( _MTLFX_PRIVATE_CLS( MTLFXFrameInterpolatorDescriptor ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTLFX::FrameInterpolatorDescriptor* MTLFX::FrameInterpolatorDescriptor::init() +{ + return NS::Object::init< FrameInterpolatorDescriptor >(); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::FrameInterpolatorDescriptor::colorTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( colorTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::FrameInterpolatorDescriptor::setColorTextureFormat( MTL::PixelFormat colorTextureFormat ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setColorTextureFormat_ ), colorTextureFormat ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::FrameInterpolatorDescriptor::outputTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( outputTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::FrameInterpolatorDescriptor::setOutputTextureFormat( MTL::PixelFormat outputTextureFormat ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setOutputTextureFormat_ ), outputTextureFormat ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::FrameInterpolatorDescriptor::depthTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( depthTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::FrameInterpolatorDescriptor::setDepthTextureFormat( MTL::PixelFormat depthTextureFormat ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setDepthTextureFormat_ ), depthTextureFormat ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::FrameInterpolatorDescriptor::motionTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( motionTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::FrameInterpolatorDescriptor::setMotionTextureFormat( MTL::PixelFormat motionTextureFormat ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setMotionTextureFormat_ ), motionTextureFormat ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::FrameInterpolatorDescriptor::uiTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( uiTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::FrameInterpolatorDescriptor::setUITextureFormat( MTL::PixelFormat uiTextureFormat ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setUITextureFormat_ ), uiTextureFormat ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTLFX::FrameInterpolatableScaler* MTLFX::FrameInterpolatorDescriptor::scaler() const +{ + return NS::Object::sendMessage< MTLFX::FrameInterpolatableScaler* >( this, _MTLFX_PRIVATE_SEL( scaler ) ); +} + +_MTLFX_INLINE void MTLFX::FrameInterpolatorDescriptor::setScaler(MTLFX::FrameInterpolatableScaler* scaler) +{ + NS::Object::sendMessage< void >(this, _MTLFX_PRIVATE_SEL( setScaler_ ), scaler ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::FrameInterpolatorDescriptor::inputWidth() const +{ + return NS::Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( inputWidth ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::FrameInterpolatorDescriptor::setInputWidth( NS::UInteger inputWidth ) +{ + NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setInputWidth_ ), inputWidth ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::FrameInterpolatorDescriptor::inputHeight() const +{ + return NS::Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( inputHeight ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::FrameInterpolatorDescriptor::setInputHeight( NS::UInteger inputHeight ) +{ + NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setInputHeight_ ), inputHeight ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::FrameInterpolatorDescriptor::outputWidth() const +{ + return NS::Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( outputWidth ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::FrameInterpolatorDescriptor::setOutputWidth( NS::UInteger outputWidth ) +{ + NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setOutputWidth_ ), outputWidth ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::FrameInterpolatorDescriptor::outputHeight() const +{ + return NS::Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( outputHeight ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::FrameInterpolatorDescriptor::setOutputHeight( NS::UInteger outputHeight ) +{ + NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setOutputHeight_ ), outputHeight ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTLFX::FrameInterpolator* MTLFX::FrameInterpolatorDescriptor::newFrameInterpolator( const MTL::Device* device ) const +{ + return NS::Object::sendMessage< MTLFX::FrameInterpolator* >( this, _MTLFX_PRIVATE_SEL( newFrameInterpolatorWithDevice_ ), device ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL4FX::FrameInterpolator* MTLFX::FrameInterpolatorDescriptor::newFrameInterpolator( const MTL::Device* device, const MTL4::Compiler* compiler ) const +{ + return NS::Object::sendMessage< MTL4FX::FrameInterpolator* >( this, _MTLFX_PRIVATE_SEL( newFrameInterpolatorWithDevice_compiler_ ), device, compiler ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE bool MTLFX::FrameInterpolatorDescriptor::supportsMetal4FX(MTL::Device* device) +{ + return NS::Object::sendMessageSafe< bool >( _MTLFX_PRIVATE_CLS(MTLFXFrameInterpolatorDescriptor), _MTLFX_PRIVATE_SEL( supportsMetal4FX_ ), device ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE bool MTLFX::FrameInterpolatorDescriptor::supportsDevice(MTL::Device* device) +{ + return NS::Object::sendMessageSafe< bool >( _MTLFX_PRIVATE_CLS(MTLFXFrameInterpolatorDescriptor), _MTLFX_PRIVATE_SEL( supportsDevice_ ), device ); +} + + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::TextureUsage MTLFX::FrameInterpolatorBase::colorTextureUsage() const +{ + return NS::Object::sendMessage< MTL::TextureUsage >( this, _MTLFX_PRIVATE_SEL( colorTextureUsage ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::TextureUsage MTLFX::FrameInterpolatorBase::outputTextureUsage() const +{ + return NS::Object::sendMessage< MTL::TextureUsage >( this, _MTLFX_PRIVATE_SEL( outputTextureUsage ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::TextureUsage MTLFX::FrameInterpolatorBase::depthTextureUsage() const +{ + return NS::Object::sendMessage< MTL::TextureUsage >( this, _MTLFX_PRIVATE_SEL( depthTextureUsage ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::TextureUsage MTLFX::FrameInterpolatorBase::motionTextureUsage() const +{ + return NS::Object::sendMessage< MTL::TextureUsage >( this, _MTLFX_PRIVATE_SEL( motionTextureUsage ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::TextureUsage MTLFX::FrameInterpolatorBase::uiTextureUsage() const +{ + return NS::Object::sendMessage< MTL::TextureUsage >( this, _MTLFX_PRIVATE_SEL( uiTextureUsage ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::FrameInterpolatorBase::colorTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( colorTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::FrameInterpolatorBase::depthTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( depthTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::FrameInterpolatorBase::motionTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( motionTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::FrameInterpolatorBase::outputTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( outputTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::FrameInterpolatorBase::inputWidth() const +{ + return NS::Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( inputWidth ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::FrameInterpolatorBase::inputHeight() const +{ + return NS::Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( inputHeight ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::FrameInterpolatorBase::outputWidth() const +{ + return NS::Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( outputWidth ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::FrameInterpolatorBase::outputHeight() const +{ + return NS::Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( outputHeight ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::FrameInterpolatorBase::uiTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( uiTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::Texture* MTLFX::FrameInterpolatorBase::colorTexture() const +{ + return NS::Object::sendMessage< MTL::Texture* >( this, _MTLFX_PRIVATE_SEL( colorTexture ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::FrameInterpolatorBase::setColorTexture(MTL::Texture* colorTexture) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setColorTexture_ ), colorTexture ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::Texture* MTLFX::FrameInterpolatorBase::prevColorTexture() const +{ + return NS::Object::sendMessage< MTL::Texture* >( this, _MTLFX_PRIVATE_SEL( prevColorTexture ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::FrameInterpolatorBase::setPrevColorTexture(MTL::Texture* prevColorTexture) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setPrevColorTexture_ ), prevColorTexture ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::Texture* MTLFX::FrameInterpolatorBase::depthTexture() const +{ + return NS::Object::sendMessage< MTL::Texture* >( this, _MTLFX_PRIVATE_SEL( depthTexture ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::FrameInterpolatorBase::setDepthTexture(MTL::Texture* depthTexture) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setDepthTexture_ ), depthTexture ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::Texture* MTLFX::FrameInterpolatorBase::motionTexture() const +{ + return NS::Object::sendMessage< MTL::Texture* >( this, _MTLFX_PRIVATE_SEL( motionTexture ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::FrameInterpolatorBase::setMotionTexture(MTL::Texture* motionTexture) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setMotionTexture_ ), motionTexture ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE float MTLFX::FrameInterpolatorBase::motionVectorScaleX() const +{ + return NS::Object::sendMessage< float >( this, _MTLFX_PRIVATE_SEL( motionVectorScaleX ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::FrameInterpolatorBase::setMotionVectorScaleX(float scaleX) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setMotionVectorScaleX_ ), scaleX ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE float MTLFX::FrameInterpolatorBase::motionVectorScaleY() const +{ + return NS::Object::sendMessage< float >( this, _MTLFX_PRIVATE_SEL( motionVectorScaleY ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::FrameInterpolatorBase::setMotionVectorScaleY(float scaleY) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setMotionVectorScaleY_ ), scaleY ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE float MTLFX::FrameInterpolatorBase::deltaTime() const +{ + return NS::Object::sendMessage< float >( this, _MTLFX_PRIVATE_SEL( deltaTime ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::FrameInterpolatorBase::setDeltaTime( float deltaTime ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setDeltaTime_ ), deltaTime ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE float MTLFX::FrameInterpolatorBase::nearPlane() const +{ + return NS::Object::sendMessage< float >( this, _MTLFX_PRIVATE_SEL( nearPlane ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::FrameInterpolatorBase::setNearPlane( float nearPlane ) +{ + NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setNearPlane_ ), nearPlane ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE float MTLFX::FrameInterpolatorBase::farPlane() const +{ + return NS::Object::sendMessage< float >( this, _MTLFX_PRIVATE_SEL( farPlane ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::FrameInterpolatorBase::setFarPlane( float farPlane ) +{ + NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setFarPlane_ ), farPlane ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE float MTLFX::FrameInterpolatorBase::fieldOfView() const +{ + return NS::Object::sendMessage< float >( this, _MTLFX_PRIVATE_SEL( fieldOfView ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::FrameInterpolatorBase::setFieldOfView( float fieldOfView ) +{ + NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setFieldOfView_ ), fieldOfView ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE float MTLFX::FrameInterpolatorBase::aspectRatio() const +{ + return NS::Object::sendMessage< float >( this, _MTLFX_PRIVATE_SEL( aspectRatio ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::FrameInterpolatorBase::setAspectRatio( float aspectRatio ) +{ + NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setAspectRatio_ ), aspectRatio ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::Texture* MTLFX::FrameInterpolatorBase::uiTexture() const +{ + return NS::Object::sendMessage< MTL::Texture* >( this, _MTLFX_PRIVATE_SEL( uiTexture ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::FrameInterpolatorBase::setUITexture(MTL::Texture* uiTexture) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setUITexture_ ), uiTexture ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE float MTLFX::FrameInterpolatorBase::jitterOffsetX() const +{ + return NS::Object::sendMessage< float >( this, _MTLFX_PRIVATE_SEL( jitterOffsetX ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::FrameInterpolatorBase::setJitterOffsetX( float jitterOffsetX ) +{ + NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setJitterOffsetX_ ), jitterOffsetX ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE float MTLFX::FrameInterpolatorBase::jitterOffsetY() const +{ + return NS::Object::sendMessage< float >( this, _MTLFX_PRIVATE_SEL( jitterOffsetY ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::FrameInterpolatorBase::setJitterOffsetY( float jitterOffsetY ) +{ + NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setJitterOffsetY_ ), jitterOffsetY ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE bool MTLFX::FrameInterpolatorBase::isUITextureComposited() const +{ + return NS::Object::sendMessage< bool >( this, _MTLFX_PRIVATE_SEL( isUITextureComposited ) ); +} + +_MTLFX_INLINE void MTLFX::FrameInterpolatorBase::setIsUITextureComposited( bool uiTextureComposited ) +{ + NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setIsUITextureComposited_ ), uiTextureComposited ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE bool MTLFX::FrameInterpolatorBase::shouldResetHistory() const +{ + return NS::Object::sendMessage< bool >( this, _MTLFX_PRIVATE_SEL( shouldResetHistory ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::FrameInterpolatorBase::setShouldResetHistory(bool shouldResetHistory) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setShouldResetHistory_ ), shouldResetHistory ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::Texture* MTLFX::FrameInterpolatorBase::outputTexture() const +{ + return NS::Object::sendMessage< MTL::Texture* >( this, _MTLFX_PRIVATE_SEL( outputTexture ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::FrameInterpolatorBase::setOutputTexture(MTL::Texture* outputTexture) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setOutputTexture_ ), outputTexture ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::Fence* MTLFX::FrameInterpolatorBase::fence() const +{ + return NS::Object::sendMessage< MTL::Fence* >( this, _MTLFX_PRIVATE_SEL( fence ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::FrameInterpolatorBase::setFence(MTL::Fence* fence) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setFence_ ), fence ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE bool MTLFX::FrameInterpolatorBase::isDepthReversed() const +{ + return NS::Object::sendMessage< bool >( this, _MTLFX_PRIVATE_SEL( isDepthReversed ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::FrameInterpolatorBase::setDepthReversed(bool depthReversed) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setDepthReversed_ ), depthReversed ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::FrameInterpolator::encodeToCommandBuffer(MTL::CommandBuffer* commandBuffer) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( encodeToCommandBuffer_ ), commandBuffer ); +} diff --git a/Source/Cxxmlx/metal-cpp/MetalFX/MTLFXPrivate.hpp b/Source/Cxxmlx/metal-cpp/MetalFX/MTLFXPrivate.hpp new file mode 100644 index 00000000..21fd728e --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/MetalFX/MTLFXPrivate.hpp @@ -0,0 +1,482 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// MetalFX/MTLFXPrivate.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "MTLFXDefines.hpp" + +#include + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#define _MTLFX_PRIVATE_CLS( symbol ) ( MTLFX::Private::Class::s_k##symbol ) +#define _MTLFX_PRIVATE_SEL( accessor ) ( MTLFX::Private::Selector::s_k##accessor ) + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#if defined( MTLFX_PRIVATE_IMPLEMENTATION ) + +#if defined( METALCPP_SYMBOL_VISIBILITY_HIDDEN ) +#define _MTLFX_PRIVATE_VISIBILITY __attribute__( ( visibility("hidden" ) ) ) +#else +#define _MTLFX_PRIVATE_VISIBILITY __attribute__( ( visibility("default" ) ) ) +#endif // METALCPP_SYMBOL_VISIBILITY_HIDDEN + +#define _MTLFX_PRIVATE_IMPORT __attribute__( ( weak_import ) ) + +#ifdef __OBJC__ +#define _MTLFX_PRIVATE_OBJC_LOOKUP_CLASS( symbol ) ( ( __bridge void* ) objc_lookUpClass( #symbol ) ) +#define _MTLFX_PRIVATE_OBJC_GET_PROTOCOL( symbol ) ( ( __bridge void* ) objc_getProtocol( #symbol ) ) +#else +#define _MTLFX_PRIVATE_OBJC_LOOKUP_CLASS( symbol ) objc_lookUpClass(#symbol) +#define _MTLFX_PRIVATE_OBJC_GET_PROTOCOL( symbol ) objc_getProtocol(#symbol) +#endif // __OBJC__ + +#define _MTLFX_PRIVATE_DEF_CLS( symbol ) void* s_k##symbol _MTLFX_PRIVATE_VISIBILITY = _MTLFX_PRIVATE_OBJC_LOOKUP_CLASS( symbol ) +#define _MTLFX_PRIVATE_DEF_PRO( symbol ) void* s_k##symbol _MTLFX_PRIVATE_VISIBILITY = _MTLFX_PRIVATE_OBJC_GET_PROTOCOL( symbol ) +#define _MTLFX_PRIVATE_DEF_SEL( accessor, symbol ) SEL s_k##accessor _MTLFX_PRIVATE_VISIBILITY = sel_registerName( symbol ) + +#include +#define MTLFX_DEF_FUNC( name, signature ) using Fn##name = signature; \ + Fn##name name = reinterpret_cast< Fn##name >( dlsym( RTLD_DEFAULT, #name ) ) + +namespace MTLFX::Private +{ + template + + inline _Type const LoadSymbol(const char* pSymbol) + { + const _Type* pAddress = static_cast<_Type*>(dlsym(RTLD_DEFAULT, pSymbol)); + + return pAddress ? *pAddress : nullptr; + } +} // MTLFX::Private + +#if defined(__MAC_26_0) || defined(__IPHONE_26_0) || defined(__TVOS_26_0) + +#define _MTLFX_PRIVATE_DEF_STR( type, symbol ) \ + _MTLFX_EXTERN type const MTLFX##symbol _MTLFX_PRIVATE_IMPORT; \ + type const MTLFX::symbol = ( nullptr != &MTLFX##symbol ) ? MTLFX##ssymbol : nullptr + +#define _MTLFX_PRIVATE_DEF_CONST( type, symbol ) \ + _MTLFX_EXTERN type const MTLFX##ssymbol _MTLFX_PRIVATE_IMPORT; \ + type const MTLFX::symbol = (nullptr != &MTLFX##ssymbol) ? MTLFX##ssymbol : nullptr + +#define _MTLFX_PRIVATE_DEF_WEAK_CONST( type, symbol ) \ + _MTLFX_EXTERN type const MTLFX##ssymbol; \ + type const MTLFX::symbol = Private::LoadSymbol< type >( "MTLFX" #symbol ) + +#else + +#define _MTLFX_PRIVATE_DEF_STR( type, symbol ) \ + _MTLFX_EXTERN type const MTLFX##ssymbol; \ + type const MTLFX::symbol = Private::LoadSymbol< type >( "MTLFX" #symbol ) + +#define _MTLFX_PRIVATE_DEF_CONST( type, symbol ) \ + _MTLFX_EXTERN type const MTLFX##ssymbol; \ + type const MTLFX::symbol = Private::LoadSymbol< type >( "MTLFX" #symbol ) + +#define _MTLFX_PRIVATE_DEF_WEAK_CONST( type, symbol ) _MTLFX_PRIVATE_DEF_CONST( type, symbol ) + +#endif + +#else + +#define _MTLFX_PRIVATE_DEF_CLS( symbol ) extern void* s_k##symbol +#define _MTLFX_PRIVATE_DEF_PRO( symbol ) extern void* s_k##symbol +#define _MTLFX_PRIVATE_DEF_SEL( accessor, symbol ) extern SEL s_k##accessor +#define _MTLFX_PRIVATE_DEF_STR( type, symbol ) extern type const MTLFX::symbol +#define _MTLFX_PRIVATE_DEF_CONST( type, symbol ) extern type const MTLFX::symbol +#define _MTLFX_PRIVATE_DEF_WEAK_CONST( type, symbol ) extern type const MTLFX::symbol + +#endif // MTLFX_PRIVATE_IMPLEMENTATION + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace MTLFX +{ + namespace Private + { + namespace Class + { + _MTLFX_PRIVATE_DEF_CLS( MTLFXSpatialScalerDescriptor ); + _MTLFX_PRIVATE_DEF_CLS( MTLFXTemporalScalerDescriptor ); + _MTLFX_PRIVATE_DEF_CLS( MTLFXFrameInterpolatorDescriptor ); + _MTLFX_PRIVATE_DEF_CLS( MTLFXTemporalDenoisedScalerDescriptor ); + + _MTLFX_PRIVATE_DEF_CLS( MTL4FXSpatialScalerDescriptor ); + _MTLFX_PRIVATE_DEF_CLS( MTL4FXTemporalScalerDescriptor ); + _MTLFX_PRIVATE_DEF_CLS( MTL4FXFrameInterpolatorDescriptor ); + _MTLFX_PRIVATE_DEF_CLS( MTL4FXTemporalDenoisedScalerDescriptor ); + } // Class + } // Private +} // MTLFX + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace MTLFX +{ + namespace Private + { + namespace Protocol + { + _MTLFX_PRIVATE_DEF_PRO( MTLFXSpatialScaler ); + _MTLFX_PRIVATE_DEF_PRO( MTLFXTemporalScaler ); + } // Protocol + } // Private +} // MTLFX + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace MTLFX +{ + namespace Private + { + namespace Selector + { + _MTLFX_PRIVATE_DEF_SEL( aspectRatio, + "aspectRatio" ); + _MTLFX_PRIVATE_DEF_SEL( colorProcessingMode, + "colorProcessingMode" ); + _MTLFX_PRIVATE_DEF_SEL( colorTexture, + "colorTexture" ); + _MTLFX_PRIVATE_DEF_SEL( colorTextureFormat, + "colorTextureFormat" ); + _MTLFX_PRIVATE_DEF_SEL( colorTextureUsage, + "colorTextureUsage" ); + _MTLFX_PRIVATE_DEF_SEL( deltaTime, + "deltaTime" ); + _MTLFX_PRIVATE_DEF_SEL( denoiseStrengthMaskTexture, + "denoiseStrengthMaskTexture" ); + _MTLFX_PRIVATE_DEF_SEL( denoiseStrengthMaskTextureFormat, + "denoiseStrengthMaskTextureFormat" ); + _MTLFX_PRIVATE_DEF_SEL( denoiseStrengthMaskTextureUsage, + "denoiseStrengthMaskTextureUsage" ); + _MTLFX_PRIVATE_DEF_SEL( depthTexture, + "depthTexture" ); + _MTLFX_PRIVATE_DEF_SEL( depthTextureFormat, + "depthTextureFormat" ); + _MTLFX_PRIVATE_DEF_SEL( depthTextureUsage, + "depthTextureUsage" ); + _MTLFX_PRIVATE_DEF_SEL( diffuseAlbedoTexture, + "diffuseAlbedoTexture" ); + _MTLFX_PRIVATE_DEF_SEL( diffuseAlbedoTextureFormat, + "diffuseAlbedoTextureFormat" ); + _MTLFX_PRIVATE_DEF_SEL( diffuseAlbedoTextureUsage, + "diffuseAlbedoTextureUsage" ); + _MTLFX_PRIVATE_DEF_SEL( encodeToCommandBuffer_, + "encodeToCommandBuffer:" ); + _MTLFX_PRIVATE_DEF_SEL( exposureTexture, + "exposureTexture" ); + _MTLFX_PRIVATE_DEF_SEL( farPlane, + "farPlane" ); + _MTLFX_PRIVATE_DEF_SEL( fence, + "fence" ); + _MTLFX_PRIVATE_DEF_SEL( fieldOfView, + "fieldOfView" ); + _MTLFX_PRIVATE_DEF_SEL( height, + "height" ); + _MTLFX_PRIVATE_DEF_SEL( inputContentHeight, + "inputContentHeight" ); + _MTLFX_PRIVATE_DEF_SEL( inputContentMaxScale, + "inputContentMaxScale" ); + _MTLFX_PRIVATE_DEF_SEL( inputContentMinScale, + "inputContentMinScale" ); + _MTLFX_PRIVATE_DEF_SEL( inputContentWidth, + "inputContentWidth" ); + _MTLFX_PRIVATE_DEF_SEL( inputHeight, + "inputHeight" ); + _MTLFX_PRIVATE_DEF_SEL( inputWidth, + "inputWidth" ); + _MTLFX_PRIVATE_DEF_SEL( isAutoExposureEnabled, + "isAutoExposureEnabled" ); + _MTLFX_PRIVATE_DEF_SEL( isDenoiseStrengthMaskTextureEnabled, + "isDenoiseStrengthMaskTextureEnabled" ); + _MTLFX_PRIVATE_DEF_SEL( isDepthReversed, + "isDepthReversed" ); + _MTLFX_PRIVATE_DEF_SEL( isInputContentPropertiesEnabled, + "isInputContentPropertiesEnabled" ); + _MTLFX_PRIVATE_DEF_SEL( isTransparencyOverlayTextureEnabled, + "isTransparencyOverlayTextureEnabled" ); + _MTLFX_PRIVATE_DEF_SEL( isReactiveMaskTextureEnabled, + "isReactiveMaskTextureEnabled" ); + _MTLFX_PRIVATE_DEF_SEL( isSpecularHitDistanceTextureEnabled, + "isSpecularHitDistanceTextureEnabled" ); + _MTLFX_PRIVATE_DEF_SEL( isUITextureComposited, + "isUITextureComposited" ); + _MTLFX_PRIVATE_DEF_SEL( jitterOffsetX, + "jitterOffsetX" ); + _MTLFX_PRIVATE_DEF_SEL( jitterOffsetY, + "jitterOffsetY" ); + _MTLFX_PRIVATE_DEF_SEL( maskTexture, + "maskTexture" ); + _MTLFX_PRIVATE_DEF_SEL( maskTextureFormat, + "maskTextureFormat" ); + _MTLFX_PRIVATE_DEF_SEL( maskTextureUsage, + "maskTextureUsage" ); + _MTLFX_PRIVATE_DEF_SEL( motionTexture, + "motionTexture" ); + _MTLFX_PRIVATE_DEF_SEL( motionTextureFormat, + "motionTextureFormat" ); + _MTLFX_PRIVATE_DEF_SEL( motionTextureUsage, + "motionTextureUsage" ); + _MTLFX_PRIVATE_DEF_SEL( motionVectorScaleX, + "motionVectorScaleX" ); + _MTLFX_PRIVATE_DEF_SEL( motionVectorScaleY, + "motionVectorScaleY" ); + _MTLFX_PRIVATE_DEF_SEL( nearPlane, + "nearPlane" ); + _MTLFX_PRIVATE_DEF_SEL( newFrameInterpolatorWithDevice_, + "newFrameInterpolatorWithDevice:" ); + _MTLFX_PRIVATE_DEF_SEL( newFrameInterpolatorWithDevice_compiler_, + "newFrameInterpolatorWithDevice:compiler:" ); + _MTLFX_PRIVATE_DEF_SEL( newTemporalDenoisedScalerWithDevice_, + "newTemporalDenoisedScalerWithDevice:" ); + _MTLFX_PRIVATE_DEF_SEL( newTemporalDenoisedScalerWithDevice_compiler_, + "newTemporalDenoisedScalerWithDevice:compiler:" ); + _MTLFX_PRIVATE_DEF_SEL( newSpatialScalerWithDevice_, + "newSpatialScalerWithDevice:" ); + _MTLFX_PRIVATE_DEF_SEL( newSpatialScalerWithDevice_compiler_, + "newSpatialScalerWithDevice:compiler:" ); + _MTLFX_PRIVATE_DEF_SEL( newTemporalScalerWithDevice_, + "newTemporalScalerWithDevice:" ); + _MTLFX_PRIVATE_DEF_SEL( newTemporalScalerWithDevice_compiler_, + "newTemporalScalerWithDevice:compiler:" ); + _MTLFX_PRIVATE_DEF_SEL( normalTexture, + "normalTexture" ); + _MTLFX_PRIVATE_DEF_SEL( normalTextureFormat, + "normalTextureFormat" ); + _MTLFX_PRIVATE_DEF_SEL( normalTextureUsage, + "normalTextureUsage" ); + _MTLFX_PRIVATE_DEF_SEL( outputHeight, + "outputHeight" ); + _MTLFX_PRIVATE_DEF_SEL( outputTexture, + "outputTexture" ); + _MTLFX_PRIVATE_DEF_SEL( outputTextureFormat, + "outputTextureFormat" ); + _MTLFX_PRIVATE_DEF_SEL( outputTextureUsage, + "outputTextureUsage" ); + _MTLFX_PRIVATE_DEF_SEL( outputWidth, + "outputWidth" ); + _MTLFX_PRIVATE_DEF_SEL( preExposure, + "preExposure" ); + _MTLFX_PRIVATE_DEF_SEL( transparencyOverlayTextureFormat, + "transparencyOverlayTextureFormat" ); + _MTLFX_PRIVATE_DEF_SEL( transparencyOverlayTextureUsage, + "transparencyOverlayTextureUsage" ); + _MTLFX_PRIVATE_DEF_SEL( prevColorTexture, + "prevColorTexture" ); + _MTLFX_PRIVATE_DEF_SEL( reactiveMaskTextureFormat, + "reactiveMaskTextureFormat" ); + _MTLFX_PRIVATE_DEF_SEL( reactiveTextureUsage, + "reactiveTextureUsage" ); + _MTLFX_PRIVATE_DEF_SEL( reactiveMaskTexture, + "reactiveMaskTexture" ); + _MTLFX_PRIVATE_DEF_SEL( reset, + "reset" ); + _MTLFX_PRIVATE_DEF_SEL( requiresSynchronousInitialization, + "requiresSynchronousInitialization" ); + _MTLFX_PRIVATE_DEF_SEL( roughnessTextureFormat, + "roughnessTextureFormat" ); + _MTLFX_PRIVATE_DEF_SEL( roughnessTextureUsage, + "roughnessTextureUsage" ); + _MTLFX_PRIVATE_DEF_SEL( scaler, + "scaler" ); + _MTLFX_PRIVATE_DEF_SEL( scaler4, + "scaler4" ); + _MTLFX_PRIVATE_DEF_SEL( setAspectRatio_, + "setAspectRatio:" ); + _MTLFX_PRIVATE_DEF_SEL( setAutoExposureEnabled_, + "setAutoExposureEnabled:" ); + _MTLFX_PRIVATE_DEF_SEL( setColorProcessingMode_, + "setColorProcessingMode:" ); + _MTLFX_PRIVATE_DEF_SEL( setColorTexture_, + "setColorTexture:" ); + _MTLFX_PRIVATE_DEF_SEL( setColorTextureFormat_, + "setColorTextureFormat:" ); + _MTLFX_PRIVATE_DEF_SEL( setDeltaTime_, + "setDeltaTime:" ); + _MTLFX_PRIVATE_DEF_SEL( setDenoiseStrengthMaskTexture_, + "setDenoiseStrengthMaskTexture:" ); + _MTLFX_PRIVATE_DEF_SEL( setDenoiseStrengthMaskTextureEnabled_, + "setDenoiseStrengthMaskTextureEnabled:" ); + _MTLFX_PRIVATE_DEF_SEL( setDenoiseStrengthMaskTextureFormat_, + "setDenoiseStrengthMaskTextureFormat:" ); + _MTLFX_PRIVATE_DEF_SEL( setDepthInverted_, + "setDepthInverted:" ); + _MTLFX_PRIVATE_DEF_SEL( setDepthReversed_, + "setDepthReversed:" ); + _MTLFX_PRIVATE_DEF_SEL( setDepthTexture_, + "setDepthTexture:" ); + _MTLFX_PRIVATE_DEF_SEL( setDepthTextureFormat_, + "setDepthTextureFormat:" ); + _MTLFX_PRIVATE_DEF_SEL( setDiffuseAlbedoTexture_, + "setDiffuseAlbedoTexture:" ); + _MTLFX_PRIVATE_DEF_SEL( setDiffuseAlbedoTextureFormat_, + "setDiffuseAlbedoTextureFormat:" ); + _MTLFX_PRIVATE_DEF_SEL( setExposureTexture_, + "setExposureTexture:" ); + _MTLFX_PRIVATE_DEF_SEL( setFarPlane_, + "setFarPlane:" ); + _MTLFX_PRIVATE_DEF_SEL( setFence_, + "setFence:" ); + _MTLFX_PRIVATE_DEF_SEL( setFieldOfView_, + "setFieldOfView:" ); + _MTLFX_PRIVATE_DEF_SEL( setHeight_, + "setHeight:" ); + _MTLFX_PRIVATE_DEF_SEL( setInputContentHeight_, + "setInputContentHeight:" ); + _MTLFX_PRIVATE_DEF_SEL( setInputContentMaxScale_, + "setInputContentMaxScale:" ); + _MTLFX_PRIVATE_DEF_SEL( setInputContentMinScale_, + "setInputContentMinScale:" ); + _MTLFX_PRIVATE_DEF_SEL( setInputContentPropertiesEnabled_, + "setInputContentPropertiesEnabled:" ); + _MTLFX_PRIVATE_DEF_SEL( setInputContentWidth_, + "setInputContentWidth:" ); + _MTLFX_PRIVATE_DEF_SEL( setInputHeight_, + "setInputHeight:" ); + _MTLFX_PRIVATE_DEF_SEL( setInputWidth_, + "setInputWidth:" ); + _MTLFX_PRIVATE_DEF_SEL( setIsUITextureComposited_, + "setIsUITextureComposited:" ); + _MTLFX_PRIVATE_DEF_SEL( setJitterOffsetX_, + "setJitterOffsetX:" ); + _MTLFX_PRIVATE_DEF_SEL( setJitterOffsetY_, + "setJitterOffsetY:" ); + _MTLFX_PRIVATE_DEF_SEL( setNearPlane_, + "setNearPlane:" ); + _MTLFX_PRIVATE_DEF_SEL( setMaskTexture_, + "setMaskTexture:" ); + _MTLFX_PRIVATE_DEF_SEL( setMaskTextureFormat_, + "setMaskTextureFormat:" ); + _MTLFX_PRIVATE_DEF_SEL( setMotionTexture_, + "setMotionTexture:" ); + _MTLFX_PRIVATE_DEF_SEL( setMotionTextureFormat_, + "setMotionTextureFormat:" ); + _MTLFX_PRIVATE_DEF_SEL( setMotionVectorScaleX_, + "setMotionVectorScaleX:" ); + _MTLFX_PRIVATE_DEF_SEL( setMotionVectorScaleY_, + "setMotionVectorScaleY:" ); + _MTLFX_PRIVATE_DEF_SEL( setNormalTexture_, + "setNormalTexture:" ); + _MTLFX_PRIVATE_DEF_SEL( setNormalTextureFormat_, + "setNormalTextureFormat:" ); + _MTLFX_PRIVATE_DEF_SEL( setOutputHeight_, + "setOutputHeight:" ); + _MTLFX_PRIVATE_DEF_SEL( setOutputTexture_, + "setOutputTexture:" ); + _MTLFX_PRIVATE_DEF_SEL( setOutputTextureFormat_, + "setOutputTextureFormat:" ); + _MTLFX_PRIVATE_DEF_SEL( setOutputWidth_, + "setOutputWidth:" ); + _MTLFX_PRIVATE_DEF_SEL( transparencyOverlayTexture, + "transparencyOverlayTexture" ); + _MTLFX_PRIVATE_DEF_SEL( setTransparencyOverlayTexture_, + "setTransparencyOverlayTexture:" ); + _MTLFX_PRIVATE_DEF_SEL( setTransparencyOverlayTextureEnabled_, + "setTransparencyOverlayTextureEnabled:" ); + _MTLFX_PRIVATE_DEF_SEL( setPreExposure_, + "setPreExposure:" ); + _MTLFX_PRIVATE_DEF_SEL( setTransparencyOverlayTextureFormat_, + "setTransparencyOverlayTextureFormat:" ); + _MTLFX_PRIVATE_DEF_SEL( setPrevColorTexture_, + "setPrevColorTexture:" ); + _MTLFX_PRIVATE_DEF_SEL( setReactiveMaskTexture_, + "setReactiveMaskTexture:" ); + _MTLFX_PRIVATE_DEF_SEL( setReactiveMaskTextureEnabled_, + "setReactiveMaskTextureEnabled:" ); + _MTLFX_PRIVATE_DEF_SEL( setReactiveMaskTextureFormat_, + "setReactiveMaskTextureFormat:" ); + _MTLFX_PRIVATE_DEF_SEL( setRequiresSynchronousInitialization_, + "setRequiresSynchronousInitialization:" ); + _MTLFX_PRIVATE_DEF_SEL( setReset_, + "setReset:" ); + _MTLFX_PRIVATE_DEF_SEL( roughnessTexture, + "roughnessTexture" ); + _MTLFX_PRIVATE_DEF_SEL( setRoughnessTexture_, + "setRoughnessTexture:" ); + _MTLFX_PRIVATE_DEF_SEL( setRoughnessTextureFormat_, + "setRoughnessTextureFormat:" ); + _MTLFX_PRIVATE_DEF_SEL( setScaler_, + "setScaler:" ); + _MTLFX_PRIVATE_DEF_SEL( setShouldResetHistory_, + "setShouldResetHistory:" ); + _MTLFX_PRIVATE_DEF_SEL( specularHitDistanceTexture, + "specularHitDistanceTexture" ); + _MTLFX_PRIVATE_DEF_SEL( setSpecularHitDistanceTexture_, + "setSpecularHitDistanceTexture:" ); + _MTLFX_PRIVATE_DEF_SEL( setSpecularHitDistanceTextureEnabled_, + "setSpecularHitDistanceTextureEnabled:" ); + _MTLFX_PRIVATE_DEF_SEL( setSpecularAlbedoTexture_, + "setSpecularAlbedoTexture:" ); + _MTLFX_PRIVATE_DEF_SEL( setSpecularAlbedoTextureFormat_, + "setSpecularAlbedoTextureFormat:" ); + _MTLFX_PRIVATE_DEF_SEL( setSpecularHitDistanceTextureFormat_, + "setSpecularHitDistanceTextureFormat:" ); + _MTLFX_PRIVATE_DEF_SEL( setUITexture_, + "setUITexture:" ); + _MTLFX_PRIVATE_DEF_SEL( setUITextureFormat_, + "setUITextureFormat:" ); + _MTLFX_PRIVATE_DEF_SEL( setViewToClipMatrix_, + "setViewToClipMatrix:" ); + _MTLFX_PRIVATE_DEF_SEL( setWidth_, + "setWidth:" ); + _MTLFX_PRIVATE_DEF_SEL( setWorldToViewMatrix_, + "setWorldToViewMatrix:" ); + _MTLFX_PRIVATE_DEF_SEL( shouldResetHistory, + "shouldResetHistory" ); + _MTLFX_PRIVATE_DEF_SEL( specularAlbedoTexture, + "specularAlbedoTexture" ); + _MTLFX_PRIVATE_DEF_SEL( specularAlbedoTextureFormat, + "specularAlbedoTextureFormat" ); + _MTLFX_PRIVATE_DEF_SEL( specularAlbedoTextureUsage, + "specularAlbedoTextureUsage" ); + _MTLFX_PRIVATE_DEF_SEL( specularHitDistanceTextureFormat, + "specularHitDistanceTextureFormat" ); + _MTLFX_PRIVATE_DEF_SEL( specularHitDistanceTextureUsage, + "specularHitDistanceTextureUsage" ); + _MTLFX_PRIVATE_DEF_SEL( supportedInputContentMaxScaleForDevice_, + "supportedInputContentMaxScaleForDevice:" ); + _MTLFX_PRIVATE_DEF_SEL( supportedInputContentMinScaleForDevice_, + "supportedInputContentMinScaleForDevice:" ); + _MTLFX_PRIVATE_DEF_SEL( supportsDevice_, + "supportsDevice:" ); + _MTLFX_PRIVATE_DEF_SEL( supportsMetal4FX_, + "supportsMetal4FX:" ); + _MTLFX_PRIVATE_DEF_SEL( uiTexture, + "uiTexture" ); + _MTLFX_PRIVATE_DEF_SEL( uiTextureFormat, + "uiTextureFormat" ); + _MTLFX_PRIVATE_DEF_SEL( uiTextureUsage, + "uiTextureFormat" ); + _MTLFX_PRIVATE_DEF_SEL( viewToClipMatrix, + "viewToClipMatrix" ); + _MTLFX_PRIVATE_DEF_SEL( width, + "width" ); + _MTLFX_PRIVATE_DEF_SEL( worldToViewMatrix, + "worldToViewMatrix" ); + } // Selector + } // Private +} // MTLFX + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/Source/Cxxmlx/metal-cpp/MetalFX/MTLFXSpatialScaler.hpp b/Source/Cxxmlx/metal-cpp/MetalFX/MTLFXSpatialScaler.hpp new file mode 100644 index 00000000..cb1186ed --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/MetalFX/MTLFXSpatialScaler.hpp @@ -0,0 +1,397 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// MetalFX/MTLFXSpatialScaler.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "MTLFXDefines.hpp" +#include "MTLFXPrivate.hpp" + +#include "../Metal/Metal.hpp" + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace MTL4FX +{ + class SpatialScaler; +} + +namespace MTLFX +{ + _MTLFX_ENUM( NS::Integer, SpatialScalerColorProcessingMode ) + { + SpatialScalerColorProcessingModePerceptual = 0, + SpatialScalerColorProcessingModeLinear = 1, + SpatialScalerColorProcessingModeHDR = 2 + }; + + class SpatialScalerDescriptor : public NS::Copying< SpatialScalerDescriptor > + { + public: + static class SpatialScalerDescriptor* alloc(); + class SpatialScalerDescriptor* init(); + + MTL::PixelFormat colorTextureFormat() const; + void setColorTextureFormat( MTL::PixelFormat format ); + + MTL::PixelFormat outputTextureFormat() const; + void setOutputTextureFormat( MTL::PixelFormat format ); + + NS::UInteger inputWidth() const; + void setInputWidth( NS::UInteger width ); + + NS::UInteger inputHeight() const; + void setInputHeight( NS::UInteger height ); + + NS::UInteger outputWidth() const; + void setOutputWidth( NS::UInteger width ); + + NS::UInteger outputHeight() const; + void setOutputHeight( NS::UInteger height ); + + SpatialScalerColorProcessingMode colorProcessingMode() const; + void setColorProcessingMode( SpatialScalerColorProcessingMode mode ); + + class SpatialScaler* newSpatialScaler( const MTL::Device* pDevice ) const; + MTL4FX::SpatialScaler* newSpatialScaler( const MTL::Device* pDevice, const MTL4::Compiler* pCompiler ) const; + + static bool supportsDevice( const MTL::Device* pDevice); + static bool supportsMetal4FX( const MTL::Device* pDevice ); + }; + + class SpatialScalerBase : public NS::Referencing< SpatialScaler > + { + public: + MTL::TextureUsage colorTextureUsage() const; + MTL::TextureUsage outputTextureUsage() const; + + NS::UInteger inputContentWidth() const; + void setInputContentWidth( NS::UInteger width ); + + NS::UInteger inputContentHeight() const; + void setInputContentHeight( NS::UInteger height ); + + MTL::Texture* colorTexture() const; + void setColorTexture( MTL::Texture* pTexture ); + + MTL::Texture* outputTexture() const; + void setOutputTexture( MTL::Texture* pTexture ); + + MTL::PixelFormat colorTextureFormat() const; + MTL::PixelFormat outputTextureFormat() const; + NS::UInteger inputWidth() const; + NS::UInteger inputHeight() const; + NS::UInteger outputWidth() const; + NS::UInteger outputHeight() const; + SpatialScalerColorProcessingMode colorProcessingMode() const; + + MTL::Fence* fence() const; + void setFence( MTL::Fence* pFence ); + }; + + class SpatialScaler : public NS::Referencing< SpatialScaler, SpatialScalerBase > + { + public: + void encodeToCommandBuffer( MTL::CommandBuffer* pCommandBuffer ); + }; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTLFX::SpatialScalerDescriptor* MTLFX::SpatialScalerDescriptor::alloc() +{ + return NS::Object::alloc< SpatialScalerDescriptor >( _MTLFX_PRIVATE_CLS( MTLFXSpatialScalerDescriptor ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTLFX::SpatialScalerDescriptor* MTLFX::SpatialScalerDescriptor::init() +{ + return NS::Object::init< SpatialScalerDescriptor >(); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::SpatialScalerDescriptor::colorTextureFormat() const +{ + return Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( colorTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::SpatialScalerDescriptor::setColorTextureFormat( MTL::PixelFormat format ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setColorTextureFormat_ ), format ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::SpatialScalerDescriptor::outputTextureFormat() const +{ + return Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( outputTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::SpatialScalerDescriptor::setOutputTextureFormat( MTL::PixelFormat format ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setOutputTextureFormat_ ), format ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::SpatialScalerDescriptor::inputWidth() const +{ + return Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( inputWidth ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::SpatialScalerDescriptor::setInputWidth( NS::UInteger width ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setInputWidth_ ), width ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::SpatialScalerDescriptor::inputHeight() const +{ + return Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( inputHeight ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::SpatialScalerDescriptor::setInputHeight( NS::UInteger height ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setInputHeight_ ), height ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::SpatialScalerDescriptor::outputWidth() const +{ + return Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( outputWidth ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::SpatialScalerDescriptor::setOutputWidth( NS::UInteger width ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setOutputWidth_ ), width ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::SpatialScalerDescriptor::outputHeight() const +{ + return Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( outputHeight ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::SpatialScalerDescriptor::setOutputHeight( NS::UInteger height ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setOutputHeight_ ), height ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTLFX::SpatialScalerColorProcessingMode MTLFX::SpatialScalerDescriptor::colorProcessingMode() const +{ + return Object::sendMessage< SpatialScalerColorProcessingMode >( this, _MTLFX_PRIVATE_SEL( colorProcessingMode ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::SpatialScalerDescriptor::setColorProcessingMode( SpatialScalerColorProcessingMode mode ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setColorProcessingMode_ ), mode ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTLFX::SpatialScaler* MTLFX::SpatialScalerDescriptor::newSpatialScaler( const MTL::Device* pDevice ) const +{ + return Object::sendMessage< SpatialScaler* >( this, _MTLFX_PRIVATE_SEL( newSpatialScalerWithDevice_ ), pDevice ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL4FX::SpatialScaler* MTLFX::SpatialScalerDescriptor::newSpatialScaler( const MTL::Device* pDevice, const MTL4::Compiler* pCompiler ) const +{ + return Object::sendMessage< MTL4FX::SpatialScaler* >( this, _MTLFX_PRIVATE_SEL( newSpatialScalerWithDevice_compiler_ ), pDevice, pCompiler ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE bool MTLFX::SpatialScalerDescriptor::supportsDevice( const MTL::Device* pDevice ) +{ + return Object::sendMessageSafe< bool >( _NS_PRIVATE_CLS( MTLFXSpatialScalerDescriptor ), _MTLFX_PRIVATE_SEL( supportsDevice_ ), pDevice ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE bool MTLFX::SpatialScalerDescriptor::supportsMetal4FX( const MTL::Device* pDevice ) +{ + return Object::sendMessageSafe< bool >( _NS_PRIVATE_CLS( MTLFXSpatialScalerDescriptor ), _MTLFX_PRIVATE_SEL( supportsMetal4FX_ ), pDevice ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::TextureUsage MTLFX::SpatialScalerBase::colorTextureUsage() const +{ + return Object::sendMessage< MTL::TextureUsage >( this, _MTLFX_PRIVATE_SEL( colorTextureUsage ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::TextureUsage MTLFX::SpatialScalerBase::outputTextureUsage() const +{ + return Object::sendMessage< MTL::TextureUsage >( this, _MTLFX_PRIVATE_SEL( outputTextureUsage ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::SpatialScalerBase::inputContentWidth() const +{ + return Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( inputContentWidth ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::SpatialScalerBase::setInputContentWidth( NS::UInteger width ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setInputContentWidth_ ), width ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::SpatialScalerBase::inputContentHeight() const +{ + return Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( inputContentHeight ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::SpatialScalerBase::setInputContentHeight( NS::UInteger height ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setInputContentHeight_ ), height ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::Texture* MTLFX::SpatialScalerBase::colorTexture() const +{ + return Object::sendMessage< MTL::Texture* >( this, _MTLFX_PRIVATE_SEL( colorTexture ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::SpatialScalerBase::setColorTexture( MTL::Texture* pTexture ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setColorTexture_ ), pTexture ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::Texture* MTLFX::SpatialScalerBase::outputTexture() const +{ + return Object::sendMessage< MTL::Texture* >( this, _MTLFX_PRIVATE_SEL( outputTexture ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::SpatialScalerBase::setOutputTexture( MTL::Texture* pTexture ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setOutputTexture_ ), pTexture ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::SpatialScalerBase::colorTextureFormat() const +{ + return Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( colorTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::SpatialScalerBase::outputTextureFormat() const +{ + return Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( outputTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::SpatialScalerBase::inputWidth() const +{ + return Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( inputWidth ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::SpatialScalerBase::inputHeight() const +{ + return Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( inputHeight ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::SpatialScalerBase::outputWidth() const +{ + return Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( outputWidth ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::SpatialScalerBase::outputHeight() const +{ + return Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( outputHeight ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTLFX::SpatialScalerColorProcessingMode MTLFX::SpatialScalerBase::colorProcessingMode() const +{ + return Object::sendMessage< SpatialScalerColorProcessingMode >( this, _MTLFX_PRIVATE_SEL( colorProcessingMode ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::Fence* MTLFX::SpatialScalerBase::fence() const +{ + return Object::sendMessage< MTL::Fence* >( this, _MTLFX_PRIVATE_SEL( fence ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::SpatialScalerBase::setFence( MTL::Fence* pFence ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setFence_ ), pFence ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::SpatialScaler::encodeToCommandBuffer( MTL::CommandBuffer* pCommandBuffer ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( encodeToCommandBuffer_ ), pCommandBuffer ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/Source/Cxxmlx/metal-cpp/MetalFX/MTLFXTemporalDenoisedScaler.hpp b/Source/Cxxmlx/metal-cpp/MetalFX/MTLFXTemporalDenoisedScaler.hpp new file mode 100644 index 00000000..5863e078 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/MetalFX/MTLFXTemporalDenoisedScaler.hpp @@ -0,0 +1,1208 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// MetalFX/MTLFXTemporalDenoisedScaler.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "MTLFXDefines.hpp" +#include "MTLFXPrivate.hpp" +#include "MTLFXTemporalScaler.hpp" + +#include "../Metal/Metal.hpp" + +#include + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace MTL4FX +{ + class TemporalDenoisedScaler; +} + +namespace MTLFX +{ + class TemporalDenoisedScalerDescriptor : public NS::Copying< TemporalDenoisedScalerDescriptor > + { + public: + static class TemporalDenoisedScalerDescriptor* alloc(); + class TemporalDenoisedScalerDescriptor* init(); + + MTL::PixelFormat colorTextureFormat() const; + void setColorTextureFormat( MTL::PixelFormat pixelFormat ); + + MTL::PixelFormat depthTextureFormat() const; + void setDepthTextureFormat( MTL::PixelFormat pixelFormat ); + + MTL::PixelFormat motionTextureFormat() const; + void setMotionTextureFormat( MTL::PixelFormat pixelFormal ); + + MTL::PixelFormat diffuseAlbedoTextureFormat() const; + void setDiffuseAlbedoTextureFormat( MTL::PixelFormat pixelFormat ); + + MTL::PixelFormat specularAlbedoTextureFormat() const; + void setSpecularAlbedoTextureFormat( MTL::PixelFormat pixelFormat ); + + MTL::PixelFormat normalTextureFormat() const; + void setNormalTextureFormat( MTL::PixelFormat pixelFormat ); + + MTL::PixelFormat roughnessTextureFormat() const; + void setRoughnessTextureFormat( MTL::PixelFormat pixelFormat ); + + MTL::PixelFormat specularHitDistanceTextureFormat() const; + void setSpecularHitDistanceTextureFormat( MTL::PixelFormat pixelFormat ); + + MTL::PixelFormat denoiseStrengthMaskTextureFormat() const; + void setDenoiseStrengthMaskTextureFormat( MTL::PixelFormat pixelFormat ); + + MTL::PixelFormat transparencyOverlayTextureFormat() const; + void setTransparencyOverlayTextureFormat( MTL::PixelFormat pixelFormat ); + + MTL::PixelFormat outputTextureFormat() const; + void setOutputTextureFormat( MTL::PixelFormat pixelFormat ); + + NS::UInteger inputWidth() const; + void setInputWidth( NS::UInteger inputWidth ); + + NS::UInteger inputHeight() const; + void setInputHeight( NS::UInteger inputHeight ); + + NS::UInteger outputWidth() const; + void setOutputWidth( NS::UInteger outputWidth ); + + NS::UInteger outputHeight() const; + void setOutputHeight( NS::UInteger outputHeight ); + + bool requiresSynchronousInitialization() const; + void setRequiresSynchronousInitialization( bool requiresSynchronousInitialization ); + + bool isAutoExposureEnabled() const; + void setAutoExposureEnabled( bool enabled ); + + bool isInputContentPropertiesEnabled() const; + void setInputContentPropertiesEnabled( bool enabled ); + + float inputContentMinScale() const; + void setInputContentMinScale( float inputContentMinScale ); + + float inputContentMaxScale() const; + void setInputContentMaxScale( float inputContentMaxScale ); + + bool isReactiveMaskTextureEnabled() const; + void setReactiveMaskTextureEnabled( bool enabled ); + + MTL::PixelFormat reactiveMaskTextureFormat() const; + void setReactiveMaskTextureFormat( MTL::PixelFormat pixelFormat ); + + bool isSpecularHitDistanceTextureEnabled() const; + void setSpecularHitDistanceTextureEnabled( bool enabled ); + + bool isDenoiseStrengthMaskTextureEnabled() const; + void setDenoiseStrengthMaskTextureEnabled( bool enabled ); + + bool isTransparencyOverlayTextureEnabled() const; + void setTransparencyOverlayTextureEnabled( bool enabled ); + + class TemporalDenoisedScaler* newTemporalDenoisedScaler( const MTL::Device* device ) const; + MTL4FX::TemporalDenoisedScaler* newTemporalDenoisedScaler( const MTL::Device* device, const MTL4::Compiler* compiler) const; + + static float supportedInputContentMinScale(MTL::Device* device); + static float supportedInputContentMaxScale(MTL::Device* device); + + static bool supportsMetal4FX( MTL::Device* device); + static bool supportsDevice( MTL::Device* device); + }; + + class TemporalDenoisedScalerBase : public NS::Referencing< TemporalDenoisedScalerBase, FrameInterpolatableScaler > + { + public: + MTL::TextureUsage colorTextureUsage() const; + MTL::TextureUsage depthTextureUsage() const; + MTL::TextureUsage motionTextureUsage() const; + MTL::TextureUsage reactiveTextureUsage() const; + MTL::TextureUsage diffuseAlbedoTextureUsage() const; + MTL::TextureUsage specularAlbedoTextureUsage() const; + MTL::TextureUsage normalTextureUsage() const; + MTL::TextureUsage roughnessTextureUsage() const; + MTL::TextureUsage specularHitDistanceTextureUsage() const; + MTL::TextureUsage denoiseStrengthMaskTextureUsage() const; + MTL::TextureUsage transparencyOverlayTextureUsage() const; + MTL::TextureUsage outputTextureUsage() const; + + MTL::Texture* colorTexture() const; + void setColorTexture( MTL::Texture* colorTexture ); + + MTL::Texture* depthTexture() const; + void setDepthTexture( MTL::Texture* depthTexture ); + + MTL::Texture* motionTexture() const; + void setMotionTexture( MTL::Texture* motionTexture ); + + MTL::Texture* diffuseAlbedoTexture() const; + void setDiffuseAlbedoTexture( MTL::Texture* diffuseAlbedoTexture ); + + MTL::Texture* specularAlbedoTexture() const; + void setSpecularAlbedoTexture( MTL::Texture* specularAlbedoTexture ); + + MTL::Texture* normalTexture() const; + void setNormalTexture( MTL::Texture* normalTexture ); + + MTL::Texture* roughnessTexture() const; + void setRoughnessTexture( MTL::Texture* roughnessTexture ); + + MTL::Texture* specularHitDistanceTexture() const; + void setSpecularHitDistanceTexture( MTL::Texture* specularHitDistanceTexture ); + + MTL::Texture* denoiseStrengthMaskTexture() const; + void setDenoiseStrengthMaskTexture( MTL::Texture* denoiseStrengthMaskTexture ); + + MTL::Texture* transparencyOverlayTexture() const; + void setTransparencyOverlayTexture( MTL::Texture* transparencyOverlayTexture ); + + MTL::Texture* outputTexture() const; + void setOutputTexture( MTL::Texture* outputTexture ); + + MTL::Texture* exposureTexture() const; + void setExposureTexture( MTL::Texture* exposureTexture ); + + float preExposure() const; + void setPreExposure( float preExposure ); + + MTL::Texture* reactiveMaskTexture() const; + void setReactiveMaskTexture( MTL::Texture* reactiveMaskTexture ); + + float jitterOffsetX() const; + void setJitterOffsetX( float jitterOffsetX ); + + float jitterOffsetY() const; + void setJitterOffsetY( float jitterOffsetY ); + + float motionVectorScaleX() const; + void setMotionVectorScaleX( float motionVectorScaleX ); + + float motionVectorScaleY() const; + void setMotionVectorScaleY( float motionVectorScaleY ); + + bool shouldResetHistory() const; + void setShouldResetHistory( bool shouldResetHistory ); + + bool isDepthReversed() const; + void setDepthReversed( bool depthReversed ); + + MTL::PixelFormat colorTextureFormat() const; + MTL::PixelFormat depthTextureFormat() const; + MTL::PixelFormat motionTextureFormat() const; + MTL::PixelFormat diffuseAlbedoTextureFormat() const; + MTL::PixelFormat specularAlbedoTextureFormat() const; + MTL::PixelFormat normalTextureFormat() const; + MTL::PixelFormat roughnessTextureFormat() const; + MTL::PixelFormat specularHitDistanceTextureFormat() const; + MTL::PixelFormat denoiseStrengthMaskTextureFormat() const; + MTL::PixelFormat transparencyOverlayTextureFormat() const; + MTL::PixelFormat reactiveMaskTextureFormat() const; + MTL::PixelFormat outputTextureFormat() const; + + NS::UInteger inputWidth() const; + NS::UInteger inputHeight() const; + NS::UInteger outputWidth() const; + NS::UInteger outputHeight() const; + float inputContentMinScale() const; + float inputContentMaxScale() const; + + simd::float4x4 worldToViewMatrix() const; + void setWorldToViewMatrix( simd::float4x4 worldToViewMatrix ); + + simd::float4x4 viewToClipMatrix() const; + void setViewToClipMatrix( simd::float4x4 viewToClipMatrix ); + + MTL::Fence* fence() const; + void setFence( MTL::Fence* fence ); + }; + + class TemporalDenoisedScaler : public NS::Referencing< TemporalDenoisedScaler, TemporalDenoisedScalerBase > + { + public: + + void encodeToCommandBuffer(MTL::CommandBuffer* commandBuffer); + }; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTLFX::TemporalDenoisedScalerDescriptor* MTLFX::TemporalDenoisedScalerDescriptor::alloc() +{ + return NS::Object::alloc< TemporalDenoisedScalerDescriptor >( _MTLFX_PRIVATE_CLS( MTLFXTemporalDenoisedScalerDescriptor ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTLFX::TemporalDenoisedScalerDescriptor* MTLFX::TemporalDenoisedScalerDescriptor::init() +{ + return NS::Object::init< TemporalDenoisedScalerDescriptor >(); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalDenoisedScalerDescriptor::colorTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( colorTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerDescriptor::setColorTextureFormat( MTL::PixelFormat pixelFormat ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setColorTextureFormat_ ), pixelFormat ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalDenoisedScalerDescriptor::depthTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( depthTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerDescriptor::setDepthTextureFormat( MTL::PixelFormat pixelFormat ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setDepthTextureFormat_ ), pixelFormat ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalDenoisedScalerDescriptor::motionTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( motionTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerDescriptor::setMotionTextureFormat( MTL::PixelFormat pixelFormat ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setMotionTextureFormat_ ), pixelFormat ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalDenoisedScalerDescriptor::diffuseAlbedoTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( diffuseAlbedoTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerDescriptor::setDiffuseAlbedoTextureFormat( MTL::PixelFormat pixelFormat ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setDiffuseAlbedoTextureFormat_ ), pixelFormat ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalDenoisedScalerDescriptor::specularAlbedoTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( specularAlbedoTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerDescriptor::setSpecularAlbedoTextureFormat( MTL::PixelFormat pixelFormat ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setSpecularAlbedoTextureFormat_ ), pixelFormat ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalDenoisedScalerDescriptor::normalTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( normalTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerDescriptor::setNormalTextureFormat( MTL::PixelFormat pixelFormat ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setNormalTextureFormat_ ), pixelFormat ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalDenoisedScalerDescriptor::roughnessTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( roughnessTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerDescriptor::setRoughnessTextureFormat( MTL::PixelFormat pixelFormat ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setRoughnessTextureFormat_ ), pixelFormat ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalDenoisedScalerDescriptor::specularHitDistanceTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( specularHitDistanceTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerDescriptor::setSpecularHitDistanceTextureFormat( MTL::PixelFormat pixelFormat ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setSpecularHitDistanceTextureFormat_ ), pixelFormat ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalDenoisedScalerDescriptor::denoiseStrengthMaskTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( denoiseStrengthMaskTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerDescriptor::setDenoiseStrengthMaskTextureFormat( MTL::PixelFormat pixelFormat ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setDenoiseStrengthMaskTextureFormat_ ), pixelFormat ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalDenoisedScalerDescriptor::transparencyOverlayTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( transparencyOverlayTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerDescriptor::setTransparencyOverlayTextureFormat( MTL::PixelFormat pixelFormat ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setTransparencyOverlayTextureFormat_ ), pixelFormat ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalDenoisedScalerDescriptor::outputTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( outputTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerDescriptor::setOutputTextureFormat( MTL::PixelFormat pixelFormat ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setOutputTextureFormat_ ), pixelFormat ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::TemporalDenoisedScalerDescriptor::inputWidth() const +{ + return NS::Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( inputWidth ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerDescriptor::setInputWidth( NS::UInteger inputWidth ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setInputWidth_ ), inputWidth ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::TemporalDenoisedScalerDescriptor::inputHeight() const +{ + return NS::Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( inputHeight ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerDescriptor::setInputHeight( NS::UInteger inputHeight ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setInputHeight_ ), inputHeight ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::TemporalDenoisedScalerDescriptor::outputWidth() const +{ + return NS::Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( outputWidth ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerDescriptor::setOutputWidth( NS::UInteger outputWidth ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setOutputWidth_ ), outputWidth ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::TemporalDenoisedScalerDescriptor::outputHeight() const +{ + return NS::Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( outputHeight ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerDescriptor::setOutputHeight( NS::UInteger outputHeight ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setOutputHeight_ ), outputHeight ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE bool MTLFX::TemporalDenoisedScalerDescriptor::requiresSynchronousInitialization() const +{ + return NS::Object::sendMessage< bool >( this, _MTLFX_PRIVATE_SEL( requiresSynchronousInitialization ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerDescriptor::setRequiresSynchronousInitialization( bool requiresSynchronousInitialization ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setRequiresSynchronousInitialization_ ), requiresSynchronousInitialization ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE bool MTLFX::TemporalDenoisedScalerDescriptor::isAutoExposureEnabled() const +{ + return NS::Object::sendMessage< bool >( this, _MTLFX_PRIVATE_SEL( isAutoExposureEnabled ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerDescriptor::setAutoExposureEnabled( bool enabled ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setAutoExposureEnabled_ ), enabled ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE bool MTLFX::TemporalDenoisedScalerDescriptor::isInputContentPropertiesEnabled() const +{ + return NS::Object::sendMessage< bool >( this, _MTLFX_PRIVATE_SEL( isInputContentPropertiesEnabled ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerDescriptor::setInputContentPropertiesEnabled( bool enabled ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setInputContentPropertiesEnabled_ ), enabled ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE float MTLFX::TemporalDenoisedScalerDescriptor::inputContentMinScale() const +{ + return NS::Object::sendMessage< float >( this, _MTLFX_PRIVATE_SEL( inputContentMinScale ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerDescriptor::setInputContentMinScale( float inputContentMinScale ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setInputContentMinScale_ ), inputContentMinScale ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE float MTLFX::TemporalDenoisedScalerDescriptor::inputContentMaxScale() const +{ + return NS::Object::sendMessage< float >( this, _MTLFX_PRIVATE_SEL( inputContentMaxScale ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerDescriptor::setInputContentMaxScale( float inputContentMaxScale ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setInputContentMaxScale_ ), inputContentMaxScale ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE bool MTLFX::TemporalDenoisedScalerDescriptor::isReactiveMaskTextureEnabled() const +{ + return NS::Object::sendMessage< bool >( this, _MTLFX_PRIVATE_SEL( isReactiveMaskTextureEnabled ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerDescriptor::setReactiveMaskTextureEnabled( bool enabled ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setReactiveMaskTextureEnabled_ ), enabled ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalDenoisedScalerDescriptor::reactiveMaskTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( reactiveMaskTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerDescriptor::setReactiveMaskTextureFormat( MTL::PixelFormat pixelFormat ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setReactiveMaskTextureFormat_ ), pixelFormat ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE bool MTLFX::TemporalDenoisedScalerDescriptor::isSpecularHitDistanceTextureEnabled() const +{ + return NS::Object::sendMessage< bool >( this, _MTLFX_PRIVATE_SEL( isSpecularHitDistanceTextureEnabled ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerDescriptor::setSpecularHitDistanceTextureEnabled( bool enabled ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setSpecularHitDistanceTextureEnabled_ ), enabled ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE bool MTLFX::TemporalDenoisedScalerDescriptor::isDenoiseStrengthMaskTextureEnabled() const +{ + return NS::Object::sendMessage< bool >( this, _MTLFX_PRIVATE_SEL( isDenoiseStrengthMaskTextureEnabled ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerDescriptor::setDenoiseStrengthMaskTextureEnabled( bool enabled ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setDenoiseStrengthMaskTextureEnabled_ ), enabled ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE bool MTLFX::TemporalDenoisedScalerDescriptor::isTransparencyOverlayTextureEnabled() const +{ + return NS::Object::sendMessage< bool >( this, _MTLFX_PRIVATE_SEL( isTransparencyOverlayTextureEnabled ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerDescriptor::setTransparencyOverlayTextureEnabled( bool enabled ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setTransparencyOverlayTextureEnabled_ ), enabled ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTLFX::TemporalDenoisedScaler* MTLFX::TemporalDenoisedScalerDescriptor::newTemporalDenoisedScaler( const MTL::Device* device ) const +{ + return NS::Object::sendMessage< TemporalDenoisedScaler* >( this, _MTLFX_PRIVATE_SEL( newTemporalDenoisedScalerWithDevice_ ), device ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL4FX::TemporalDenoisedScaler* MTLFX::TemporalDenoisedScalerDescriptor::newTemporalDenoisedScaler( const MTL::Device* device, const MTL4::Compiler* compiler ) const +{ + return NS::Object::sendMessage< MTL4FX::TemporalDenoisedScaler* >( this, _MTLFX_PRIVATE_SEL( newTemporalDenoisedScalerWithDevice_compiler_ ), device, compiler ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE float MTLFX::TemporalDenoisedScalerDescriptor::supportedInputContentMinScale( MTL::Device* pDevice ) +{ + float scale = 1.0f; + + if ( nullptr != methodSignatureForSelector( _MTLFX_PRIVATE_CLS( MTLFXTemporalDenoisedScalerDescriptor ), _MTLFX_PRIVATE_SEL( supportedInputContentMinScaleForDevice_ ) ) ) + { + scale = sendMessage< float >( _NS_PRIVATE_CLS( MTLFXTemporalDenoisedScalerDescriptor ), _MTLFX_PRIVATE_SEL( supportedInputContentMinScaleForDevice_ ), pDevice ); + } + + return scale; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE float MTLFX::TemporalDenoisedScalerDescriptor::supportedInputContentMaxScale( MTL::Device* pDevice ) +{ + float scale = 1.0f; + + if ( nullptr != methodSignatureForSelector( _MTLFX_PRIVATE_CLS( MTLFXTemporalDenoisedScalerDescriptor ), _MTLFX_PRIVATE_SEL( supportedInputContentMaxScaleForDevice_ ) ) ) + { + scale = sendMessage< float >( _MTLFX_PRIVATE_CLS( MTLFXTemporalDenoisedScalerDescriptor ), _MTLFX_PRIVATE_SEL( supportedInputContentMaxScaleForDevice_ ), pDevice ); + } + else if ( supportsDevice( pDevice ) ) + { + scale = 2.0f; + } + + return scale; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE bool MTLFX::TemporalDenoisedScalerDescriptor::supportsMetal4FX( MTL::Device* device ) +{ + return NS::Object::sendMessageSafe< bool >( _MTLFX_PRIVATE_CLS(MTLFXTemporalDenoisedScalerDescriptor), _MTLFX_PRIVATE_SEL( supportsMetal4FX_ ), device ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE bool MTLFX::TemporalDenoisedScalerDescriptor::supportsDevice( MTL::Device* device ) +{ + return NS::Object::sendMessageSafe< bool >( _MTLFX_PRIVATE_CLS(MTLFXTemporalDenoisedScalerDescriptor), _MTLFX_PRIVATE_SEL( supportsDevice_ ), device ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::TextureUsage MTLFX::TemporalDenoisedScalerBase::colorTextureUsage() const +{ + return NS::Object::sendMessage< MTL::TextureUsage >( this, _MTLFX_PRIVATE_SEL( colorTextureUsage ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::TextureUsage MTLFX::TemporalDenoisedScalerBase::depthTextureUsage() const +{ + return NS::Object::sendMessage< MTL::TextureUsage >( this, _MTLFX_PRIVATE_SEL( depthTextureUsage ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::TextureUsage MTLFX::TemporalDenoisedScalerBase::motionTextureUsage() const +{ + return NS::Object::sendMessage< MTL::TextureUsage >( this, _MTLFX_PRIVATE_SEL( motionTextureUsage ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::TextureUsage MTLFX::TemporalDenoisedScalerBase::reactiveTextureUsage() const +{ + return NS::Object::sendMessage< MTL::TextureUsage >( this, _MTLFX_PRIVATE_SEL( reactiveTextureUsage ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::TextureUsage MTLFX::TemporalDenoisedScalerBase::diffuseAlbedoTextureUsage() const +{ + return NS::Object::sendMessage< MTL::TextureUsage >( this, _MTLFX_PRIVATE_SEL( diffuseAlbedoTextureUsage ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::TextureUsage MTLFX::TemporalDenoisedScalerBase::specularAlbedoTextureUsage() const +{ + return NS::Object::sendMessage< MTL::TextureUsage >( this, _MTLFX_PRIVATE_SEL( specularAlbedoTextureUsage ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::TextureUsage MTLFX::TemporalDenoisedScalerBase::normalTextureUsage() const +{ + return NS::Object::sendMessage< MTL::TextureUsage >( this, _MTLFX_PRIVATE_SEL( normalTextureUsage ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::TextureUsage MTLFX::TemporalDenoisedScalerBase::roughnessTextureUsage() const +{ + return NS::Object::sendMessage< MTL::TextureUsage >( this, _MTLFX_PRIVATE_SEL( roughnessTextureUsage ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::TextureUsage MTLFX::TemporalDenoisedScalerBase::specularHitDistanceTextureUsage() const +{ + return NS::Object::sendMessage< MTL::TextureUsage >( this, _MTLFX_PRIVATE_SEL( specularHitDistanceTextureUsage ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::TextureUsage MTLFX::TemporalDenoisedScalerBase::denoiseStrengthMaskTextureUsage() const +{ + return NS::Object::sendMessage< MTL::TextureUsage >( this, _MTLFX_PRIVATE_SEL( denoiseStrengthMaskTextureUsage ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::TextureUsage MTLFX::TemporalDenoisedScalerBase::transparencyOverlayTextureUsage() const +{ + return NS::Object::sendMessage< MTL::TextureUsage >( this, _MTLFX_PRIVATE_SEL( transparencyOverlayTextureUsage ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::TextureUsage MTLFX::TemporalDenoisedScalerBase::outputTextureUsage() const +{ + return NS::Object::sendMessage< MTL::TextureUsage >( this, _MTLFX_PRIVATE_SEL( outputTextureUsage ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::Texture* MTLFX::TemporalDenoisedScalerBase::colorTexture() const +{ + return NS::Object::sendMessage< MTL::Texture* >( this, _MTLFX_PRIVATE_SEL( colorTexture ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerBase::setColorTexture( MTL::Texture* colorTexture ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setColorTexture_ ), colorTexture ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::Texture* MTLFX::TemporalDenoisedScalerBase::depthTexture() const +{ + return NS::Object::sendMessage< MTL::Texture* >( this, _MTLFX_PRIVATE_SEL( depthTexture ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerBase::setDepthTexture( MTL::Texture* depthTexture ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setDepthTexture_ ), depthTexture ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::Texture* MTLFX::TemporalDenoisedScalerBase::motionTexture() const +{ + return NS::Object::sendMessage< MTL::Texture* >( this, _MTLFX_PRIVATE_SEL( motionTexture ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerBase::setMotionTexture( MTL::Texture* motionTexture ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setMotionTexture_ ), motionTexture ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::Texture* MTLFX::TemporalDenoisedScalerBase::diffuseAlbedoTexture() const +{ + return NS::Object::sendMessage< MTL::Texture* >( this, _MTLFX_PRIVATE_SEL( diffuseAlbedoTexture ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerBase::setDiffuseAlbedoTexture( MTL::Texture* diffuseAlbedoTexture ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setDiffuseAlbedoTexture_ ), diffuseAlbedoTexture ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::Texture* MTLFX::TemporalDenoisedScalerBase::specularAlbedoTexture() const +{ + return NS::Object::sendMessage< MTL::Texture* >( this, _MTLFX_PRIVATE_SEL( specularAlbedoTexture ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerBase::setSpecularAlbedoTexture( MTL::Texture* specularAlbedoTexture ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setSpecularAlbedoTexture_ ), specularAlbedoTexture ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::Texture* MTLFX::TemporalDenoisedScalerBase::normalTexture() const +{ + return NS::Object::sendMessage< MTL::Texture* >( this, _MTLFX_PRIVATE_SEL( normalTexture ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerBase::setNormalTexture( MTL::Texture* normalTexture ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setNormalTexture_ ), normalTexture ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::Texture* MTLFX::TemporalDenoisedScalerBase::roughnessTexture() const +{ + return NS::Object::sendMessage< MTL::Texture* >( this, _MTLFX_PRIVATE_SEL( roughnessTexture ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerBase::setRoughnessTexture( MTL::Texture* roughnessTexture ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setRoughnessTexture_ ), roughnessTexture ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::Texture* MTLFX::TemporalDenoisedScalerBase::specularHitDistanceTexture() const +{ + return NS::Object::sendMessage< MTL::Texture* >( this, _MTLFX_PRIVATE_SEL( specularHitDistanceTexture ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerBase::setSpecularHitDistanceTexture( MTL::Texture* specularHitDistanceTexture ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setSpecularHitDistanceTexture_ ), specularHitDistanceTexture ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::Texture* MTLFX::TemporalDenoisedScalerBase::denoiseStrengthMaskTexture() const +{ + return NS::Object::sendMessage< MTL::Texture* >( this, _MTLFX_PRIVATE_SEL( denoiseStrengthMaskTexture ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerBase::setDenoiseStrengthMaskTexture( MTL::Texture* denoiseStrengthMaskTexture ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setDenoiseStrengthMaskTexture_ ), denoiseStrengthMaskTexture ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::Texture* MTLFX::TemporalDenoisedScalerBase::transparencyOverlayTexture() const +{ + return NS::Object::sendMessage< MTL::Texture* >( this, _MTLFX_PRIVATE_SEL( transparencyOverlayTexture ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerBase::setTransparencyOverlayTexture( MTL::Texture* transparencyOverlayTexture ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setTransparencyOverlayTexture_ ), transparencyOverlayTexture ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::Texture* MTLFX::TemporalDenoisedScalerBase::outputTexture() const +{ + return NS::Object::sendMessage< MTL::Texture* >( this, _MTLFX_PRIVATE_SEL( outputTexture ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerBase::setOutputTexture( MTL::Texture* outputTexture ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setOutputTexture_ ), outputTexture ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::Texture* MTLFX::TemporalDenoisedScalerBase::exposureTexture() const +{ + return NS::Object::sendMessage< MTL::Texture* >( this, _MTLFX_PRIVATE_SEL( exposureTexture ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerBase::setExposureTexture( MTL::Texture* exposureTexture ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setExposureTexture_ ), exposureTexture ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE float MTLFX::TemporalDenoisedScalerBase::preExposure() const +{ + return NS::Object::sendMessage< float >( this, _MTLFX_PRIVATE_SEL( preExposure ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerBase::setPreExposure( float preExposure ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setPreExposure_ ), preExposure ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::Texture* MTLFX::TemporalDenoisedScalerBase::reactiveMaskTexture() const +{ + return NS::Object::sendMessage< MTL::Texture* >( this, _MTLFX_PRIVATE_SEL( reactiveMaskTexture ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerBase::setReactiveMaskTexture( MTL::Texture* reactiveMaskTexture ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setReactiveMaskTexture_ ), reactiveMaskTexture ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE float MTLFX::TemporalDenoisedScalerBase::jitterOffsetX() const +{ + return NS::Object::sendMessage< float >( this, _MTLFX_PRIVATE_SEL( jitterOffsetX ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerBase::setJitterOffsetX( float jitterOffsetX ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setJitterOffsetX_ ), jitterOffsetX ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE float MTLFX::TemporalDenoisedScalerBase::jitterOffsetY() const +{ + return NS::Object::sendMessage< float >( this, _MTLFX_PRIVATE_SEL( jitterOffsetY ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerBase::setJitterOffsetY( float jitterOffsetY ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setJitterOffsetY_ ), jitterOffsetY ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE float MTLFX::TemporalDenoisedScalerBase::motionVectorScaleX() const +{ + return NS::Object::sendMessage< float >( this, _MTLFX_PRIVATE_SEL( motionVectorScaleX ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerBase::setMotionVectorScaleX( float motionVectorScaleX ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setMotionVectorScaleX_ ), motionVectorScaleX ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE float MTLFX::TemporalDenoisedScalerBase::motionVectorScaleY() const +{ + return NS::Object::sendMessage< float >( this, _MTLFX_PRIVATE_SEL( motionVectorScaleY ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerBase::setMotionVectorScaleY( float motionVectorScaleY ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setMotionVectorScaleY_ ), motionVectorScaleY ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE bool MTLFX::TemporalDenoisedScalerBase::shouldResetHistory() const +{ + return NS::Object::sendMessage< bool >( this, _MTLFX_PRIVATE_SEL( shouldResetHistory ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerBase::setShouldResetHistory( bool shouldResetHistory ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setShouldResetHistory_ ), shouldResetHistory ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE bool MTLFX::TemporalDenoisedScalerBase::isDepthReversed() const +{ + return NS::Object::sendMessage< bool >( this, _MTLFX_PRIVATE_SEL( isDepthReversed ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerBase::setDepthReversed( bool depthReversed ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setDepthReversed_ ), depthReversed ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalDenoisedScalerBase::colorTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( colorTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalDenoisedScalerBase::depthTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( depthTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalDenoisedScalerBase::motionTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( motionTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalDenoisedScalerBase::diffuseAlbedoTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( diffuseAlbedoTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalDenoisedScalerBase::specularAlbedoTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( specularAlbedoTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalDenoisedScalerBase::normalTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( normalTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalDenoisedScalerBase::roughnessTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( roughnessTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalDenoisedScalerBase::specularHitDistanceTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( specularHitDistanceTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalDenoisedScalerBase::denoiseStrengthMaskTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( denoiseStrengthMaskTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalDenoisedScalerBase::transparencyOverlayTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( transparencyOverlayTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalDenoisedScalerBase::reactiveMaskTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( reactiveMaskTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalDenoisedScalerBase::outputTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( outputTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::TemporalDenoisedScalerBase::inputWidth() const +{ + return NS::Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( inputWidth ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::TemporalDenoisedScalerBase::inputHeight() const +{ + return NS::Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( inputHeight ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::TemporalDenoisedScalerBase::outputWidth() const +{ + return NS::Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( outputWidth ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::TemporalDenoisedScalerBase::outputHeight() const +{ + return NS::Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( outputHeight ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE float MTLFX::TemporalDenoisedScalerBase::inputContentMinScale() const +{ + return NS::Object::sendMessage< float >( this, _MTLFX_PRIVATE_SEL( inputContentMinScale ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE float MTLFX::TemporalDenoisedScalerBase::inputContentMaxScale() const +{ + return NS::Object::sendMessage< float >( this, _MTLFX_PRIVATE_SEL( inputContentMaxScale ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE simd::float4x4 MTLFX::TemporalDenoisedScalerBase::worldToViewMatrix() const +{ + return NS::Object::sendMessage< simd::float4x4 >( this, _MTLFX_PRIVATE_SEL( worldToViewMatrix ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerBase::setWorldToViewMatrix( simd::float4x4 worldToViewMatrix ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setWorldToViewMatrix_ ), worldToViewMatrix ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE simd::float4x4 MTLFX::TemporalDenoisedScalerBase::viewToClipMatrix() const +{ + return NS::Object::sendMessage< simd::float4x4 >( this, _MTLFX_PRIVATE_SEL( viewToClipMatrix ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerBase::setViewToClipMatrix( simd::float4x4 viewToClipMatrix ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setViewToClipMatrix_ ), viewToClipMatrix ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::Fence* MTLFX::TemporalDenoisedScalerBase::fence() const +{ + return NS::Object::sendMessage< MTL::Fence* >( this, _MTLFX_PRIVATE_SEL( fence ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerBase::setFence( MTL::Fence* fence ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setFence_ ), fence ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScaler::encodeToCommandBuffer( MTL::CommandBuffer* commandBuffer ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( encodeToCommandBuffer_ ), commandBuffer ); +} diff --git a/Source/Cxxmlx/metal-cpp/MetalFX/MTLFXTemporalScaler.hpp b/Source/Cxxmlx/metal-cpp/MetalFX/MTLFXTemporalScaler.hpp new file mode 100644 index 00000000..c13d4242 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/MetalFX/MTLFXTemporalScaler.hpp @@ -0,0 +1,803 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// MetalFX/MTLFXTemporalScaler.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "MTLFXDefines.hpp" +#include "MTLFXPrivate.hpp" + +#include "../Metal/Metal.hpp" + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace MTL4FX +{ + class TemporalScaler; +} + +namespace MTLFX +{ + class TemporalScalerDescriptor : public NS::Copying< TemporalScalerDescriptor > + { + public: + static class TemporalScalerDescriptor* alloc(); + class TemporalScalerDescriptor* init(); + + MTL::PixelFormat colorTextureFormat() const; + void setColorTextureFormat( MTL::PixelFormat format ); + + MTL::PixelFormat depthTextureFormat() const; + void setDepthTextureFormat( MTL::PixelFormat format ); + + MTL::PixelFormat motionTextureFormat() const; + void setMotionTextureFormat( MTL::PixelFormat format ); + + MTL::PixelFormat outputTextureFormat() const; + void setOutputTextureFormat( MTL::PixelFormat format ); + + NS::UInteger inputWidth() const; + void setInputWidth( NS::UInteger width ); + + NS::UInteger inputHeight() const; + void setInputHeight( NS::UInteger height ); + + NS::UInteger outputWidth() const; + void setOutputWidth( NS::UInteger width ); + + NS::UInteger outputHeight() const; + void setOutputHeight( NS::UInteger height ); + + bool isAutoExposureEnabled() const; + void setAutoExposureEnabled( bool enabled ); + + bool isInputContentPropertiesEnabled() const; + void setInputContentPropertiesEnabled( bool enabled ); + + bool requiresSynchronousInitialization() const; + void setRequiresSynchronousInitialization(bool requiresSynchronousInitialization); + + bool isReactiveMaskTextureEnabled() const; + void setReactiveMaskTextureEnabled( bool enabled ); + + MTL::PixelFormat reactiveMaskTextureFormat() const; + void setReactiveMaskTextureFormat( MTL::PixelFormat pixelFormat ); + + float inputContentMinScale() const; + void setInputContentMinScale( float scale ); + + float inputContentMaxScale() const; + void setInputContentMaxScale( float scale ); + + class TemporalScaler* newTemporalScaler( const MTL::Device* pDevice ) const; + MTL4FX::TemporalScaler* newTemporalScaler( const MTL::Device* pDevice, const MTL4::Compiler* pCompiler) const; + + static float supportedInputContentMinScale( const MTL::Device* pDevice ); + static float supportedInputContentMaxScale( const MTL::Device* pDevice ); + + static bool supportsDevice( const MTL::Device* pDevice ); + static bool supportsMetal4FX( const MTL::Device* pDevice ); + }; + + class FrameInterpolatableScaler : public NS::Copying< FrameInterpolatableScaler > + { + }; + + class TemporalScalerBase : public NS::Referencing< TemporalScaler, FrameInterpolatableScaler > + { + public: + MTL::TextureUsage colorTextureUsage() const; + MTL::TextureUsage depthTextureUsage() const; + MTL::TextureUsage motionTextureUsage() const; + MTL::TextureUsage outputTextureUsage() const; + + NS::UInteger inputContentWidth() const; + void setInputContentWidth( NS::UInteger width ); + + NS::UInteger inputContentHeight() const; + void setInputContentHeight( NS::UInteger height ); + + MTL::Texture* colorTexture() const; + void setColorTexture( MTL::Texture* pTexture ); + + MTL::Texture* depthTexture() const; + void setDepthTexture( MTL::Texture* pTexture ); + + MTL::Texture* motionTexture() const; + void setMotionTexture( MTL::Texture* pTexture ); + + MTL::Texture* outputTexture() const; + void setOutputTexture( MTL::Texture* pTexture ); + + MTL::Texture* exposureTexture() const; + void setExposureTexture( MTL::Texture* pTexture ); + + float preExposure() const; + void setPreExposure( float preExposure ); + + float jitterOffsetX() const; + void setJitterOffsetX( float offset ); + + float jitterOffsetY() const; + void setJitterOffsetY( float offset ); + + float motionVectorScaleX() const; + void setMotionVectorScaleX( float scale ); + + float motionVectorScaleY() const; + void setMotionVectorScaleY( float scale ); + + MTL::Texture* reactiveMaskTexture() const; + void setReactiveMaskTexture( MTL::Texture* reactiveMaskTexture ); + + MTL::TextureUsage reactiveTextureUsage() const; + + bool reset() const; + void setReset( bool reset ); + + bool isDepthReversed() const; + void setDepthReversed( bool depthReversed ); + + MTL::PixelFormat colorTextureFormat() const; + MTL::PixelFormat depthTextureFormat() const; + MTL::PixelFormat motionTextureFormat() const; + MTL::PixelFormat reactiveTextureFormat() const; + MTL::PixelFormat outputTextureFormat() const; + NS::UInteger inputWidth() const; + NS::UInteger inputHeight() const; + NS::UInteger outputWidth() const; + NS::UInteger outputHeight() const; + float inputContentMinScale() const; + float inputContentMaxScale() const; + + MTL::Fence* fence() const; + void setFence( MTL::Fence* pFence ); + }; + + class TemporalScaler : public NS::Referencing< TemporalScaler, TemporalScalerBase > + { + public: + void encodeToCommandBuffer( MTL::CommandBuffer* pCommandBuffer ); + }; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTLFX::TemporalScalerDescriptor* MTLFX::TemporalScalerDescriptor::alloc() +{ + return NS::Object::alloc< TemporalScalerDescriptor >( _MTLFX_PRIVATE_CLS( MTLFXTemporalScalerDescriptor ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTLFX::TemporalScalerDescriptor* MTLFX::TemporalScalerDescriptor::init() +{ + return NS::Object::init< TemporalScalerDescriptor >(); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalScalerDescriptor::colorTextureFormat() const +{ + return Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( colorTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerDescriptor::setColorTextureFormat( MTL::PixelFormat format ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setColorTextureFormat_ ), format ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalScalerDescriptor::depthTextureFormat() const +{ + return Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( depthTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerDescriptor::setDepthTextureFormat( MTL::PixelFormat format ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setDepthTextureFormat_ ), format ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalScalerDescriptor::motionTextureFormat() const +{ + return Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( motionTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerDescriptor::setMotionTextureFormat( MTL::PixelFormat format ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setMotionTextureFormat_ ), format ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalScalerDescriptor::outputTextureFormat() const +{ + return Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( outputTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerDescriptor::setOutputTextureFormat( MTL::PixelFormat format ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setOutputTextureFormat_ ), format ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::TemporalScalerDescriptor::inputWidth() const +{ + return Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( inputWidth ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerDescriptor::setInputWidth( NS::UInteger width ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setInputWidth_ ), width ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::TemporalScalerDescriptor::inputHeight() const +{ + return Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( inputHeight ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerDescriptor::setInputHeight( NS::UInteger height ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setInputHeight_ ), height ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::TemporalScalerDescriptor::outputWidth() const +{ + return Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( outputWidth ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerDescriptor::setOutputWidth( NS::UInteger width ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setOutputWidth_ ), width ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::TemporalScalerDescriptor::outputHeight() const +{ + return Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( outputHeight ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerDescriptor::setOutputHeight( NS::UInteger height ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setOutputHeight_ ), height ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE bool MTLFX::TemporalScalerDescriptor::isAutoExposureEnabled() const +{ + return Object::sendMessage< bool >( this, _MTLFX_PRIVATE_SEL( isAutoExposureEnabled ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerDescriptor::setAutoExposureEnabled( bool enabled ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setAutoExposureEnabled_ ), enabled ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE bool MTLFX::TemporalScalerDescriptor::isInputContentPropertiesEnabled() const +{ + return Object::sendMessage< bool >( this, _MTLFX_PRIVATE_SEL( isInputContentPropertiesEnabled ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerDescriptor::setInputContentPropertiesEnabled( bool enabled ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setInputContentPropertiesEnabled_ ), enabled ); +} + + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE bool MTLFX::TemporalScalerDescriptor::requiresSynchronousInitialization() const +{ + return Object::sendMessage< bool >( this, _MTLFX_PRIVATE_SEL( requiresSynchronousInitialization ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerDescriptor::setRequiresSynchronousInitialization(bool requiresSynchronousInitialization) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setRequiresSynchronousInitialization_ ), requiresSynchronousInitialization ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE bool MTLFX::TemporalScalerDescriptor::isReactiveMaskTextureEnabled() const +{ + return Object::sendMessage< bool >( this, _MTLFX_PRIVATE_SEL( isReactiveMaskTextureEnabled ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerDescriptor::setReactiveMaskTextureEnabled( bool enabled ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setReactiveMaskTextureEnabled_ ), enabled ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalScalerDescriptor::reactiveMaskTextureFormat() const +{ + return Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( reactiveMaskTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerDescriptor::setReactiveMaskTextureFormat( MTL::PixelFormat pixelFormat ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setReactiveMaskTextureFormat_ ), pixelFormat ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE float MTLFX::TemporalScalerDescriptor::inputContentMinScale() const +{ + return Object::sendMessage< float >( this, _MTLFX_PRIVATE_SEL( inputContentMinScale ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerDescriptor::setInputContentMinScale( float scale ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setInputContentMinScale_ ), scale ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE float MTLFX::TemporalScalerDescriptor::inputContentMaxScale() const +{ + return Object::sendMessage< float >( this, _MTLFX_PRIVATE_SEL( inputContentMaxScale ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerDescriptor::setInputContentMaxScale( float scale ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setInputContentMaxScale_ ), scale ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTLFX::TemporalScaler* MTLFX::TemporalScalerDescriptor::newTemporalScaler( const MTL::Device* pDevice ) const +{ + return Object::sendMessage< TemporalScaler* >( this, _MTLFX_PRIVATE_SEL( newTemporalScalerWithDevice_ ), pDevice ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL4FX::TemporalScaler* MTLFX::TemporalScalerDescriptor::newTemporalScaler( const MTL::Device* pDevice, const MTL4::Compiler* pCompiler ) const +{ + return Object::sendMessage< MTL4FX::TemporalScaler* >( this, _MTLFX_PRIVATE_SEL( newTemporalScalerWithDevice_compiler_ ), pDevice, pCompiler ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE float MTLFX::TemporalScalerDescriptor::supportedInputContentMinScale( const MTL::Device* pDevice ) +{ + float scale = 1.0f; + + if ( nullptr != methodSignatureForSelector( _NS_PRIVATE_CLS( MTLFXTemporalScalerDescriptor ), _MTLFX_PRIVATE_SEL( supportedInputContentMinScaleForDevice_ ) ) ) + { + scale = sendMessage< float >( _NS_PRIVATE_CLS( MTLFXTemporalScalerDescriptor ), _MTLFX_PRIVATE_SEL( supportedInputContentMinScaleForDevice_ ), pDevice ); + } + + return scale; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE float MTLFX::TemporalScalerDescriptor::supportedInputContentMaxScale( const MTL::Device* pDevice ) +{ + float scale = 1.0f; + + if ( nullptr != methodSignatureForSelector( _NS_PRIVATE_CLS( MTLFXTemporalScalerDescriptor ), _MTLFX_PRIVATE_SEL( supportedInputContentMaxScaleForDevice_ ) ) ) + { + scale = sendMessage< float >( _NS_PRIVATE_CLS( MTLFXTemporalScalerDescriptor ), _MTLFX_PRIVATE_SEL( supportedInputContentMaxScaleForDevice_ ), pDevice ); + } + else if ( supportsDevice( pDevice ) ) + { + scale = 2.0f; + } + + return scale; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE bool MTLFX::TemporalScalerDescriptor::supportsDevice( const MTL::Device* pDevice ) +{ + return Object::sendMessageSafe< bool >( _NS_PRIVATE_CLS( MTLFXTemporalScalerDescriptor ), _MTLFX_PRIVATE_SEL( supportsDevice_ ), pDevice ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE bool MTLFX::TemporalScalerDescriptor::supportsMetal4FX( const MTL::Device* pDevice ) +{ + return Object::sendMessageSafe< bool >( _NS_PRIVATE_CLS( MTLFXTemporalScalerDescriptor ), _MTLFX_PRIVATE_SEL( supportsMetal4FX_ ), pDevice ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::TextureUsage MTLFX::TemporalScalerBase::colorTextureUsage() const +{ + return Object::sendMessage< MTL::TextureUsage >( this, _MTLFX_PRIVATE_SEL( colorTextureUsage ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::TextureUsage MTLFX::TemporalScalerBase::depthTextureUsage() const +{ + return Object::sendMessage< MTL::TextureUsage >( this, _MTLFX_PRIVATE_SEL( depthTextureUsage ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::TextureUsage MTLFX::TemporalScalerBase::motionTextureUsage() const +{ + return Object::sendMessage< MTL::TextureUsage >( this, _MTLFX_PRIVATE_SEL( motionTextureUsage ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::TextureUsage MTLFX::TemporalScalerBase::outputTextureUsage() const +{ + return Object::sendMessage< MTL::TextureUsage >( this, _MTLFX_PRIVATE_SEL( outputTextureUsage ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::TemporalScalerBase::inputContentWidth() const +{ + return Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( inputContentWidth ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerBase::setInputContentWidth( NS::UInteger width ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setInputContentWidth_ ), width ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::TemporalScalerBase::inputContentHeight() const +{ + return Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( inputContentHeight ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerBase::setInputContentHeight( NS::UInteger height ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setInputContentHeight_ ), height ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::Texture* MTLFX::TemporalScalerBase::colorTexture() const +{ + return Object::sendMessage< MTL::Texture* >( this, _MTLFX_PRIVATE_SEL( colorTexture ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerBase::setColorTexture( MTL::Texture* pTexture ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setColorTexture_ ), pTexture ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::Texture* MTLFX::TemporalScalerBase::depthTexture() const +{ + return Object::sendMessage< MTL::Texture* >( this, _MTLFX_PRIVATE_SEL( depthTexture ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerBase::setDepthTexture( MTL::Texture* pTexture ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setDepthTexture_ ), pTexture ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::Texture* MTLFX::TemporalScalerBase::motionTexture() const +{ + return Object::sendMessage< MTL::Texture* >( this, _MTLFX_PRIVATE_SEL( motionTexture ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerBase::setMotionTexture( MTL::Texture* pTexture ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setMotionTexture_ ), pTexture ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::Texture* MTLFX::TemporalScalerBase::outputTexture() const +{ + return Object::sendMessage< MTL::Texture* >( this, _MTLFX_PRIVATE_SEL( outputTexture ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerBase::setOutputTexture( MTL::Texture* pTexture ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setOutputTexture_ ), pTexture ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::Texture* MTLFX::TemporalScalerBase::exposureTexture() const +{ + return Object::sendMessage< MTL::Texture* >( this, _MTLFX_PRIVATE_SEL( exposureTexture ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerBase::setExposureTexture( MTL::Texture* pTexture ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setExposureTexture_ ), pTexture ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE float MTLFX::TemporalScalerBase::preExposure() const +{ + return Object::sendMessage< float >( this, _MTLFX_PRIVATE_SEL( preExposure ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerBase::setPreExposure( float preExposure ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setPreExposure_ ), preExposure ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE float MTLFX::TemporalScalerBase::jitterOffsetX() const +{ + return Object::sendMessage< float >( this, _MTLFX_PRIVATE_SEL( jitterOffsetX ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerBase::setJitterOffsetX( float offset ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setJitterOffsetX_ ), offset ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE float MTLFX::TemporalScalerBase::jitterOffsetY() const +{ + return Object::sendMessage< float >( this, _MTLFX_PRIVATE_SEL( jitterOffsetY ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerBase::setJitterOffsetY( float offset ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setJitterOffsetY_ ), offset ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE float MTLFX::TemporalScalerBase::motionVectorScaleX() const +{ + return Object::sendMessage< float >( this, _MTLFX_PRIVATE_SEL( motionVectorScaleX ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerBase::setMotionVectorScaleX( float scale ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setMotionVectorScaleX_ ), scale ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE float MTLFX::TemporalScalerBase::motionVectorScaleY() const +{ + return Object::sendMessage< float >( this, _MTLFX_PRIVATE_SEL( motionVectorScaleY ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerBase::setMotionVectorScaleY( float scale ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setMotionVectorScaleY_ ), scale ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::Texture* MTLFX::TemporalScalerBase::reactiveMaskTexture() const +{ + return Object::sendMessage< MTL::Texture* >( this, _MTLFX_PRIVATE_SEL( reactiveMaskTexture ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerBase::setReactiveMaskTexture( MTL::Texture* reactiveMaskTexture ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setReactiveMaskTexture_ ), reactiveMaskTexture ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::TextureUsage MTLFX::TemporalScalerBase::reactiveTextureUsage() const +{ + return Object::sendMessage< MTL::TextureUsage >( this, _MTLFX_PRIVATE_SEL( reactiveTextureUsage ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE bool MTLFX::TemporalScalerBase::reset() const +{ + return Object::sendMessage< bool >( this, _MTLFX_PRIVATE_SEL( reset ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerBase::setReset( bool reset ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setReset_ ), reset ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE bool MTLFX::TemporalScalerBase::isDepthReversed() const +{ + return Object::sendMessage< bool >( this, _MTLFX_PRIVATE_SEL( isDepthReversed ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerBase::setDepthReversed( bool depthReversed ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setDepthReversed_ ), depthReversed ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalScalerBase::colorTextureFormat() const +{ + return Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( colorTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalScalerBase::depthTextureFormat() const +{ + return Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( depthTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalScalerBase::motionTextureFormat() const +{ + return Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( motionTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalScalerBase::outputTextureFormat() const +{ + return Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( outputTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::TemporalScalerBase::inputWidth() const +{ + return Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( inputWidth ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::TemporalScalerBase::inputHeight() const +{ + return Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( inputHeight ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::TemporalScalerBase::outputWidth() const +{ + return Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( outputWidth ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::TemporalScalerBase::outputHeight() const +{ + return Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( outputHeight ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE float MTLFX::TemporalScalerBase::inputContentMinScale() const +{ + return Object::sendMessage< float >( this, _MTLFX_PRIVATE_SEL( inputContentMinScale ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE float MTLFX::TemporalScalerBase::inputContentMaxScale() const +{ + return Object::sendMessage< float >( this, _MTLFX_PRIVATE_SEL( inputContentMaxScale ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::Fence* MTLFX::TemporalScalerBase::fence() const +{ + return Object::sendMessage< MTL::Fence* >( this, _MTLFX_PRIVATE_SEL( fence ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerBase::setFence( MTL::Fence* pFence ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setFence_ ), pFence ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScaler::encodeToCommandBuffer( MTL::CommandBuffer* pCommandBuffer ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( encodeToCommandBuffer_ ), pCommandBuffer ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/Source/Cxxmlx/metal-cpp/MetalFX/MetalFX.hpp b/Source/Cxxmlx/metal-cpp/MetalFX/MetalFX.hpp new file mode 100644 index 00000000..20e647e8 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/MetalFX/MetalFX.hpp @@ -0,0 +1,35 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// MetalFX/MetalFX.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "MTLFXSpatialScaler.hpp" +#include "MTLFXTemporalScaler.hpp" +#include "MTLFXTemporalDenoisedScaler.hpp" +#include "MTLFXFrameInterpolator.hpp" + +#include "MTL4FXSpatialScaler.hpp" +#include "MTL4FXTemporalScaler.hpp" +#include "MTL4FXTemporalDenoisedScaler.hpp" +#include "MTL4FXFrameInterpolator.hpp" + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/Source/Cxxmlx/metal-cpp/QuartzCore/CADefines.hpp b/Source/Cxxmlx/metal-cpp/QuartzCore/CADefines.hpp new file mode 100644 index 00000000..b0641de0 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/QuartzCore/CADefines.hpp @@ -0,0 +1,41 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// QuartzCore/CADefines.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "../Foundation/NSDefines.hpp" + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#define _CA_EXPORT _NS_EXPORT +#define _CA_EXTERN _NS_EXTERN +#define _CA_INLINE _NS_INLINE +#define _CA_PACKED _NS_PACKED + +#define _CA_CONST(type, name) _NS_CONST(type, name) +#define _CA_ENUM(type, name) _NS_ENUM(type, name) +#define _CA_OPTIONS(type, name) _NS_OPTIONS(type, name) + +#define _CA_VALIDATE_SIZE(ns, name) _NS_VALIDATE_SIZE(ns, name) +#define _CA_VALIDATE_ENUM(ns, name) _NS_VALIDATE_ENUM(ns, name) + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/Source/Cxxmlx/metal-cpp/QuartzCore/CAMetalDrawable.hpp b/Source/Cxxmlx/metal-cpp/QuartzCore/CAMetalDrawable.hpp new file mode 100644 index 00000000..0057773a --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/QuartzCore/CAMetalDrawable.hpp @@ -0,0 +1,57 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// QuartzCore/CAMetalDrawable.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "../Metal/MTLDrawable.hpp" +#include "../Metal/MTLTexture.hpp" + +#include "CADefines.hpp" +#include "CAPrivate.hpp" + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace CA +{ +class MetalDrawable : public NS::Referencing +{ +public: + class MetalLayer* layer() const; + MTL::Texture* texture() const; +}; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_CA_INLINE CA::MetalLayer* CA::MetalDrawable::layer() const +{ + return Object::sendMessage(this, _CA_PRIVATE_SEL(layer)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_CA_INLINE MTL::Texture* CA::MetalDrawable::texture() const +{ + return Object::sendMessage(this, _CA_PRIVATE_SEL(texture)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/Source/Cxxmlx/metal-cpp/QuartzCore/CAMetalLayer.hpp b/Source/Cxxmlx/metal-cpp/QuartzCore/CAMetalLayer.hpp new file mode 100644 index 00000000..53f6857d --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/QuartzCore/CAMetalLayer.hpp @@ -0,0 +1,216 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// QuartzCore/CAMetalDrawable.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "../Metal/MTLPixelFormat.hpp" +#include "../Metal/MTLTexture.hpp" +#include "../Metal/MTLResidencySet.hpp" +#include "../Foundation/NSTypes.hpp" +#include +#include + +#include "CADefines.hpp" +#include "CAMetalDrawable.hpp" +#include "CAPrivate.hpp" + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace CA +{ + +class MetalLayer : public NS::Referencing +{ +public: + static class MetalLayer* layer(); + + MTL::Device* device() const; + void setDevice(MTL::Device* device); + + MTL::PixelFormat pixelFormat() const; + void setPixelFormat(MTL::PixelFormat pixelFormat); + + bool framebufferOnly() const; + void setFramebufferOnly(bool framebufferOnly); + + CGSize drawableSize() const; + void setDrawableSize(CGSize drawableSize); + + class MetalDrawable* nextDrawable(); + + NS::UInteger maximumDrawableCount() const; + void setMaximumDrawableCount(NS::UInteger maximumDrawableCount); + + bool displaySyncEnabled() const; + void setDisplaySyncEnabled(bool displaySyncEnabled); + + CGColorSpaceRef colorspace() const; + void setColorspace(CGColorSpaceRef colorspace); + + bool allowsNextDrawableTimeout() const; + void setAllowsNextDrawableTimeout(bool allowsNextDrawableTimeout); + + MTL::ResidencySet* residencySet() const; +}; +} // namespace CA + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +_CA_INLINE CA::MetalLayer* CA::MetalLayer::layer() +{ + return Object::sendMessage(_CA_PRIVATE_CLS(CAMetalLayer), _CA_PRIVATE_SEL(layer)); +} +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_CA_INLINE MTL::Device* CA::MetalLayer::device() const +{ + return Object::sendMessage(this, _CA_PRIVATE_SEL(device)); +} +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_CA_INLINE void CA::MetalLayer::setDevice(MTL::Device* device) +{ + return Object::sendMessage(this, _CA_PRIVATE_SEL(setDevice_), device); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_CA_INLINE MTL::PixelFormat CA::MetalLayer::pixelFormat() const +{ + return Object::sendMessage(this, + _CA_PRIVATE_SEL(pixelFormat)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_CA_INLINE void CA::MetalLayer::setPixelFormat(MTL::PixelFormat pixelFormat) +{ + return Object::sendMessage(this, _CA_PRIVATE_SEL(setPixelFormat_), + pixelFormat); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_CA_INLINE bool CA::MetalLayer::framebufferOnly() const +{ + return Object::sendMessage(this, _CA_PRIVATE_SEL(framebufferOnly)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_CA_INLINE void CA::MetalLayer::setFramebufferOnly(bool framebufferOnly) +{ + return Object::sendMessage(this, _CA_PRIVATE_SEL(setFramebufferOnly_), + framebufferOnly); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_CA_INLINE CGSize CA::MetalLayer::drawableSize() const +{ + return Object::sendMessage(this, _CA_PRIVATE_SEL(drawableSize)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_CA_INLINE void CA::MetalLayer::setDrawableSize(CGSize drawableSize) +{ + return Object::sendMessage(this, _CA_PRIVATE_SEL(setDrawableSize_), + drawableSize); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_CA_INLINE CA::MetalDrawable* CA::MetalLayer::nextDrawable() +{ + return Object::sendMessage(this, + _CA_PRIVATE_SEL(nextDrawable)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_CA_INLINE NS::UInteger CA::MetalLayer::maximumDrawableCount() const +{ + return Object::sendMessage(this, + _CA_PRIVATE_SEL(maximumDrawableCount)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_CA_INLINE void CA::MetalLayer::setMaximumDrawableCount(NS::UInteger maximumDrawableCount) +{ + return Object::sendMessage(this, _CA_PRIVATE_SEL(setMaximumDrawableCount_), + maximumDrawableCount); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_CA_INLINE bool CA::MetalLayer::displaySyncEnabled() const +{ + return Object::sendMessage(this, _CA_PRIVATE_SEL(displaySyncEnabled)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_CA_INLINE void CA::MetalLayer::setDisplaySyncEnabled(bool displaySyncEnabled) +{ + return Object::sendMessage(this, _CA_PRIVATE_SEL(setDisplaySyncEnabled_), + displaySyncEnabled); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_CA_INLINE CGColorSpaceRef CA::MetalLayer::colorspace() const +{ + return Object::sendMessage(this, _CA_PRIVATE_SEL(colorspace)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_CA_INLINE void CA::MetalLayer::setColorspace(CGColorSpaceRef colorspace) +{ + return Object::sendMessage(this, _CA_PRIVATE_SEL(setColorspace_), + colorspace); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_CA_INLINE bool CA::MetalLayer::allowsNextDrawableTimeout() const +{ + return Object::sendMessage(this, _CA_PRIVATE_SEL(allowsNextDrawableTimeout)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_CA_INLINE void CA::MetalLayer::setAllowsNextDrawableTimeout(bool allowsNextDrawableTimeout) +{ + return Object::sendMessage(this, _CA_PRIVATE_SEL(setAllowsNextDrawableTimeout_), + allowsNextDrawableTimeout); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_CA_INLINE MTL::ResidencySet* CA::MetalLayer::residencySet() const +{ + return Object::sendMessage(this, _CA_PRIVATE_SEL(residencySet) ); +} diff --git a/Source/Cxxmlx/metal-cpp/QuartzCore/CAPrivate.hpp b/Source/Cxxmlx/metal-cpp/QuartzCore/CAPrivate.hpp new file mode 100644 index 00000000..0b7486a7 --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/QuartzCore/CAPrivate.hpp @@ -0,0 +1,150 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// QuartzCore/CAPrivate.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "CADefines.hpp" + +#include + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#define _CA_PRIVATE_CLS(symbol) (Private::Class::s_k##symbol) +#define _CA_PRIVATE_SEL(accessor) (Private::Selector::s_k##accessor) + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#if defined(CA_PRIVATE_IMPLEMENTATION) + +#ifdef METALCPP_SYMBOL_VISIBILITY_HIDDEN +#define _CA_PRIVATE_VISIBILITY __attribute__((visibility("hidden"))) +#else +#define _CA_PRIVATE_VISIBILITY __attribute__((visibility("default"))) +#endif // METALCPP_SYMBOL_VISIBILITY_HIDDEN + +#define _CA_PRIVATE_IMPORT __attribute__((weak_import)) + +#ifdef __OBJC__ +#define _CA_PRIVATE_OBJC_LOOKUP_CLASS(symbol) ((__bridge void*)objc_lookUpClass(#symbol)) +#define _CA_PRIVATE_OBJC_GET_PROTOCOL(symbol) ((__bridge void*)objc_getProtocol(#symbol)) +#else +#define _CA_PRIVATE_OBJC_LOOKUP_CLASS(symbol) objc_lookUpClass(#symbol) +#define _CA_PRIVATE_OBJC_GET_PROTOCOL(symbol) objc_getProtocol(#symbol) +#endif // __OBJC__ + +#define _CA_PRIVATE_DEF_CLS(symbol) void* s_k##symbol _CA_PRIVATE_VISIBILITY = _CA_PRIVATE_OBJC_LOOKUP_CLASS(symbol) +#define _CA_PRIVATE_DEF_PRO(symbol) void* s_k##symbol _CA_PRIVATE_VISIBILITY = _CA_PRIVATE_OBJC_GET_PROTOCOL(symbol) +#define _CA_PRIVATE_DEF_SEL(accessor, symbol) SEL s_k##accessor _CA_PRIVATE_VISIBILITY = sel_registerName(symbol) +#define _CA_PRIVATE_DEF_STR(type, symbol) \ + _CA_EXTERN type const CA##symbol _CA_PRIVATE_IMPORT; \ + type const CA::symbol = (nullptr != &CA##symbol) ? CA##symbol : nullptr + +#else + +#define _CA_PRIVATE_DEF_CLS(symbol) extern void* s_k##symbol +#define _CA_PRIVATE_DEF_PRO(symbol) extern void* s_k##symbol +#define _CA_PRIVATE_DEF_SEL(accessor, symbol) extern SEL s_k##accessor +#define _CA_PRIVATE_DEF_STR(type, symbol) extern type const CA::symbol + +#endif // CA_PRIVATE_IMPLEMENTATION + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace CA +{ +namespace Private +{ + namespace Class + { + _CA_PRIVATE_DEF_CLS(CAMetalLayer); + } // Class +} // Private +} // CA + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace CA +{ +namespace Private +{ + namespace Protocol + { + + _CA_PRIVATE_DEF_PRO(CAMetalDrawable); + + } // Protocol +} // Private +} // CA + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace CA +{ +namespace Private +{ + namespace Selector + { + _CA_PRIVATE_DEF_SEL(allowsNextDrawableTimeout, + "allowsNextDrawableTimeout"); + _CA_PRIVATE_DEF_SEL(colorspace, + "colorspace"); + _CA_PRIVATE_DEF_SEL(device, + "device"); + _CA_PRIVATE_DEF_SEL(displaySyncEnabled, + "displaySyncEnabled"); + _CA_PRIVATE_DEF_SEL(drawableSize, + "drawableSize"); + _CA_PRIVATE_DEF_SEL(framebufferOnly, + "framebufferOnly"); + _CA_PRIVATE_DEF_SEL(layer, + "layer"); + _CA_PRIVATE_DEF_SEL(maximumDrawableCount, + "maximumDrawableCount"); + _CA_PRIVATE_DEF_SEL(nextDrawable, + "nextDrawable"); + _CA_PRIVATE_DEF_SEL(pixelFormat, + "pixelFormat"); + _CA_PRIVATE_DEF_SEL(residencySet, + "residencySet"); + _CA_PRIVATE_DEF_SEL(setAllowsNextDrawableTimeout_, + "setAllowsNextDrawableTimeout:"); + _CA_PRIVATE_DEF_SEL(setColorspace_, + "setColorspace:"); + _CA_PRIVATE_DEF_SEL(setDevice_, + "setDevice:"); + _CA_PRIVATE_DEF_SEL(setDisplaySyncEnabled_, + "setDisplaySyncEnabled:"); + _CA_PRIVATE_DEF_SEL(setDrawableSize_, + "setDrawableSize:"); + _CA_PRIVATE_DEF_SEL(setFramebufferOnly_, + "setFramebufferOnly:"); + _CA_PRIVATE_DEF_SEL(setMaximumDrawableCount_, + "setMaximumDrawableCount:"); + _CA_PRIVATE_DEF_SEL(setPixelFormat_, + "setPixelFormat:"); + _CA_PRIVATE_DEF_SEL(texture, + "texture"); + } // Class +} // Private +} // CA + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/Source/Cxxmlx/metal-cpp/QuartzCore/QuartzCore.hpp b/Source/Cxxmlx/metal-cpp/QuartzCore/QuartzCore.hpp new file mode 100644 index 00000000..681003ad --- /dev/null +++ b/Source/Cxxmlx/metal-cpp/QuartzCore/QuartzCore.hpp @@ -0,0 +1,28 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// QuartzCore/QuartzCore.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "CAMetalDrawable.hpp" +#include "CAMetalLayer.hpp" + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/Source/Cmlx/metal-cpp/README.md b/Source/Cxxmlx/metal-cpp/README.md similarity index 100% rename from Source/Cmlx/metal-cpp/README.md rename to Source/Cxxmlx/metal-cpp/README.md diff --git a/Source/Cmlx/metal-cpp/SingleHeader/MakeSingleHeader.py b/Source/Cxxmlx/metal-cpp/SingleHeader/MakeSingleHeader.py similarity index 100% rename from Source/Cmlx/metal-cpp/SingleHeader/MakeSingleHeader.py rename to Source/Cxxmlx/metal-cpp/SingleHeader/MakeSingleHeader.py diff --git a/Source/Cmlx/mlx b/Source/Cxxmlx/mlx similarity index 100% rename from Source/Cmlx/mlx rename to Source/Cxxmlx/mlx diff --git a/Source/Cmlx/mlx-conditional/compiled_conditional.cpp b/Source/Cxxmlx/mlx-conditional/compiled_conditional.cpp similarity index 100% rename from Source/Cmlx/mlx-conditional/compiled_conditional.cpp rename to Source/Cxxmlx/mlx-conditional/compiled_conditional.cpp diff --git a/Source/Cmlx/mlx-generated/arange.cpp b/Source/Cxxmlx/mlx-generated/arange.cpp similarity index 100% rename from Source/Cmlx/mlx-generated/arange.cpp rename to Source/Cxxmlx/mlx-generated/arange.cpp diff --git a/Source/Cmlx/mlx-generated/binary.cpp b/Source/Cxxmlx/mlx-generated/binary.cpp similarity index 100% rename from Source/Cmlx/mlx-generated/binary.cpp rename to Source/Cxxmlx/mlx-generated/binary.cpp diff --git a/Source/Cmlx/mlx-generated/binary_ops.cpp b/Source/Cxxmlx/mlx-generated/binary_ops.cpp similarity index 100% rename from Source/Cmlx/mlx-generated/binary_ops.cpp rename to Source/Cxxmlx/mlx-generated/binary_ops.cpp diff --git a/Source/Cmlx/mlx-generated/binary_two.cpp b/Source/Cxxmlx/mlx-generated/binary_two.cpp similarity index 100% rename from Source/Cmlx/mlx-generated/binary_two.cpp rename to Source/Cxxmlx/mlx-generated/binary_two.cpp diff --git a/Source/Cmlx/mlx-generated/compiled_preamble.cpp b/Source/Cxxmlx/mlx-generated/compiled_preamble.cpp similarity index 100% rename from Source/Cmlx/mlx-generated/compiled_preamble.cpp rename to Source/Cxxmlx/mlx-generated/compiled_preamble.cpp diff --git a/Source/Cmlx/mlx-generated/conv.cpp b/Source/Cxxmlx/mlx-generated/conv.cpp similarity index 100% rename from Source/Cmlx/mlx-generated/conv.cpp rename to Source/Cxxmlx/mlx-generated/conv.cpp diff --git a/Source/Cmlx/mlx-generated/copy.cpp b/Source/Cxxmlx/mlx-generated/copy.cpp similarity index 100% rename from Source/Cmlx/mlx-generated/copy.cpp rename to Source/Cxxmlx/mlx-generated/copy.cpp diff --git a/Source/Cmlx/mlx-generated/cuda/cuda_jit_sources.h b/Source/Cxxmlx/mlx-generated/cuda/cuda_jit_sources.h similarity index 100% rename from Source/Cmlx/mlx-generated/cuda/cuda_jit_sources.h rename to Source/Cxxmlx/mlx-generated/cuda/cuda_jit_sources.h diff --git a/Source/Cmlx/mlx-generated/fft.cpp b/Source/Cxxmlx/mlx-generated/fft.cpp similarity index 100% rename from Source/Cmlx/mlx-generated/fft.cpp rename to Source/Cxxmlx/mlx-generated/fft.cpp diff --git a/Source/Cmlx/mlx-generated/fp_quantized.cpp b/Source/Cxxmlx/mlx-generated/fp_quantized.cpp similarity index 100% rename from Source/Cmlx/mlx-generated/fp_quantized.cpp rename to Source/Cxxmlx/mlx-generated/fp_quantized.cpp diff --git a/Source/Cmlx/mlx-generated/fp_quantized_nax.cpp b/Source/Cxxmlx/mlx-generated/fp_quantized_nax.cpp similarity index 100% rename from Source/Cmlx/mlx-generated/fp_quantized_nax.cpp rename to Source/Cxxmlx/mlx-generated/fp_quantized_nax.cpp diff --git a/Source/Cmlx/mlx-generated/gather.cpp b/Source/Cxxmlx/mlx-generated/gather.cpp similarity index 100% rename from Source/Cmlx/mlx-generated/gather.cpp rename to Source/Cxxmlx/mlx-generated/gather.cpp diff --git a/Source/Cmlx/mlx-generated/gather_axis.cpp b/Source/Cxxmlx/mlx-generated/gather_axis.cpp similarity index 100% rename from Source/Cmlx/mlx-generated/gather_axis.cpp rename to Source/Cxxmlx/mlx-generated/gather_axis.cpp diff --git a/Source/Cmlx/mlx-generated/gather_front.cpp b/Source/Cxxmlx/mlx-generated/gather_front.cpp similarity index 100% rename from Source/Cmlx/mlx-generated/gather_front.cpp rename to Source/Cxxmlx/mlx-generated/gather_front.cpp diff --git a/Source/Cmlx/mlx-generated/gemm.cpp b/Source/Cxxmlx/mlx-generated/gemm.cpp similarity index 100% rename from Source/Cmlx/mlx-generated/gemm.cpp rename to Source/Cxxmlx/mlx-generated/gemm.cpp diff --git a/Source/Cmlx/mlx-generated/gemm_nax.cpp b/Source/Cxxmlx/mlx-generated/gemm_nax.cpp similarity index 100% rename from Source/Cmlx/mlx-generated/gemm_nax.cpp rename to Source/Cxxmlx/mlx-generated/gemm_nax.cpp diff --git a/Source/Cmlx/mlx-generated/gemv_masked.cpp b/Source/Cxxmlx/mlx-generated/gemv_masked.cpp similarity index 100% rename from Source/Cmlx/mlx-generated/gemv_masked.cpp rename to Source/Cxxmlx/mlx-generated/gemv_masked.cpp diff --git a/Source/Cmlx/mlx-generated/hadamard.cpp b/Source/Cxxmlx/mlx-generated/hadamard.cpp similarity index 100% rename from Source/Cmlx/mlx-generated/hadamard.cpp rename to Source/Cxxmlx/mlx-generated/hadamard.cpp diff --git a/Source/Cmlx/mlx-generated/logsumexp.cpp b/Source/Cxxmlx/mlx-generated/logsumexp.cpp similarity index 100% rename from Source/Cmlx/mlx-generated/logsumexp.cpp rename to Source/Cxxmlx/mlx-generated/logsumexp.cpp diff --git a/Source/Cmlx/mlx-generated/masked_scatter.cpp b/Source/Cxxmlx/mlx-generated/masked_scatter.cpp similarity index 100% rename from Source/Cmlx/mlx-generated/masked_scatter.cpp rename to Source/Cxxmlx/mlx-generated/masked_scatter.cpp diff --git a/Source/Cxxmlx/mlx-generated/metal/arange.h b/Source/Cxxmlx/mlx-generated/metal/arange.h new file mode 100644 index 00000000..5448fe9a --- /dev/null +++ b/Source/Cxxmlx/mlx-generated/metal/arange.h @@ -0,0 +1,9 @@ +// Copyright © 2023-2024 Apple Inc. +template +[[kernel]] void arange( + constant const T& start, + constant const T& step, + device T* out, + uint index [[thread_position_in_grid]]) { + out[index] = start + index * step; +} diff --git a/Source/Cmlx/mlx-generated/metal/arg_reduce.metal b/Source/Cxxmlx/mlx-generated/metal/arg_reduce.metal similarity index 100% rename from Source/Cmlx/mlx-generated/metal/arg_reduce.metal rename to Source/Cxxmlx/mlx-generated/metal/arg_reduce.metal diff --git a/Source/Cxxmlx/mlx-generated/metal/atomic.h b/Source/Cxxmlx/mlx-generated/metal/atomic.h new file mode 100644 index 00000000..93952c2c --- /dev/null +++ b/Source/Cxxmlx/mlx-generated/metal/atomic.h @@ -0,0 +1,345 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include +#include + +using namespace metal; + +/////////////////////////////////////////////////////////////////////////////// +// Atomic utils +/////////////////////////////////////////////////////////////////////////////// + +#pragma METAL internals : enable +template +constexpr constant bool is_metal_atomic = _disjunction< + is_same, + is_same, + is_same, + is_same>::value; + +#pragma METAL internals : disable + +template +struct mlx_atomic { + atomic val; +}; + +template +struct mlx_atomic>> { + atomic val; +}; + +/////////////////////////////////////////////////////////////////////////////// +// Native metal atomics +/////////////////////////////////////////////////////////////////////////////// + +template , bool> = true> +METAL_FUNC T +mlx_atomic_load_explicit(device mlx_atomic* object, size_t offset) { + return atomic_load_explicit(&(object[offset].val), memory_order_relaxed); +} + +template , bool> = true> +METAL_FUNC void +mlx_atomic_store_explicit(device mlx_atomic* object, T val, size_t offset) { + atomic_store_explicit(&(object[offset].val), val, memory_order_relaxed); +} + +template , bool> = true> +METAL_FUNC void mlx_atomic_fetch_and_explicit( + device mlx_atomic* object, + T val, + size_t offset) { + atomic_fetch_and_explicit(&(object[offset].val), val, memory_order_relaxed); +} + +template , bool> = true> +METAL_FUNC void mlx_atomic_fetch_or_explicit( + device mlx_atomic* object, + T val, + size_t offset) { + atomic_fetch_or_explicit(&(object[offset].val), val, memory_order_relaxed); +} + +template , bool> = true> +METAL_FUNC void mlx_atomic_fetch_min_explicit( + device mlx_atomic* object, + T val, + size_t offset) { + atomic_fetch_min_explicit(&(object[offset].val), val, memory_order_relaxed); +} + +template , bool> = true> +METAL_FUNC void mlx_atomic_fetch_max_explicit( + device mlx_atomic* object, + T val, + size_t offset) { + atomic_fetch_max_explicit(&(object[offset].val), val, memory_order_relaxed); +} + +template , bool> = true> +METAL_FUNC void mlx_atomic_fetch_add_explicit( + device mlx_atomic* object, + T val, + size_t offset) { + atomic_fetch_add_explicit(&(object[offset].val), val, memory_order_relaxed); +} + +template , bool> = true> +METAL_FUNC void mlx_atomic_fetch_mul_explicit( + device mlx_atomic* object, + T val, + size_t offset) { + T expected = mlx_atomic_load_explicit(object, offset); + while (!mlx_atomic_compare_exchange_weak_explicit( + object, &expected, val * expected, offset)) { + } +} + +template , bool> = true> +METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit( + device mlx_atomic* object, + thread T* expected, + T val, + size_t offset) { + return atomic_compare_exchange_weak_explicit( + &(object[offset].val), + expected, + val, + memory_order_relaxed, + memory_order_relaxed); +} + +// Specialization for float since it does not atomic_fetch_min_explicit +template <> +METAL_FUNC void mlx_atomic_fetch_min_explicit( + device mlx_atomic* object, + float val, + size_t offset) { + float expected = mlx_atomic_load_explicit(object, offset); + while (val < expected) { + if (mlx_atomic_compare_exchange_weak_explicit( + object, &expected, val, offset)) { + return; + } + } +} + +// Specialization for float since it does not atomic_fetch_max_explicit +template <> +METAL_FUNC void mlx_atomic_fetch_max_explicit( + device mlx_atomic* object, + float val, + size_t offset) { + float expected = mlx_atomic_load_explicit(object, offset); + while (val > expected) { + if (mlx_atomic_compare_exchange_weak_explicit( + object, &expected, val, offset)) { + return; + } + } +} + +/////////////////////////////////////////////////////////////////////////////// +// Custom atomics +/////////////////////////////////////////////////////////////////////////////// + +namespace { + +template +constexpr constant uint packing_size = sizeof(uint) / sizeof(T); + +template +union uint_or_packed { + T val[packing_size]; + uint bits; +}; + +template +struct mlx_atomic_update_helper { + uint operator()(uint_or_packed init, T update, size_t elem_offset) { + Op op; + init.val[elem_offset] = op(update, init.val[elem_offset]); + return init.bits; + } +}; + +template +METAL_FUNC void mlx_atomic_update_and_store( + device mlx_atomic* object, + T update, + size_t offset) { + size_t pack_offset = offset / packing_size; + size_t elem_offset = offset % packing_size; + + mlx_atomic_update_helper helper; + uint_or_packed expected; + expected.bits = + atomic_load_explicit(&(object[pack_offset].val), memory_order_relaxed); + + while (Op::condition(update, expected.val[elem_offset]) && + !mlx_atomic_compare_exchange_weak_explicit( + object, + &(expected.bits), + helper(expected, update, elem_offset), + pack_offset)) { + } +} + +template +struct __None { + static bool condition(T a, T b) { +#pragma unused(a) +#pragma unused(b) + return true; + } + + T operator()(T a, T b) { +#pragma unused(b) + return a; + } +}; + +template +struct __Add { + static bool condition(T a, T b) { +#pragma unused(a) +#pragma unused(b) + return true; + } + + T operator()(T a, T b) { + return a + b; + } +}; + +template +struct __Mul { + static bool condition(T a, T b) { +#pragma unused(a) + return b != 0; + } + + T operator()(T a, T b) { + return a * b; + } +}; + +template +struct __Max { + static bool condition(T a, T b) { + return a > b; + } + + T operator()(T a, T b) { + return max(a, b); + } +}; + +template +struct __Min { + static bool condition(T a, T b) { + return a < b; + } + + T operator()(T a, T b) { + return min(a, b); + } +}; + +} // namespace + +template , bool> = true> +METAL_FUNC T +mlx_atomic_load_explicit(device mlx_atomic* object, size_t offset) { + size_t pack_offset = offset / sizeof(T); + size_t elem_offset = offset % sizeof(T); + uint_or_packed packed_val; + packed_val.bits = + atomic_load_explicit(&(object[pack_offset].val), memory_order_relaxed); + return packed_val.val[elem_offset]; +} + +template , bool> = true> +METAL_FUNC void +mlx_atomic_store_explicit(device mlx_atomic* object, T val, size_t offset) { + mlx_atomic_update_and_store>(object, val, offset); +} + +template , bool> = true> +METAL_FUNC void mlx_atomic_fetch_and_explicit( + device mlx_atomic* object, + T val, + size_t offset) { + size_t pack_offset = offset / packing_size; + size_t elem_offset = offset % packing_size; + uint_or_packed identity; + identity.bits = __UINT32_MAX__; + identity.val[elem_offset] = val; + + atomic_fetch_and_explicit( + &(object[pack_offset].val), identity.bits, memory_order_relaxed); +} + +template , bool> = true> +METAL_FUNC void mlx_atomic_fetch_or_explicit( + device mlx_atomic* object, + T val, + size_t offset) { + size_t pack_offset = offset / packing_size; + size_t elem_offset = offset % packing_size; + uint_or_packed identity; + identity.bits = 0; + identity.val[elem_offset] = val; + + atomic_fetch_or_explicit( + &(object[pack_offset].val), identity.bits, memory_order_relaxed); +} + +template , bool> = true> +METAL_FUNC void mlx_atomic_fetch_min_explicit( + device mlx_atomic* object, + T val, + size_t offset) { + mlx_atomic_update_and_store>(object, val, offset); +} + +template , bool> = true> +METAL_FUNC void mlx_atomic_fetch_max_explicit( + device mlx_atomic* object, + T val, + size_t offset) { + mlx_atomic_update_and_store>(object, val, offset); +} + +template , bool> = true> +METAL_FUNC void mlx_atomic_fetch_add_explicit( + device mlx_atomic* object, + T val, + size_t offset) { + mlx_atomic_update_and_store>(object, val, offset); +} + +template , bool> = true> +METAL_FUNC void mlx_atomic_fetch_mul_explicit( + device mlx_atomic* object, + T val, + size_t offset) { + mlx_atomic_update_and_store>(object, val, offset); +} + +template , bool> = true> +METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit( + device mlx_atomic* object, + thread uint* expected, + uint val, + size_t offset) { + return atomic_compare_exchange_weak_explicit( + &(object[offset].val), + expected, + val, + memory_order_relaxed, + memory_order_relaxed); +} diff --git a/Source/Cxxmlx/mlx-generated/metal/bf16.h b/Source/Cxxmlx/mlx-generated/metal/bf16.h new file mode 100644 index 00000000..aa3c3c78 --- /dev/null +++ b/Source/Cxxmlx/mlx-generated/metal/bf16.h @@ -0,0 +1,16 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include + +using namespace metal; + +typedef bfloat bfloat16_t; +inline uint16_t bfloat16_to_uint16(const bfloat16_t x) { + return as_type(x); +} + +inline bfloat16_t uint16_to_bfloat16(const uint16_t x) { + return as_type(x); +} diff --git a/Source/Cxxmlx/mlx-generated/metal/bf16_math.h b/Source/Cxxmlx/mlx-generated/metal/bf16_math.h new file mode 100644 index 00000000..0643fb3e --- /dev/null +++ b/Source/Cxxmlx/mlx-generated/metal/bf16_math.h @@ -0,0 +1,380 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +/////////////////////////////////////////////////////////////////////////////// +// Metal math for bfloat16 +/////////////////////////////////////////////////////////////////////////////// + +/* + +Following the Metal Shading Language Specification (Metal 3.1) + +"bfloat is an extended itypeing point type that only allows implicit conversion + to a type of greater itypeing point rank. While bfloat can be implicitly + converted to itype, it cannot be implicitly converted to half, and neither + itype nor half can be implicitly converted to bfloat." + +Further, as far as I can tell, the stdlib math/simd functions are not defined +for bfloat and calling with an argument of type bfloat will result in that +argument getting implicitly converted to itype which then returns an output +that is (likely) a itype which cannot be implicitly converted into a bfloat + +This leads to situations where +bfloat a = 5.0bf; +bfloat b = metal::abs(a); // this will throw an error since abs return itype +bfloat c = static_cast(metal::abs(a)); // this is fine + +For the moment, I will be adding overloaded instantiations of the math +functions to accordingly automatically handle the casting + +*/ + +#define instantiate_metal_math_funcs(itype, otype, ctype, mfast) \ + \ + METAL_FUNC otype abs(itype x) { \ + return static_cast(__metal_fabs(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype acos(itype x) { \ + return static_cast(__metal_acos(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype acosh(itype x) { \ + return static_cast(__metal_acosh(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype asin(itype x) { \ + return static_cast(__metal_asin(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype asinh(itype x) { \ + return static_cast(__metal_asinh(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype atan(itype y_over_x) { \ + return static_cast( \ + __metal_atan(static_cast(y_over_x), mfast)); \ + } \ + METAL_FUNC otype atan2(itype y, itype x) { \ + return static_cast( \ + __metal_atan2(static_cast(y), static_cast(x), mfast)); \ + } \ + METAL_FUNC otype atanh(itype x) { \ + return static_cast(__metal_atanh(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype ceil(itype x) { \ + return static_cast(__metal_ceil(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype cos(itype x) { \ + return static_cast(__metal_cos(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype cosh(itype x) { \ + return static_cast(__metal_cosh(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype cospi(itype x) { \ + return static_cast(__metal_cospi(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype divide(itype x, itype y) { \ + return static_cast( \ + __metal_divide(static_cast(x), static_cast(y), mfast)); \ + } \ + METAL_FUNC otype exp(itype x) { \ + return static_cast(__metal_exp(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype exp10(itype x) { \ + return static_cast(__metal_exp10(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype exp2(itype x) { \ + return static_cast(__metal_exp2(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype fabs(itype x) { \ + return static_cast(__metal_fabs(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype fdim(itype x, itype y) { \ + ctype t = static_cast(x - y); \ + return static_cast(select(t, ctype(0), t < ctype(0) || x == y)); \ + } \ + METAL_FUNC otype floor(itype x) { \ + return static_cast(__metal_floor(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype fma(itype x, itype y, itype z) { \ + return static_cast(__metal_fma( \ + static_cast(x), static_cast(y), static_cast(z))); \ + } \ + METAL_FUNC otype fmax(itype x, itype y) { \ + return static_cast( \ + __metal_fmax(static_cast(x), static_cast(y), mfast)); \ + } \ + METAL_FUNC otype fmax3(itype x, itype y, itype z) { \ + return static_cast(__metal_fmax3( \ + static_cast(x), \ + static_cast(y), \ + static_cast(z), \ + mfast)); \ + } \ + METAL_FUNC otype fmedian3(itype x, itype y, itype z) { \ + return static_cast(__metal_fmedian3( \ + static_cast(x), \ + static_cast(y), \ + static_cast(z), \ + mfast)); \ + } \ + METAL_FUNC otype fmin(itype x, itype y) { \ + return static_cast( \ + __metal_fmin(static_cast(x), static_cast(y), mfast)); \ + } \ + METAL_FUNC otype fmin3(itype x, itype y, itype z) { \ + return static_cast(__metal_fmin3( \ + static_cast(x), \ + static_cast(y), \ + static_cast(z), \ + mfast)); \ + } \ + METAL_FUNC otype fmod(itype x, itype y) { \ + return static_cast( \ + __metal_fmod(static_cast(x), static_cast(y), mfast)); \ + } \ + METAL_FUNC otype fract(itype x) { \ + return static_cast(__metal_fract(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype frexp(itype x, thread int& exp) { \ + return static_cast(__metal_frexp(static_cast(x), &exp)); \ + } \ + METAL_FUNC otype ldexp(itype x, int k) { \ + return static_cast(__metal_ldexp(static_cast(x), k, mfast)); \ + } \ + METAL_FUNC otype log(itype x) { \ + return static_cast(__metal_log(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype log10(itype x) { \ + return static_cast(__metal_log10(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype log2(itype x) { \ + return static_cast(__metal_log2(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype max(itype x, itype y) { \ + return static_cast( \ + __metal_fmax(static_cast(x), static_cast(y), mfast)); \ + } \ + METAL_FUNC otype max3(itype x, itype y, itype z) { \ + return static_cast(__metal_fmax3( \ + static_cast(x), \ + static_cast(y), \ + static_cast(z), \ + mfast)); \ + } \ + METAL_FUNC otype median3(itype x, itype y, itype z) { \ + return static_cast(__metal_fmedian3( \ + static_cast(x), \ + static_cast(y), \ + static_cast(z), \ + mfast)); \ + } \ + METAL_FUNC otype min(itype x, itype y) { \ + return static_cast( \ + __metal_fmin(static_cast(x), static_cast(y), mfast)); \ + } \ + METAL_FUNC otype min3(itype x, itype y, itype z) { \ + return static_cast(__metal_fmin3( \ + static_cast(x), \ + static_cast(y), \ + static_cast(z), \ + mfast)); \ + } \ + METAL_FUNC otype nextafter(itype x, itype y) { \ + return static_cast( \ + __metal_nextafter(static_cast(x), static_cast(y))); \ + } \ + METAL_FUNC otype pow(itype x, itype y) { \ + return static_cast( \ + __metal_pow(static_cast(x), static_cast(y), mfast)); \ + } \ + METAL_FUNC otype powr(itype x, itype y) { \ + return static_cast( \ + __metal_powr(static_cast(x), static_cast(y), mfast)); \ + } \ + METAL_FUNC otype rint(itype x) { \ + return static_cast(__metal_rint(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype round(itype x) { \ + return static_cast(__metal_round(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype rsqrt(itype x) { \ + return static_cast(__metal_rsqrt(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype sin(itype x) { \ + return static_cast(__metal_sin(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype sinh(itype x) { \ + return static_cast(__metal_sinh(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype sinpi(itype x) { \ + return static_cast(__metal_sinpi(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype sqrt(itype x) { \ + return static_cast(__metal_sqrt(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype tan(itype x) { \ + return static_cast(__metal_tan(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype tanh(itype x) { \ + return static_cast(__metal_tanh(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype tanpi(itype x) { \ + return static_cast(__metal_tanpi(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype trunc(itype x) { \ + return static_cast(__metal_trunc(static_cast(x), mfast)); \ + } + +namespace metal { + +instantiate_metal_math_funcs( + bfloat16_t, + bfloat16_t, + float, + __METAL_MAYBE_FAST_MATH__); + +namespace fast { + +instantiate_metal_math_funcs( + bfloat16_t, + bfloat16_t, + float, + __METAL_FAST_MATH__); + +} // namespace fast + +namespace precise { + +instantiate_metal_math_funcs( + bfloat16_t, + bfloat16_t, + float, + __METAL_PRECISE_MATH__); + +} // namespace precise + +} // namespace metal + +/////////////////////////////////////////////////////////////////////////////// +// Metal simd for bfloat16 +/////////////////////////////////////////////////////////////////////////////// + +#define instantiate_metal_simd_comm_funcs( \ + itype, otype, ctype, itype_to_ctype, ctype_to_otype) \ + \ + METAL_FUNC otype simd_broadcast(itype data, ushort broadcast_lane_id) { \ + return ctype_to_otype( \ + __metal_simd_broadcast(itype_to_ctype(data), broadcast_lane_id)); \ + } \ + \ + METAL_FUNC otype simd_shuffle(itype data, ushort simd_lane_id) { \ + return ctype_to_otype( \ + __metal_simd_shuffle(itype_to_ctype(data), simd_lane_id)); \ + } \ + \ + METAL_FUNC otype simd_shuffle_and_fill_down( \ + itype data, itype filling_data, ushort delta, ushort modulo) { \ + return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \ + itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \ + } \ + \ + METAL_FUNC otype simd_shuffle_and_fill_down( \ + itype data, itype filling_data, ushort delta) { \ + return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \ + itype_to_ctype(data), \ + itype_to_ctype(filling_data), \ + delta, \ + __metal_get_simdgroup_size(ushort()))); \ + } \ + \ + METAL_FUNC otype simd_shuffle_and_fill_up( \ + itype data, itype filling_data, ushort delta, ushort modulo) { \ + return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \ + itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \ + } \ + \ + METAL_FUNC otype simd_shuffle_and_fill_up( \ + itype data, itype filling_data, ushort delta) { \ + return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \ + itype_to_ctype(data), \ + itype_to_ctype(filling_data), \ + delta, \ + __metal_get_simdgroup_size(ushort()))); \ + } \ + \ + METAL_FUNC otype simd_shuffle_down(itype data, ushort delta) { \ + return ctype_to_otype( \ + __metal_simd_shuffle_down(itype_to_ctype(data), delta)); \ + } \ + \ + METAL_FUNC otype simd_shuffle_rotate_down(itype data, ushort delta) { \ + return ctype_to_otype( \ + __metal_simd_shuffle_rotate_down(itype_to_ctype(data), delta)); \ + } \ + \ + METAL_FUNC otype simd_shuffle_rotate_up(itype data, ushort delta) { \ + return ctype_to_otype( \ + __metal_simd_shuffle_rotate_up(itype_to_ctype(data), delta)); \ + } \ + \ + METAL_FUNC otype simd_shuffle_up(itype data, ushort delta) { \ + return ctype_to_otype( \ + __metal_simd_shuffle_up(itype_to_ctype(data), delta)); \ + } \ + \ + METAL_FUNC otype simd_shuffle_xor(itype data, ushort mask) { \ + return ctype_to_otype( \ + __metal_simd_shuffle_xor(itype_to_ctype(data), mask)); \ + } + +#define instantiate_metal_simd_reduction_funcs(itype, otype, ctype) \ + \ + METAL_FUNC otype simd_max(itype data) { \ + return static_cast(__metal_simd_max(static_cast(data))); \ + } \ + \ + METAL_FUNC otype simd_min(itype data) { \ + return static_cast(__metal_simd_min(static_cast(data))); \ + } \ + \ + METAL_FUNC otype simd_prefix_exclusive_product(itype data) { \ + return static_cast( \ + __metal_simd_prefix_exclusive_product(static_cast(data))); \ + } \ + \ + METAL_FUNC otype simd_prefix_exclusive_sum(itype data) { \ + return static_cast( \ + __metal_simd_prefix_exclusive_sum(static_cast(data))); \ + } \ + \ + METAL_FUNC otype simd_prefix_inclusive_product(itype data) { \ + return static_cast( \ + __metal_simd_prefix_inclusive_product(static_cast(data))); \ + } \ + \ + METAL_FUNC otype simd_prefix_inclusive_sum(itype data) { \ + return static_cast( \ + __metal_simd_prefix_inclusive_sum(static_cast(data))); \ + } \ + \ + METAL_FUNC otype simd_product(itype data) { \ + return static_cast(__metal_simd_product(static_cast(data))); \ + } \ + \ + METAL_FUNC otype simd_sum(itype data) { \ + return static_cast(__metal_simd_sum(static_cast(data))); \ + } \ + \ + METAL_FUNC otype simd_xor(itype data) { \ + return static_cast(__metal_simd_xor(static_cast(data))); \ + } + +namespace metal { + +instantiate_metal_simd_comm_funcs( + bfloat16_t, + bfloat16_t, + uint16_t, + bfloat16_to_uint16, + uint16_to_bfloat16); +instantiate_metal_simd_reduction_funcs(bfloat16_t, bfloat16_t, float); + +} // namespace metal diff --git a/Source/Cxxmlx/mlx-generated/metal/binary.h b/Source/Cxxmlx/mlx-generated/metal/binary.h new file mode 100644 index 00000000..f1df8853 --- /dev/null +++ b/Source/Cxxmlx/mlx-generated/metal/binary.h @@ -0,0 +1,199 @@ +// Copyright © 2024 Apple Inc. + +template +[[kernel]] void binary_ss( + device const T* a, + device const T* b, + device U* c, + uint index [[thread_position_in_grid]]) { + c[index] = Op()(a[0], b[0]); +} + +template ::n> +[[kernel]] void binary_sv( + device const T* a, + device const T* b, + device U* c, + constant uint& size, + uint index [[thread_position_in_grid]]) { + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + c[index + i] = Op()(a[0], b[index + i]); + } + } else { + for (int i = 0; i < N; ++i) { + c[index + i] = Op()(a[0], b[index + i]); + } + } +} + +template ::n> +[[kernel]] void binary_vs( + device const T* a, + device const T* b, + device U* c, + constant uint& size, + uint index [[thread_position_in_grid]]) { + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + c[index + i] = Op()(a[index + i], b[0]); + } + } else { + for (int i = 0; i < N; ++i) { + c[index + i] = Op()(a[index + i], b[0]); + } + } +} + +template ::n> +[[kernel]] void binary_vv( + device const T* a, + device const T* b, + device U* c, + constant uint& size, + uint index [[thread_position_in_grid]]) { + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + c[index + i] = Op()(a[index + i], b[index + i]); + } + } else { + for (int i = 0; i < N; ++i) { + c[index + i] = Op()(a[index + i], b[index + i]); + } + } +} + +template ::n> +[[kernel]] void binary_sv2( + device const T* a, + device const T* b, + device U* c, + constant int64_t& size, + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + c[offset + i] = Op()(a[0], b[offset + i]); + } + } else { + for (int i = 0; i < N; ++i) { + c[offset + i] = Op()(a[0], b[offset + i]); + } + } +} + +template ::n> +[[kernel]] void binary_vs2( + device const T* a, + device const T* b, + device U* c, + constant int64_t& size, + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + c[offset + i] = Op()(a[offset + i], b[0]); + } + } else { + for (int i = 0; i < N; ++i) { + c[offset + i] = Op()(a[offset + i], b[0]); + } + } +} + +template ::n> +[[kernel]] void binary_vv2( + device const T* a, + device const T* b, + device U* c, + constant int64_t& size, + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + c[offset + i] = Op()(a[offset + i], b[offset + i]); + } + } else { + for (int i = 0; i < N; ++i) { + c[offset + i] = Op()(a[offset + i], b[offset + i]); + } + } +} + +template +[[kernel]] void binary_g_nd1( + device const T* a, + device const T* b, + device U* c, + constant const int64_t& a_stride, + constant const int64_t& b_stride, + uint index [[thread_position_in_grid]]) { + auto a_idx = elem_to_loc_1(index, a_stride); + auto b_idx = elem_to_loc_1(index, b_stride); + c[index] = Op()(a[a_idx], b[b_idx]); +} + +template +[[kernel]] void binary_g_nd2( + device const T* a, + device const T* b, + device U* c, + constant const int64_t a_strides[2], + constant const int64_t b_strides[2], + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + auto a_idx = elem_to_loc_2(index, a_strides); + auto b_idx = elem_to_loc_2(index, b_strides); + IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y; + c[out_idx] = Op()(a[a_idx], b[b_idx]); +} + +template +[[kernel]] void binary_g_nd3( + device const T* a, + device const T* b, + device U* c, + constant const int64_t a_strides[3], + constant const int64_t b_strides[3], + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + auto a_idx = elem_to_loc_3(index, a_strides); + auto b_idx = elem_to_loc_3(index, b_strides); + IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z); + c[out_idx] = Op()(a[a_idx], b[b_idx]); +} + +template < + typename T, + typename U, + typename Op, + int N = 1, + typename IdxT = int64_t> +[[kernel]] void binary_g( + device const T* a, + device const T* b, + device U* c, + constant const int* shape, + constant const int64_t* a_strides, + constant const int64_t* b_strides, + constant const int& ndim, + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + auto idx = elem_to_loc_2_nd( + {N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim); + auto xshape = shape[ndim - 1]; + IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z); + IdxT a_xstride = a_strides[ndim - 1]; + IdxT b_xstride = b_strides[ndim - 1]; + for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { + c[out_idx++] = Op()(a[idx.x], b[idx.y]); + idx.x += a_xstride; + idx.y += b_xstride; + } +} diff --git a/Source/Cxxmlx/mlx-generated/metal/binary_ops.h b/Source/Cxxmlx/mlx-generated/metal/binary_ops.h new file mode 100644 index 00000000..4e3d881f --- /dev/null +++ b/Source/Cxxmlx/mlx-generated/metal/binary_ops.h @@ -0,0 +1,330 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include +#include + +constant mlx::os_log logger("mlx", "binary_ops"); + +struct Add { + template + T operator()(T x, T y) { + return x + y; + } +}; + +struct FloorDivide { + template + T operator()(T x, T y) { + return x / y; + } + template <> + float operator()(float x, float y) { + return trunc(x / y); + } + template <> + half operator()(half x, half y) { + return trunc(x / y); + } + template <> + bfloat16_t operator()(bfloat16_t x, bfloat16_t y) { + return trunc(x / y); + } +}; + +struct Divide { + template + T operator()(T x, T y) { + return x / y; + } +}; + +struct Remainder { + template + metal::enable_if_t & !metal::is_signed_v, T> + operator()(T x, T y) { + return x % y; + } + template + metal::enable_if_t & metal::is_signed_v, T> + operator()(T x, T y) { + auto r = x % y; + if (r != 0 && (r < 0 != y < 0)) { + r += y; + } + return r; + } + template + metal::enable_if_t, T> operator()(T x, T y) { + T r = fmod(x, y); + if (r != 0 && (r < 0 != y < 0)) { + r += y; + } + return r; + } + template <> + complex64_t operator()(complex64_t x, complex64_t y) { + return x % y; + } +}; + +struct Equal { + template + bool operator()(T x, T y) { + return x == y; + } +}; + +struct NaNEqual { + template + bool operator()(T x, T y) { + return x == y || (metal::isnan(x) && metal::isnan(y)); + } + template <> + bool operator()(complex64_t x, complex64_t y) { + return x == y || + (metal::isnan(x.real) && metal::isnan(y.real) && metal::isnan(x.imag) && + metal::isnan(y.imag)) || + (x.real == y.real && metal::isnan(x.imag) && metal::isnan(y.imag)) || + (metal::isnan(x.real) && metal::isnan(y.real) && x.imag == y.imag); + } +}; + +struct Greater { + template + bool operator()(T x, T y) { + return x > y; + } +}; + +struct GreaterEqual { + template + bool operator()(T x, T y) { + return x >= y; + } +}; + +struct Less { + template + bool operator()(T x, T y) { + return x < y; + } +}; + +struct LessEqual { + template + bool operator()(T x, T y) { + return x <= y; + } +}; + +struct LogAddExp { + template + T operator()(T x, T y) { + if (metal::isnan(x) || metal::isnan(y)) { + return metal::numeric_limits::quiet_NaN(); + } + constexpr T inf = metal::numeric_limits::infinity(); + T maxval = metal::max(x, y); + T minval = metal::min(x, y); + return (minval == -inf || maxval == inf) + ? maxval + : (maxval + log1p(metal::exp(minval - maxval))); + }; + + complex64_t operator()(complex64_t x, complex64_t y) { + if (metal::isnan(x.real) || metal::isnan(x.imag) || metal::isnan(y.real) || + metal::isnan(y.imag)) { + return metal::numeric_limits::quiet_NaN(); + } + constexpr float inf = metal::numeric_limits::infinity(); + complex64_t maxval = x > y ? x : y; + complex64_t minval = x < y ? x : y; + if (minval.real == -inf || maxval.real == inf) + return maxval; + float m = metal::exp(minval.real - maxval.real); + complex64_t dexp{ + m * metal::cos(minval.imag - maxval.imag), + m * metal::sin(minval.imag - maxval.imag), + }; + return maxval + log1p(dexp); + } +}; + +struct Maximum { + template + metal::enable_if_t, T> operator()(T x, T y) { + return metal::max(x, y); + } + + template + metal::enable_if_t, T> operator()(T x, T y) { + if (metal::isnan(x)) { + return x; + } + return x > y ? x : y; + } + + template <> + complex64_t operator()(complex64_t x, complex64_t y) { + if (metal::isnan(x.real) || metal::isnan(x.imag)) { + return x; + } + return x > y ? x : y; + } +}; + +struct Minimum { + template + metal::enable_if_t, T> operator()(T x, T y) { + return metal::min(x, y); + } + + template + metal::enable_if_t, T> operator()(T x, T y) { + if (metal::isnan(x)) { + return x; + } + return x < y ? x : y; + } + + template <> + complex64_t operator()(complex64_t x, complex64_t y) { + if (metal::isnan(x.real) || metal::isnan(x.imag)) { + return x; + } + return x < y ? x : y; + } +}; + +struct Multiply { + template + T operator()(T x, T y) { + return x * y; + } +}; + +struct NotEqual { + template + bool operator()(T x, T y) { + return x != y; + } + template <> + bool operator()(complex64_t x, complex64_t y) { + return x.real != y.real || x.imag != y.imag; + } +}; + +struct Power { + template + metal::enable_if_t, T> operator()(T base, T exp) { + return metal::pow(base, exp); + } + + template + metal::enable_if_t, T> operator()(T base, T exp) { + T res = 1; + // Undefined to raise integer to negative power + if (exp < 0) { + logger.log_debug( + "int pow exp<0 (base=%ld exp=%ld)", (long)base, (long)exp); + return 0; + } + + while (exp) { + if (exp & 1) { + res *= base; + } + exp >>= 1; + base *= base; + } + return res; + } + + template <> + complex64_t operator()(complex64_t x, complex64_t y) { + if (x.real == 0 && x.imag == 0) { + if (metal::isnan(y.real) || metal::isnan(y.imag)) { + auto nan = metal::numeric_limits::quiet_NaN(); + return {nan, nan}; + } + return {0.0, 0.0}; + } + auto x_theta = metal::atan2(x.imag, x.real); + auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag); + auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta); + auto phase = y.imag * x_ln_r + y.real * x_theta; + return {mag * metal::cos(phase), mag * metal::sin(phase)}; + } +}; + +struct Subtract { + template + T operator()(T x, T y) { + return x - y; + } +}; + +struct LogicalAnd { + template + T operator()(T x, T y) { + return x && y; + }; +}; + +struct LogicalOr { + template + T operator()(T x, T y) { + return x || y; + }; +}; + +struct BitwiseAnd { + template + T operator()(T x, T y) { + return x & y; + }; +}; + +struct BitwiseOr { + template + T operator()(T x, T y) { + return x | y; + }; +}; + +struct BitwiseXor { + template + T operator()(T x, T y) { + return x ^ y; + }; +}; + +struct LeftShift { + template + T operator()(T x, T y) { + return x << y; + }; +}; + +struct RightShift { + template + T operator()(T x, T y) { + return x >> y; + }; +}; + +struct ArcTan2 { + template + T operator()(T y, T x) { + return metal::precise::atan2(y, x); + } +}; + +struct DivMod { + template + metal::array operator()(T x, T y) { + return {FloorDivide{}(x, y), Remainder{}(x, y)}; + }; +}; diff --git a/Source/Cxxmlx/mlx-generated/metal/binary_two.h b/Source/Cxxmlx/mlx-generated/metal/binary_two.h new file mode 100644 index 00000000..4455e4ca --- /dev/null +++ b/Source/Cxxmlx/mlx-generated/metal/binary_two.h @@ -0,0 +1,244 @@ +// Copyright © 2024 Apple Inc. + +template +[[kernel]] void binary_ss( + device const T* a, + device const T* b, + device U* c, + device U* d, + uint index [[thread_position_in_grid]]) { + auto out = Op()(a[0], b[0]); + c[index] = out[0]; + d[index] = out[1]; +} + +template ::n> +[[kernel]] void binary_sv( + device const T* a, + device const T* b, + device U* c, + device U* d, + constant uint& size, + uint index [[thread_position_in_grid]]) { + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + auto out = Op()(a[0], b[index + i]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } + } else { + for (int i = 0; i < N; ++i) { + auto out = Op()(a[0], b[index + i]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } + } +} + +template ::n> +[[kernel]] void binary_vs( + device const T* a, + device const T* b, + device U* c, + device U* d, + constant uint& size, + uint index [[thread_position_in_grid]]) { + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + auto out = Op()(a[index + i], b[0]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } + } else { + for (int i = 0; i < N; ++i) { + auto out = Op()(a[index + i], b[0]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } + } +} + +template ::n> +[[kernel]] void binary_vv( + device const T* a, + device const T* b, + device U* c, + device U* d, + constant uint& size, + uint index [[thread_position_in_grid]]) { + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + auto out = Op()(a[index + i], b[index + i]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } + } else { + for (int i = 0; i < N; ++i) { + auto out = Op()(a[index + i], b[index + i]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } + } +} + +template ::n> +[[kernel]] void binary_sv2( + device const T* a, + device const T* b, + device U* c, + device U* d, + constant int64_t& size, + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + auto out = Op()(a[0], b[offset + i]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } + } else { + for (int i = 0; i < N; ++i) { + auto out = Op()(a[0], b[offset + i]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } + } +} + +template ::n> +[[kernel]] void binary_vs2( + device const T* a, + device const T* b, + device U* c, + device U* d, + constant int64_t& size, + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + auto out = Op()(a[offset + i], b[0]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } + } else { + for (int i = 0; i < N; ++i) { + auto out = Op()(a[offset + i], b[0]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } + } +} + +template ::n> +[[kernel]] void binary_vv2( + device const T* a, + device const T* b, + device U* c, + device U* d, + constant int64_t& size, + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + auto out = Op()(a[offset + i], b[offset + i]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } + } else { + for (int i = 0; i < N; ++i) { + auto out = Op()(a[offset + i], b[offset + i]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } + } +} + +template +[[kernel]] void binary_g_nd1( + device const T* a, + device const T* b, + device U* c, + device U* d, + constant const int64_t& a_stride, + constant const int64_t& b_stride, + uint index [[thread_position_in_grid]]) { + auto a_idx = elem_to_loc_1(index, a_stride); + auto b_idx = elem_to_loc_1(index, b_stride); + auto out = Op()(a[a_idx], b[b_idx]); + c[index] = out[0]; + d[index] = out[1]; +} + +template +[[kernel]] void binary_g_nd2( + device const T* a, + device const T* b, + device U* c, + device U* d, + constant const int64_t a_strides[2], + constant const int64_t b_strides[2], + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + auto a_idx = elem_to_loc_2(index, a_strides); + auto b_idx = elem_to_loc_2(index, b_strides); + IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y; + auto out = Op()(a[a_idx], b[b_idx]); + c[out_idx] = out[0]; + d[out_idx] = out[1]; +} + +template +[[kernel]] void binary_g_nd3( + device const T* a, + device const T* b, + device U* c, + device U* d, + constant const int64_t a_strides[3], + constant const int64_t b_strides[3], + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + auto a_idx = elem_to_loc_3(index, a_strides); + auto b_idx = elem_to_loc_3(index, b_strides); + IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z); + auto out = Op()(a[a_idx], b[b_idx]); + c[out_idx] = out[0]; + d[out_idx] = out[1]; +} + +template < + typename T, + typename U, + typename Op, + int N = 1, + typename IdxT = int64_t> +[[kernel]] void binary_g( + device const T* a, + device const T* b, + device U* c, + device U* d, + constant const int* shape, + constant const int64_t* a_strides, + constant const int64_t* b_strides, + constant const int& ndim, + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + auto idx = elem_to_loc_2_nd( + {N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim); + auto xshape = shape[ndim - 1]; + IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z); + IdxT a_xstride = a_strides[ndim - 1]; + IdxT b_xstride = b_strides[ndim - 1]; + for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { + auto out = Op()(a[idx.x], b[idx.y]); + c[out_idx] = out[0]; + d[out_idx++] = out[1]; + idx.x += a_xstride; + idx.y += b_xstride; + } +} diff --git a/Source/Cxxmlx/mlx-generated/metal/cexpf.h b/Source/Cxxmlx/mlx-generated/metal/cexpf.h new file mode 100644 index 00000000..b45fe6a2 --- /dev/null +++ b/Source/Cxxmlx/mlx-generated/metal/cexpf.h @@ -0,0 +1,134 @@ +// Copyright © 2025 Apple Inc. +// Copyright © 2008-2013 NVIDIA Corporation +// Copyright © 2013 Filipe RNC Maia +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Forked from +// https://github.com/NVIDIA/cccl/blob/main/thrust/thrust/detail/complex/cexpf.h + +// TODO: We should use thrust::exp but the thrust header in old CUDA versions +// can not be used in JIT. + +#pragma once + +#include + +using ieee_float_shape_type = union { + float value; + uint32_t word; +}; + +inline void get_float_word(thread uint32_t& i, float d) { + ieee_float_shape_type gf_u; + gf_u.value = (d); + (i) = gf_u.word; +} + +inline void get_float_word(thread int32_t& i, float d) { + ieee_float_shape_type gf_u; + gf_u.value = (d); + (i) = gf_u.word; +} + +inline void set_float_word(thread float& d, uint32_t i) { + ieee_float_shape_type sf_u; + sf_u.word = (i); + (d) = sf_u.value; +} + +inline float frexp_expf(float x, thread int* expt) { + const uint32_t k = 235; + const float kln2 = 162.88958740F; + + float exp_x; + uint32_t hx; + + exp_x = metal::exp(x - kln2); + get_float_word(hx, exp_x); + *expt = (hx >> 23) - (0x7f + 127) + k; + set_float_word(exp_x, (hx & 0x7fffff) | ((0x7f + 127) << 23)); + return exp_x; +} + +inline complex64_t ldexp_cexpf(complex64_t z, int expt) { + float x, y, exp_x, scale1, scale2; + int ex_expt, half_expt; + + x = z.real; + y = z.imag; + exp_x = frexp_expf(x, &ex_expt); + expt += ex_expt; + + half_expt = expt / 2; + set_float_word(scale1, (0x7f + half_expt) << 23); + half_expt = expt - half_expt; + set_float_word(scale2, (0x7f + half_expt) << 23); + + return complex64_t{ + metal::cos(y) * exp_x * scale1 * scale2, + metal::sin(y) * exp_x * scale1 * scale2}; +} + +inline complex64_t cexpf(const thread complex64_t& z) { + float x, y, exp_x; + uint32_t hx, hy; + + const uint32_t exp_ovfl = 0x42b17218, cexp_ovfl = 0x43400074; + + x = z.real; + y = z.imag; + + get_float_word(hy, y); + hy &= 0x7fffffff; + + /* cexp(x + I 0) = exp(x) + I 0 */ + if (hy == 0) { + return complex64_t{metal::exp(x), y}; + } + get_float_word(hx, x); + /* cexp(0 + I y) = cos(y) + I sin(y) */ + if ((hx & 0x7fffffff) == 0) { + return complex64_t{metal::cos(y), metal::sin(y)}; + } + if (hy >= 0x7f800000) { + if ((hx & 0x7fffffff) != 0x7f800000) { + /* cexp(finite|NaN +- I Inf|NaN) = NaN + I NaN */ + return complex64_t{y - y, y - y}; + } else if (hx & 0x80000000) { + /* cexp(-Inf +- I Inf|NaN) = 0 + I 0 */ + return complex64_t{0.0, 0.0}; + } else { + /* cexp(+Inf +- I Inf|NaN) = Inf + I NaN */ + return complex64_t{x, y - y}; + } + } + + if (hx >= exp_ovfl && hx <= cexp_ovfl) { + /* + * x is between 88.7 and 192, so we must scale to avoid + * overflow in expf(x). + */ + return ldexp_cexpf(z, 0); + } else { + /* + * Cases covered here: + * - x < exp_ovfl and exp(x) won't overflow (common case) + * - x > cexp_ovfl, so exp(x) * s overflows for all s > 0 + * - x = +-Inf (generated by exp()) + * - x = NaN (spurious inexact exception from y) + */ + exp_x = metal::exp(x); + return complex64_t{exp_x * metal::cos(y), exp_x * metal::sin(y)}; + } +} diff --git a/Source/Cxxmlx/mlx-generated/metal/complex.h b/Source/Cxxmlx/mlx-generated/metal/complex.h new file mode 100644 index 00000000..6e391483 --- /dev/null +++ b/Source/Cxxmlx/mlx-generated/metal/complex.h @@ -0,0 +1,173 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include + +using namespace metal; + +struct complex64_t; + +template +static constexpr constant bool can_convert_to_complex64 = + !is_same_v && is_convertible_v; + +template +static constexpr constant bool can_convert_from_complex64 = + !is_same_v && + (is_convertible_v || is_convertible_v); + +struct complex64_t { + float real; + float imag; + + // Constructors + constexpr complex64_t(float real, float imag) : real(real), imag(imag) {}; + constexpr complex64_t() : real(0), imag(0) {}; + constexpr complex64_t() threadgroup : real(0), imag(0) {}; + + // Conversions to complex64_t + template < + typename T, + typename = typename enable_if>::type> + constexpr complex64_t(T x) thread : real(x), imag(0) {} + + template < + typename T, + typename = typename enable_if>::type> + constexpr complex64_t(T x) threadgroup : real(x), imag(0) {} + + template < + typename T, + typename = typename enable_if>::type> + constexpr complex64_t(T x) device : real(x), imag(0) {} + + template < + typename T, + typename = typename enable_if>::type> + constexpr complex64_t(T x) constant : real(x), imag(0) {} + + // Conversions from complex64_t + template < + typename T, + typename = typename enable_if>::type> + constexpr operator T() const thread { + return static_cast(real); + } + + template < + typename T, + typename = typename enable_if>::type> + constexpr operator T() const threadgroup { + return static_cast(real); + } + + template < + typename T, + typename = typename enable_if>::type> + constexpr operator T() const device { + return static_cast(real); + } + + template < + typename T, + typename = typename enable_if>::type> + constexpr operator T() const constant { + return static_cast(real); + } +}; + +constexpr complex64_t operator-(complex64_t x) { + return {-x.real, -x.imag}; +} + +constexpr bool operator>=(complex64_t a, complex64_t b) { + return (a.real > b.real) || (a.real == b.real && a.imag >= b.imag); +} + +constexpr bool operator>(complex64_t a, complex64_t b) { + return (a.real > b.real) || (a.real == b.real && a.imag > b.imag); +} + +constexpr bool operator<=(complex64_t a, complex64_t b) { + return operator>=(b, a); +} + +constexpr bool operator<(complex64_t a, complex64_t b) { + return operator>(b, a); +} + +constexpr bool operator==(complex64_t a, complex64_t b) { + return a.real == b.real && a.imag == b.imag; +} + +constexpr complex64_t operator+(complex64_t a, complex64_t b) { + return {a.real + b.real, a.imag + b.imag}; +} + +constexpr thread complex64_t& operator+=(thread complex64_t& a, complex64_t b) { + a.real += b.real; + a.imag += b.imag; + return a; +} + +constexpr threadgroup complex64_t& operator+=( + threadgroup complex64_t& a, + complex64_t b) { + a.real += b.real; + a.imag += b.imag; + return a; +} + +constexpr device complex64_t& operator+=(device complex64_t& a, complex64_t b) { + a.real += b.real; + a.imag += b.imag; + return a; +} + +constexpr complex64_t operator+(float a, complex64_t b) { + return {a + b.real, b.imag}; +} +constexpr complex64_t operator+(complex64_t a, float b) { + return {a.real + b, a.imag}; +} + +constexpr complex64_t operator-(complex64_t a, complex64_t b) { + return {a.real - b.real, a.imag - b.imag}; +} +constexpr complex64_t operator-(float a, complex64_t b) { + return {a - b.real, -b.imag}; +} +constexpr complex64_t operator-(complex64_t a, float b) { + return {a.real - b, a.imag}; +} + +constexpr complex64_t operator*(complex64_t a, complex64_t b) { + return {a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real}; +} + +constexpr complex64_t operator/(complex64_t a, complex64_t b) { + auto denom = b.real * b.real + b.imag * b.imag; + auto x = a.real * b.real + a.imag * b.imag; + auto y = a.imag * b.real - a.real * b.imag; + return {x / denom, y / denom}; +} + +constexpr complex64_t operator/(float a, complex64_t b) { + auto denom = b.real * b.real + b.imag * b.imag; + auto x = a * b.real; + auto y = -a * b.imag; + return {x / denom, y / denom}; +} + +constexpr complex64_t operator%(complex64_t a, complex64_t b) { + auto real = a.real - (b.real * static_cast(a.real / b.real)); + auto imag = a.imag - (b.imag * static_cast(a.imag / b.imag)); + if (real != 0 && (real < 0 != b.real < 0)) { + real += b.real; + } + if (imag != 0 && (imag < 0 != b.imag < 0)) { + imag += b.imag; + } + return {real, imag}; +} diff --git a/Source/Cmlx/mlx-generated/metal/conv.metal b/Source/Cxxmlx/mlx-generated/metal/conv.metal similarity index 100% rename from Source/Cmlx/mlx-generated/metal/conv.metal rename to Source/Cxxmlx/mlx-generated/metal/conv.metal diff --git a/Source/Cxxmlx/mlx-generated/metal/copy.h b/Source/Cxxmlx/mlx-generated/metal/copy.h new file mode 100644 index 00000000..cf22347e --- /dev/null +++ b/Source/Cxxmlx/mlx-generated/metal/copy.h @@ -0,0 +1,276 @@ +// Copyright © 2024 Apple Inc. + +template ::n> +[[kernel]] void copy_s( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant uint& size, + uint index [[thread_position_in_grid]]) { + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + dst[index + i] = static_cast(src[0]); + } + } else { + for (int i = 0; i < N; ++i) { + dst[index + i] = static_cast(src[0]); + } + } +} + +template ::n> +[[kernel]] void copy_v( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant uint& size, + uint index [[thread_position_in_grid]]) { + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + dst[index + i] = static_cast(src[index + i]); + } + } else { + for (int i = 0; i < N; ++i) { + dst[index + i] = static_cast(src[index + i]); + } + } +} + +template ::n> +[[kernel]] void copy_s2( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant int64_t& size, + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + dst[offset + i] = static_cast(src[0]); + } + } else { + for (int i = 0; i < N; ++i) { + dst[offset + i] = static_cast(src[0]); + } + } +} + +template ::n> +[[kernel]] void copy_v2( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant int64_t& size, + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + dst[offset + i] = static_cast(src[offset + i]); + } + } else { + for (int i = 0; i < N; ++i) { + dst[offset + i] = static_cast(src[offset + i]); + } + } +} + +template +[[kernel]] void copy_g_nd1( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant const int64_t& src_stride [[buffer(3)]], + uint index [[thread_position_in_grid]]) { + auto src_idx = elem_to_loc_1(index, src_stride); + dst[index] = static_cast(src[src_idx]); +} + +template +[[kernel]] void copy_g_nd2( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant const int64_t* src_strides [[buffer(3)]], + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + auto src_idx = elem_to_loc_2(index, src_strides); + IdxT dst_idx = index.x + IdxT(grid_dim.x) * index.y; + dst[dst_idx] = static_cast(src[src_idx]); +} + +template +[[kernel]] void copy_g_nd3( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant const int64_t* src_strides [[buffer(3)]], + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + auto src_idx = elem_to_loc_3(index, src_strides); + IdxT dst_idx = + index.x + IdxT(grid_dim.x) * (index.y + IdxT(grid_dim.y) * index.z); + dst[dst_idx] = static_cast(src[src_idx]); +} + +template +[[kernel]] void copy_g( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant const int* src_shape [[buffer(2)]], + constant const int64_t* src_strides [[buffer(3)]], + constant const int& ndim [[buffer(5)]], + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + auto src_idx = elem_to_loc( + {N * index.x, index.y, index.z}, src_shape, src_strides, ndim); + if (N == 1) { + IdxT dst_idx = + index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z); + dst[dst_idx] = static_cast(src[src_idx]); + return; + } + auto xshape = src_shape[ndim - 1]; + IdxT dst_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z); + auto src_xstride = src_strides[ndim - 1]; + for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { + dst[dst_idx + i] = static_cast(src[src_idx]); + src_idx += src_xstride; + } +} + +template +[[kernel]] void copy_gg_nd1( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant const int64_t& src_stride [[buffer(3)]], + constant const int64_t& dst_stride [[buffer(4)]], + uint index [[thread_position_in_grid]]) { + auto src_idx = elem_to_loc_1(index, src_stride); + auto dst_idx = elem_to_loc_1(index, dst_stride); + dst[dst_idx] = static_cast(src[src_idx]); +} + +template +[[kernel]] void copy_gg_nd2( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant const int64_t* src_strides [[buffer(3)]], + constant const int64_t* dst_strides [[buffer(4)]], + uint2 index [[thread_position_in_grid]]) { + auto src_idx = elem_to_loc_2(index, src_strides); + auto dst_idx = elem_to_loc_2(index, dst_strides); + dst[dst_idx] = static_cast(src[src_idx]); +} + +template +[[kernel]] void copy_gg_nd3( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant const int64_t* src_strides [[buffer(3)]], + constant const int64_t* dst_strides [[buffer(4)]], + uint3 index [[thread_position_in_grid]]) { + auto src_idx = elem_to_loc_3(index, src_strides); + auto dst_idx = elem_to_loc_3(index, dst_strides); + dst[dst_idx] = static_cast(src[src_idx]); +} + +template +[[kernel]] void copy_gg( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant const int* src_shape [[buffer(2)]], + constant const int64_t* src_strides [[buffer(3)]], + constant const int64_t* dst_strides [[buffer(4)]], + constant const int& ndim [[buffer(5)]], + uint3 index [[thread_position_in_grid]]) { + auto idx = elem_to_loc_2_nd( + {N * index.x, index.y, index.z}, + src_shape, + src_strides, + dst_strides, + ndim); + if (N == 1) { + dst[idx.y] = static_cast(src[idx.x]); + return; + } + IdxT src_xstride = src_strides[ndim - 1]; + IdxT dst_xstride = dst_strides[ndim - 1]; + auto xshape = src_shape[ndim - 1]; + for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { + dst[idx.y] = static_cast(src[idx.x]); + idx.x += src_xstride; + idx.y += dst_xstride; + } +} + +template +[[kernel]] void copy_gg_dynamic_nd1( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant const int64_t& src_stride [[buffer(3)]], + constant const int64_t& dst_stride [[buffer(4)]], + constant const int64_t& src_offset [[buffer(6)]], + constant const int64_t& dst_offset [[buffer(7)]], + uint index [[thread_position_in_grid]]) { + auto src_idx = elem_to_loc_1(index, src_stride); + auto dst_idx = elem_to_loc_1(index, dst_stride); + dst[dst_idx + dst_offset] = src[src_idx + src_offset]; +} + +template +[[kernel]] void copy_gg_dynamic_nd2( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant const int64_t* src_strides [[buffer(3)]], + constant const int64_t* dst_strides [[buffer(4)]], + constant const int64_t& src_offset [[buffer(6)]], + constant const int64_t& dst_offset [[buffer(7)]], + uint2 index [[thread_position_in_grid]]) { + auto src_idx = elem_to_loc_2(index, src_strides); + auto dst_idx = elem_to_loc_2(index, dst_strides); + dst[dst_idx + dst_offset] = src[src_idx + src_offset]; +} + +template +[[kernel]] void copy_gg_dynamic_nd3( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant const int64_t* src_strides [[buffer(3)]], + constant const int64_t* dst_strides [[buffer(4)]], + constant const int64_t& src_offset [[buffer(6)]], + constant const int64_t& dst_offset [[buffer(7)]], + uint3 index [[thread_position_in_grid]]) { + auto src_idx = elem_to_loc_3(index, src_strides); + auto dst_idx = elem_to_loc_3(index, dst_strides); + dst[dst_idx + dst_offset] = src[src_idx + src_offset]; +} + +template +[[kernel]] void copy_gg_dynamic( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant const int* src_shape [[buffer(2)]], + constant const int64_t* src_strides [[buffer(3)]], + constant const int64_t* dst_strides [[buffer(4)]], + constant const int& ndim [[buffer(5)]], + constant const int64_t& src_offset [[buffer(6)]], + constant const int64_t& dst_offset [[buffer(7)]], + uint3 index [[thread_position_in_grid]]) { + src += src_offset; + dst += dst_offset; + auto idx = elem_to_loc_2_nd( + {N * index.x, index.y, index.z}, + src_shape, + src_strides, + dst_strides, + ndim); + if (N == 1) { + dst[idx.y] = src[idx.x]; + return; + } + IdxT src_xstride = src_strides[ndim - 1]; + IdxT dst_xstride = dst_strides[ndim - 1]; + auto xshape = src_shape[ndim - 1]; + for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { + dst[idx.y] = src[idx.x]; + idx.x += src_xstride; + idx.y += dst_xstride; + } +} diff --git a/Source/Cxxmlx/mlx-generated/metal/defines.h b/Source/Cxxmlx/mlx-generated/metal/defines.h new file mode 100644 index 00000000..c369adb7 --- /dev/null +++ b/Source/Cxxmlx/mlx-generated/metal/defines.h @@ -0,0 +1,24 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#if defined __METAL__ || defined MLX_METAL_JIT +#define MTL_CONST constant +#else +#define MTL_CONST +#endif + +static MTL_CONST constexpr int MAX_REDUCE_SPECIALIZED_DIMS = 4; +static MTL_CONST constexpr int REDUCE_N_READS = 4; +static MTL_CONST constexpr int REDUCE_N_WRITES = 4; +static MTL_CONST constexpr int SOFTMAX_N_READS = 4; +static MTL_CONST constexpr int RMS_N_READS = 4; +static MTL_CONST constexpr int RMS_LOOPED_LIMIT = 4096; + +// Instantiate a templated kernel. +// Extra args are used as template parameters: +// e.g. instantiate_kernel(binary_int, binary, a, b) -> +// [[host_name(binary_int)]] [kernel] binary +#define instantiate_kernel(name, func, ...) \ + template [[host_name( \ + name)]] [[kernel]] decltype(func<__VA_ARGS__>) func<__VA_ARGS__>; diff --git a/Source/Cmlx/mlx-generated/metal/erf.h b/Source/Cxxmlx/mlx-generated/metal/erf.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/erf.h rename to Source/Cxxmlx/mlx-generated/metal/erf.h diff --git a/Source/Cxxmlx/mlx-generated/metal/expm1f.h b/Source/Cxxmlx/mlx-generated/metal/expm1f.h new file mode 100644 index 00000000..68224e17 --- /dev/null +++ b/Source/Cxxmlx/mlx-generated/metal/expm1f.h @@ -0,0 +1,90 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include + +// Original license copied below: +// Copyright (c) 2015-2023 Norbert Juffa +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// +// 1. Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +/* Compute exponential base e minus 1. Maximum ulp error = 0.997458 + + i = rint(a/log(2)), f = a-i*log(2). Then expm1(a) = 2**i * (expm1(f)+1) - 1. + Compute r = expm1(f). Then expm1(a)= 2 * (0.5 * 2**i * r + 0.5 * 2**i - 0.5). + With t = 0.5*2**i, expm1(a) = 2*(r * t + t-0.5). However, for best accuracy, + when i == 1, expm1(a)= 2*(r + 0.5), and when i == 0, expm1(a) = r. + + NOTE: Scale factor b is only applied if i < 0 or i > 1 (should be power of 2) +*/ +float expm1f_scaled_unchecked(float a, float b) { + float f, j, r, s, t, u, v, x, y; + int i; + + // exp(a) = 2**i * exp(f); i = rintf (a / log(2)) + j = fma(1.442695f, a, 12582912.f); // 0x1.715476p0, 0x1.8p23 + j = j - 12582912.0f; // 0x1.8p23 + i = (int)j; + f = fma(j, -6.93145752e-1f, a); + + // approximate r = exp(f)-1 on interval [-log(2)/2, +log(2)/2] + s = f * f; + if (a == 0.0f) + s = a; // ensure -0 is passed through + // err = 0.997458 ulp1 = 11081805 + r = 1.97350979e-4f; // 0x1.9de000p-13 + r = fma(r, f, 1.39309070e-3f); // 0x1.6d30bcp-10 + r = fma(r, f, 8.33343994e-3f); // 0x1.1111f6p-7 + r = fma(r, f, 4.16668020e-2f); // 0x1.55559ep-5 + r = fma(r, f, 1.66666716e-1f); // 0x1.55555cp-3 + r = fma(r, f, 4.99999970e-1f); // 0x1.fffffep-2 + u = (j == 1) ? (f + 0.5f) : f; + v = fma(r, s, u); + s = 0.5f * b; + t = ldexp(s, i); + y = t - s; + x = (t - y) - s; // double-float canonicalization of difference + r = fma(v, t, x) + y; + r = r + r; + if (j == 0) + r = v; + if (j == 1) + r = v + v; + return r; +} + +/* Compute exponential base e minus 1. max ulp err = 0.99746 */ +float expm1f(float a) { + float r; + + r = expm1f_scaled_unchecked(a, 1.0f); + /* handle severe overflow and underflow */ + if (abs(a - 1.0f) > 88.0f) { + r = pow(2, a); + r = fma(r, r, -1.0f); + } + return r; +} diff --git a/Source/Cmlx/mlx-generated/metal/fft.h b/Source/Cxxmlx/mlx-generated/metal/fft.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/fft.h rename to Source/Cxxmlx/mlx-generated/metal/fft.h diff --git a/Source/Cxxmlx/mlx-generated/metal/fft/radix.h b/Source/Cxxmlx/mlx-generated/metal/fft/radix.h new file mode 100644 index 00000000..bd61eef6 --- /dev/null +++ b/Source/Cxxmlx/mlx-generated/metal/fft/radix.h @@ -0,0 +1,328 @@ +// Copyright © 2024 Apple Inc. + +/* Radix kernels + +We provide optimized, single threaded Radix codelets +for n=2,3,4,5,6,7,8,10,11,12,13. + +For n=2,3,4,5,6 we hand write the codelets. +For n=8,10,12 we combine smaller codelets. +For n=7,11,13 we use Rader's algorithm which decomposes +them into (n-1)=6,10,12 codelets. */ + +#pragma once + +#include +#include +#include + +METAL_FUNC float2 complex_mul(float2 a, float2 b) { + return float2(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x); +} + +// Complex mul followed by conjugate +METAL_FUNC float2 complex_mul_conj(float2 a, float2 b) { + return float2(a.x * b.x - a.y * b.y, -a.x * b.y - a.y * b.x); +} + +// Compute an FFT twiddle factor +METAL_FUNC float2 get_twiddle(int k, int p) { + float theta = -2.0f * k * M_PI_F / p; + + float2 twiddle = {metal::fast::cos(theta), metal::fast::sin(theta)}; + return twiddle; +} + +METAL_FUNC void radix2(thread float2* x, thread float2* y) { + y[0] = x[0] + x[1]; + y[1] = x[0] - x[1]; +} + +METAL_FUNC void radix3(thread float2* x, thread float2* y) { + float pi_2_3 = -0.8660254037844387; + + float2 a_1 = x[1] + x[2]; + float2 a_2 = x[1] - x[2]; + + y[0] = x[0] + a_1; + float2 b_1 = x[0] - 0.5 * a_1; + float2 b_2 = pi_2_3 * a_2; + + float2 b_2_j = {-b_2.y, b_2.x}; + y[1] = b_1 + b_2_j; + y[2] = b_1 - b_2_j; +} + +METAL_FUNC void radix4(thread float2* x, thread float2* y) { + float2 z_0 = x[0] + x[2]; + float2 z_1 = x[0] - x[2]; + float2 z_2 = x[1] + x[3]; + float2 z_3 = x[1] - x[3]; + float2 z_3_i = {z_3.y, -z_3.x}; + + y[0] = z_0 + z_2; + y[1] = z_1 + z_3_i; + y[2] = z_0 - z_2; + y[3] = z_1 - z_3_i; +} + +METAL_FUNC void radix5(thread float2* x, thread float2* y) { + float2 root_5_4 = 0.5590169943749475; + float2 sin_2pi_5 = 0.9510565162951535; + float2 sin_1pi_5 = 0.5877852522924731; + + float2 a_1 = x[1] + x[4]; + float2 a_2 = x[2] + x[3]; + float2 a_3 = x[1] - x[4]; + float2 a_4 = x[2] - x[3]; + + float2 a_5 = a_1 + a_2; + float2 a_6 = root_5_4 * (a_1 - a_2); + float2 a_7 = x[0] - a_5 / 4; + float2 a_8 = a_7 + a_6; + float2 a_9 = a_7 - a_6; + float2 a_10 = sin_2pi_5 * a_3 + sin_1pi_5 * a_4; + float2 a_11 = sin_1pi_5 * a_3 - sin_2pi_5 * a_4; + float2 a_10_j = {a_10.y, -a_10.x}; + float2 a_11_j = {a_11.y, -a_11.x}; + + y[0] = x[0] + a_5; + y[1] = a_8 + a_10_j; + y[2] = a_9 + a_11_j; + y[3] = a_9 - a_11_j; + y[4] = a_8 - a_10_j; +} + +METAL_FUNC void radix6(thread float2* x, thread float2* y) { + float sin_pi_3 = 0.8660254037844387; + float2 a_1 = x[2] + x[4]; + float2 a_2 = x[0] - a_1 / 2; + float2 a_3 = sin_pi_3 * (x[2] - x[4]); + float2 a_4 = x[5] + x[1]; + float2 a_5 = x[3] - a_4 / 2; + float2 a_6 = sin_pi_3 * (x[5] - x[1]); + float2 a_7 = x[0] + a_1; + + float2 a_3_i = {a_3.y, -a_3.x}; + float2 a_6_i = {a_6.y, -a_6.x}; + float2 a_8 = a_2 + a_3_i; + float2 a_9 = a_2 - a_3_i; + float2 a_10 = x[3] + a_4; + float2 a_11 = a_5 + a_6_i; + float2 a_12 = a_5 - a_6_i; + + y[0] = a_7 + a_10; + y[1] = a_8 - a_11; + y[2] = a_9 + a_12; + y[3] = a_7 - a_10; + y[4] = a_8 + a_11; + y[5] = a_9 - a_12; +} + +METAL_FUNC void radix7(thread float2* x, thread float2* y) { + // Rader's algorithm + float2 inv = {1 / 6.0, -1 / 6.0}; + + // fft + float2 in1[6] = {x[1], x[3], x[2], x[6], x[4], x[5]}; + radix6(in1, y + 1); + + y[0] = y[1] + x[0]; + + // b_q + y[1] = complex_mul_conj(y[1], float2(-1, 0)); + y[2] = complex_mul_conj(y[2], float2(2.44013336, -1.02261879)); + y[3] = complex_mul_conj(y[3], float2(2.37046941, -1.17510629)); + y[4] = complex_mul_conj(y[4], float2(0, -2.64575131)); + y[5] = complex_mul_conj(y[5], float2(2.37046941, 1.17510629)); + y[6] = complex_mul_conj(y[6], float2(-2.44013336, -1.02261879)); + + // ifft + radix6(y + 1, x + 1); + + y[1] = x[1] * inv + x[0]; + y[5] = x[2] * inv + x[0]; + y[4] = x[3] * inv + x[0]; + y[6] = x[4] * inv + x[0]; + y[2] = x[5] * inv + x[0]; + y[3] = x[6] * inv + x[0]; +} + +METAL_FUNC void radix8(thread float2* x, thread float2* y) { + float cos_pi_4 = 0.7071067811865476; + float2 w_0 = {cos_pi_4, -cos_pi_4}; + float2 w_1 = {-cos_pi_4, -cos_pi_4}; + float2 temp[8] = {x[0], x[2], x[4], x[6], x[1], x[3], x[5], x[7]}; + radix4(temp, x); + radix4(temp + 4, x + 4); + + y[0] = x[0] + x[4]; + y[4] = x[0] - x[4]; + float2 x_5 = complex_mul(x[5], w_0); + y[1] = x[1] + x_5; + y[5] = x[1] - x_5; + float2 x_6 = {x[6].y, -x[6].x}; + y[2] = x[2] + x_6; + y[6] = x[2] - x_6; + float2 x_7 = complex_mul(x[7], w_1); + y[3] = x[3] + x_7; + y[7] = x[3] - x_7; +} + +template +METAL_FUNC void radix10(thread float2* x, thread float2* y) { + float2 w[4]; + w[0] = {0.8090169943749475, -0.5877852522924731}; + w[1] = {0.30901699437494745, -0.9510565162951535}; + w[2] = {-w[1].x, w[1].y}; + w[3] = {-w[0].x, w[0].y}; + + if (raders_perm) { + float2 temp[10] = { + x[0], x[3], x[4], x[8], x[2], x[1], x[7], x[9], x[6], x[5]}; + radix5(temp, x); + radix5(temp + 5, x + 5); + } else { + float2 temp[10] = { + x[0], x[2], x[4], x[6], x[8], x[1], x[3], x[5], x[7], x[9]}; + radix5(temp, x); + radix5(temp + 5, x + 5); + } + + y[0] = x[0] + x[5]; + y[5] = x[0] - x[5]; + for (int t = 1; t < 5; t++) { + float2 a = complex_mul(x[t + 5], w[t - 1]); + y[t] = x[t] + a; + y[t + 5] = x[t] - a; + } +} + +METAL_FUNC void radix11(thread float2* x, thread float2* y) { + // Raders Algorithm + float2 inv = {1 / 10.0, -1 / 10.0}; + + // fft + radix10(x + 1, y + 1); + + y[0] = y[1] + x[0]; + + // b_q + y[1] = complex_mul_conj(y[1], float2(-1, 0)); + y[2] = complex_mul_conj(y[2], float2(0.955301878, -3.17606649)); + y[3] = complex_mul_conj(y[3], float2(2.63610556, 2.01269656)); + y[4] = complex_mul_conj(y[4], float2(2.54127802, 2.13117479)); + y[5] = complex_mul_conj(y[5], float2(2.07016210, 2.59122150)); + y[6] = complex_mul_conj(y[6], float2(0, -3.31662479)); + y[7] = complex_mul_conj(y[7], float2(2.07016210, -2.59122150)); + y[8] = complex_mul_conj(y[8], float2(-2.54127802, 2.13117479)); + y[9] = complex_mul_conj(y[9], float2(2.63610556, -2.01269656)); + y[10] = complex_mul_conj(y[10], float2(-0.955301878, -3.17606649)); + + // ifft + radix10(y + 1, x + 1); + + y[1] = x[1] * inv + x[0]; + y[6] = x[2] * inv + x[0]; + y[3] = x[3] * inv + x[0]; + y[7] = x[4] * inv + x[0]; + y[9] = x[5] * inv + x[0]; + y[10] = x[6] * inv + x[0]; + y[5] = x[7] * inv + x[0]; + y[8] = x[8] * inv + x[0]; + y[4] = x[9] * inv + x[0]; + y[2] = x[10] * inv + x[0]; +} + +template +METAL_FUNC void radix12(thread float2* x, thread float2* y) { + float2 w[6]; + float sin_pi_3 = 0.8660254037844387; + w[0] = {sin_pi_3, -0.5}; + w[1] = {0.5, -sin_pi_3}; + w[2] = {0, -1}; + w[3] = {-0.5, -sin_pi_3}; + w[4] = {-sin_pi_3, -0.5}; + + if (raders_perm) { + float2 temp[12] = { + x[0], + x[3], + x[2], + x[11], + x[8], + x[9], + x[1], + x[7], + x[5], + x[10], + x[4], + x[6]}; + radix6(temp, x); + radix6(temp + 6, x + 6); + } else { + float2 temp[12] = { + x[0], + x[2], + x[4], + x[6], + x[8], + x[10], + x[1], + x[3], + x[5], + x[7], + x[9], + x[11]}; + radix6(temp, x); + radix6(temp + 6, x + 6); + } + + y[0] = x[0] + x[6]; + y[6] = x[0] - x[6]; + for (int t = 1; t < 6; t++) { + float2 a = complex_mul(x[t + 6], w[t - 1]); + y[t] = x[t] + a; + y[t + 6] = x[t] - a; + } +} + +METAL_FUNC void radix13(thread float2* x, thread float2* y) { + // Raders Algorithm + float2 inv = {1 / 12.0, -1 / 12.0}; + + // fft + radix12(x + 1, y + 1); + + y[0] = y[1] + x[0]; + + // b_q + y[1] = complex_mul_conj(y[1], float2(-1, 0)); + y[2] = complex_mul_conj(y[2], float2(3.07497206, -1.88269669)); + y[3] = complex_mul_conj(y[3], float2(3.09912468, 1.84266823)); + y[4] = complex_mul_conj(y[4], float2(3.45084438, -1.04483161)); + y[5] = complex_mul_conj(y[5], float2(0.91083583, 3.48860690)); + y[6] = complex_mul_conj(y[6], float2(-3.60286363, 0.139189267)); + y[7] = complex_mul_conj(y[7], float2(3.60555128, 0)); + y[8] = complex_mul_conj(y[8], float2(3.60286363, 0.139189267)); + y[9] = complex_mul_conj(y[9], float2(0.91083583, -3.48860690)); + y[10] = complex_mul_conj(y[10], float2(-3.45084438, -1.04483161)); + y[11] = complex_mul_conj(y[11], float2(3.09912468, -1.84266823)); + y[12] = complex_mul_conj(y[12], float2(-3.07497206, -1.88269669)); + + // ifft + radix12(y + 1, x + 1); + + y[1] = x[1] * inv + x[0]; + y[7] = x[2] * inv + x[0]; + y[10] = x[3] * inv + x[0]; + y[5] = x[4] * inv + x[0]; + y[9] = x[5] * inv + x[0]; + y[11] = x[6] * inv + x[0]; + y[12] = x[7] * inv + x[0]; + y[6] = x[8] * inv + x[0]; + y[3] = x[9] * inv + x[0]; + y[8] = x[10] * inv + x[0]; + y[4] = x[11] * inv + x[0]; + y[2] = x[12] * inv + x[0]; +} \ No newline at end of file diff --git a/Source/Cmlx/mlx-generated/metal/fft/readwrite.h b/Source/Cxxmlx/mlx-generated/metal/fft/readwrite.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/fft/readwrite.h rename to Source/Cxxmlx/mlx-generated/metal/fft/readwrite.h diff --git a/Source/Cxxmlx/mlx-generated/metal/fp4.h b/Source/Cxxmlx/mlx-generated/metal/fp4.h new file mode 100644 index 00000000..25642f20 --- /dev/null +++ b/Source/Cxxmlx/mlx-generated/metal/fp4.h @@ -0,0 +1,48 @@ +#pragma once + +struct fp4_e2m1 { + fp4_e2m1(float x) { + if (metal::isnan(x)) { + bits = 0x7; + return; + } + + const uint8_t sign_bit = (metal::signbit(x)) ? 0x8 : 0x0; + x = metal::abs(x); + + if (x > 5.0f) { + bits = 0x7; + } else if (x >= 3.5f) { + bits = 0x6; + } else if (x > 2.5f) { + bits = 0x5; + } else if (x >= 1.75f) { + bits = 0x4; + } else if (x > 1.25f) { + bits = 0x3; + } else if (x >= 0.75f) { + bits = 0x2; + } else if (x > 0.25f) { + bits = 0x1; + } else { + bits = 0x0; + } + bits |= sign_bit; + } + + operator float16_t() { + half converted = as_type(ushort((bits & 7) << 9)); + converted *= 16384.0; + return bits & 8 ? -converted : converted; + } + + operator float() { + return static_cast(this->operator float16_t()); + } + + operator bfloat16_t() { + return static_cast(this->operator float16_t()); + } + + uint8_t bits; +}; diff --git a/Source/Cxxmlx/mlx-generated/metal/fp8.h b/Source/Cxxmlx/mlx-generated/metal/fp8.h new file mode 100644 index 00000000..60d34be6 --- /dev/null +++ b/Source/Cxxmlx/mlx-generated/metal/fp8.h @@ -0,0 +1,80 @@ +#pragma once + +struct fp8_e4m3 { + template + fp8_e4m3(T f) { + // From PyTorch + // https://github.com/pytorch/pytorch/blob/e3643e1e0e923f0fc063dfab6f45c956d568919d/c10/util/Float8_e4m3fn.h#L148 + uint32_t fp8_max = 543 << 21; + uint32_t denorm_mask = 141 << 23; + uint32_t f_bits = as_type(static_cast(f)); + uint32_t sign = f_bits & 0x80000000; + f_bits ^= sign; + if (f_bits >= fp8_max) { + // Default behavior saturates to min/max + bits = 0x7E; + } else { + if (f_bits < (121 << 23)) { + f_bits = as_type( + as_type(f_bits) + as_type(denorm_mask)); + bits = static_cast(f_bits - denorm_mask); + } else { + // resulting mantissa is odd + uint8_t mant_odd = (f_bits >> 20) & 1; + f_bits += ((uint32_t)(7 - 127) << 23) + 0x7FFFF; + f_bits += mant_odd; + bits = static_cast(f_bits >> 20); + } + } + bits |= static_cast(sign >> 24); + } + + operator float16_t() { + uint16_t v = (bits & 127) << 7; + half converted = as_type(v); + converted *= 256.0; + auto sign = bits & 128; + return (sign ? -converted : converted); + } + + operator bfloat16_t() { + return static_cast(this->operator float16_t()); + } + + operator float() { + return static_cast(this->operator float16_t()); + } + + uint8_t bits; +}; + +struct fp8_e8m0 { + fp8_e8m0(float x) { + if (!metal::isfinite(x)) { + bits = 0xFF; + return; + } + if (x < 0.0f) { + bits = 0x00; + return; + } + float le = metal::log2(x); + int n = int(metal::round(le)); + + n = n < -127 ? -127 : n; + n = n > 127 ? 127 : n; + bits = static_cast(n + 127); + } + + operator bfloat16_t() { + uint16_t out = (bits == 0 ? 0x40 : (static_cast(bits) << 7)); + return as_type(out); + } + + operator float() { + uint32_t out = (bits == 0 ? 0x400000 : (static_cast(bits) << 23)); + return as_type(out); + } + + uint8_t bits; +}; diff --git a/Source/Cmlx/mlx-generated/metal/fp_quantized.h b/Source/Cxxmlx/mlx-generated/metal/fp_quantized.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/fp_quantized.h rename to Source/Cxxmlx/mlx-generated/metal/fp_quantized.h diff --git a/Source/Cmlx/mlx-generated/metal/fp_quantized_nax.h b/Source/Cxxmlx/mlx-generated/metal/fp_quantized_nax.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/fp_quantized_nax.h rename to Source/Cxxmlx/mlx-generated/metal/fp_quantized_nax.h diff --git a/Source/Cmlx/mlx-generated/metal/gemv.metal b/Source/Cxxmlx/mlx-generated/metal/gemv.metal similarity index 100% rename from Source/Cmlx/mlx-generated/metal/gemv.metal rename to Source/Cxxmlx/mlx-generated/metal/gemv.metal diff --git a/Source/Cmlx/mlx-generated/metal/gemv_masked.h b/Source/Cxxmlx/mlx-generated/metal/gemv_masked.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/gemv_masked.h rename to Source/Cxxmlx/mlx-generated/metal/gemv_masked.h diff --git a/Source/Cmlx/mlx-generated/metal/hadamard.h b/Source/Cxxmlx/mlx-generated/metal/hadamard.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/hadamard.h rename to Source/Cxxmlx/mlx-generated/metal/hadamard.h diff --git a/Source/Cmlx/mlx-generated/metal/indexing/gather.h b/Source/Cxxmlx/mlx-generated/metal/indexing/gather.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/indexing/gather.h rename to Source/Cxxmlx/mlx-generated/metal/indexing/gather.h diff --git a/Source/Cxxmlx/mlx-generated/metal/indexing/gather_axis.h b/Source/Cxxmlx/mlx-generated/metal/indexing/gather_axis.h new file mode 100644 index 00000000..bf490ade --- /dev/null +++ b/Source/Cxxmlx/mlx-generated/metal/indexing/gather_axis.h @@ -0,0 +1,44 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +template +[[kernel]] void gather_axis( + const device T* src [[buffer(0)]], + const device IdxT* indices [[buffer(1)]], + device T* out [[buffer(2)]], + const constant int* shape [[buffer(3)]], + const constant int64_t* src_strides [[buffer(4)]], + const constant int64_t* idx_strides [[buffer(5)]], + const constant size_t& ndim [[buffer(6)]], + const constant int& axis [[buffer(7)]], + const constant int& axis_size [[buffer(8)]], + const constant size_t& src_ax_stride [[buffer(9)]], + const constant size_t& idx_ax_stride [[buffer(10)]], + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + LocT elem_idx = index.z * static_cast(grid_dim.x); + LocT out_idx = elem_idx * grid_dim.y + index.x; + + LocT idx_loc = index.y * static_cast(idx_ax_stride); + if (IdxC) { + idx_loc += out_idx; + } else { + idx_loc += elem_to_loc(elem_idx + index.x, shape, idx_strides, ndim); + } + + auto idx_val = indices[idx_loc]; + if (is_signed_v) { + idx_val = (idx_val < 0) ? idx_val + axis_size : idx_val; + } + + LocT src_idx = idx_val * static_cast(src_ax_stride); + if (SrcC) { + src_idx += elem_idx * axis_size + index.x; + } else { + src_idx += elem_to_loc(elem_idx + index.x, shape, src_strides, ndim); + } + + out_idx += index.y * static_cast(grid_dim.x); + out[out_idx] = src[src_idx]; +} diff --git a/Source/Cmlx/mlx-generated/metal/indexing/gather_front.h b/Source/Cxxmlx/mlx-generated/metal/indexing/gather_front.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/indexing/gather_front.h rename to Source/Cxxmlx/mlx-generated/metal/indexing/gather_front.h diff --git a/Source/Cxxmlx/mlx-generated/metal/indexing/indexing.h b/Source/Cxxmlx/mlx-generated/metal/indexing/indexing.h new file mode 100644 index 00000000..2a4b4f92 --- /dev/null +++ b/Source/Cxxmlx/mlx-generated/metal/indexing/indexing.h @@ -0,0 +1,23 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include + +template +struct Indices { + const array buffers; + const constant int* shapes; + const constant int64_t* strides; + const constant bool* row_contiguous; + const int ndim; +}; + +template +METAL_FUNC size_t offset_neg_idx(IdxT idx, int size) { + if (is_unsigned_v) { + return idx; + } else { + return (idx < 0) ? idx + size : idx; + } +} diff --git a/Source/Cxxmlx/mlx-generated/metal/indexing/masked_scatter.h b/Source/Cxxmlx/mlx-generated/metal/indexing/masked_scatter.h new file mode 100644 index 00000000..2ba54740 --- /dev/null +++ b/Source/Cxxmlx/mlx-generated/metal/indexing/masked_scatter.h @@ -0,0 +1,41 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +constant mlx::os_log logger("mlx", "masked_assign"); + +template +[[kernel]] void masked_assign_impl( + const device bool* mask [[buffer(0)]], + const device uint* scatter_offsets [[buffer(1)]], + const device T* src [[buffer(2)]], + device T* out [[buffer(3)]], + const constant int* src_shapes [[buffer(4)]], + const constant int64_t* src_strides [[buffer(5)]], + const constant int& src_ndim [[buffer(6)]], + const constant int64_t& src_batch_size [[buffer(7)]], + const constant int64_t& mask_batch_size [[buffer(8)]], + uint idx [[thread_position_in_grid]]) { + const bool mask_value = mask[idx]; + if (!mask_value) { + return; + } + + const uint src_index = scatter_offsets[idx]; + if (src_index >= src_batch_size) { + logger.log_debug("Out of bound read from src"); + return; + } + + const uint batch_idx = idx / mask_batch_size; + + if (src_contiguous) { + out[idx] = src[batch_idx * src_batch_size + src_index]; + } else { + out[idx] = src[elem_to_loc( + batch_idx * src_batch_size + src_index, + src_shapes, + src_strides, + src_ndim)]; + } +} diff --git a/Source/Cmlx/mlx-generated/metal/indexing/scatter.h b/Source/Cxxmlx/mlx-generated/metal/indexing/scatter.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/indexing/scatter.h rename to Source/Cxxmlx/mlx-generated/metal/indexing/scatter.h diff --git a/Source/Cxxmlx/mlx-generated/metal/indexing/scatter_axis.h b/Source/Cxxmlx/mlx-generated/metal/indexing/scatter_axis.h new file mode 100644 index 00000000..73fd7ab4 --- /dev/null +++ b/Source/Cxxmlx/mlx-generated/metal/indexing/scatter_axis.h @@ -0,0 +1,52 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +template < + typename T, + typename IdxT, + typename LocT, + typename Op, + bool UpdC, + bool IdxC> +[[kernel]] void scatter_axis( + const device T* upd [[buffer(0)]], + const device IdxT* indices [[buffer(1)]], + device mlx_atomic* out [[buffer(2)]], + const constant int* shape [[buffer(3)]], + const constant int64_t* upd_strides [[buffer(4)]], + const constant int64_t* idx_strides [[buffer(5)]], + const constant size_t& ndim [[buffer(6)]], + const constant int& axis [[buffer(7)]], + const constant int& out_axis_size [[buffer(8)]], + const constant size_t& upd_ax_stride [[buffer(9)]], + const constant size_t& idx_ax_stride [[buffer(10)]], + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + Op op; + + LocT elem_idx = index.z * static_cast(grid_dim.x); + + LocT idx_loc = index.y * static_cast(idx_ax_stride); + if (IdxC) { + idx_loc += elem_idx * grid_dim.y + index.x; + } else { + idx_loc += elem_to_loc(elem_idx + index.x, shape, idx_strides, ndim); + } + + auto idx_val = indices[idx_loc]; + if (is_signed_v) { + idx_val = (idx_val < 0) ? idx_val + out_axis_size : idx_val; + } + + LocT upd_idx = index.y * static_cast(upd_ax_stride); + if (UpdC) { + upd_idx += elem_idx * grid_dim.y + index.x; + } else { + upd_idx += elem_to_loc(elem_idx + index.x, shape, upd_strides, ndim); + } + + LocT out_idx = elem_idx * static_cast(out_axis_size) + + idx_val * grid_dim.x + index.x; + op.atomic_update(out, upd[upd_idx], out_idx); +} diff --git a/Source/Cmlx/mlx-generated/metal/layer_norm.metal b/Source/Cxxmlx/mlx-generated/metal/layer_norm.metal similarity index 100% rename from Source/Cmlx/mlx-generated/metal/layer_norm.metal rename to Source/Cxxmlx/mlx-generated/metal/layer_norm.metal diff --git a/Source/Cxxmlx/mlx-generated/metal/logging.h b/Source/Cxxmlx/mlx-generated/metal/logging.h new file mode 100644 index 00000000..7b3ee046 --- /dev/null +++ b/Source/Cxxmlx/mlx-generated/metal/logging.h @@ -0,0 +1,26 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#if defined(__METAL_VERSION__) && (__METAL_VERSION__ >= 320) +#include + +namespace mlx { +using os_log = metal::os_log; +} // namespace mlx + +#else + +namespace mlx { +struct os_log { + constexpr os_log(constant char*, constant char*) constant {} + + template + void log_debug(constant char*, Args...) const {} + + template + void log_debug(constant char*, Args...) const constant {} +}; +} // namespace mlx + +#endif \ No newline at end of file diff --git a/Source/Cxxmlx/mlx-generated/metal/logsumexp.h b/Source/Cxxmlx/mlx-generated/metal/logsumexp.h new file mode 100644 index 00000000..c746050b --- /dev/null +++ b/Source/Cxxmlx/mlx-generated/metal/logsumexp.h @@ -0,0 +1,140 @@ +// Copyright © 2025 Apple Inc. + +template +[[kernel]] void logsumexp( + const device T* in, + device T* out, + constant int& axis_size, + uint gid [[threadgroup_position_in_grid]], + uint _lid [[thread_position_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + int lid = _lid; + + constexpr int SIMD_SIZE = 32; + + threadgroup AccT local_max[SIMD_SIZE]; + threadgroup AccT local_normalizer[SIMD_SIZE]; + + AccT ld[N_READS]; + + in += gid * size_t(axis_size) + lid * N_READS; + if (lid * N_READS + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + ld[i] = AccT(in[i]); + } + } else { + for (int i = 0; i < N_READS; i++) { + ld[i] = + ((lid * N_READS + i) < axis_size) ? AccT(in[i]) : Limits::min; + } + } + if (simd_group_id == 0) { + local_max[simd_lane_id] = Limits::min; + local_normalizer[simd_lane_id] = 0; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Get the max + AccT maxval = Limits::finite_min; + for (int i = 0; i < N_READS; i++) { + maxval = (maxval < ld[i]) ? ld[i] : maxval; + } + maxval = simd_max(maxval); + if (simd_lane_id == 0) { + local_max[simd_group_id] = maxval; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (simd_group_id == 0) { + maxval = simd_max(local_max[simd_lane_id]); + if (simd_lane_id == 0) { + local_max[0] = maxval; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + maxval = local_max[0]; + + // Compute exp(x_i - maxval) and store the partial sums in local_normalizer + AccT normalizer = 0; + for (int i = 0; i < N_READS; i++) { + normalizer += fast::exp(ld[i] - maxval); + } + normalizer = simd_sum(normalizer); + if (simd_lane_id == 0) { + local_normalizer[simd_group_id] = normalizer; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (simd_group_id == 0) { + normalizer = simd_sum(local_normalizer[simd_lane_id]); + if (simd_lane_id == 0) { + out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval); + } + } +} + +template +[[kernel]] void logsumexp_looped( + const device T* in, + device T* out, + constant int& axis_size, + uint gid [[threadgroup_position_in_grid]], + uint lid [[thread_position_in_threadgroup]], + uint lsize [[threads_per_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + in += gid * size_t(axis_size); + + constexpr int SIMD_SIZE = 32; + + threadgroup AccT local_max[SIMD_SIZE]; + threadgroup AccT local_normalizer[SIMD_SIZE]; + + // Get the max and the normalizer in one go + AccT prevmax; + AccT maxval = Limits::finite_min; + AccT normalizer = 0; + for (int r = 0; r < static_cast(ceildiv(axis_size, N_READS * lsize)); + r++) { + int offset = r * lsize * N_READS + lid * N_READS; + AccT vals[N_READS]; + if (offset + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + vals[i] = AccT(in[offset + i]); + } + } else { + for (int i = 0; i < N_READS; i++) { + vals[i] = + (offset + i < axis_size) ? AccT(in[offset + i]) : Limits::min; + } + } + prevmax = maxval; + for (int i = 0; i < N_READS; i++) { + maxval = (maxval < vals[i]) ? vals[i] : maxval; + } + normalizer *= fast::exp(prevmax - maxval); + for (int i = 0; i < N_READS; i++) { + normalizer += fast::exp(vals[i] - maxval); + } + } + prevmax = maxval; + maxval = simd_max(maxval); + normalizer *= fast::exp(prevmax - maxval); + normalizer = simd_sum(normalizer); + + prevmax = maxval; + if (simd_lane_id == 0) { + local_max[simd_group_id] = maxval; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + maxval = simd_max(local_max[simd_lane_id]); + normalizer *= fast::exp(prevmax - maxval); + if (simd_lane_id == 0) { + local_normalizer[simd_group_id] = normalizer; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + normalizer = simd_sum(local_normalizer[simd_lane_id]); + + if (lid == 0) { + out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval); + } +} diff --git a/Source/Cxxmlx/mlx-generated/metal/quantized.h b/Source/Cxxmlx/mlx-generated/metal/quantized.h new file mode 100644 index 00000000..5ac4c6e1 --- /dev/null +++ b/Source/Cxxmlx/mlx-generated/metal/quantized.h @@ -0,0 +1,2508 @@ +// Copyright © 2023-2024 Apple Inc. + +#include +#include + +constant bool align_M [[function_constant(200)]]; +constant bool align_N [[function_constant(201)]]; +constant bool align_K [[function_constant(202)]]; + +using namespace metal; + +#define MLX_MTL_CONST static constant constexpr const + +MLX_MTL_CONST int SIMD_SIZE = 32; +MLX_MTL_CONST int QUAD_SIZE = 4; + +template +inline constexpr short get_pack_factor() { + return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits); +} + +template +inline constexpr short get_bytes_per_pack() { + constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; + return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3); +} + +template +inline U load_vector(const device T* x, thread U* x_thread) { + static_assert( + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + + U sum = 0; + + if (bits == 2) { + for (int i = 0; i < values_per_thread; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 4.0f; + x_thread[i + 2] = x[i + 2] / 16.0f; + x_thread[i + 3] = x[i + 3] / 64.0f; + } + } + + else if (bits == 3) { + for (int i = 0; i < values_per_thread; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 8.0f; + x_thread[i + 2] = x[i + 2] / 64.0f; + x_thread[i + 3] = x[i + 3] / 2.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 128.0f; + x_thread[i + 6] = x[i + 6] / 4.0f; + x_thread[i + 7] = x[i + 7] / 32.0f; + } + } + + else if (bits == 4) { + for (int i = 0; i < values_per_thread; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 16.0f; + x_thread[i + 2] = x[i + 2] / 256.0f; + x_thread[i + 3] = x[i + 3] / 4096.0f; + } + } + + else if (bits == 5) { + for (int i = 0; i < values_per_thread; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 32.0f; + x_thread[i + 2] = x[i + 2] / 4.0f; + x_thread[i + 3] = x[i + 3] / 128.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 2.0f; + x_thread[i + 6] = x[i + 6] / 64.0f; + x_thread[i + 7] = x[i + 7] / 8.0f; + } + } + + else if (bits == 6) { + for (int i = 0; i < values_per_thread; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 64.0f; + x_thread[i + 2] = x[i + 2] / 16.0f; + x_thread[i + 3] = x[i + 3] / 4.0f; + } + } + + else if (bits == 8) { + for (int i = 0; i < values_per_thread; i++) { + sum += x[i]; + x_thread[i] = x[i]; + } + } + + return sum; +} + +template +inline U load_vector_safe(const device T* x, thread U* x_thread, int N) { + static_assert( + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + + U sum = 0; + + if (bits == 2) { + for (int i = 0; i < N; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 4.0f; + x_thread[i + 2] = x[i + 2] / 16.0f; + x_thread[i + 3] = x[i + 3] / 64.0f; + } + } + + else if (bits == 3) { + for (int i = 0; i < N; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 8.0f; + x_thread[i + 2] = x[i + 2] / 64.0f; + x_thread[i + 3] = x[i + 3] / 2.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 128.0f; + x_thread[i + 6] = x[i + 6] / 4.0f; + x_thread[i + 7] = x[i + 7] / 32.0f; + } + } + + else if (bits == 4) { + for (int i = 0; i < N; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 16.0f; + x_thread[i + 2] = x[i + 2] / 256.0f; + x_thread[i + 3] = x[i + 3] / 4096.0f; + } + } + + else if (bits == 5) { + for (int i = 0; i < N; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 32.0f; + x_thread[i + 2] = x[i + 2] / 4.0f; + x_thread[i + 3] = x[i + 3] / 128.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 2.0f; + x_thread[i + 6] = x[i + 6] / 64.0f; + x_thread[i + 7] = x[i + 7] / 8.0f; + } + } + + else if (bits == 6) { + for (int i = 0; i < N; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 64.0f; + x_thread[i + 2] = x[i + 2] / 16.0f; + x_thread[i + 3] = x[i + 3] / 4.0f; + } + } + + else if (bits == 8) { + for (int i = 0; i < N; i++) { + sum += x[i]; + x_thread[i] = x[i]; + } + } + + for (int i = N; i < values_per_thread; i++) { + x_thread[i] = 0; + } + + return sum; +} + +template +inline U qdot( + const device uint8_t* w, + const thread U* x_thread, + U scale, + U bias, + U sum) { + static_assert( + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + + U accum = 0; + + if (bits == 2) { + for (int i = 0; i < (values_per_thread / 4); i++) { + accum += + (x_thread[4 * i] * (w[i] & 0x03) + + x_thread[4 * i + 1] * (w[i] & 0x0c) + + x_thread[4 * i + 2] * (w[i] & 0x30) + + x_thread[4 * i + 3] * (w[i] & 0xc0)); + } + } + + else if (bits == 3) { + for (int i = 0; i < (values_per_thread / 8); i++) { + x_thread += 8 * i; + w += 3 * i; + + accum += (w[0] & 0x07) * x_thread[0]; + accum += (w[0] & 0x38) * x_thread[1]; + accum += (w[0] & 0xc0) * x_thread[2]; + accum += (w[1] & 0x01) * (x_thread[2] * 256.0f); + + accum += (w[1] & 0x0e) * x_thread[3]; + accum += (w[1] & 0x70) * x_thread[4]; + accum += (w[1] & 0x80) * x_thread[5]; + accum += (w[2] & 0x03) * (x_thread[5] * 256.0f); + + accum += (w[2] & 0x1c) * x_thread[6]; + accum += (w[2] & 0xe0) * x_thread[7]; + } + } + + else if (bits == 4) { + const device uint16_t* ws = (const device uint16_t*)w; + for (int i = 0; i < (values_per_thread / 4); i++) { + accum += + (x_thread[4 * i] * (ws[i] & 0x000f) + + x_thread[4 * i + 1] * (ws[i] & 0x00f0) + + x_thread[4 * i + 2] * (ws[i] & 0x0f00) + + x_thread[4 * i + 3] * (ws[i] & 0xf000)); + } + } + + else if (bits == 5) { + for (int i = 0; i < (values_per_thread / 8); i++) { + x_thread += 8 * i; + w += 5 * i; + + accum += (w[0] & 0x1f) * x_thread[0]; + accum += (w[0] & 0xe0) * x_thread[1]; + accum += (w[1] & 0x3) * (x_thread[1] * 256.0f); + accum += (w[1] & 0x7c) * x_thread[2]; + accum += (w[1] & 0x80) * x_thread[3]; + accum += (w[2] & 0xf) * (x_thread[3] * 256.0f); + accum += (w[2] & 0xf0) * x_thread[4]; + accum += (w[3] & 0x1) * (x_thread[4] * 256.0f); + accum += (w[3] & 0x3e) * x_thread[5]; + accum += (w[3] & 0xc0) * x_thread[6]; + accum += (w[4] & 0x7) * (x_thread[6] * 256.0f); + accum += (w[4] & 0xf8) * x_thread[7]; + } + } + + else if (bits == 6) { + for (int i = 0; i < (values_per_thread / 4); i++) { + x_thread += 4 * i; + w += 3 * i; + + accum += (w[0] & 0x3f) * x_thread[0]; + + accum += (w[0] & 0xc0) * x_thread[1]; + accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f); + + accum += (w[1] & 0xf0) * x_thread[2]; + accum += (w[2] & 0x03) * (x_thread[2] * 256.0f); + + accum += (w[2] & 0xfc) * x_thread[3]; + } + } + + else if (bits == 8) { + for (int i = 0; i < values_per_thread; i++) { + accum += x_thread[i] * w[i]; + } + } + + return scale * accum + sum * bias; +} + +template +inline U qdot_safe( + const device uint8_t* w, + const thread U* x_thread, + U scale, + U bias, + U sum, + int N) { + static_assert( + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + + U accum = 0; + + if (bits == 2) { + for (int i = 0; i < (N / 4); i++) { + accum += + (x_thread[4 * i] * (w[i] & 0x03) + + x_thread[4 * i + 1] * (w[i] & 0x0c) + + x_thread[4 * i + 2] * (w[i] & 0x30) + + x_thread[4 * i + 3] * (w[i] & 0xc0)); + } + } + + else if (bits == 3) { + for (int i = 0; i < (N / 8); i++) { + x_thread += 8 * i; + w += 3 * i; + + accum += (w[0] & 0x07) * x_thread[0]; + accum += (w[0] & 0x38) * x_thread[1]; + accum += (w[0] & 0xc0) * x_thread[2]; + accum += (w[1] & 0x01) * (x_thread[2] * 256.0f); + + accum += (w[1] & 0x0e) * x_thread[3]; + accum += (w[1] & 0x70) * x_thread[4]; + accum += (w[1] & 0x80) * x_thread[5]; + accum += (w[2] & 0x03) * (x_thread[5] * 256.0f); + + accum += (w[2] & 0x1c) * x_thread[6]; + accum += (w[2] & 0xe0) * x_thread[7]; + } + } + + else if (bits == 4) { + const device uint16_t* ws = (const device uint16_t*)w; + for (int i = 0; i < (N / 4); i++) { + accum += + (x_thread[4 * i] * (ws[i] & 0x000f) + + x_thread[4 * i + 1] * (ws[i] & 0x00f0) + + x_thread[4 * i + 2] * (ws[i] & 0x0f00) + + x_thread[4 * i + 3] * (ws[i] & 0xf000)); + } + } + + else if (bits == 5) { + for (int i = 0; i < (N / 8); i++) { + x_thread += 8 * i; + w += 5 * i; + + accum += (w[0] & 0x1f) * x_thread[0]; + accum += (w[0] & 0xe0) * x_thread[1]; + accum += (w[1] & 0x3) * (x_thread[1] * 256.0f); + accum += (w[1] & 0x7c) * x_thread[2]; + accum += (w[1] & 0x80) * x_thread[3]; + accum += (w[2] & 0xf) * (x_thread[3] * 256.0f); + accum += (w[2] & 0xf0) * x_thread[4]; + accum += (w[3] & 0x1) * (x_thread[4] * 256.0f); + accum += (w[3] & 0x3e) * x_thread[5]; + accum += (w[3] & 0xc0) * x_thread[6]; + accum += (w[4] & 0x7) * (x_thread[6] * 256.0f); + accum += (w[4] & 0xf8) * x_thread[7]; + } + } + + else if (bits == 6) { + for (int i = 0; i < (N / 4); i++) { + x_thread += 4 * i; + w += 3 * i; + + accum += (w[0] & 0x3f) * x_thread[0]; + + accum += (w[0] & 0xc0) * x_thread[1]; + accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f); + + accum += (w[1] & 0xf0) * x_thread[2]; + accum += (w[2] & 0x03) * (x_thread[2] * 256.0f); + + accum += (w[2] & 0xfc) * x_thread[3]; + } + } + + else if (bits == 8) { + for (int i = 0; i < N; i++) { + accum += x_thread[i] * w[i]; + } + } + + return scale * accum + sum * bias; +} + +template +inline void +qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) { + static_assert( + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + + if (bits == 2) { + U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f}; + for (int i = 0; i < (values_per_thread / 4); i++) { + result[4 * i] += x * (s[0] * (w[i] & 0x03) + bias); + result[4 * i + 1] += x * (s[1] * (w[i] & 0x0c) + bias); + result[4 * i + 2] += x * (s[2] * (w[i] & 0x30) + bias); + result[4 * i + 3] += x * (s[3] * (w[i] & 0xc0) + bias); + } + } + + else if (bits == 3) { + for (int i = 0; i < (values_per_thread / 8); i++) { + uint8_t w0 = w[3 * i]; + uint8_t w1 = w[3 * i + 1]; + uint8_t w2 = w[3 * i + 2]; + + result[8 * i] += x * ((w0 & 0x7) * scale + bias); + result[8 * i + 1] += x * (((w0 & 0x38) >> 3) * scale + bias); + result[8 * i + 2] += + x * ((((w0 & 0xc0) >> 6) + ((w1 & 0x1) << 2)) * scale + bias); + result[8 * i + 3] += x * (((w1 & 0xe) >> 1) * scale + bias); + result[8 * i + 4] += x * (((w1 & 0x70) >> 4) * scale + bias); + result[8 * i + 5] += + x * ((((w1 & 0x80) >> 7) + ((w2 & 0x3) << 1)) * scale + bias); + result[8 * i + 6] += x * (((w2 & 0x1c) >> 2) * scale + bias); + result[8 * i + 7] += x * (((w2 & 0xe0) >> 5) * scale + bias); + } + } + + else if (bits == 4) { + U s[2] = {scale, scale / 16.0f}; + for (int i = 0; i < (values_per_thread / 2); i++) { + result[2 * i] += x * (s[0] * (w[i] & 0x0f) + bias); + result[2 * i + 1] += x * (s[1] * (w[i] & 0xf0) + bias); + } + } + + else if (bits == 5) { + for (int i = 0; i < (values_per_thread / 8); i++) { + uint8_t w0 = w[5 * i]; + uint8_t w1 = w[5 * i + 1]; + uint8_t w2 = w[5 * i + 2]; + uint8_t w3 = w[5 * i + 3]; + uint8_t w4 = w[5 * i + 4]; + result[8 * i] += x * ((w0 & 0x1f) * scale + bias); + result[8 * i + 1] += + x * ((((w0 & 0xe0) >> 5) + ((w1 & 0x3) << 3)) * scale + bias); + result[8 * i + 2] += x * (((w1 & 0x7c) >> 2) * scale + bias); + result[8 * i + 3] += + x * ((((w1 & 0x80) >> 7) + ((w2 & 0xf) << 1)) * scale + bias); + result[8 * i + 4] += + x * ((((w2 & 0xf0) >> 4) + ((w3 & 0x1) << 4)) * scale + bias); + result[8 * i + 5] += x * (((w3 & 0x3e) >> 1) * scale + bias); + result[8 * i + 6] += + x * ((((w3 & 0xc0) >> 6) + ((w4 & 0x7) << 2)) * scale + bias); + result[8 * i + 7] += x * (((w4 & 0xf8) >> 3) * scale + bias); + } + } + + else if (bits == 6) { + for (int i = 0; i < (values_per_thread / 4); i++) { + uint8_t w0 = w[3 * i]; + uint8_t w1 = w[3 * i + 1]; + uint8_t w2 = w[3 * i + 2]; + + result[4 * i] += x * ((w0 & 0x3f) * scale + bias); + result[4 * i + 1] += + x * ((((w0 >> 6) & 0x03) + ((w1 & 0x0f) << 2)) * scale + bias); + result[4 * i + 2] += + x * ((((w1 >> 4) & 0x0f) + ((w2 & 0x03) << 4)) * scale + bias); + result[4 * i + 3] += x * (((w2 >> 2) & 0x3f) * scale + bias); + } + } + + else if (bits == 8) { + for (int i = 0; i < values_per_thread; i++) { + result[i] += x * (scale * w[i] + bias); + } + } +} + +template +inline void +dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) { + static_assert( + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + + if (bits == 2) { + U s[4] = { + scale, + scale / static_cast(4.0f), + scale / static_cast(16.0f), + scale / static_cast(64.0f)}; + for (int i = 0; i < (N / 4); i++) { + w_local[4 * i] = s[0] * (w[i] & 0x03) + bias; + w_local[4 * i + 1] = s[1] * (w[i] & 0x0c) + bias; + w_local[4 * i + 2] = s[2] * (w[i] & 0x30) + bias; + w_local[4 * i + 3] = s[3] * (w[i] & 0xc0) + bias; + } + } + + else if (bits == 3) { + for (int i = 0; i < (N / 8); i++) { + w_local += 8 * i; + w += 3 * i; + + w_local[0] = (w[0] & 0x7) * scale + bias; + w_local[1] = ((w[0] & 0x38) >> 3) * scale + bias; + w_local[2] = (((w[0] & 0xc0) >> 6) + ((w[1] & 0x1) << 2)) * scale + bias; + w_local[3] = ((w[1] & 0xe) >> 1) * scale + bias; + w_local[4] = ((w[1] & 0x70) >> 4) * scale + bias; + w_local[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias; + w_local[6] = ((w[2] & 0x1c) >> 2) * scale + bias; + w_local[7] = ((w[2] & 0xe0) >> 5) * scale + bias; + } + } + + else if (bits == 4) { + U s[2] = {scale, scale / static_cast(16.0f)}; + for (int i = 0; i < (N / 2); i++) { + w_local[2 * i] = s[0] * (w[i] & 0x0f) + bias; + w_local[2 * i + 1] = s[1] * (w[i] & 0xf0) + bias; + } + } + + else if (bits == 5) { + for (int i = 0; i < (N / 8); i++) { + w_local += 8 * i; + w += 5 * i; + + w_local[0] = (w[0] & 0x1f) * scale + bias; + w_local[1] = (((w[0] & 0xe0) >> 5) + ((w[1] & 0x3) << 3)) * scale + bias; + w_local[2] = ((w[1] & 0x7c) >> 2) * scale + bias; + w_local[3] = (((w[1] & 0x80) >> 7) + ((w[2] & 0xf) << 1)) * scale + bias; + w_local[4] = (((w[2] & 0xf0) >> 4) + ((w[3] & 0x1) << 4)) * scale + bias; + w_local[5] = ((w[3] & 0x3e) >> 1) * scale + bias; + w_local[6] = (((w[3] & 0xc0) >> 6) + ((w[4] & 0x7) << 2)) * scale + bias; + w_local[7] = ((w[4] & 0xf8) >> 3) * scale + bias; + } + } + + else if (bits == 6) { + for (int i = 0; i < (N / 4); i++) { + w_local += 4 * i; + w += 3 * i; + w_local[0] = (w[0] & 0x3f) * scale + bias; + w_local[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias; + w_local[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias; + w_local[3] = ((w[2] >> 2) & 0x3f) * scale + bias; + } + } + + else if (bits == 8) { + for (int i = 0; i < N; i++) { + w_local[i] = scale * w[i] + bias; + } + } +} + +template < + typename T, + short BROWS, + short BCOLS, + short dst_ld, + short reduction_dim, + short tgp_size, + short group_size, + short bits> +struct QuantizedBlockLoader { + static_assert( + BCOLS <= group_size, + "The group size should be larger than the columns"); + static_assert( + group_size % BCOLS == 0, + "The group size should be divisible by the columns"); + static_assert( + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + + MLX_MTL_CONST short pack_factor = get_pack_factor(); + MLX_MTL_CONST short bytes_per_pack = get_bytes_per_pack(); + MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor; + MLX_MTL_CONST short n_reads = + (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size; + MLX_MTL_CONST short group_steps = group_size / BCOLS; + + const int src_ld; + const int tile_stride; + short group_step_cnt; + const int group_stride; + + const short thread_idx; + const short bi; + const short bj; + + threadgroup T* dst; + const device uint8_t* src; + const device T* scales; + const device T* biases; + + QuantizedBlockLoader( + const device uint8_t* src_, + const device T* scales_, + const device T* biases_, + const int src_ld_, + threadgroup T* dst_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(src_ld_), + tile_stride( + reduction_dim ? BCOLS_PACKED * bytes_per_pack + : BROWS * src_ld * bytes_per_pack / pack_factor), + group_step_cnt(0), + group_stride(BROWS * src_ld / group_size), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(n_reads * thread_idx / BCOLS_PACKED), + bj((n_reads * thread_idx) % BCOLS_PACKED), + dst(dst_ + bi * dst_ld + bj * pack_factor), + src(src_ + bi * src_ld * bytes_per_pack / pack_factor + + bj * bytes_per_pack), + scales(scales_ + bi * src_ld / group_size), + biases(biases_ + bi * src_ld / group_size) {} + + void load_unsafe() const { + if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { + return; + } + + T scale = *scales; + T bias = *biases; + for (int i = 0; i < n_reads; i++) { + dequantize( + src + i * bytes_per_pack, scale, bias, dst + i * pack_factor); + } + } + + void load_safe(short2 src_tile_dim) const { + if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { + return; + } + + if (reduction_dim == 1 && bi >= src_tile_dim.x) { + for (int i = 0; i < n_reads * pack_factor; i++) { + dst[i] = T(0); + } + return; + } + + if (reduction_dim == 0 && bi >= src_tile_dim.y) { + for (int i = 0; i < n_reads * pack_factor; i++) { + dst[i] = T(0); + } + return; + } + + T scale = *scales; + T bias = *biases; + for (int i = 0; i < n_reads; i++) { + dequantize( + (device uint8_t*)(src + i * bytes_per_pack), + scale, + bias, + dst + i * pack_factor); + } + } + + void next() { + src += tile_stride; + if (reduction_dim == 1) { + if (group_steps > 1) { + group_step_cnt++; + if (group_step_cnt == group_steps) { + group_step_cnt = 0; + scales++; + biases++; + } + } else { + scales++; + biases++; + } + } else { + scales += group_stride; + biases += group_stride; + } + } +}; + +template +METAL_FUNC void qmv_quad_impl( + const device uint32_t* w, + const device T* scales, + const device T* biases, + const device T* x, + device T* y, + constant int& in_vec_size, + const constant int& out_vec_size, + uint3 tid [[threadgroup_position_in_grid]], + uint quad_gid [[quadgroup_index_in_threadgroup]], + uint quad_lid [[thread_index_in_quadgroup]]) { + constexpr int quads_per_simd = SIMD_SIZE / QUAD_SIZE; + constexpr int pack_factor = 32 / bits; + constexpr int values_per_thread = D / QUAD_SIZE; + constexpr int packs_per_thread = values_per_thread / pack_factor; + constexpr int scale_step_per_thread = group_size / values_per_thread; + constexpr int results_per_quadgroup = 8; + + typedef float U; + + thread U x_thread[values_per_thread]; + thread U result[results_per_quadgroup] = {0}; + + // Adjust positions + const int in_vec_size_w = in_vec_size / pack_factor; + const int in_vec_size_g = in_vec_size / group_size; + const int out_row = tid.y * quads_per_simd * results_per_quadgroup + quad_gid; + + w += out_row * in_vec_size_w + quad_lid * packs_per_thread; + scales += out_row * in_vec_size_g + quad_lid / scale_step_per_thread; + biases += out_row * in_vec_size_g + quad_lid / scale_step_per_thread; + x += tid.x * in_vec_size + quad_lid * values_per_thread; + y += tid.x * out_vec_size + out_row; + + U sum = load_vector(x, x_thread); + + for (int row = 0; row < results_per_quadgroup; row++) { + auto wl = (const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd); + const device T* sl = scales + row * in_vec_size_g * quads_per_simd; + const device T* bl = biases + row * in_vec_size_g * quads_per_simd; + + U s = sl[0]; + U b = bl[0]; + if (row * quads_per_simd + out_row < out_vec_size) { + result[row] += qdot(wl, x_thread, s, b, sum); + } + } + + for (int row = 0; row < results_per_quadgroup; row++) { + result[row] = quad_sum(result[row]); + if (quad_lid == 0 && row * quads_per_simd + out_row < out_vec_size) { + y[row * quads_per_simd] = static_cast(result[row]); + } + } +} + +template +METAL_FUNC void qmv_fast_impl( + const device uint32_t* w, + const device T* scales, + const device T* biases, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + constexpr int packs_per_thread = bits == 2 ? 1 : 2; + constexpr int num_simdgroups = 2; + constexpr int results_per_simdgroup = 4; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + constexpr int values_per_thread = pack_factor * packs_per_thread; + constexpr int block_size = values_per_thread * SIMD_SIZE; + constexpr int scale_step_per_thread = group_size / values_per_thread; + + const device uint8_t* ws = (const device uint8_t*)w; + + typedef float U; + + thread U x_thread[values_per_thread]; + thread U result[results_per_simdgroup] = {0}; + + // Adjust positions + const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; + const int in_vec_size_g = in_vec_size / group_size; + const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + + simd_gid * results_per_simdgroup; + + ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack; + scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; + biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; + x += tid.x * in_vec_size + simd_lid * values_per_thread; + y += tid.x * out_vec_size + out_row; + + for (int k = 0; k < in_vec_size; k += block_size) { + U sum = load_vector(x, x_thread); + + for (int row = 0; row < results_per_simdgroup; row++) { + auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); + const device T* sl = scales + row * in_vec_size_g; + const device T* bl = biases + row * in_vec_size_g; + + U s = sl[0]; + U b = bl[0]; + result[row] += qdot(wl, x_thread, s, b, sum); + } + + ws += block_size * bytes_per_pack / pack_factor; + scales += block_size / group_size; + biases += block_size / group_size; + x += block_size; + } + + for (int row = 0; row < results_per_simdgroup; row++) { + result[row] = simd_sum(result[row]); + if (simd_lid == 0) { + y[row] = static_cast(result[row]); + } + } +} + +template +METAL_FUNC void qmv_impl( + const device uint32_t* w, + const device T* scales, + const device T* biases, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + constexpr int num_simdgroups = 2; + constexpr int results_per_simdgroup = 4; + constexpr int packs_per_thread = 1; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + + constexpr int values_per_thread = pack_factor * packs_per_thread; + constexpr int block_size = values_per_thread * SIMD_SIZE; + constexpr int scale_step_per_thread = group_size / values_per_thread; + + const device uint8_t* ws = (const device uint8_t*)w; + + typedef float U; + + thread U x_thread[values_per_thread]; + thread U result[results_per_simdgroup] = {0}; + + // Adjust positions + const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; + const int in_vec_size_g = in_vec_size / group_size; + const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + + simd_gid * results_per_simdgroup; + const int used_out_row = min(out_vec_size - results_per_simdgroup, out_row); + + if (out_row >= out_vec_size) { + return; + } + + // In this case we need to properly guard all our reads because there isn't + // even 1 tile in the matrix + if (out_vec_size < (num_simdgroups * results_per_simdgroup)) { + ws += + out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack; + scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; + biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; + x += tid.x * in_vec_size + simd_lid * values_per_thread; + y += tid.x * out_vec_size + out_row; + + int k = 0; + for (; k < in_vec_size - block_size; k += block_size) { + U sum = load_vector(x, x_thread); + + for (int row = 0; + row < results_per_simdgroup && out_row + row < out_vec_size; + row++) { + auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); + const device T* sl = scales + row * in_vec_size_g; + const device T* bl = biases + row * in_vec_size_g; + + U s = sl[0]; + U b = bl[0]; + result[row] += + qdot(wl, x_thread, s, b, sum); + } + + ws += block_size * bytes_per_pack / pack_factor; + scales += block_size / group_size; + biases += block_size / group_size; + x += block_size; + } + const int remaining = clamp( + static_cast(in_vec_size - k - simd_lid * values_per_thread), + 0, + values_per_thread); + if (remaining > 0) { + U sum = load_vector_safe( + x, x_thread, remaining); + + for (int row = 0; + row < results_per_simdgroup && out_row + row < out_vec_size; + row++) { + auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); + const device T* sl = scales + row * in_vec_size_g; + const device T* bl = biases + row * in_vec_size_g; + + U s = sl[0]; + U b = bl[0]; + result[row] += qdot_safe( + wl, x_thread, s, b, sum, remaining); + } + } + + for (int row = 0; + row < results_per_simdgroup && out_row + row < out_vec_size; + row++) { + result[row] = simd_sum(result[row]); + if (simd_lid == 0) { + y[row] = static_cast(result[row]); + } + } + } + + // In this case the last tile is moved back to redo some output values + else { + ws += used_out_row * in_vec_size_w + + simd_lid * packs_per_thread * bytes_per_pack; + scales += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread; + biases += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread; + x += tid.x * in_vec_size + simd_lid * values_per_thread; + y += tid.x * out_vec_size + used_out_row; + + int k = 0; + for (; k < in_vec_size - block_size; k += block_size) { + U sum = load_vector(x, x_thread); + + for (int row = 0; row < results_per_simdgroup; row++) { + auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); + const device T* sl = scales + row * in_vec_size_g; + const device T* bl = biases + row * in_vec_size_g; + + U s = sl[0]; + U b = bl[0]; + result[row] += + qdot(wl, x_thread, s, b, sum); + } + + ws += block_size * bytes_per_pack / pack_factor; + scales += block_size / group_size; + biases += block_size / group_size; + x += block_size; + } + const int remaining = clamp( + static_cast(in_vec_size - k - simd_lid * values_per_thread), + 0, + values_per_thread); + if (remaining > 0) { + U sum = load_vector_safe( + x, x_thread, remaining); + + for (int row = 0; row < results_per_simdgroup; row++) { + auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); + const device T* sl = scales + row * in_vec_size_g; + const device T* bl = biases + row * in_vec_size_g; + + U s = sl[0]; + U b = bl[0]; + result[row] += qdot_safe( + wl, x_thread, s, b, sum, remaining); + } + } + for (int row = 0; row < results_per_simdgroup; row++) { + result[row] = simd_sum(result[row]); + if (simd_lid == 0) { + y[row] = static_cast(result[row]); + } + } + } +} + +template +METAL_FUNC void qvm_impl( + const device uint32_t* w, + const device T* scales, + const device T* biases, + const device T* x, + device T* y, + const int in_vec_size, + const int out_vec_size, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; + constexpr int num_simdgroups = 2; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + + constexpr int tn = 32 / pack_factor; + constexpr int block_size = SIMD_SIZE; + + using W_T = + typename ConditionalType::type; + const device W_T* ws = (const device W_T*)w; + + typedef float U; + typedef struct { + W_T wi[tn * bytes_per_pack]; + } vec_w; + + thread vec_w w_local; + thread U result[tn * pack_factor] = {0}; + thread U scale = 1; + thread U bias = 0; + thread U x_local = 0; + + // Adjust positions + const int out_vec_size_w = out_vec_size * bytes_per_pack / pack_factor; + const int out_vec_size_g = out_vec_size / group_size; + int out_col = pack_factor * tn * (tid.y * num_simdgroups + simd_gid); + ws += out_col * bytes_per_pack / pack_factor + simd_lid * out_vec_size_w; + scales += out_col / group_size + simd_lid * out_vec_size_g; + biases += out_col / group_size + simd_lid * out_vec_size_g; + x += tid.x * in_vec_size + simd_lid; + y += tid.x * out_vec_size + out_col; + + if (out_col >= out_vec_size) { + return; + } + + // Loop over in_vec in blocks of block_size + int remaining = in_vec_size % block_size; + if (remaining == 0) { + for (int i = 0; i < in_vec_size; i += block_size) { + x_local = *x; + scale = *scales; + bias = *biases; + w_local = *((device vec_w*)ws); + qouter( + (thread uint8_t*)&w_local, x_local, scale, bias, result); + + x += block_size; + scales += block_size * out_vec_size_g; + biases += block_size * out_vec_size_g; + ws += block_size * out_vec_size_w; + } + } else { + for (int i = block_size; i < in_vec_size; i += block_size) { + x_local = *x; + scale = *scales; + bias = *biases; + w_local = *((device vec_w*)ws); + + qouter( + (thread uint8_t*)&w_local, x_local, scale, bias, result); + + x += block_size; + scales += block_size * out_vec_size_g; + biases += block_size * out_vec_size_g; + ws += block_size * out_vec_size_w; + } + if (static_cast(simd_lid) < remaining) { + x_local = *x; + scale = *scales; + bias = *biases; + w_local = *((device vec_w*)ws); + } else { + x_local = 0; + scale = 0; + bias = 0; + } + qouter( + (thread uint8_t*)&w_local, x_local, scale, bias, result); + } + +// Accumulate in the simdgroup +#pragma clang loop unroll(full) + for (int k = 0; k < tn * pack_factor; k++) { + result[k] = simd_sum(result[k]); + } + + // Store the result + if (simd_lid == 0) { +#pragma clang loop unroll(full) + for (int k = 0; k < tn * pack_factor; k++) { + y[k] = static_cast(result[k]); + } + } +} + +template < + typename T, + const int group_size, + const int bits, + const bool aligned_N, + const int BM = 32, + const int BK = 32, + const int BN = 32> +METAL_FUNC void qmm_t_impl( + const device uint32_t* w, + const device T* scales, + const device T* biases, + const device T* x, + device T* y, + threadgroup T* Xs, + threadgroup T* Ws, + const constant int& K, + const constant int& N, + const constant int& M, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); + static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); + + (void)lid; + + constexpr int WM = 2; + constexpr int WN = 2; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + + // Instantiate the appropriate BlockMMA and Loader + using mma_t = mlx::steel:: + BlockMMA; + using loader_x_t = + mlx::steel::BlockLoader; + using loader_w_t = QuantizedBlockLoader< + T, + BN, + BK, + BK_padded, + 1, + WM * WN * SIMD_SIZE, + group_size, + bits>; + + // Set the block + const int K_w = K * bytes_per_pack / pack_factor; + const int K_g = K / group_size; + const int y_row = tid.y * BM; + const int y_col = tid.x * BN; + + auto wl = (const device uint8_t*)w; + + x += y_row * static_cast(K); + wl += y_col * K_w; + scales += y_col * K_g; + biases += y_col * K_g; + y += y_row * static_cast(N) + y_col; + + // Make the x loader and mma operation + const short num_els = min(BM, M - y_row); + const short num_outs = min(BN, N - y_col); + loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid); + loader_w_t loader_w(wl, scales, biases, K, Ws, simd_gid, simd_lid); + mma_t mma_op(simd_gid, simd_lid); + + if (num_els < BM) { + if (!aligned_N && num_outs < BN) { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(BK, num_els)); + loader_w.load_safe(short2(BK, num_outs)); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } else { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(BK, num_els)); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } + } else { + if (!aligned_N && num_outs < BN) { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_unsafe(); + loader_w.load_safe(short2(BK, num_outs)); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } else { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_unsafe(); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } + } + + // Store results to device memory + threadgroup_barrier(mem_flags::mem_threadgroup); + if (num_els < BM || num_outs < BN) { + mma_op.store_result_safe(y, N, short2(num_outs, num_els)); + } else { + mma_op.store_result(y, N); + } +} + +template < + typename T, + const int group_size, + const int bits, + const int BM = 32, + const int BK = 32, + const int BN = 32> +METAL_FUNC void qmm_n_impl( + const device uint32_t* w, + const device T* scales, + const device T* biases, + const device T* x, + device T* y, + threadgroup T* Xs, + threadgroup T* Ws, + const constant int& K, + const constant int& N, + const constant int& M, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); + static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); + + (void)lid; + + constexpr int WM = 2; + constexpr int WN = 2; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + + // Instantiate the appropriate BlockMMA and Loader + using mma_t = mlx::steel:: + BlockMMA; + using loader_x_t = mlx::steel:: + BlockLoader; + using loader_w_t = QuantizedBlockLoader< + T, + BK, + BN, + BN_padded, + 0, + WM * WN * SIMD_SIZE, + group_size, + bits>; + + auto wl = (const device uint8_t*)w; + + // Set the block + const int y_row = tid.y * BM; + const int y_col = tid.x * BN; + x += y_row * static_cast(K); + wl += y_col * bytes_per_pack / pack_factor; + scales += y_col / group_size; + biases += y_col / group_size; + y += y_row * static_cast(N) + y_col; + + // Make the x loader and mma operation + const short num_els = min(BM, M - y_row); + loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid); + loader_w_t loader_w(wl, scales, biases, N, Ws, simd_gid, simd_lid); + mma_t mma_op(simd_gid, simd_lid); + + if (num_els < BM) { + if ((K % BK) != 0) { + const int k_blocks = K / BK; + for (int k = 0; k < k_blocks; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(BK, num_els)); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + const short num_k = K - k_blocks * BK; + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(num_k, num_els)); + loader_w.load_safe(short2(BN, num_k)); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + } else { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(BK, num_els)); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } + } else { + if ((K % BK) != 0) { + const int k_blocks = K / BK; + for (int k = 0; k < k_blocks; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_unsafe(); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + const short num_k = K - k_blocks * BK; + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(num_k, BM)); + loader_w.load_safe(short2(BN, num_k)); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + } else { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_unsafe(); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } + } + + // Store results to device memory + threadgroup_barrier(mem_flags::mem_threadgroup); + if (num_els < BM) { + mma_op.store_result_safe(y, N, short2(BN, num_els)); + } else { + mma_op.store_result(y, N); + } +} + +template +METAL_FUNC void adjust_matrix_offsets( + const device T*& x, + const device uint32_t*& w, + const device T*& scales, + const device T*& biases, + device T*& y, + int output_stride, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + const constant int64_t* b_strides, + uint3 tid [[threadgroup_position_in_grid]]) { + // Set the input/output matrices + uint32_t x_idx = tid.z; + uint32_t w_idx = tid.z; + if (x_batch_ndims == 1) { + x += x_idx * x_strides[0]; + } else { + x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims); + } + if (w_batch_ndims == 1) { + w += w_idx * w_strides[0]; + scales += w_idx * s_strides[0]; + biases += w_idx * b_strides[0]; + } else { + ulong3 idx = elem_to_loc_broadcast( + w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims); + w += idx.x; + scales += idx.y; + biases += idx.z; + } + y += tid.z * output_stride; +} + +template +METAL_FUNC void adjust_matrix_offsets( + const device T*& x, + const device uint32_t*& w, + const device T*& scales, + const device T*& biases, + const device uint32_t* lhs_indices, + const device uint32_t* rhs_indices, + device T*& y, + int output_stride, + const constant int& batch_ndims, + const constant int* batch_shape, + const constant int64_t* lhs_strides, + const constant int64_t* rhs_strides, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + const constant int64_t* b_strides, + uint3 tid [[threadgroup_position_in_grid]]) { + // Set the input/output matrices + uint32_t x_idx; + uint32_t w_idx; + if (batch_ndims == 1) { + x_idx = lhs_indices[tid.z * lhs_strides[0]]; + w_idx = rhs_indices[tid.z * rhs_strides[0]]; + } else { + ulong2 idx = elem_to_loc_broadcast( + tid.z, batch_shape, lhs_strides, rhs_strides, batch_ndims); + x_idx = lhs_indices[idx.x]; + w_idx = rhs_indices[idx.y]; + } + if (x_batch_ndims == 1) { + x += x_idx * x_strides[0]; + } else { + x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims); + } + if (w_batch_ndims == 1) { + w += w_idx * w_strides[0]; + scales += w_idx * s_strides[0]; + biases += w_idx * b_strides[0]; + } else { + ulong3 idx = elem_to_loc_broadcast( + w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims); + w += idx.x; + scales += idx.y; + biases += idx.z; + } + y += tid.z * output_stride; +} + +template +[[kernel]] void affine_qmv_quad( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + device T* y [[buffer(4)]], + const constant int& in_vec_size [[buffer(5)]], + const constant int& out_vec_size [[buffer(6)]], + const constant int& x_batch_ndims [[buffer(7)]], + const constant int* x_shape [[buffer(8)]], + const constant int64_t* x_strides [[buffer(9)]], + const constant int& w_batch_ndims [[buffer(10)]], + const constant int* w_shape [[buffer(11)]], + const constant int64_t* w_strides [[buffer(12)]], + const constant int64_t* s_strides [[buffer(13)]], + const constant int64_t* b_strides [[buffer(14)]], + uint3 tid [[threadgroup_position_in_grid]], + uint quad_gid [[quadgroup_index_in_threadgroup]], + uint quad_lid [[thread_index_in_quadgroup]]) { + if (batched) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + biases, + y, + out_vec_size * M, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + } + qmv_quad_impl( + w, + scales, + biases, + x, + y, + in_vec_size, + out_vec_size, + tid, + quad_gid, + quad_lid); +} + +template +[[kernel]] void affine_qmv_fast( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + device T* y [[buffer(4)]], + const constant int& in_vec_size [[buffer(5)]], + const constant int& out_vec_size [[buffer(6)]], + const constant int& x_batch_ndims [[buffer(7)]], + const constant int* x_shape [[buffer(8)]], + const constant int64_t* x_strides [[buffer(9)]], + const constant int& w_batch_ndims [[buffer(10)]], + const constant int* w_shape [[buffer(11)]], + const constant int64_t* w_strides [[buffer(12)]], + const constant int64_t* s_strides [[buffer(13)]], + const constant int64_t* b_strides [[buffer(14)]], + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if (batched) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + biases, + y, + out_vec_size * M, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + } + qmv_fast_impl( + w, + scales, + biases, + x, + y, + in_vec_size, + out_vec_size, + tid, + simd_gid, + simd_lid); +} + +template +[[kernel]] void affine_qmv( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + device T* y [[buffer(4)]], + const constant int& in_vec_size [[buffer(5)]], + const constant int& out_vec_size [[buffer(6)]], + const constant int& x_batch_ndims [[buffer(7)]], + const constant int* x_shape [[buffer(8)]], + const constant int64_t* x_strides [[buffer(9)]], + const constant int& w_batch_ndims [[buffer(10)]], + const constant int* w_shape [[buffer(11)]], + const constant int64_t* w_strides [[buffer(12)]], + const constant int64_t* s_strides [[buffer(13)]], + const constant int64_t* b_strides [[buffer(14)]], + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if (batched) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + biases, + y, + out_vec_size * M, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + } + qmv_impl( + w, + scales, + biases, + x, + y, + in_vec_size, + out_vec_size, + tid, + simd_gid, + simd_lid); +} + +template +[[kernel]] void affine_qvm( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + device T* y [[buffer(4)]], + const constant int& in_vec_size [[buffer(5)]], + const constant int& out_vec_size [[buffer(6)]], + const constant int& x_batch_ndims [[buffer(7)]], + const constant int* x_shape [[buffer(8)]], + const constant int64_t* x_strides [[buffer(9)]], + const constant int& w_batch_ndims [[buffer(10)]], + const constant int* w_shape [[buffer(11)]], + const constant int64_t* w_strides [[buffer(12)]], + const constant int64_t* s_strides [[buffer(13)]], + const constant int64_t* b_strides [[buffer(14)]], + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if (batched) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + biases, + y, + out_vec_size * M, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + } + qvm_impl( + w, + scales, + biases, + x, + y, + in_vec_size, + out_vec_size, + tid, + simd_gid, + simd_lid); +} + +template +[[kernel]] void affine_qvm_split_k( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + device T* y [[buffer(4)]], + const constant int& in_vec_size [[buffer(5)]], + const constant int& out_vec_size [[buffer(6)]], + const constant int& x_batch_ndims [[buffer(7)]], + const constant int* x_shape [[buffer(8)]], + const constant int64_t* x_strides [[buffer(9)]], + const constant int& w_batch_ndims [[buffer(10)]], + const constant int* w_shape [[buffer(11)]], + const constant int64_t* w_strides [[buffer(12)]], + const constant int64_t* s_strides [[buffer(13)]], + const constant int64_t* b_strides [[buffer(14)]], + const constant int& final_block_size [[buffer(15)]], + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + biases, + y, + out_vec_size * M, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + + // When (in_vec_size % split_k != 0) the final block needs to be smaller + int in_vec_size_adj = + tid.z % split_k == split_k - 1 ? final_block_size : in_vec_size; + + qvm_impl( + w, + scales, + biases, + x, + y, + in_vec_size_adj, + out_vec_size, + tid, + simd_gid, + simd_lid); +} + +template < + typename T, + const int group_size, + const int bits, + const bool aligned_N, + const bool batched, + const int BM = 32, + const int BK = 32, + const int BN = 32> +[[kernel]] void affine_qmm_t( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + device T* y [[buffer(4)]], + const constant int& K [[buffer(5)]], + const constant int& N [[buffer(6)]], + const constant int& M [[buffer(7)]], + const constant int& x_batch_ndims [[buffer(8)]], + const constant int* x_shape [[buffer(9)]], + const constant int64_t* x_strides [[buffer(10)]], + const constant int& w_batch_ndims [[buffer(11)]], + const constant int* w_shape [[buffer(12)]], + const constant int64_t* w_strides [[buffer(13)]], + const constant int64_t* s_strides [[buffer(14)]], + const constant int64_t* b_strides [[buffer(15)]], + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BN * BK_padded]; + + if (batched) { + adjust_matrix_offsets( + x, + w, + scales, + biases, + y, + M * N, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + } + qmm_t_impl( + w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); +} + +template < + typename T, + const int group_size, + const int bits, + const bool batched, + const int BM = 32, + const int BK = 32, + const int BN = 32> +[[kernel]] void affine_qmm_n( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + device T* y [[buffer(4)]], + const constant int& K [[buffer(5)]], + const constant int& N [[buffer(6)]], + const constant int& M [[buffer(7)]], + const constant int& x_batch_ndims [[buffer(8)]], + const constant int* x_shape [[buffer(9)]], + const constant int64_t* x_strides [[buffer(10)]], + const constant int& w_batch_ndims [[buffer(11)]], + const constant int* w_shape [[buffer(12)]], + const constant int64_t* w_strides [[buffer(13)]], + const constant int64_t* s_strides [[buffer(14)]], + const constant int64_t* b_strides [[buffer(15)]], + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BK * BN_padded]; + + if (batched) { + adjust_matrix_offsets( + x, + w, + scales, + biases, + y, + M * N, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + } + + qmm_n_impl( + w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); +} + +template +[[kernel]] void affine_gather_qmv_fast( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + const device uint32_t* lhs_indices [[buffer(4)]], + const device uint32_t* rhs_indices [[buffer(5)]], + device T* y [[buffer(6)]], + const constant int& in_vec_size [[buffer(7)]], + const constant int& out_vec_size [[buffer(8)]], + const constant int& x_batch_ndims [[buffer(9)]], + const constant int* x_shape [[buffer(10)]], + const constant int64_t* x_strides [[buffer(11)]], + const constant int& w_batch_ndims [[buffer(12)]], + const constant int* w_shape [[buffer(13)]], + const constant int64_t* w_strides [[buffer(14)]], + const constant int64_t* s_strides [[buffer(15)]], + const constant int64_t* b_strides [[buffer(16)]], + const constant int& batch_ndims [[buffer(17)]], + const constant int* batch_shape [[buffer(18)]], + const constant int64_t* lhs_strides [[buffer(19)]], + const constant int64_t* rhs_strides [[buffer(20)]], + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + biases, + lhs_indices, + rhs_indices, + y, + out_vec_size * M, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + qmv_fast_impl( + w, + scales, + biases, + x, + y, + in_vec_size, + out_vec_size, + tid, + simd_gid, + simd_lid); +} + +template +[[kernel]] void affine_gather_qmv( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + const device uint32_t* lhs_indices [[buffer(4)]], + const device uint32_t* rhs_indices [[buffer(5)]], + device T* y [[buffer(6)]], + const constant int& in_vec_size [[buffer(7)]], + const constant int& out_vec_size [[buffer(8)]], + const constant int& x_batch_ndims [[buffer(9)]], + const constant int* x_shape [[buffer(10)]], + const constant int64_t* x_strides [[buffer(11)]], + const constant int& w_batch_ndims [[buffer(12)]], + const constant int* w_shape [[buffer(13)]], + const constant int64_t* w_strides [[buffer(14)]], + const constant int64_t* s_strides [[buffer(15)]], + const constant int64_t* b_strides [[buffer(16)]], + const constant int& batch_ndims [[buffer(17)]], + const constant int* batch_shape [[buffer(18)]], + const constant int64_t* lhs_strides [[buffer(19)]], + const constant int64_t* rhs_strides [[buffer(20)]], + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + biases, + lhs_indices, + rhs_indices, + y, + out_vec_size * M, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + qmv_impl( + w, + scales, + biases, + x, + y, + in_vec_size, + out_vec_size, + tid, + simd_gid, + simd_lid); +} + +template +[[kernel]] void affine_gather_qvm( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + const device uint32_t* lhs_indices [[buffer(4)]], + const device uint32_t* rhs_indices [[buffer(5)]], + device T* y [[buffer(6)]], + const constant int& in_vec_size [[buffer(7)]], + const constant int& out_vec_size [[buffer(8)]], + const constant int& x_batch_ndims [[buffer(9)]], + const constant int* x_shape [[buffer(10)]], + const constant int64_t* x_strides [[buffer(11)]], + const constant int& w_batch_ndims [[buffer(12)]], + const constant int* w_shape [[buffer(13)]], + const constant int64_t* w_strides [[buffer(14)]], + const constant int64_t* s_strides [[buffer(15)]], + const constant int64_t* b_strides [[buffer(16)]], + const constant int& batch_ndims [[buffer(17)]], + const constant int* batch_shape [[buffer(18)]], + const constant int64_t* lhs_strides [[buffer(19)]], + const constant int64_t* rhs_strides [[buffer(20)]], + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + biases, + lhs_indices, + rhs_indices, + y, + out_vec_size * M, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + qvm_impl( + w, + scales, + biases, + x, + y, + in_vec_size, + out_vec_size, + tid, + simd_gid, + simd_lid); +} + +template < + typename T, + const int group_size, + const int bits, + const bool aligned_N, + const int BM = 32, + const int BK = 32, + const int BN = 32> +[[kernel]] void affine_gather_qmm_t( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + const device uint32_t* lhs_indices [[buffer(4)]], + const device uint32_t* rhs_indices [[buffer(5)]], + device T* y [[buffer(6)]], + const constant int& K [[buffer(7)]], + const constant int& N [[buffer(8)]], + const constant int& M [[buffer(9)]], + const constant int& x_batch_ndims [[buffer(10)]], + const constant int* x_shape [[buffer(11)]], + const constant int64_t* x_strides [[buffer(12)]], + const constant int& w_batch_ndims [[buffer(13)]], + const constant int* w_shape [[buffer(14)]], + const constant int64_t* w_strides [[buffer(15)]], + const constant int64_t* s_strides [[buffer(16)]], + const constant int64_t* b_strides [[buffer(17)]], + const constant int& batch_ndims [[buffer(18)]], + const constant int* batch_shape [[buffer(19)]], + const constant int64_t* lhs_strides [[buffer(20)]], + const constant int64_t* rhs_strides [[buffer(21)]], + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BN * BK_padded]; + + adjust_matrix_offsets( + x, + w, + scales, + biases, + lhs_indices, + rhs_indices, + y, + M * N, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + qmm_t_impl( + w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); +} + +template < + typename T, + const int group_size, + const int bits, + const int BM = 32, + const int BK = 32, + const int BN = 32> +[[kernel]] void affine_gather_qmm_n( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + const device uint32_t* lhs_indices [[buffer(4)]], + const device uint32_t* rhs_indices [[buffer(5)]], + device T* y [[buffer(6)]], + const constant int& K [[buffer(7)]], + const constant int& N [[buffer(8)]], + const constant int& M [[buffer(9)]], + const constant int& x_batch_ndims [[buffer(10)]], + const constant int* x_shape [[buffer(11)]], + const constant int64_t* x_strides [[buffer(12)]], + const constant int& w_batch_ndims [[buffer(13)]], + const constant int* w_shape [[buffer(14)]], + const constant int64_t* w_strides [[buffer(15)]], + const constant int64_t* s_strides [[buffer(16)]], + const constant int64_t* b_strides [[buffer(17)]], + const constant int& batch_ndims [[buffer(18)]], + const constant int* batch_shape [[buffer(19)]], + const constant int64_t* lhs_strides [[buffer(20)]], + const constant int64_t* rhs_strides [[buffer(21)]], + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BK * BN_padded]; + + adjust_matrix_offsets( + x, + w, + scales, + biases, + lhs_indices, + rhs_indices, + y, + M * N, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + qmm_n_impl( + w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); +} + +template < + typename T, + int group_size, + int bits, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose> +[[kernel]] void affine_gather_qmm_rhs( + const device T* x [[buffer(0)]], + const device uint32_t* w [[buffer(1)]], + const device T* scales [[buffer(2)]], + const device T* biases [[buffer(3)]], + const device uint32_t* indices [[buffer(4)]], + device T* y [[buffer(5)]], + const constant int& M [[buffer(6)]], + const constant int& N [[buffer(7)]], + const constant int& K [[buffer(8)]], + uint3 tid [[threadgroup_position_in_grid]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) { + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + + using mma_t = mlx::steel::BlockMMA< + T, + T, + BM, + BN, + BK, + WM, + WN, + false, + transpose, + BK_padded, + transpose ? BK_padded : BN_padded>; + using loader_x_t = + mlx::steel::BlockLoader; + using loader_w_t = QuantizedBlockLoader< + T, + transpose ? BN : BK, + transpose ? BK : BN, + transpose ? BK_padded : BN_padded, + transpose, + WM * WN * SIMD_SIZE, + group_size, + bits>; + + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[transpose ? BN * BK_padded : BK * BN_padded]; + + // Compute the block + const int K_w = K * bytes_per_pack / pack_factor; + const int K_g = K / group_size; + const int N_w = N * bytes_per_pack / pack_factor; + const int N_g = N / group_size; + const int K_it = K / BK; + const size_t stride_w = transpose ? N * K_w : K * N_w; + const size_t stride_s = transpose ? N * K_g : K * N_g; + const int y_row = tid.y * BM; + const int y_col = tid.x * BN; + const size_t y_row_long = size_t(y_row); + const size_t y_col_long = size_t(y_col); + + // Prepare threadgroup bounds + const short tgp_bm = align_M ? BM : short(min(BM, M - y_row)); + const short tgp_bn = align_N ? BN : short(min(BN, N - y_col)); + + // Calculate the final tiles in the case that K is not aligned + const int k_remain = K - K_it * BK; + const short2 tile_x = short2(k_remain, tgp_bm); + const short2 tile_w = + transpose ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + + // Move x and output to the correct block + auto wl = (const device uint8_t*)w; + x += y_row_long * K; + y += y_row_long * N + y_col_long; + wl += transpose ? y_col_long * K_w : y_col * bytes_per_pack / pack_factor; + scales += transpose ? y_col_long * K_g : y_col / group_size; + biases += transpose ? y_col_long * K_g : y_col / group_size; + + // Do as many matmuls as necessary + uint32_t index; + short offset; + uint32_t index_next = indices[y_row]; + short offset_next = 0; + int n = 0; + while (n < tgp_bm) { + n++; + offset = offset_next; + index = index_next; + offset_next = tgp_bm; + for (; n < tgp_bm; n++) { + if (indices[y_row + n] != index) { + offset_next = n; + index_next = indices[y_row + n]; + break; + } + } + threadgroup_barrier(mem_flags::mem_none); + + // Prepare threadgroup mma operation + thread mma_t mma_op(simd_group_id, simd_lane_id); + + // Prepare threadgroup loading operations + thread loader_x_t loader_x(x, K, Xs, simd_group_id, simd_lane_id); + thread loader_w_t loader_w( + wl + index * stride_w, + scales + index * stride_s, + biases + index * stride_s, + transpose ? K : N, + Ws, + simd_group_id, + simd_lane_id); + + // Matrices are all aligned check nothing + if (align_M && align_N) { + gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it); + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + gemm_loop_finalize(Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); + } + + // Store results to device memory + if (offset_next - offset == BM) { + mma_op.store_result(y, N); + } else { + mma_op.store_result_slice( + y, N, short2(0, offset), short2(BN, offset_next)); + } + } else { + // Tile aligned so check outside of the hot loop + if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) { + gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it); + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + gemm_loop_finalize( + Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); + } + + // Store results to device memory + if (offset_next - offset == BM) { + mma_op.store_result(y, N); + } else { + mma_op.store_result_slice( + y, N, short2(0, offset), short2(BN, offset_next)); + } + } + + // Tile partially aligned check rows + else if (align_N || tgp_bn == BN) { + gemm_loop_unaligned( + Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK); + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + gemm_loop_finalize( + Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); + } + mma_op.store_result_slice( + y, N, short2(0, offset), short2(BN, offset_next)); + } + + // Tile partially aligned check cols + else if (align_M || tgp_bm == BM) { + gemm_loop_unaligned( + Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK); + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + gemm_loop_finalize( + Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); + } + mma_op.store_result_slice( + y, N, short2(0, offset), short2(tgp_bn, offset_next)); + } + + // Nothing aligned so check both rows and cols + else { + gemm_loop_unaligned( + Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK); + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + gemm_loop_finalize( + Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); + } + mma_op.store_result_slice( + y, N, short2(0, offset), short2(tgp_bn, offset_next)); + } + } + } +} + +template +[[kernel]] void affine_quantize( + const device T* w [[buffer(0)]], + device uint8_t* out [[buffer(1)]], + device T* scales [[buffer(2)]], + device T* biases [[buffer(3)]], + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + constexpr float eps = 1e-7; + constexpr int simd_size = 32; + constexpr float n_bins = (1 << bits) - 1; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + constexpr int values_per_reduce = group_size / simd_size; + constexpr int writes_per_reduce = pack_factor / values_per_reduce; + constexpr int writes_per_pack = + writes_per_reduce > 1 ? 1 : values_per_reduce / pack_factor; + constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; + + static_assert( + group_size % simd_size == 0, + "Group size must be divisible by simd size."); + + size_t offset = index.x + grid_dim.x * size_t(index.y); + size_t in_index = offset * values_per_reduce; + size_t out_index = power_of_2_bits + ? offset * writes_per_pack + : offset * bytes_per_pack / writes_per_reduce; + + float w_thread[values_per_reduce]; + float w_min = Limits::max; + float w_max = 0; + +#pragma clang loop unroll(full) + for (int i = 0; i < values_per_reduce; i++) { + float val = w[in_index + i]; + w_thread[i] = val; + w_min = min(w_min, val); + w_max = max(w_max, val); + } + + w_min = simd_min(w_min); + w_max = simd_max(w_max); + + float scale = max((w_max - w_min) / n_bins, eps); + bool side = abs(w_min) > abs(w_max); + scale = side ? scale : -scale; + float edge = side ? w_min : w_max; + float q0 = round(edge / scale); + bool at_zero = q0 == 0.0f; + scale = at_zero ? scale : edge / q0; + float bias = at_zero ? 0 : edge; + + // Write out the scales and biases + size_t gindex = in_index / group_size; + if (in_index % group_size == 0) { + scales[gindex] = static_cast(scale); + biases[gindex] = static_cast(bias); + } + + using OutType = metal::conditional_t; + OutType output = 0; + +#pragma clang loop unroll(full) + for (int i = 0; i < values_per_reduce; i++) { + uint8_t val = min(round((w_thread[i] - bias) / scale), n_bins); + if (bits == 8) { + output = val; + } else { + output |= val << (bits * (i % pack_factor)); + } + + if (pack_factor < values_per_reduce && i % pack_factor == pack_factor - 1) { + out[out_index + i / pack_factor] = output; + output = 0; + } else { +#pragma clang loop unroll(full) + for (int j = 1; j < writes_per_reduce; j++) { + uint8_t sval = simd_shuffle_down(val, j); + output |= static_cast(sval) + << (bits * (j * values_per_reduce + i)); + } + } + } + if (bits == 3 || bits == 6) { + if (in_index % pack_factor == 0 && out_index % bytes_per_pack == 0) { + out[out_index] = output & 0xff; + out[out_index + 1] = (output & 0xff00) >> 8; + out[out_index + 2] = (output & 0xff0000) >> 16; + } + } else if (bits == 5) { + if (in_index % pack_factor == 0 && out_index % bytes_per_pack == 0) { + out[out_index] = output & 0xff; + out[out_index + 1] = (output & 0xff00) >> 8; + out[out_index + 2] = (output & 0xff0000) >> 16; + out[out_index + 3] = (output & 0xff000000) >> 24; + out[out_index + 4] = (output & 0xff00000000) >> 32; + } + } else { + if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) { + out[out_index / writes_per_reduce] = output; + } + } +} + +template +[[kernel]] void affine_dequantize( + const device uint8_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + device T* out [[buffer(3)]], + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + + size_t offset = index.x + grid_dim.x * size_t(index.y); + size_t oindex = offset * pack_factor; + size_t gindex = oindex / group_size; + T scale = scales[gindex]; + T bias = biases[gindex]; + + out += oindex; + + if (bits == 3) { + w += offset * bytes_per_pack; + out[0] = (w[0] & 0x7) * scale + bias; + out[1] = ((w[0] & 0x38) >> 3) * scale + bias; + out[2] = (((w[0] & 0xc0) >> 6) + ((w[1] & 0x1) << 2)) * scale + bias; + out[3] = ((w[1] & 0xe) >> 1) * scale + bias; + out[4] = ((w[1] & 0x70) >> 4) * scale + bias; + out[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias; + out[6] = ((w[2] & 0x1c) >> 2) * scale + bias; + out[7] = ((w[2] & 0xe0) >> 5) * scale + bias; + } else if (bits == 5) { + w += offset * bytes_per_pack; + out[0] = (w[0] & 0x1f) * scale + bias; + out[1] = (((w[0] & 0xe0) >> 5) + ((w[1] & 0x3) << 3)) * scale + bias; + out[2] = ((w[1] & 0x7c) >> 2) * scale + bias; + out[3] = (((w[1] & 0x80) >> 7) + ((w[2] & 0xf) << 1)) * scale + bias; + out[4] = (((w[2] & 0xf0) >> 4) + ((w[3] & 0x1) << 4)) * scale + bias; + out[5] = ((w[3] & 0x3e) >> 1) * scale + bias; + out[6] = (((w[3] & 0xc0) >> 6) + ((w[4] & 0x7) << 2)) * scale + bias; + out[7] = ((w[4] & 0xf8) >> 3) * scale + bias; + } else if (bits == 6) { + w += offset * bytes_per_pack; + out[0] = (w[0] & 0x3f) * scale + bias; + out[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias; + out[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias; + out[3] = ((w[2] >> 2) & 0x3f) * scale + bias; + } else { + uint val = w[offset]; +#pragma clang loop unroll(full) + for (int i = 0; i < pack_factor; i++) { + uint8_t d; + if (bits == 2) { + d = (val >> (bits * i)) & 0x03; + } else if (bits == 4) { + d = (val >> (bits * i)) & 0x0f; + } else if (bits == 8) { + d = val; + } + out[i] = scale * d + bias; + } + } +} diff --git a/Source/Cxxmlx/mlx-generated/metal/quantized_nax.h b/Source/Cxxmlx/mlx-generated/metal/quantized_nax.h new file mode 100644 index 00000000..c26ff646 --- /dev/null +++ b/Source/Cxxmlx/mlx-generated/metal/quantized_nax.h @@ -0,0 +1,1705 @@ +// Copyright © 2023-2024 Apple Inc. + +#include +#include + +using namespace metal; +using namespace mlx::steel; + +constant bool align_M [[function_constant(200)]]; +constant bool align_N [[function_constant(201)]]; +constant bool align_K [[function_constant(202)]]; + +using namespace metal; + +#define MLX_MTL_CONST static constant constexpr const + +MLX_MTL_CONST int SIMD_SIZE = 32; +MLX_MTL_CONST int QUAD_SIZE = 4; + +template +inline constexpr short get_pack_factor() { + return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits); +} + +template +inline constexpr short get_bytes_per_pack() { + constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; + return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3); +} + +template +inline U load_vector(const device T* x, thread U* x_thread) { + static_assert( + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + + U sum = 0; + + if (bits == 2) { + for (int i = 0; i < values_per_thread; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 4.0f; + x_thread[i + 2] = x[i + 2] / 16.0f; + x_thread[i + 3] = x[i + 3] / 64.0f; + } + } + + else if (bits == 3) { + for (int i = 0; i < values_per_thread; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 8.0f; + x_thread[i + 2] = x[i + 2] / 64.0f; + x_thread[i + 3] = x[i + 3] / 2.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 128.0f; + x_thread[i + 6] = x[i + 6] / 4.0f; + x_thread[i + 7] = x[i + 7] / 32.0f; + } + } + + else if (bits == 4) { + for (int i = 0; i < values_per_thread; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 16.0f; + x_thread[i + 2] = x[i + 2] / 256.0f; + x_thread[i + 3] = x[i + 3] / 4096.0f; + } + } + + else if (bits == 5) { + for (int i = 0; i < values_per_thread; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 32.0f; + x_thread[i + 2] = x[i + 2] / 4.0f; + x_thread[i + 3] = x[i + 3] / 128.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 2.0f; + x_thread[i + 6] = x[i + 6] / 64.0f; + x_thread[i + 7] = x[i + 7] / 8.0f; + } + } + + else if (bits == 6) { + for (int i = 0; i < values_per_thread; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 64.0f; + x_thread[i + 2] = x[i + 2] / 16.0f; + x_thread[i + 3] = x[i + 3] / 4.0f; + } + } + + else if (bits == 8) { + for (int i = 0; i < values_per_thread; i++) { + sum += x[i]; + x_thread[i] = x[i]; + } + } + + return sum; +} + +template +inline U load_vector_safe(const device T* x, thread U* x_thread, int N) { + static_assert( + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + + U sum = 0; + + if (bits == 2) { + for (int i = 0; i < N; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 4.0f; + x_thread[i + 2] = x[i + 2] / 16.0f; + x_thread[i + 3] = x[i + 3] / 64.0f; + } + } + + else if (bits == 3) { + for (int i = 0; i < N; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 8.0f; + x_thread[i + 2] = x[i + 2] / 64.0f; + x_thread[i + 3] = x[i + 3] / 2.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 128.0f; + x_thread[i + 6] = x[i + 6] / 4.0f; + x_thread[i + 7] = x[i + 7] / 32.0f; + } + } + + else if (bits == 4) { + for (int i = 0; i < N; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 16.0f; + x_thread[i + 2] = x[i + 2] / 256.0f; + x_thread[i + 3] = x[i + 3] / 4096.0f; + } + } + + else if (bits == 5) { + for (int i = 0; i < N; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 32.0f; + x_thread[i + 2] = x[i + 2] / 4.0f; + x_thread[i + 3] = x[i + 3] / 128.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 2.0f; + x_thread[i + 6] = x[i + 6] / 64.0f; + x_thread[i + 7] = x[i + 7] / 8.0f; + } + } + + else if (bits == 6) { + for (int i = 0; i < N; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 64.0f; + x_thread[i + 2] = x[i + 2] / 16.0f; + x_thread[i + 3] = x[i + 3] / 4.0f; + } + } + + else if (bits == 8) { + for (int i = 0; i < N; i++) { + sum += x[i]; + x_thread[i] = x[i]; + } + } + + for (int i = N; i < values_per_thread; i++) { + x_thread[i] = 0; + } + + return sum; +} + +template +inline U qdot( + const device uint8_t* w, + const thread U* x_thread, + U scale, + U bias, + U sum) { + static_assert( + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + + U accum = 0; + + if (bits == 2) { + for (int i = 0; i < (values_per_thread / 4); i++) { + accum += + (x_thread[4 * i] * (w[i] & 0x03) + + x_thread[4 * i + 1] * (w[i] & 0x0c) + + x_thread[4 * i + 2] * (w[i] & 0x30) + + x_thread[4 * i + 3] * (w[i] & 0xc0)); + } + } + + else if (bits == 3) { + for (int i = 0; i < (values_per_thread / 8); i++) { + x_thread += 8 * i; + w += 3 * i; + + accum += (w[0] & 0x07) * x_thread[0]; + accum += (w[0] & 0x38) * x_thread[1]; + accum += (w[0] & 0xc0) * x_thread[2]; + accum += (w[1] & 0x01) * (x_thread[2] * 256.0f); + + accum += (w[1] & 0x0e) * x_thread[3]; + accum += (w[1] & 0x70) * x_thread[4]; + accum += (w[1] & 0x80) * x_thread[5]; + accum += (w[2] & 0x03) * (x_thread[5] * 256.0f); + + accum += (w[2] & 0x1c) * x_thread[6]; + accum += (w[2] & 0xe0) * x_thread[7]; + } + } + + else if (bits == 4) { + const device uint16_t* ws = (const device uint16_t*)w; + for (int i = 0; i < (values_per_thread / 4); i++) { + accum += + (x_thread[4 * i] * (ws[i] & 0x000f) + + x_thread[4 * i + 1] * (ws[i] & 0x00f0) + + x_thread[4 * i + 2] * (ws[i] & 0x0f00) + + x_thread[4 * i + 3] * (ws[i] & 0xf000)); + } + } + + else if (bits == 5) { + for (int i = 0; i < (values_per_thread / 8); i++) { + x_thread += 8 * i; + w += 5 * i; + + accum += (w[0] & 0x1f) * x_thread[0]; + accum += (w[0] & 0xe0) * x_thread[1]; + accum += (w[1] & 0x3) * (x_thread[1] * 256.0f); + accum += (w[1] & 0x7c) * x_thread[2]; + accum += (w[1] & 0x80) * x_thread[3]; + accum += (w[2] & 0xf) * (x_thread[3] * 256.0f); + accum += (w[2] & 0xf0) * x_thread[4]; + accum += (w[3] & 0x1) * (x_thread[4] * 256.0f); + accum += (w[3] & 0x3e) * x_thread[5]; + accum += (w[3] & 0xc0) * x_thread[6]; + accum += (w[4] & 0x7) * (x_thread[6] * 256.0f); + accum += (w[4] & 0xf8) * x_thread[7]; + } + } + + else if (bits == 6) { + for (int i = 0; i < (values_per_thread / 4); i++) { + x_thread += 4 * i; + w += 3 * i; + + accum += (w[0] & 0x3f) * x_thread[0]; + + accum += (w[0] & 0xc0) * x_thread[1]; + accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f); + + accum += (w[1] & 0xf0) * x_thread[2]; + accum += (w[2] & 0x03) * (x_thread[2] * 256.0f); + + accum += (w[2] & 0xfc) * x_thread[3]; + } + } + + else if (bits == 8) { + for (int i = 0; i < values_per_thread; i++) { + accum += x_thread[i] * w[i]; + } + } + + return scale * accum + sum * bias; +} + +template +inline U qdot_safe( + const device uint8_t* w, + const thread U* x_thread, + U scale, + U bias, + U sum, + int N) { + static_assert( + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + + U accum = 0; + + if (bits == 2) { + for (int i = 0; i < (N / 4); i++) { + accum += + (x_thread[4 * i] * (w[i] & 0x03) + + x_thread[4 * i + 1] * (w[i] & 0x0c) + + x_thread[4 * i + 2] * (w[i] & 0x30) + + x_thread[4 * i + 3] * (w[i] & 0xc0)); + } + } + + else if (bits == 3) { + for (int i = 0; i < (N / 8); i++) { + x_thread += 8 * i; + w += 3 * i; + + accum += (w[0] & 0x07) * x_thread[0]; + accum += (w[0] & 0x38) * x_thread[1]; + accum += (w[0] & 0xc0) * x_thread[2]; + accum += (w[1] & 0x01) * (x_thread[2] * 256.0f); + + accum += (w[1] & 0x0e) * x_thread[3]; + accum += (w[1] & 0x70) * x_thread[4]; + accum += (w[1] & 0x80) * x_thread[5]; + accum += (w[2] & 0x03) * (x_thread[5] * 256.0f); + + accum += (w[2] & 0x1c) * x_thread[6]; + accum += (w[2] & 0xe0) * x_thread[7]; + } + } + + else if (bits == 4) { + const device uint16_t* ws = (const device uint16_t*)w; + for (int i = 0; i < (N / 4); i++) { + accum += + (x_thread[4 * i] * (ws[i] & 0x000f) + + x_thread[4 * i + 1] * (ws[i] & 0x00f0) + + x_thread[4 * i + 2] * (ws[i] & 0x0f00) + + x_thread[4 * i + 3] * (ws[i] & 0xf000)); + } + } + + else if (bits == 5) { + for (int i = 0; i < (N / 8); i++) { + x_thread += 8 * i; + w += 5 * i; + + accum += (w[0] & 0x1f) * x_thread[0]; + accum += (w[0] & 0xe0) * x_thread[1]; + accum += (w[1] & 0x3) * (x_thread[1] * 256.0f); + accum += (w[1] & 0x7c) * x_thread[2]; + accum += (w[1] & 0x80) * x_thread[3]; + accum += (w[2] & 0xf) * (x_thread[3] * 256.0f); + accum += (w[2] & 0xf0) * x_thread[4]; + accum += (w[3] & 0x1) * (x_thread[4] * 256.0f); + accum += (w[3] & 0x3e) * x_thread[5]; + accum += (w[3] & 0xc0) * x_thread[6]; + accum += (w[4] & 0x7) * (x_thread[6] * 256.0f); + accum += (w[4] & 0xf8) * x_thread[7]; + } + } + + else if (bits == 6) { + for (int i = 0; i < (N / 4); i++) { + x_thread += 4 * i; + w += 3 * i; + + accum += (w[0] & 0x3f) * x_thread[0]; + + accum += (w[0] & 0xc0) * x_thread[1]; + accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f); + + accum += (w[1] & 0xf0) * x_thread[2]; + accum += (w[2] & 0x03) * (x_thread[2] * 256.0f); + + accum += (w[2] & 0xfc) * x_thread[3]; + } + } + + else if (bits == 8) { + for (int i = 0; i < N; i++) { + accum += x_thread[i] * w[i]; + } + } + + return scale * accum + sum * bias; +} + +template +inline void +qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) { + static_assert( + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + + if (bits == 2) { + U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f}; + for (int i = 0; i < (values_per_thread / 4); i++) { + result[4 * i] += x * (s[0] * (w[i] & 0x03) + bias); + result[4 * i + 1] += x * (s[1] * (w[i] & 0x0c) + bias); + result[4 * i + 2] += x * (s[2] * (w[i] & 0x30) + bias); + result[4 * i + 3] += x * (s[3] * (w[i] & 0xc0) + bias); + } + } + + else if (bits == 3) { + for (int i = 0; i < (values_per_thread / 8); i++) { + uint8_t w0 = w[3 * i]; + uint8_t w1 = w[3 * i + 1]; + uint8_t w2 = w[3 * i + 2]; + + result[8 * i] += x * ((w0 & 0x7) * scale + bias); + result[8 * i + 1] += x * (((w0 & 0x38) >> 3) * scale + bias); + result[8 * i + 2] += + x * ((((w0 & 0xc0) >> 6) + ((w1 & 0x1) << 2)) * scale + bias); + result[8 * i + 3] += x * (((w1 & 0xe) >> 1) * scale + bias); + result[8 * i + 4] += x * (((w1 & 0x70) >> 4) * scale + bias); + result[8 * i + 5] += + x * ((((w1 & 0x80) >> 7) + ((w2 & 0x3) << 1)) * scale + bias); + result[8 * i + 6] += x * (((w2 & 0x1c) >> 2) * scale + bias); + result[8 * i + 7] += x * (((w2 & 0xe0) >> 5) * scale + bias); + } + } + + else if (bits == 4) { + U s[2] = {scale, scale / 16.0f}; + for (int i = 0; i < (values_per_thread / 2); i++) { + result[2 * i] += x * (s[0] * (w[i] & 0x0f) + bias); + result[2 * i + 1] += x * (s[1] * (w[i] & 0xf0) + bias); + } + } + + else if (bits == 5) { + for (int i = 0; i < (values_per_thread / 8); i++) { + uint8_t w0 = w[5 * i]; + uint8_t w1 = w[5 * i + 1]; + uint8_t w2 = w[5 * i + 2]; + uint8_t w3 = w[5 * i + 3]; + uint8_t w4 = w[5 * i + 4]; + result[8 * i] += x * ((w0 & 0x1f) * scale + bias); + result[8 * i + 1] += + x * ((((w0 & 0xe0) >> 5) + ((w1 & 0x3) << 3)) * scale + bias); + result[8 * i + 2] += x * (((w1 & 0x7c) >> 2) * scale + bias); + result[8 * i + 3] += + x * ((((w1 & 0x80) >> 7) + ((w2 & 0xf) << 1)) * scale + bias); + result[8 * i + 4] += + x * ((((w2 & 0xf0) >> 4) + ((w3 & 0x1) << 4)) * scale + bias); + result[8 * i + 5] += x * (((w3 & 0x3e) >> 1) * scale + bias); + result[8 * i + 6] += + x * ((((w3 & 0xc0) >> 6) + ((w4 & 0x7) << 2)) * scale + bias); + result[8 * i + 7] += x * (((w4 & 0xf8) >> 3) * scale + bias); + } + } + + else if (bits == 6) { + for (int i = 0; i < (values_per_thread / 4); i++) { + uint8_t w0 = w[3 * i]; + uint8_t w1 = w[3 * i + 1]; + uint8_t w2 = w[3 * i + 2]; + + result[4 * i] += x * ((w0 & 0x3f) * scale + bias); + result[4 * i + 1] += + x * ((((w0 >> 6) & 0x03) + ((w1 & 0x0f) << 2)) * scale + bias); + result[4 * i + 2] += + x * ((((w1 >> 4) & 0x0f) + ((w2 & 0x03) << 4)) * scale + bias); + result[4 * i + 3] += x * (((w2 >> 2) & 0x3f) * scale + bias); + } + } + + else if (bits == 8) { + for (int i = 0; i < values_per_thread; i++) { + result[i] += x * (scale * w[i] + bias); + } + } +} + +template +inline void +dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) { + static_assert( + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + + if (bits == 2) { + U s[4] = { + scale, + scale / static_cast(4.0f), + scale / static_cast(16.0f), + scale / static_cast(64.0f)}; + for (int i = 0; i < (N / 4); i++) { + w_local[4 * i] = s[0] * (w[i] & 0x03) + bias; + w_local[4 * i + 1] = s[1] * (w[i] & 0x0c) + bias; + w_local[4 * i + 2] = s[2] * (w[i] & 0x30) + bias; + w_local[4 * i + 3] = s[3] * (w[i] & 0xc0) + bias; + } + } + + else if (bits == 3) { + for (int i = 0; i < (N / 8); i++) { + w_local += 8 * i; + w += 3 * i; + + w_local[0] = (w[0] & 0x7) * scale + bias; + w_local[1] = ((w[0] & 0x38) >> 3) * scale + bias; + w_local[2] = (((w[0] & 0xc0) >> 6) + ((w[1] & 0x1) << 2)) * scale + bias; + w_local[3] = ((w[1] & 0xe) >> 1) * scale + bias; + w_local[4] = ((w[1] & 0x70) >> 4) * scale + bias; + w_local[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias; + w_local[6] = ((w[2] & 0x1c) >> 2) * scale + bias; + w_local[7] = ((w[2] & 0xe0) >> 5) * scale + bias; + } + } + + else if (bits == 4) { + U s[2] = {scale, scale / static_cast(16.0f)}; + for (int i = 0; i < (N / 2); i++) { + w_local[2 * i] = s[0] * (w[i] & 0x0f) + bias; + w_local[2 * i + 1] = s[1] * (w[i] & 0xf0) + bias; + } + } + + else if (bits == 5) { + for (int i = 0; i < (N / 8); i++) { + w_local += 8 * i; + w += 5 * i; + + w_local[0] = (w[0] & 0x1f) * scale + bias; + w_local[1] = (((w[0] & 0xe0) >> 5) + ((w[1] & 0x3) << 3)) * scale + bias; + w_local[2] = ((w[1] & 0x7c) >> 2) * scale + bias; + w_local[3] = (((w[1] & 0x80) >> 7) + ((w[2] & 0xf) << 1)) * scale + bias; + w_local[4] = (((w[2] & 0xf0) >> 4) + ((w[3] & 0x1) << 4)) * scale + bias; + w_local[5] = ((w[3] & 0x3e) >> 1) * scale + bias; + w_local[6] = (((w[3] & 0xc0) >> 6) + ((w[4] & 0x7) << 2)) * scale + bias; + w_local[7] = ((w[4] & 0xf8) >> 3) * scale + bias; + } + } + + else if (bits == 6) { + for (int i = 0; i < (N / 4); i++) { + w_local += 4 * i; + w += 3 * i; + w_local[0] = (w[0] & 0x3f) * scale + bias; + w_local[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias; + w_local[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias; + w_local[3] = ((w[2] >> 2) & 0x3f) * scale + bias; + } + } + + else if (bits == 8) { + for (int i = 0; i < N; i++) { + w_local[i] = scale * w[i] + bias; + } + } +} + +template < + typename T, + short BROWS, + short BCOLS, + short dst_ld, + short reduction_dim, + short tgp_size, + short group_size, + short bits> +struct QuantizedBlockLoader { + static_assert( + BCOLS <= group_size, + "The group size should be larger than the columns"); + static_assert( + group_size % BCOLS == 0, + "The group size should be divisible by the columns"); + static_assert( + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + + MLX_MTL_CONST short pack_factor = get_pack_factor(); + MLX_MTL_CONST short bytes_per_pack = get_bytes_per_pack(); + MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor; + MLX_MTL_CONST short n_reads = + (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size; + MLX_MTL_CONST short group_steps = group_size / BCOLS; + + const int src_ld; + const int tile_stride; + short group_step_cnt; + const int group_stride; + + const short thread_idx; + const short bi; + const short bj; + + threadgroup T* dst; + const device uint8_t* src; + const device T* scales; + const device T* biases; + + QuantizedBlockLoader( + const device uint8_t* src_, + const device T* scales_, + const device T* biases_, + const int src_ld_, + threadgroup T* dst_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(src_ld_), + tile_stride( + reduction_dim ? BCOLS_PACKED * bytes_per_pack + : BROWS * src_ld * bytes_per_pack / pack_factor), + group_step_cnt(0), + group_stride(BROWS * src_ld / group_size), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(n_reads * thread_idx / BCOLS_PACKED), + bj((n_reads * thread_idx) % BCOLS_PACKED), + dst(dst_ + bi * dst_ld + bj * pack_factor), + src(src_ + bi * src_ld * bytes_per_pack / pack_factor + + bj * bytes_per_pack), + scales(scales_ + bi * src_ld / group_size), + biases(biases_ + bi * src_ld / group_size) {} + + void load_unsafe() const { + if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { + return; + } + + T scale = *scales; + T bias = *biases; + for (int i = 0; i < n_reads; i++) { + dequantize( + src + i * bytes_per_pack, scale, bias, dst + i * pack_factor); + } + } + + void load_safe(short2 src_tile_dim) const { + if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { + return; + } + + if (reduction_dim == 1 && bi >= src_tile_dim.x) { + for (int i = 0; i < n_reads * pack_factor; i++) { + dst[i] = T(0); + } + return; + } + + if (reduction_dim == 0 && bi >= src_tile_dim.y) { + for (int i = 0; i < n_reads * pack_factor; i++) { + dst[i] = T(0); + } + return; + } + + T scale = *scales; + T bias = *biases; + for (int i = 0; i < n_reads; i++) { + dequantize( + (device uint8_t*)(src + i * bytes_per_pack), + scale, + bias, + dst + i * pack_factor); + } + } + + void next() { + src += tile_stride; + if (reduction_dim == 1) { + if (group_steps > 1) { + group_step_cnt++; + if (group_step_cnt == group_steps) { + group_step_cnt = 0; + scales++; + biases++; + } + } else { + scales++; + biases++; + } + } else { + scales += group_stride; + biases += group_stride; + } + } +}; + +template < + typename T, + short BROWS, + short BCOLS, + short dst_ld, + short reduction_dim, + short tgp_size, + short bits> +struct QuantizedBlockLoader< + T, + BROWS, + BCOLS, + dst_ld, + reduction_dim, + tgp_size, + 32, + bits> { + MLX_MTL_CONST short group_size = 32; + + static_assert( + BCOLS % group_size == 0, + "The group size should be divisible by the columns"); + static_assert( + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + + MLX_MTL_CONST short pack_factor = get_pack_factor(); + MLX_MTL_CONST short bytes_per_pack = get_bytes_per_pack(); + MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor; + MLX_MTL_CONST short n_reads = + (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size; + MLX_MTL_CONST short n_groups = BCOLS / group_size; + + static_assert( + (BCOLS_PACKED / n_reads) == n_groups, + "Other configurations are not yet supported"); + + const int src_ld; + const int tile_stride; + const int group_stride; + + const short thread_idx; + const short bi; + const short bj; + + const short group_id; + + threadgroup T* dst; + const device uint8_t* src; + const device T* scales; + const device T* biases; + + QuantizedBlockLoader( + const device uint8_t* src_, + const device T* scales_, + const device T* biases_, + const int src_ld_, + threadgroup T* dst_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(src_ld_), + tile_stride( + reduction_dim ? BCOLS_PACKED * bytes_per_pack + : BROWS * src_ld * bytes_per_pack / pack_factor), + group_stride(BROWS * src_ld / group_size), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(n_reads * thread_idx / BCOLS_PACKED), + bj((n_reads * thread_idx) % BCOLS_PACKED), + group_id((bj * pack_factor) / group_size), + dst(dst_ + bi * dst_ld + bj * pack_factor), + src(src_ + bi * src_ld * bytes_per_pack / pack_factor + + bj * bytes_per_pack), + scales(scales_ + bi * src_ld / group_size + group_id), + biases(biases_ + bi * src_ld / group_size + group_id) {} + + void load_unsafe() const { + if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { + return; + } + + T scale = *scales; + T bias = *biases; + for (int i = 0; i < n_reads; i++) { + dequantize( + src + i * bytes_per_pack, scale, bias, dst + i * pack_factor); + } + } + + void load_safe(short2 src_tile_dim) const { + if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { + return; + } + + if (reduction_dim == 1 && bi >= src_tile_dim.x) { + for (int i = 0; i < n_reads * pack_factor; i++) { + dst[i] = T(0); + } + return; + } + + if (reduction_dim == 0 && bi >= src_tile_dim.y) { + for (int i = 0; i < n_reads * pack_factor; i++) { + dst[i] = T(0); + } + return; + } + + T scale = *scales; + T bias = *biases; + for (int i = 0; i < n_reads; i++) { + dequantize( + (device uint8_t*)(src + i * bytes_per_pack), + scale, + bias, + dst + i * pack_factor); + } + } + + void next() { + src += tile_stride; + if (reduction_dim == 1) { + // if (group_steps > 1) { + // group_step_cnt++; + // if (group_step_cnt == group_steps) { + // group_step_cnt = 0; + // scales++; + // biases++; + // } + // } else { + scales += n_groups; + biases += n_groups; + // } + } else { + scales += n_groups * group_stride; + biases += n_groups * group_stride; + } + } +}; + +template +METAL_FUNC void adjust_matrix_offsets( + const device T*& x, + const device uint32_t*& w, + const device T*& scales, + const device T*& biases, + device T*& y, + int output_stride, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + const constant int64_t* b_strides, + uint3 tid [[threadgroup_position_in_grid]]) { + // Set the input/output matrices + uint32_t x_idx = tid.z; + uint32_t w_idx = tid.z; + if (x_batch_ndims == 1) { + x += x_idx * x_strides[0]; + } else { + x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims); + } + if (w_batch_ndims == 1) { + w += w_idx * w_strides[0]; + scales += w_idx * s_strides[0]; + biases += w_idx * b_strides[0]; + } else { + ulong3 idx = elem_to_loc_broadcast( + w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims); + w += idx.x; + scales += idx.y; + biases += idx.z; + } + y += tid.z * output_stride; +} + +template +METAL_FUNC void adjust_matrix_offsets( + const device T*& x, + const device uint32_t*& w, + const device T*& scales, + const device T*& biases, + const device uint32_t* lhs_indices, + const device uint32_t* rhs_indices, + device T*& y, + int output_stride, + const constant int& batch_ndims, + const constant int* batch_shape, + const constant int64_t* lhs_strides, + const constant int64_t* rhs_strides, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + const constant int64_t* b_strides, + uint3 tid [[threadgroup_position_in_grid]]) { + // Set the input/output matrices + uint32_t x_idx; + uint32_t w_idx; + if (batch_ndims == 1) { + x_idx = lhs_indices[tid.z * lhs_strides[0]]; + w_idx = rhs_indices[tid.z * rhs_strides[0]]; + } else { + ulong2 idx = elem_to_loc_broadcast( + tid.z, batch_shape, lhs_strides, rhs_strides, batch_ndims); + x_idx = lhs_indices[idx.x]; + w_idx = rhs_indices[idx.y]; + } + if (x_batch_ndims == 1) { + x += x_idx * x_strides[0]; + } else { + x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims); + } + if (w_batch_ndims == 1) { + w += w_idx * w_strides[0]; + scales += w_idx * s_strides[0]; + biases += w_idx * b_strides[0]; + } else { + ulong3 idx = elem_to_loc_broadcast( + w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims); + w += idx.x; + scales += idx.y; + biases += idx.z; + } + y += tid.z * output_stride; +} + +template < + typename T, + const int group_size, + const int bits, + const bool aligned_N, + const int BM = 64, + const int BK = 64, + const int BN = 64, + const int WM = 2, + const int WN = 2> +METAL_FUNC void qmm_t_nax_tgp_impl( + const device uint32_t* w, + const device T* scales, + const device T* biases, + const device T* x, + device T* y, + threadgroup T* Ws, + const constant int& K, + const constant int& N, + const constant int& M, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); + static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); + + (void)lid; + + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + + using loader_w_t = QuantizedBlockLoader< + T, + BN, + BK, + BK_padded, + 1, + WM * WN * SIMD_SIZE, + group_size, + bits>; + + // Set the block + const int K_w = K * bytes_per_pack / pack_factor; + const int K_g = K / group_size; + const int y_row = tid.y * BM; + const int y_col = tid.x * BN; + + auto wl = (const device uint8_t*)w; + + x += y_row * static_cast(K); + wl += y_col * K_w; + scales += y_col * K_g; + biases += y_col * K_g; + y += y_row * static_cast(N) + y_col; + + // Make the weight loader + loader_w_t loader_w(wl, scales, biases, K, Ws, simd_gid, simd_lid); + + constexpr short UM = 16; + constexpr short UN = 32; + constexpr short UK = 16; + constexpr short SM = BM / WM; + constexpr short SN = BN / WN; + constexpr short SK = 32; + + constexpr short TM = SM / UM; + constexpr short TN = SN / UN; + constexpr short TK = SK / UK; + + const short tm = SM * (simd_gid / WN); + const short tn = SN * (simd_gid % WN); + + constexpr bool transpose_a = false; + constexpr bool transpose_b = true; + + const short sgp_sm = min(SM, short(M - (y_row + tm))); + const bool is_unaligned_sm = (sgp_sm != SM); + + const short sgp_sn = aligned_N ? SN : min(SN, short(N - (y_col + tn))); + + const short tgp_bn = aligned_N ? BN : min(BN, int(N - (y_col))); + const bool is_unaligned_bn = aligned_N ? false : (tgp_bn != BN); + + using AccumType = float; + + using ASubTile = NAXSubTile; + using BSubTile = NAXSubTile; + using DSubTile = NAXSubTile; + + NAXTile Dtile; + + Dtile.clear(); + + x += tm * K; + + dispatch_bool(!is_unaligned_sm, [&](auto kAlignedM) { + dispatch_bool(aligned_N || !is_unaligned_bn, [&](auto kAlignedN) { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + if constexpr (kAlignedN.value) { + loader_w.load_unsafe(); + } else { + loader_w.load_safe(short2(BK, tgp_bn)); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + STEEL_PRAGMA_NO_UNROLL + for (int kk1 = 0; kk1 < BK; kk1 += SK) { + NAXTile Atile; + NAXTile Btile; + + volatile int compiler_barrier; + + if constexpr (kAlignedM.value) { + Atile.load(x + kk1, K); + } else { + Atile.load_safe(x + kk1, K, short2(SK, sgp_sm)); + } + + Btile.template load(Ws + tn * BK_padded + kk1); + + tile_matmad_nax( + Dtile, + Atile, + metal::bool_constant{}, + Btile, + metal::bool_constant{}); + + (void)compiler_barrier; + } + + x += BK; + loader_w.next(); + } + + // Store results to device memory + threadgroup_barrier(mem_flags::mem_threadgroup); + + if constexpr (kAlignedM.value && kAlignedN.value) { + Dtile.store(y + tm * N + tn, N); + } else if (kAlignedM.value && sgp_sn == SN) { + Dtile.store(y + tm * N + tn, N); + } else { + Dtile.store_safe(y + tm * N + tn, N, short2(sgp_sn, sgp_sm)); + } + }); + }); +} + +template < + typename T, + const int group_size, + const int bits, + const int BM = 64, + const int BK = 64, + const int BN = 64, + const int WM = 2, + const int WN = 2> +METAL_FUNC void qmm_n_nax_tgp_impl( + const device uint32_t* w, + const device T* scales, + const device T* biases, + const device T* x, + device T* y, + threadgroup T* Ws, + const constant int& K, + const constant int& N, + const constant int& M, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + (void)M; + + static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); + static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); + + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + + constexpr int BN_padded = (BN + 16 / sizeof(T)); + + using loader_w_t = QuantizedBlockLoader< + T, + BK, + BN, + BN_padded, + 0, + WM * WN * SIMD_SIZE, + group_size, + bits>; + + // Set the block + const int K_w = K * bytes_per_pack / pack_factor; + const int K_g = K / group_size; + const int y_row = tid.y * BM; + const int y_col = tid.x * BN; + + auto wl = (const device uint8_t*)w; + + x += y_row * static_cast(K); + wl += y_col * K_w; + scales += y_col * K_g; + biases += y_col * K_g; + y += y_row * static_cast(N) + y_col; + + // Make the x loader and mma operation + // const short num_els = min(BM, M - y_row); + // const short num_outs = min(BN, N - y_col); + loader_w_t loader_w(wl, scales, biases, K, Ws, simd_gid, simd_lid); + + constexpr short UM = 16; + constexpr short UN = 32; + constexpr short UK = 16; + constexpr short SM = BM / WM; + constexpr short SN = BN / WN; + constexpr short SK = 32; + + constexpr short TM = SM / UM; + constexpr short TN = SN / UN; + constexpr short TK = SK / UK; + + const short tm = SM * (simd_gid / WN); + const short tn = SN * (simd_gid % WN); + + const short ldb_tgp = BN_padded; + + constexpr bool transpose_a = false; + constexpr bool transpose_b = false; + + using AccumType = float; + + using ASubTile = NAXSubTile; + using BSubTile = NAXSubTile; + using DSubTile = NAXSubTile; + + NAXTile Dtile; + + Dtile.clear(); + + x += tm * K; + + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + + STEEL_PRAGMA_NO_UNROLL + for (int kk1 = 0; kk1 < BK; kk1 += SK) { + NAXTile Atile; + NAXTile Btile; + + volatile int compiler_barrier; + + Atile.load(x + kk1, K); + Btile.template load(Ws + tn + kk1 * ldb_tgp); + + tile_matmad_nax( + Dtile, + Atile, + metal::bool_constant{}, + Btile, + metal::bool_constant{}); + + (void)compiler_barrier; + } + + x += BK; + loader_w.next(); + } + + // Store results to device memory + threadgroup_barrier(mem_flags::mem_threadgroup); + + Dtile.store(y + tm * N + tn, N); +} + +template < + typename T, + const int group_size, + const int bits, + const bool aligned_N, + const bool batched, + const int BM = 64, + const int BK = 32, + const int BN = 64, + const int WM = 2, + const int WN = 2> +[[kernel]] void affine_qmm_t_nax( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + device T* y [[buffer(4)]], + const constant int& K [[buffer(5)]], + const constant int& N [[buffer(6)]], + const constant int& M [[buffer(7)]], + const constant int& x_batch_ndims [[buffer(8)]], + const constant int* x_shape [[buffer(9)]], + const constant int64_t* x_strides [[buffer(10)]], + const constant int& w_batch_ndims [[buffer(11)]], + const constant int* w_shape [[buffer(12)]], + const constant int64_t* w_strides [[buffer(13)]], + const constant int64_t* s_strides [[buffer(14)]], + const constant int64_t* b_strides [[buffer(15)]], + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + + threadgroup T Ws[BN * BK_padded]; + + if (batched) { + adjust_matrix_offsets( + x, + w, + scales, + biases, + y, + M * N, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + } + qmm_t_nax_tgp_impl( + w, scales, biases, x, y, Ws, K, N, M, tid, lid, simd_gid, simd_lid); +} + +template < + typename T, + const int group_size, + const int bits, + const bool batched, + const int BM = 64, + const int BK = 64, + const int BN = 64, + const int WM = 2, + const int WN = 2> +[[kernel]] void affine_qmm_n_nax( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + device T* y [[buffer(4)]], + const constant int& K [[buffer(5)]], + const constant int& N [[buffer(6)]], + const constant int& M [[buffer(7)]], + const constant int& x_batch_ndims [[buffer(8)]], + const constant int* x_shape [[buffer(9)]], + const constant int64_t* x_strides [[buffer(10)]], + const constant int& w_batch_ndims [[buffer(11)]], + const constant int* w_shape [[buffer(12)]], + const constant int64_t* w_strides [[buffer(13)]], + const constant int64_t* s_strides [[buffer(14)]], + const constant int64_t* b_strides [[buffer(15)]], + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BN_padded = (BN + 16 / sizeof(T)); + + threadgroup T Ws[BK * BN_padded]; + + if (batched) { + adjust_matrix_offsets( + x, + w, + scales, + biases, + y, + M * N, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + } + + qmm_n_nax_tgp_impl( + w, scales, biases, x, y, Ws, K, N, M, tid, lid, simd_gid, simd_lid); +} + +template < + typename T, + const int group_size, + const int bits, + const bool aligned_N, + const int BM = 64, + const int BK = 64, + const int BN = 64, + const int WM = 2, + const int WN = 2> +[[kernel]] void affine_gather_qmm_t_nax( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + const device uint32_t* lhs_indices [[buffer(4)]], + const device uint32_t* rhs_indices [[buffer(5)]], + device T* y [[buffer(6)]], + const constant int& K [[buffer(7)]], + const constant int& N [[buffer(8)]], + const constant int& M [[buffer(9)]], + const constant int& x_batch_ndims [[buffer(10)]], + const constant int* x_shape [[buffer(11)]], + const constant int64_t* x_strides [[buffer(12)]], + const constant int& w_batch_ndims [[buffer(13)]], + const constant int* w_shape [[buffer(14)]], + const constant int64_t* w_strides [[buffer(15)]], + const constant int64_t* s_strides [[buffer(16)]], + const constant int64_t* b_strides [[buffer(17)]], + const constant int& batch_ndims [[buffer(18)]], + const constant int* batch_shape [[buffer(19)]], + const constant int64_t* lhs_strides [[buffer(20)]], + const constant int64_t* rhs_strides [[buffer(21)]], + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + + threadgroup T Ws[BN * BK_padded]; + + adjust_matrix_offsets( + x, + w, + scales, + biases, + lhs_indices, + rhs_indices, + y, + M * N, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + qmm_t_nax_tgp_impl( + w, scales, biases, x, y, Ws, K, N, M, tid, lid, simd_gid, simd_lid); +} + +template < + typename T, + const int group_size, + const int bits, + const int BM = 64, + const int BK = 64, + const int BN = 64, + const int WM = 2, + const int WN = 2> +[[kernel]] void affine_gather_qmm_n_nax( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + const device uint32_t* lhs_indices [[buffer(4)]], + const device uint32_t* rhs_indices [[buffer(5)]], + device T* y [[buffer(6)]], + const constant int& K [[buffer(7)]], + const constant int& N [[buffer(8)]], + const constant int& M [[buffer(9)]], + const constant int& x_batch_ndims [[buffer(10)]], + const constant int* x_shape [[buffer(11)]], + const constant int64_t* x_strides [[buffer(12)]], + const constant int& w_batch_ndims [[buffer(13)]], + const constant int* w_shape [[buffer(14)]], + const constant int64_t* w_strides [[buffer(15)]], + const constant int64_t* s_strides [[buffer(16)]], + const constant int64_t* b_strides [[buffer(17)]], + const constant int& batch_ndims [[buffer(18)]], + const constant int* batch_shape [[buffer(19)]], + const constant int64_t* lhs_strides [[buffer(20)]], + const constant int64_t* rhs_strides [[buffer(21)]], + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BN_padded = (BN + 16 / sizeof(T)); + + threadgroup T Ws[BK * BN_padded]; + + adjust_matrix_offsets( + x, + w, + scales, + biases, + lhs_indices, + rhs_indices, + y, + M * N, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + qmm_n_nax_tgp_impl( + w, scales, biases, x, y, Ws, K, N, M, tid, lid, simd_gid, simd_lid); +} + +template < + typename T, + int group_size, + int bits, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose> +[[kernel]] void affine_gather_qmm_rhs_nax( + const device T* x [[buffer(0)]], + const device uint32_t* w [[buffer(1)]], + const device T* scales [[buffer(2)]], + const device T* biases [[buffer(3)]], + const device uint32_t* indices [[buffer(4)]], + device T* y [[buffer(5)]], + const constant int& M [[buffer(6)]], + const constant int& N [[buffer(7)]], + const constant int& K [[buffer(8)]], + uint3 tid [[threadgroup_position_in_grid]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) { + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + + using loader_w_t = QuantizedBlockLoader< + T, + transpose ? BN : BK, + transpose ? BK : BN, + transpose ? BK_padded : BN_padded, + transpose, + WM * WN * SIMD_SIZE, + group_size, + bits>; + + threadgroup T Ws[transpose ? BN * BK_padded : BK * BN_padded]; + + // Compute the block + const int K_w = K * bytes_per_pack / pack_factor; + const int K_g = K / group_size; + const int N_w = N * bytes_per_pack / pack_factor; + const int N_g = N / group_size; + const int K_it = K / BK; + const size_t stride_w = transpose ? N * K_w : K * N_w; + const size_t stride_s = transpose ? N * K_g : K * N_g; + const int y_row = tid.y * BM; + const int y_col = tid.x * BN; + const size_t y_row_long = size_t(y_row); + const size_t y_col_long = size_t(y_col); + + // Prepare threadgroup bounds + const short tgp_bm = align_M ? BM : short(min(BM, M - y_row)); + const short tgp_bn = align_N ? BN : short(min(BN, N - y_col)); + + // Calculate the final tiles in the case that K is not aligned + const int k_remain = K - K_it * BK; + const short2 tile_w = + transpose ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + + // Move x and output to the correct block + auto wl = (const device uint8_t*)w; + x += y_row_long * K; + y += y_row_long * N + y_col_long; + wl += transpose ? y_col_long * K_w : y_col * bytes_per_pack / pack_factor; + scales += transpose ? y_col_long * K_g : y_col / group_size; + biases += transpose ? y_col_long * K_g : y_col / group_size; + + constexpr short UM = 16; + constexpr short UN = 32; + constexpr short UK = 16; + constexpr short SM = BM / WM; + constexpr short SN = BN / WN; + constexpr short SK = 32; + + constexpr short TM = SM / UM; + constexpr short TN = SN / UN; + constexpr short TK = SK / UK; + + const short tm = SM * (simd_group_id / WN); + const short tn = SN * (simd_group_id % WN); + + const short sgp_sm = + align_M ? SM : min(SM, short(max(0, (M - (y_row + tm))))); + const short sgp_sn = + align_N ? SN : min(SN, short(max(0, (N - (y_col + tn))))); + + const bool is_unaligned_sm = align_M ? false : (sgp_sm != SM); + const bool is_unaligned_bn = align_N ? false : (tgp_bn != BN); + + constexpr short BR = transpose ? TN : TK; + constexpr short BC = transpose ? TK : TN; + + using AccumType = float; + + using ASubTile = NAXSubTile; + using BSubTile = NAXSubTile; + using DSubTile = NAXSubTile; + + // Do as many matmuls as necessary + uint32_t index; + short offset; + uint32_t index_next = indices[y_row]; + short offset_next = 0; + int n = 0; + while (n < tgp_bm) { + n++; + offset = offset_next; + index = index_next; + offset_next = tgp_bm; + for (; n < tgp_bm; n++) { + if (indices[y_row + n] != index) { + offset_next = n; + index_next = indices[y_row + n]; + break; + } + } + threadgroup_barrier(mem_flags::mem_none); + + NAXTile Dtile; + + Dtile.clear(); + + const device T* xn = x + tm * K; + + // Prepare threadgroup loading operations + thread loader_w_t loader_w( + wl + index * stride_w, + scales + index * stride_s, + biases + index * stride_s, + transpose ? K : N, + Ws, + simd_group_id, + simd_lane_id); + + dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) { + dispatch_bool(align_N || !is_unaligned_bn, [&](auto kAlignedN) { + for (int k = 0; k < K_it; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + if constexpr (kAlignedN.value) { + loader_w.load_unsafe(); + } else { + loader_w.load_safe( + transpose ? short2(BK, tgp_bn) : short2(tgp_bn, BK)); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + STEEL_PRAGMA_NO_UNROLL + for (int kk1 = 0; kk1 < BK; kk1 += SK) { + NAXTile Atile; + NAXTile Btile; + + volatile int compiler_barrier; + + if constexpr (kAlignedM.value) { + Atile.load(xn + kk1, K); + } else { + Atile.load_safe(xn + kk1, K, short2(SK, sgp_sm)); + } + + if constexpr (transpose) { + Btile.template load(Ws + tn * BK_padded + kk1); + } else { + Btile.template load(Ws + tn + kk1 * BN_padded); + } + + tile_matmad_nax( + Dtile, + Atile, + metal::bool_constant{}, + Btile, + metal::bool_constant{}); + + (void)compiler_barrier; + } + + xn += BK; + loader_w.next(); + } + + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_w.load_safe(tile_w); + threadgroup_barrier(mem_flags::mem_threadgroup); + + STEEL_PRAGMA_NO_UNROLL + for (int kk1 = 0; kk1 < BK; kk1 += SK) { + NAXTile Atile; + NAXTile Btile; + + volatile int compiler_barrier; + + const short psk = min(int(SK), max(0, (BK - kk1))); + Atile.load_safe(xn + kk1, K, short2(psk, sgp_sm)); + + if constexpr (transpose) { + Btile.template load(Ws + tn * BK_padded + kk1); + } else { + Btile.template load(Ws + tn + kk1 * BN_padded); + } + + tile_matmad_nax( + Dtile, + Atile, + metal::bool_constant{}, + Btile, + metal::bool_constant{}); + + (void)compiler_barrier; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + const short m_lo_lim = min(int(sgp_sm), max(0, offset - tm)); + const short m_hi_lim = min(int(sgp_sm), max(0, offset_next - tm)); + + // Store results to device memory + if constexpr (kAlignedN.value) { + if (m_lo_lim == 0 && m_hi_lim == SM) { + Dtile.store(y + tm * N + tn, N); + } else { + Dtile.store_slice( + y + tm * N + tn, N, short2(0, m_lo_lim), short2(SN, m_hi_lim)); + } + } else { + Dtile.store_slice( + y + tm * N + tn, + N, + short2(0, m_lo_lim), + short2(sgp_sn, m_hi_lim)); + } + }); + }); + } +} \ No newline at end of file diff --git a/Source/Cxxmlx/mlx-generated/metal/quantized_utils.h b/Source/Cxxmlx/mlx-generated/metal/quantized_utils.h new file mode 100644 index 00000000..38253f8f --- /dev/null +++ b/Source/Cxxmlx/mlx-generated/metal/quantized_utils.h @@ -0,0 +1,90 @@ +// Copyright © 2023-2024 Apple Inc. + +#include +#include + +template +METAL_FUNC void gemm_loop_aligned( + threadgroup T* As, + threadgroup T* Bs, + thread mma_t& mma_op, + thread loader_a_t& loader_a, + thread loader_b_t& loader_b, + const int k_iterations) { + for (int k = 0; k < k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load elements into threadgroup memory + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } +} + +template < + bool rows_aligned, + bool cols_aligned, + bool transpose, + typename T, + typename mma_t, + typename loader_a_t, + typename loader_b_t> +METAL_FUNC void gemm_loop_unaligned( + threadgroup T* As, + threadgroup T* Bs, + thread mma_t& mma_op, + thread loader_a_t& loader_a, + thread loader_b_t& loader_b, + const int k_iterations, + const short tgp_bm, + const short tgp_bn, + const short tgp_bk) { + for (int k = 0; k < k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load elements into threadgroup memory + if (rows_aligned) { + loader_a.load_unsafe(); + } else { + loader_a.load_safe(short2(tgp_bk, tgp_bm)); + } + if (cols_aligned) { + loader_b.load_unsafe(); + } else { + loader_b.load_safe( + transpose ? short2(tgp_bk, tgp_bn) : short2(tgp_bn, tgp_bk)); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } +} + +template +METAL_FUNC void gemm_loop_finalize( + threadgroup T* As, + threadgroup T* Bs, + thread mma_t& mma_op, + thread loader_a_t& loader_a, + thread loader_b_t& loader_b, + const short2 tile_a, + const short2 tile_b) { + loader_a.load_safe(tile_a); + loader_b.load_safe(tile_b); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(As, Bs); +} diff --git a/Source/Cmlx/mlx-generated/metal/random.metal b/Source/Cxxmlx/mlx-generated/metal/random.metal similarity index 100% rename from Source/Cmlx/mlx-generated/metal/random.metal rename to Source/Cxxmlx/mlx-generated/metal/random.metal diff --git a/Source/Cmlx/mlx-generated/metal/reduce.h b/Source/Cxxmlx/mlx-generated/metal/reduce.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/reduce.h rename to Source/Cxxmlx/mlx-generated/metal/reduce.h diff --git a/Source/Cmlx/mlx-generated/metal/reduce_utils.h b/Source/Cxxmlx/mlx-generated/metal/reduce_utils.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/reduce_utils.h rename to Source/Cxxmlx/mlx-generated/metal/reduce_utils.h diff --git a/Source/Cxxmlx/mlx-generated/metal/reduction/ops.h b/Source/Cxxmlx/mlx-generated/metal/reduction/ops.h new file mode 100644 index 00000000..11d8e83a --- /dev/null +++ b/Source/Cxxmlx/mlx-generated/metal/reduction/ops.h @@ -0,0 +1,275 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include +#include + +#define DEFINE_SIMD_REDUCE() \ + template = true> \ + T simd_reduce(T val) { \ + return simd_reduce_impl(val); \ + } \ + \ + template = true> \ + T simd_reduce(T val) { \ + for (short i = simd_size / 2; i > 0; i /= 2) { \ + val = operator()(val, simd_shuffle_down(val, i)); \ + } \ + return val; \ + } + +static constant constexpr const uint8_t simd_size = 32; + +union bool4_or_uint { + bool4 b; + unsigned int i; +}; + +struct None { + template + void atomic_update(device mlx_atomic* out, T val, size_t offset = 0) { + mlx_atomic_store_explicit(out, val, offset); + } +}; + +template +struct And { + DEFINE_SIMD_REDUCE() + + bool simd_reduce_impl(bool val) { + return simd_all(val); + } + + static constexpr constant bool init = true; + + void atomic_update( + device mlx_atomic* out, + bool val, + int elem_idx, + size_t offset = 0) { + if (!val) { + bool4_or_uint update; + update.b = {true, true, true, true}; + update.b[elem_idx] = false; + mlx_atomic_fetch_and_explicit(out, update.i, offset); + } + } + + void + atomic_update(device mlx_atomic* out, bool val, size_t offset = 0) { + if (!val) { + mlx_atomic_store_explicit(out, val, offset); + } + } + + // Non atomic update + void update(device bool* out, bool val) { + *out &= val; + } + + // Operator + bool operator()(bool a, bool b) { + return a && b; + } +}; + +template +struct Or { + DEFINE_SIMD_REDUCE() + + bool simd_reduce_impl(bool val) { + return simd_any(val); + } + + static constexpr constant bool init = false; + + void atomic_update( + device mlx_atomic* out, + bool val, + int elem_idx, + size_t offset = 0) { + if (val) { + bool4_or_uint update; + update.b = {false, false, false, false}; + update.b[elem_idx] = true; + mlx_atomic_fetch_or_explicit(out, update.i, offset); + } + } + + void + atomic_update(device mlx_atomic* out, bool val, size_t offset = 0) { + if (val) { + mlx_atomic_store_explicit(out, val, offset); + } + } + + // Non atomic update + void update(device bool* out, bool val) { + *out |= val; + } + + // Operator + bool operator()(bool a, bool b) { + return a || b; + } +}; + +template +struct Sum { + DEFINE_SIMD_REDUCE() + + template + T simd_reduce_impl(T val) { + return simd_sum(val); + } + + static constexpr constant U init = U(0); + + template + void atomic_update(device mlx_atomic* out, T val, size_t offset = 0) { + mlx_atomic_fetch_add_explicit(out, val, offset); + } + + // Operator + U operator()(U a, U b) { + return a + b; + } +}; + +template +struct Prod { + DEFINE_SIMD_REDUCE() + + template + T simd_reduce_impl(T val) { + return simd_product(val); + } + + static constexpr constant U init = U(1); + + template + void atomic_update(device mlx_atomic* out, T val, size_t offset = 0) { + mlx_atomic_fetch_mul_explicit(out, val, offset); + } + + // Operator + U operator()(U a, U b) { + return a * b; + } +}; + +template +struct Min { + DEFINE_SIMD_REDUCE() + + template + metal::enable_if_t, T> simd_reduce_impl(T val) { + return simd_min(val); + } + + template + metal::enable_if_t, T> simd_reduce_impl(T val) { + if (simd_any(val != val)) { + return static_cast(NAN); + } + return simd_min(val); + } + + static constexpr constant U init = Limits::max; + + template + void atomic_update(device mlx_atomic* out, T val, size_t offset = 0) { + mlx_atomic_fetch_min_explicit(out, val, offset); + } + + // Operator + template + metal::enable_if_t, T> operator()(T a, T b) { + return a < b ? a : b; + } + + template + metal::enable_if_t, T> operator()(T a, T b) { + if (metal::isnan(a) || metal::isnan(b)) { + return static_cast(NAN); + } else { + return a < b ? a : b; + } + } + + template <> + complex64_t operator()(complex64_t a, complex64_t b) { + bool real_is_nan = metal::isnan(a.real) || metal::isnan(b.real); + bool imag_is_nan = metal::isnan(a.imag) || metal::isnan(b.imag); + + if (!real_is_nan && !imag_is_nan) { + return a < b ? a : b; + } else if (real_is_nan && !imag_is_nan) { + return complex64_t( + static_cast(NAN), a.imag < b.imag ? a.imag : b.imag); + } else if (!real_is_nan && imag_is_nan) { + return complex64_t( + a.real < b.real ? a.real : b.real, static_cast(NAN)); + } else { + return complex64_t(static_cast(NAN), static_cast(NAN)); + } + }; +}; +template +struct Max { + DEFINE_SIMD_REDUCE() + + template + metal::enable_if_t, T> simd_reduce_impl(T val) { + return simd_max(val); + } + + template + metal::enable_if_t, T> simd_reduce_impl(T val) { + if (simd_any(val != val)) { + return static_cast(NAN); + } + return simd_max(val); + } + + static constexpr constant U init = Limits::min; + + template + void atomic_update(device mlx_atomic* out, T val, size_t offset = 0) { + mlx_atomic_fetch_max_explicit(out, val, offset); + } + + // Operator + template + metal::enable_if_t, T> operator()(T a, T b) { + return a > b ? a : b; + } + + template + metal::enable_if_t, T> operator()(T a, T b) { + if (metal::isnan(a) || metal::isnan(b)) { + return static_cast(NAN); + } else { + return a > b ? a : b; + } + } + + template <> + complex64_t operator()(complex64_t a, complex64_t b) { + bool real_is_nan = metal::isnan(a.real) || metal::isnan(b.real); + bool imag_is_nan = metal::isnan(a.imag) || metal::isnan(b.imag); + + if (!real_is_nan && !imag_is_nan) { + return a > b ? a : b; + } else if (real_is_nan && !imag_is_nan) { + return complex64_t( + static_cast(NAN), a.imag > b.imag ? a.imag : b.imag); + } else if (!real_is_nan && imag_is_nan) { + return complex64_t( + a.real > b.real ? a.real : b.real, static_cast(NAN)); + } else { + return complex64_t(static_cast(NAN), static_cast(NAN)); + } + } +}; diff --git a/Source/Cxxmlx/mlx-generated/metal/reduction/reduce_all.h b/Source/Cxxmlx/mlx-generated/metal/reduction/reduce_all.h new file mode 100644 index 00000000..e0d08392 --- /dev/null +++ b/Source/Cxxmlx/mlx-generated/metal/reduction/reduce_all.h @@ -0,0 +1,66 @@ +// Copyright © 2023-2024 Apple Inc. + +template < + typename T, + typename U, + typename Op, + typename IdxT = int64_t, + int N_READS = REDUCE_N_READS> +[[kernel]] void all_reduce( + const device T* in [[buffer(0)]], + device U* out [[buffer(1)]], + const constant size_t& in_size [[buffer(2)]], + const constant size_t& row_size [[buffer(3)]], + uint3 gid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 lsize [[threads_per_threadgroup]], + uint simd_per_group [[simdgroups_per_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + Op op; + threadgroup U shared_vals[simd_size]; + + U total = Op::init; + IdxT start_idx = gid.y * IdxT(row_size); + IdxT actual_row = + (start_idx + row_size <= in_size) ? row_size : in_size - start_idx; + IdxT blocks = actual_row / (lsize.x * N_READS); + int extra = actual_row - blocks * (lsize.x * N_READS); + extra -= lid.x * N_READS; + start_idx += lid.x * N_READS; + in += start_idx; + + if (extra >= N_READS) { + blocks++; + extra = 0; + } + + for (IdxT b = 0; b < blocks; b++) { + for (int i = 0; i < N_READS; i++) { + total = op(static_cast(in[i]), total); + } + in += lsize.x * N_READS; + } + if (extra > 0) { + for (int i = 0; i < extra; i++) { + total = op(static_cast(in[i]), total); + } + } + + // Reduction within simd group + total = op.simd_reduce(total); + if (simd_per_group > 1) { + if (simd_lane_id == 0) { + shared_vals[simd_group_id] = total; + } + + // Reduction within thread group + threadgroup_barrier(mem_flags::mem_threadgroup); + total = lid.x < simd_per_group ? shared_vals[lid.x] : op.init; + total = op.simd_reduce(total); + } + + if (lid.x == 0) { + out[gid.y] = total; + } +} diff --git a/Source/Cxxmlx/mlx-generated/metal/reduction/reduce_col.h b/Source/Cxxmlx/mlx-generated/metal/reduction/reduce_col.h new file mode 100644 index 00000000..c109faf0 --- /dev/null +++ b/Source/Cxxmlx/mlx-generated/metal/reduction/reduce_col.h @@ -0,0 +1,398 @@ +// Copyright © 2023-2024 Apple Inc. + +template +[[kernel]] void col_reduce_small( + const device T* in [[buffer(0)]], + device U* out [[buffer(1)]], + const constant size_t& reduction_size [[buffer(2)]], + const constant int64_t& reduction_stride [[buffer(3)]], + const constant int* shape [[buffer(4)]], + const constant int64_t* strides [[buffer(5)]], + const constant int& ndim [[buffer(6)]], + const constant int* reduce_shape [[buffer(7)]], + const constant int64_t* reduce_strides [[buffer(8)]], + const constant int& reduce_ndim [[buffer(9)]], + const constant size_t& non_col_reductions [[buffer(10)]], + uint3 gid [[threadgroup_position_in_grid]], + uint3 gsize [[threadgroups_per_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 lsize [[threads_per_threadgroup]]) { + constexpr int n_reads = 4; + Op op; + LoopedElemToLoc 2)> loop(reduce_ndim); + const device T* row; + + U totals[n_reads]; + for (int i = 0; i < n_reads; i++) { + totals[i] = Op::init; + } + + IdxT column = IdxT(gid.x) * lsize.x * n_reads + lid.x * n_reads; + if (column >= reduction_stride) { + return; + } + bool safe = column + n_reads <= reduction_stride; + + IdxT out_idx = gid.y + gsize.y * IdxT(gid.z); + IdxT in_idx = elem_to_loc(out_idx, shape, strides, ndim); + in += in_idx + column; + + IdxT total_rows = IdxT(non_col_reductions) * IdxT(reduction_size); + loop.next(lid.y, reduce_shape, reduce_strides); + for (IdxT r = lid.y; r < total_rows; r += lsize.y) { + row = in + loop.location(); + if (safe) { + for (int i = 0; i < n_reads; i++) { + totals[i] = op(static_cast(row[i]), totals[i]); + } + } else { + U vals[n_reads]; + for (int i = 0; i < n_reads; i++) { + vals[i] = + (column + i < reduction_stride) ? static_cast(row[i]) : op.init; + } + for (int i = 0; i < n_reads; i++) { + totals[i] = op(vals[i], totals[i]); + } + } + loop.next(lsize.y, reduce_shape, reduce_strides); + } + + if (lsize.y > 1) { + // lsize.y should be <= 8 + threadgroup U shared_vals[32 * 8 * n_reads]; + for (int i = 0; i < n_reads; i++) { + shared_vals[lid.y * lsize.x * n_reads + lid.x * n_reads + i] = totals[i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (lid.y == 0) { + for (int i = 0; i < n_reads; i++) { + totals[i] = shared_vals[lid.x * n_reads + i]; + } + for (uint j = 1; j < lsize.y; j++) { + for (int i = 0; i < n_reads; i++) { + totals[i] = + op(shared_vals[j * lsize.x * n_reads + lid.x * n_reads + i], + totals[i]); + } + } + } + } + + if (lid.y == 0) { + out += out_idx * IdxT(reduction_stride) + column; + if (safe) { + for (int i = 0; i < n_reads; i++) { + out[i] = totals[i]; + } + } else { + for (int i = 0; column + i < reduction_stride; i++) { + out[i] = totals[i]; + } + } + } +} + +template +[[kernel]] void col_reduce_longcolumn( + const device T* in [[buffer(0)]], + device U* out [[buffer(1)]], + const constant size_t& reduction_size [[buffer(2)]], + const constant size_t& reduction_stride [[buffer(3)]], + const constant int* shape [[buffer(4)]], + const constant int64_t* strides [[buffer(5)]], + const constant int& ndim [[buffer(6)]], + const constant int* reduce_shape [[buffer(7)]], + const constant int64_t* reduce_strides [[buffer(8)]], + const constant int& reduce_ndim [[buffer(9)]], + const constant size_t& non_col_reductions [[buffer(10)]], + const constant size_t& out_size [[buffer(11)]], + uint3 gid [[threadgroup_position_in_grid]], + uint3 gsize [[threadgroups_per_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 lsize [[threads_per_threadgroup]]) { + Op op; + LoopedElemToLoc 2)> loop(reduce_ndim); + const device T* row; + + IdxT out_idx = gid.x + gsize.x * IdxT(gid.y); + IdxT in_idx = elem_to_loc(out_idx, shape, strides, ndim); + in += in_idx + lid.x; + + U total = Op::init; + IdxT total_rows = IdxT(non_col_reductions) * IdxT(reduction_size); + loop.next(gid.z * lsize.y + lid.y, reduce_shape, reduce_strides); + for (IdxT r = gid.z * lsize.y + lid.y; r < total_rows; + r += lsize.y * gsize.z) { + row = in + loop.location(); + total = op(static_cast(*row), total); + loop.next(lsize.y * gsize.z, reduce_shape, reduce_strides); + } + + threadgroup U shared_vals[32 * 32]; + shared_vals[lid.y * lsize.x + lid.x] = total; + threadgroup_barrier(mem_flags::mem_threadgroup); + if (lid.y == 0) { + for (uint i = 1; i < lsize.y; i++) { + total = op(total, shared_vals[i * lsize.x + lid.x]); + } + out[gid.z * IdxT(out_size) + out_idx * IdxT(reduction_stride) + lid.x] = + total; + } +} + +/** + * Our approach is the following simple looped approach: + * 1. Each thread keeps running totals for BN / n_simdgroups outputs. + * 2. Load a tile BM, BN in registers and accumulate in the running totals + * 3. Move ahead by BM steps until the column axis and the non column + * reductions are exhausted. + * 6. If BM == 32 then transpose in SM and simd reduce the running totals. + * Otherwise write in shared memory and BN threads accumulate the running + * totals with a loop. + * 7. Write them to the output + */ +template < + typename T, + typename U, + typename Op, + typename IdxT, + int NDIMS, + int BM, + int BN> +[[kernel]] void col_reduce_looped( + const device T* in [[buffer(0)]], + device U* out [[buffer(1)]], + const constant size_t& reduction_size [[buffer(2)]], + const constant int64_t& reduction_stride [[buffer(3)]], + const constant int* shape [[buffer(4)]], + const constant int64_t* strides [[buffer(5)]], + const constant int& ndim [[buffer(6)]], + const constant int* reduce_shape [[buffer(7)]], + const constant int64_t* reduce_strides [[buffer(8)]], + const constant int& reduce_ndim [[buffer(9)]], + const constant size_t& non_col_reductions [[buffer(10)]], + uint3 gid [[threadgroup_position_in_grid]], + uint3 gsize [[threadgroups_per_grid]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + Op op; + constexpr int n_simdgroups = 8; + constexpr short tgp_size = n_simdgroups * simd_size; + constexpr short n_reads = (BM * BN) / tgp_size; + constexpr short n_read_blocks = BN / n_reads; + + threadgroup U shared_vals[BN * BM]; + U totals[n_reads]; + LoopedElemToLoc 2)> loop(reduce_ndim); + const device T* row; + + for (int i = 0; i < n_reads; i++) { + totals[i] = Op::init; + } + + short lid = simd_group_id * simd_size + simd_lane_id; + short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks); + IdxT column = BN * gid.x + offset.x; + bool safe = column + n_reads <= reduction_stride; + + IdxT out_idx = gid.y + gsize.y * IdxT(gid.z); + IdxT in_idx = elem_to_loc(out_idx, shape, strides, ndim); + in += in_idx + column; + + IdxT total = IdxT(non_col_reductions) * IdxT(reduction_size); + loop.next(offset.y, reduce_shape, reduce_strides); + for (IdxT r = offset.y; r < total; r += BM) { + row = in + loop.location(); + + if (safe) { + for (int i = 0; i < n_reads; i++) { + totals[i] = op(static_cast(row[i]), totals[i]); + } + } else { + U vals[n_reads]; + for (int i = 0; i < n_reads; i++) { + vals[i] = + (column + i < reduction_stride) ? static_cast(row[i]) : op.init; + } + for (int i = 0; i < n_reads; i++) { + totals[i] = op(vals[i], totals[i]); + } + } + + loop.next(BM, reduce_shape, reduce_strides); + } + + // We can use a simd reduction to accumulate across BM so each thread writes + // the partial output to SM and then each simdgroup does BN / n_simdgroups + // accumulations. + if (BM == 32) { + constexpr int n_outputs = BN / n_simdgroups; + static_assert( + BM != 32 || n_outputs == n_reads, + "The tile should be selected such that n_outputs == n_reads"); + for (int i = 0; i < n_reads; i++) { + shared_vals[offset.y * BN + offset.x + i] = totals[i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + short2 out_offset(simd_group_id * n_outputs, simd_lane_id); + for (int i = 0; i < n_outputs; i++) { + totals[i] = + op.simd_reduce(shared_vals[out_offset.y * BN + out_offset.x + i]); + } + + // Write the output. + if (simd_lane_id == 0) { + IdxT out_column = BN * gid.x + out_offset.x; + out += out_idx * IdxT(reduction_stride) + out_column; + if (out_column + n_outputs <= reduction_stride) { + for (int i = 0; i < n_outputs; i++) { + out[i] = totals[i]; + } + } else { + for (int i = 0; out_column + i < reduction_stride; i++) { + out[i] = totals[i]; + } + } + } + } + + // Each thread holds n_reads partial results. We write them all out to shared + // memory and threads with offset.y == 0 aggregate the columns and write the + // outputs. + else { + short x_block = offset.x / n_reads; + for (int i = 0; i < n_reads; i++) { + shared_vals[x_block * BM * n_reads + i * BM + offset.y] = totals[i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (offset.y == 0) { + for (int i = 0; i < n_reads; i++) { + for (int j = 1; j < BM; j++) { + totals[i] = + op(shared_vals[x_block * BM * n_reads + i * BM + j], totals[i]); + } + } + } + + // Write the output. + if (offset.y == 0) { + out += out_idx * IdxT(reduction_stride) + column; + if (safe) { + for (int i = 0; i < n_reads; i++) { + out[i] = totals[i]; + } + } else { + for (int i = 0; column + i < reduction_stride; i++) { + out[i] = totals[i]; + } + } + } + } +} + +template < + typename T, + typename U, + typename Op, + typename IdxT, + int NDIMS, + int BM, + int BN> +[[kernel]] void col_reduce_2pass( + const device T* in [[buffer(0)]], + device U* out [[buffer(1)]], + const constant size_t& reduction_size [[buffer(2)]], + const constant int64_t& reduction_stride [[buffer(3)]], + const constant int* shape [[buffer(4)]], + const constant int64_t* strides [[buffer(5)]], + const constant int& ndim [[buffer(6)]], + const constant int* reduce_shape [[buffer(7)]], + const constant int64_t* reduce_strides [[buffer(8)]], + const constant int& reduce_ndim [[buffer(9)]], + const constant size_t& non_col_reductions [[buffer(10)]], + const constant size_t& out_size [[buffer(11)]], + uint3 gid [[threadgroup_position_in_grid]], + uint3 gsize [[threadgroups_per_grid]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + Op op; + constexpr int n_simdgroups = 8; + constexpr short tgp_size = n_simdgroups * simd_size; + constexpr short n_reads = (BM * BN) / tgp_size; + constexpr short n_read_blocks = BN / n_reads; + constexpr int n_outputs = BN / n_simdgroups; + constexpr short outer_blocks = 32; + static_assert(BM == 32, "BM should be equal to 32"); + + threadgroup U shared_vals[BN * BM]; + U totals[n_reads]; + LoopedElemToLoc 2)> loop(reduce_ndim); + const device T* row; + + for (int i = 0; i < n_reads; i++) { + totals[i] = Op::init; + } + + short lid = simd_group_id * simd_size + simd_lane_id; + short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks); + IdxT column = BN * gid.x + offset.x; + bool safe = column + n_reads <= reduction_stride; + + IdxT full_idx = gid.y + gsize.y * IdxT(gid.z); + IdxT block_idx = full_idx / IdxT(out_size); + IdxT out_idx = full_idx % IdxT(out_size); + IdxT in_idx = elem_to_loc(out_idx, shape, strides, ndim); + in += in_idx + column; + + IdxT total = IdxT(non_col_reductions) * IdxT(reduction_size); + loop.next(offset.y + block_idx * BM, reduce_shape, reduce_strides); + for (IdxT r = offset.y + block_idx * BM; r < total; r += outer_blocks * BM) { + row = in + loop.location(); + + if (safe) { + for (int i = 0; i < n_reads; i++) { + totals[i] = op(static_cast(row[i]), totals[i]); + } + } else { + U vals[n_reads]; + for (int i = 0; i < n_reads; i++) { + vals[i] = + (column + i < reduction_stride) ? static_cast(row[i]) : op.init; + } + for (int i = 0; i < n_reads; i++) { + totals[i] = op(vals[i], totals[i]); + } + } + + loop.next(outer_blocks * BM, reduce_shape, reduce_strides); + } + + // We can use a simd reduction to accumulate across BM so each thread writes + // the partial output to SM and then each simdgroup does BN / n_simdgroups + // accumulations. + for (int i = 0; i < n_reads; i++) { + shared_vals[offset.y * BN + offset.x + i] = totals[i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + short2 out_offset(simd_group_id * n_outputs, simd_lane_id); + for (int i = 0; i < n_outputs; i++) { + totals[i] = + op.simd_reduce(shared_vals[out_offset.y * BN + out_offset.x + i]); + } + + // Write the output. + if (simd_lane_id == 0) { + IdxT out_column = BN * gid.x + out_offset.x; + out += full_idx * IdxT(reduction_stride) + out_column; + if (out_column + n_outputs <= reduction_stride) { + for (int i = 0; i < n_outputs; i++) { + out[i] = totals[i]; + } + } else { + for (int i = 0; out_column + i < reduction_stride; i++) { + out[i] = totals[i]; + } + } + } +} diff --git a/Source/Cxxmlx/mlx-generated/metal/reduction/reduce_init.h b/Source/Cxxmlx/mlx-generated/metal/reduction/reduce_init.h new file mode 100644 index 00000000..604efa78 --- /dev/null +++ b/Source/Cxxmlx/mlx-generated/metal/reduction/reduce_init.h @@ -0,0 +1,8 @@ +// Copyright © 2023-2024 Apple Inc. + +template +[[kernel]] void init_reduce( + device T* out [[buffer(0)]], + uint tid [[thread_position_in_grid]]) { + out[tid] = Op::init; +} diff --git a/Source/Cxxmlx/mlx-generated/metal/reduction/reduce_row.h b/Source/Cxxmlx/mlx-generated/metal/reduction/reduce_row.h new file mode 100644 index 00000000..936d75bb --- /dev/null +++ b/Source/Cxxmlx/mlx-generated/metal/reduction/reduce_row.h @@ -0,0 +1,369 @@ +// Copyright © 2023-2024 Apple Inc. + +// Row reduction utilities +// - `per_thread_row_reduce` collaborative partial reduction in the threadgroup +// - `threadgroup_reduce` collaborative reduction in the threadgroup such that +// lid.x == 0 holds the reduced value +// - `thread_reduce` simple loop and reduce the row + +/** + * The thread group collaboratively reduces across the rows with bounds + * checking. In the end each thread holds a part of the reduction. + */ +template < + typename T, + typename U, + typename Op, + int N_READS = REDUCE_N_READS, + int N_WRITES = REDUCE_N_WRITES> +METAL_FUNC void per_thread_row_reduce( + thread U totals[N_WRITES], + const device T* inputs[N_WRITES], + int blocks, + int extra, + uint lsize_x, + uint lid_x) { + Op op; + + // Set up the accumulator registers + for (int i = 0; i < N_WRITES; i++) { + totals[i] = Op::init; + } + + // Loop over the reduction size within thread group + for (int i = 0; i < blocks; i++) { + for (int j = 0; j < N_WRITES; j++) { + for (int i = 0; i < N_READS; i++) { + totals[j] = op(static_cast(inputs[j][i]), totals[j]); + } + + inputs[j] += lsize_x * N_READS; + } + } + + // Separate case for the last set as we close the reduction size + int index = lid_x * N_READS; + if (index + N_READS <= extra) { + for (int j = 0; j < N_WRITES; j++) { + for (int i = 0; i < N_READS; i++) { + totals[j] = op(static_cast(inputs[j][i]), totals[j]); + } + } + } else { + for (int j = 0; j < N_WRITES; j++) { + for (int i = 0; index + i < extra; i++) { + totals[j] = op(static_cast(inputs[j][i]), totals[j]); + } + } + } +} + +/** + * Consecutive rows in a contiguous array. + */ +template < + typename T, + typename U, + typename Op, + int N_READS = REDUCE_N_READS, + int N_WRITES = REDUCE_N_WRITES> +METAL_FUNC void per_thread_row_reduce( + thread U totals[N_WRITES], + const device T* in, + const constant size_t& reduction_size, + int blocks, + int extra, + uint lsize_x, + uint lid_x) { + // Set up the input pointers + const device T* inputs[N_WRITES]; + inputs[0] = in + lid_x * N_READS; + for (int i = 1; i < N_READS; i++) { + inputs[i] = inputs[i - 1] + reduction_size; + } + + per_thread_row_reduce( + totals, inputs, blocks, extra, lsize_x, lid_x); +} + +/** + * Consecutive rows in an arbitrarily ordered array. + */ +template < + typename T, + typename U, + typename Op, + int N_READS = REDUCE_N_READS, + int N_WRITES = REDUCE_N_WRITES> +METAL_FUNC void per_thread_row_reduce( + thread U totals[N_WRITES], + const device T* in, + const int64_t row_idx, + int blocks, + int extra, + const constant int* shape, + const constant int64_t* strides, + const constant int& ndim, + uint lsize_x, + uint lid_x) { + // Set up the input pointers + const device T* inputs[N_WRITES]; + in += lid_x * N_READS; + for (int i = 0; i < N_READS; i++) { + inputs[i] = in + elem_to_loc(row_idx + i, shape, strides, ndim); + } + + per_thread_row_reduce( + totals, inputs, blocks, extra, lsize_x, lid_x); +} + +/** + * Reduce within the threadgroup. + */ +template < + typename T, + typename U, + typename Op, + int N_READS = REDUCE_N_READS, + int N_WRITES = REDUCE_N_WRITES> +METAL_FUNC void threadgroup_reduce( + thread U totals[N_WRITES], + threadgroup U* shared_vals, + uint3 lid [[thread_position_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_per_group [[simdgroups_per_threadgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + Op op; + + // Simdgroup first + for (int i = 0; i < N_WRITES; i++) { + totals[i] = op.simd_reduce(totals[i]); + } + + // Across simdgroups + if (simd_per_group > 1) { + if (simd_lane_id == 0) { + for (int i = 0; i < N_WRITES; i++) { + shared_vals[simd_group_id * N_WRITES + i] = totals[i]; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + U values[N_WRITES]; + for (int i = 0; i < N_WRITES; i++) { + values[i] = (lid.x < simd_per_group) ? shared_vals[lid.x * N_WRITES + i] + : op.init; + } + + for (int i = 0; i < N_WRITES; i++) { + totals[i] = op.simd_reduce(values[i]); + } + } +} + +template +METAL_FUNC void +thread_reduce(thread U& total, const device T* row, int blocks, int extra) { + Op op; + for (int i = 0; i < blocks; i++) { + U vals[N_READS]; + for (int j = 0; j < N_READS; j++) { + vals[j] = row[j]; + } + for (int j = 0; j < N_READS; j++) { + total = op(vals[j], total); + } + row += N_READS; + } + for (int i = 0; i < extra; i++) { + total = op(*row++, total); + } +} + +// Reduction kernels +// - `row_reduce_small` depending on the non-row reductions and row size it +// either just loops over everything or a simd collaboratively reduces the +// non_row reductions. In the first case one thread is responsible for one +// output on the 2nd one simd is responsible for one output. +// - `row_reduce_simple` simple contiguous row reduction +// - `row_reduce_looped` simply loop and reduce each row for each non-row +// reduction. One threadgroup is responsible for one output. + +template < + typename T, + typename U, + typename Op, + typename IdxT, + int NDIMS, + int N_READS = REDUCE_N_READS> +[[kernel]] void row_reduce_small( + const device T* in [[buffer(0)]], + device U* out [[buffer(1)]], + const constant int64_t& row_size [[buffer(2)]], + const constant int64_t& non_row_reductions [[buffer(3)]], + const constant int* shape [[buffer(4)]], + const constant int64_t* strides [[buffer(5)]], + const constant int& ndim [[buffer(6)]], + const constant int* reduce_shape [[buffer(7)]], + const constant int64_t* reduce_strides [[buffer(8)]], + const constant int& reduce_ndim [[buffer(9)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint3 gid [[threadgroup_position_in_grid]], + uint3 gsize [[threadgroups_per_grid]], + uint3 tid [[thread_position_in_grid]], + uint3 tsize [[threads_per_grid]]) { + Op op; + + U total_val = Op::init; + LoopedElemToLoc 2)> loop(reduce_ndim); + + // Precompute some row reduction numbers + const device T* row; + int blocks = IdxT(row_size) / N_READS; + int extra = IdxT(row_size) % N_READS; + + if ((non_row_reductions < 32 && row_size <= 8) || non_row_reductions <= 8) { + // Simple loop over non_row_reductions and reduce the row in the thread. + IdxT out_idx = tid.x + tsize.x * IdxT(tid.y); + in += elem_to_loc(out_idx, shape, strides, ndim); + + for (uint r = 0; r < non_row_reductions; r++) { + row = in + loop.location(); + thread_reduce(total_val, row, blocks, extra); + loop.next(reduce_shape, reduce_strides); + } + + out[out_idx] = total_val; + } else { + // Collaboratively reduce over non_row_reductions in the simdgroup. Each + // thread reduces every 32nd row and then a simple simd reduce. + IdxT out_idx = gid.y + gsize.y * IdxT(gid.z); + in += elem_to_loc(out_idx, shape, strides, ndim); + + loop.next(simd_lane_id, reduce_shape, reduce_strides); + + for (uint r = simd_lane_id; r < non_row_reductions; r += simd_size) { + row = in + loop.location(); + thread_reduce(total_val, row, blocks, extra); + loop.next(simd_size, reduce_shape, reduce_strides); + } + + total_val = op.simd_reduce(total_val); + + if (simd_lane_id == 0) { + out[out_idx] = total_val; + } + } +} + +template < + typename T, + typename U, + typename Op, + typename IdxT = int64_t, + int N_READS = REDUCE_N_READS, + int N_WRITES = REDUCE_N_WRITES> +[[kernel]] void row_reduce_simple( + const device T* in [[buffer(0)]], + device U* out [[buffer(1)]], + const constant size_t& reduction_size [[buffer(2)]], + const constant int64_t& out_size [[buffer(3)]], + uint3 gid [[threadgroup_position_in_grid]], + uint3 gsize [[threadgroups_per_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 lsize [[threads_per_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_per_group [[simdgroups_per_threadgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + threadgroup U shared_vals[simd_size * N_WRITES]; + U totals[N_WRITES]; + + // Move to the row + IdxT out_idx = N_WRITES * (gid.y + gsize.y * IdxT(gid.z)); + if (out_idx + N_WRITES > out_size) { + out_idx = out_size - N_WRITES; + } + in += out_idx * IdxT(reduction_size); + out += out_idx; + + // Each thread reduces across the row + int blocks = IdxT(reduction_size) / (lsize.x * N_READS); + int extra = reduction_size - blocks * (lsize.x * N_READS); + per_thread_row_reduce( + totals, in, reduction_size, blocks, extra, lsize.x, lid.x); + + // Reduce across the threadgroup + threadgroup_reduce( + totals, shared_vals, lid, simd_lane_id, simd_per_group, simd_group_id); + + // Write the output + if (lid.x == 0) { + for (int i = 0; i < N_WRITES; i++) { + out[i] = totals[i]; + } + } +} + +template < + typename T, + typename U, + typename Op, + typename IdxT, + int NDIMS, + int N_READS = REDUCE_N_READS> +[[kernel]] void row_reduce_looped( + const device T* in [[buffer(0)]], + device U* out [[buffer(1)]], + const constant int64_t& row_size [[buffer(2)]], + const constant int64_t& non_row_reductions [[buffer(3)]], + const constant int* shape [[buffer(4)]], + const constant int64_t* strides [[buffer(5)]], + const constant int& ndim [[buffer(6)]], + const constant int* reduce_shape [[buffer(7)]], + const constant int64_t* reduce_strides [[buffer(8)]], + const constant int& reduce_ndim [[buffer(9)]], + uint3 gid [[threadgroup_position_in_grid]], + uint3 gsize [[threadgroups_per_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 lsize [[threads_per_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_per_group [[simdgroups_per_threadgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + Op op; + threadgroup U shared_vals[simd_size]; + U total = Op::init; + + IdxT out_idx = gid.y + gsize.y * IdxT(gid.z); + + // lid.x * N_READS breaks the per_thread_row_reduce interface a bit. Maybe it + // needs a small refactor. + in += elem_to_loc(out_idx, shape, strides, ndim) + lid.x * N_READS; + + LoopedElemToLoc 2)> loop(reduce_ndim); + const device T* row; + int blocks = IdxT(row_size) / (lsize.x * N_READS); + int extra = row_size - blocks * (lsize.x * N_READS); + + for (IdxT i = 0; i < non_row_reductions; i++) { + row = in + loop.location(); + + // Each thread reduces across the row + U row_total; + per_thread_row_reduce( + &row_total, &row, blocks, extra, lsize.x, lid.x); + + // Aggregate across rows + total = op(total, row_total); + + loop.next(reduce_shape, reduce_strides); + } + + // Reduce across the threadgroup + threadgroup_reduce( + &total, shared_vals, lid, simd_lane_id, simd_per_group, simd_group_id); + + // Write the output + if (lid.x == 0) { + out[out_idx] = total; + } +} diff --git a/Source/Cmlx/mlx-generated/metal/rms_norm.metal b/Source/Cxxmlx/mlx-generated/metal/rms_norm.metal similarity index 100% rename from Source/Cmlx/mlx-generated/metal/rms_norm.metal rename to Source/Cxxmlx/mlx-generated/metal/rms_norm.metal diff --git a/Source/Cmlx/mlx-generated/metal/rope.metal b/Source/Cxxmlx/mlx-generated/metal/rope.metal similarity index 100% rename from Source/Cmlx/mlx-generated/metal/rope.metal rename to Source/Cxxmlx/mlx-generated/metal/rope.metal diff --git a/Source/Cmlx/mlx-generated/metal/scaled_dot_product_attention.metal b/Source/Cxxmlx/mlx-generated/metal/scaled_dot_product_attention.metal similarity index 100% rename from Source/Cmlx/mlx-generated/metal/scaled_dot_product_attention.metal rename to Source/Cxxmlx/mlx-generated/metal/scaled_dot_product_attention.metal diff --git a/Source/Cmlx/mlx-generated/metal/scan.h b/Source/Cxxmlx/mlx-generated/metal/scan.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/scan.h rename to Source/Cxxmlx/mlx-generated/metal/scan.h diff --git a/Source/Cxxmlx/mlx-generated/metal/sdpa_vector.h b/Source/Cxxmlx/mlx-generated/metal/sdpa_vector.h new file mode 100644 index 00000000..1eec72be --- /dev/null +++ b/Source/Cxxmlx/mlx-generated/metal/sdpa_vector.h @@ -0,0 +1,394 @@ +// Copyright © 2024 Apple Inc. + +#include + +using namespace metal; + +constant bool has_mask [[function_constant(20)]]; +constant bool query_transposed [[function_constant(21)]]; +constant bool do_causal [[function_constant(22)]]; +constant bool bool_mask [[function_constant(23)]]; +constant bool float_mask [[function_constant(24)]]; +constant bool has_sinks [[function_constant(25)]]; +constant int blocks [[function_constant(26)]]; + +template +[[kernel]] void sdpa_vector( + const device T* queries [[buffer(0)]], + const device T* keys [[buffer(1)]], + const device T* values [[buffer(2)]], + device T* out [[buffer(3)]], + const constant int& gqa_factor [[buffer(4)]], + const constant int& N [[buffer(5)]], + const constant size_t& k_head_stride [[buffer(6)]], + const constant size_t& k_seq_stride [[buffer(7)]], + const constant size_t& v_head_stride [[buffer(8)]], + const constant size_t& v_seq_stride [[buffer(9)]], + const constant float& scale [[buffer(10)]], + const device bool* bmask [[buffer(11), function_constant(bool_mask)]], + const device T* fmask [[buffer(12), function_constant(float_mask)]], + const constant int& mask_kv_seq_stride + [[buffer(13), function_constant(has_mask)]], + const constant int& mask_q_seq_stride + [[buffer(14), function_constant(has_mask)]], + const constant int& mask_head_stride + [[buffer(15), function_constant(has_mask)]], + const device T* sinks [[buffer(16), function_constant(has_sinks)]], + const constant int& num_q_heads + [[buffer(17), function_constant(has_sinks)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 tpg [[threadgroups_per_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + constexpr int BN = 32; + constexpr int BD = 32; + constexpr int qk_per_thread = D / BD; + constexpr int v_per_thread = V / BD; + int inner_k_stride = BN * int(k_seq_stride); + int inner_v_stride = BN * int(v_seq_stride); + + typedef float U; + + thread U q[qk_per_thread]; + thread U k[qk_per_thread]; + thread U o[v_per_thread]; + + threadgroup U outputs[BN * BD]; + threadgroup U max_scores[BN]; + threadgroup U sum_exp_scores[BN]; + + // Adjust positions + const int q_batch_head_idx = tid.x; + const int q_seq_idx = tid.y; + const int kv_head_idx = q_batch_head_idx / gqa_factor; + const int o_offset = q_batch_head_idx * tpg.y + q_seq_idx; + const int q_offset = + query_transposed ? tpg.x * q_seq_idx + q_batch_head_idx : o_offset; + queries += q_offset * D + simd_lid * qk_per_thread; + keys += kv_head_idx * k_head_stride + simd_gid * k_seq_stride + + simd_lid * qk_per_thread; + values += kv_head_idx * v_head_stride + simd_gid * v_seq_stride + + simd_lid * v_per_thread; + if (bool_mask) { + bmask += q_batch_head_idx * mask_head_stride + + simd_gid * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride; + } + if (float_mask) { + fmask += q_batch_head_idx * mask_head_stride + + simd_gid * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride; + } + + out += o_offset * V + simd_gid * v_per_thread; + + // Read the query and 0 the output accumulator + for (int i = 0; i < qk_per_thread; i++) { + q[i] = static_cast(scale) * queries[i]; + } + for (int i = 0; i < v_per_thread; i++) { + o[i] = 0; + } + + U max_score = Limits::finite_min; + U sum_exp_score = 0; + if (has_sinks && simd_gid == 0) { + max_score = static_cast(sinks[q_batch_head_idx % num_q_heads]); + sum_exp_score = 1; + } + + // For each key + for (int i = simd_gid; i < N; i += BN) { + bool use_key = true; + if (do_causal) { + use_key = i <= (N - int(tpg.y) + int(q_seq_idx)); + } else if (bool_mask) { + use_key = bmask[0]; + } else if (float_mask) { + use_key = (fmask[0] >= Limits::finite_min); + } + if (use_key) { + // Read the key + for (int j = 0; j < qk_per_thread; j++) { + k[j] = keys[j]; + } + + // Compute the i-th score + U score = 0; + for (int j = 0; j < qk_per_thread; j++) { + score += q[j] * k[j]; + } + score = simd_sum(score); + if (float_mask) { + score += static_cast(fmask[0]); + } + + // Update the accumulators + U new_max = max(max_score, score); + U factor = fast::exp(max_score - new_max); + U exp_score = fast::exp(score - new_max); + + max_score = new_max; + sum_exp_score = sum_exp_score * factor + exp_score; + + // Update the output accumulator + for (int j = 0; j < v_per_thread; j++) { + o[j] = o[j] * factor + exp_score * values[j]; + } + } + + // Move the pointers to the next kv + keys += inner_k_stride; + values += inner_v_stride; + if (bool_mask) { + bmask += BN * mask_kv_seq_stride; + } + if (float_mask) { + fmask += BN * mask_kv_seq_stride; + } + } + + // Each thread has a partial part of the output so we need to combine them. + + // First let's communicate the max and sum_exp + if (simd_lid == 0) { + max_scores[simd_gid] = max_score; + sum_exp_scores[simd_gid] = sum_exp_score; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + max_score = max_scores[simd_lid]; + U new_max = simd_max(max_score); + U factor = fast::exp(max_score - new_max); + sum_exp_score = simd_sum(sum_exp_scores[simd_lid] * factor); + + // Now we need to aggregate all the outputs + for (int i = 0; i < v_per_thread; i++) { + outputs[simd_lid * BD + simd_gid] = o[i]; + threadgroup_barrier(mem_flags::mem_threadgroup); + o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor); + o[i] = sum_exp_score == 0 ? o[i] : (o[i] / sum_exp_score); + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + // And write the output + if (simd_lid == 0) { + for (int i = 0; i < v_per_thread; i++) { + out[i] = static_cast(o[i]); + } + } +} + +template +[[kernel]] void sdpa_vector_2pass_1( + const device T* queries [[buffer(0)]], + const device T* keys [[buffer(1)]], + const device T* values [[buffer(2)]], + device T* out [[buffer(3)]], + device float* sums [[buffer(4)]], + device float* maxs [[buffer(5)]], + const constant int& N [[buffer(7)]], + const constant size_t& k_head_stride [[buffer(8)]], + const constant size_t& k_seq_stride [[buffer(9)]], + const constant size_t& v_head_stride [[buffer(10)]], + const constant size_t& v_seq_stride [[buffer(11)]], + const constant float& scale [[buffer(12)]], + const device bool* bmask [[buffer(13), function_constant(bool_mask)]], + const device T* fmask [[buffer(14), function_constant(float_mask)]], + const constant int& mask_kv_seq_stride + [[buffer(15), function_constant(has_mask)]], + const constant int& mask_q_seq_stride + [[buffer(16), function_constant(has_mask)]], + const constant int& mask_head_stride + [[buffer(17), function_constant(has_mask)]], + const device T* sinks [[buffer(18), function_constant(has_sinks)]], + uint3 tptg [[threads_per_threadgroup]], + uint3 tidtg [[thread_position_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 tpg [[threadgroups_per_grid]], + uint simd_lid [[thread_index_in_simdgroup]]) { + constexpr int BD = 32; + constexpr int qk_per_thread = D / BD; + constexpr int v_per_thread = V / BD; + + typedef float U; + + thread U q[qk_per_thread]; + thread U o[v_per_thread] = {0}; + + // Adjust positions + const int kv_head_idx = tid.x; + const int batch_idx = tid.y; + const int block_idx = tid.z; + const int gqa_factor = tptg.y; + const int q_seq_len = tptg.z; + const int q_seq_idx = tidtg.z; + const int q_head_idx = gqa_factor * kv_head_idx + tidtg.y; + const int num_kv_heads = tpg.x; + const int num_q_heads = num_kv_heads * gqa_factor; + const int q_batch_head_idx = (batch_idx * num_q_heads + q_head_idx); + const int o_offset = q_batch_head_idx * q_seq_len + q_seq_idx; + const int q_offset = + query_transposed ? num_q_heads * q_seq_idx + q_batch_head_idx : o_offset; + + queries += q_offset * D + simd_lid * qk_per_thread; + + const int kv_batch_head_idx = batch_idx * num_kv_heads + kv_head_idx; + keys += kv_batch_head_idx * k_head_stride + block_idx * k_seq_stride + + simd_lid * qk_per_thread; + values += kv_batch_head_idx * v_head_stride + block_idx * v_seq_stride + + simd_lid * v_per_thread; + out += o_offset * blocks * V + block_idx * V + simd_lid * v_per_thread; + if (bool_mask) { + bmask += q_batch_head_idx * mask_head_stride + + block_idx * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride; + } + if (float_mask) { + fmask += q_batch_head_idx * mask_head_stride + + block_idx * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride; + } + sums += o_offset * blocks + block_idx; + maxs += o_offset * blocks + block_idx; + + // Read the query + for (int i = 0; i < qk_per_thread; i++) { + q[i] = static_cast(scale) * queries[i]; + } + + U max_score = Limits::finite_min; + U sum_exp_score = 0; + if (has_sinks && block_idx == 0) { + max_score = static_cast(sinks[q_head_idx]); + sum_exp_score = 1; + } + + // For each key + for (int i = block_idx; i < N; i += blocks) { + bool use_key = true; + if (do_causal) { + use_key = i <= (N - q_seq_len + int(q_seq_idx)); + } else if (bool_mask) { + use_key = bmask[0]; + } else if (float_mask) { + use_key = (fmask[0] >= Limits::finite_min); + } + if (use_key) { + // Compute the i-th score + U score = 0; + for (int i = 0; i < qk_per_thread; i++) { + score += q[i] * keys[i]; + } + score = simd_sum(score); + + if (float_mask) { + score += fmask[0]; + } + + // Update the accumulators + U new_max = max(max_score, score); + U factor = fast::exp(max_score - new_max); + U exp_score = fast::exp(score - new_max); + + max_score = new_max; + sum_exp_score = sum_exp_score * factor + exp_score; + + // Update the output accumulator + for (int i = 0; i < v_per_thread; i++) { + o[i] = o[i] * factor + exp_score * values[i]; + } + } + + // Move the pointers to the next kv + keys += blocks * int(k_seq_stride); + values += blocks * int(v_seq_stride); + if (bool_mask) { + bmask += blocks * mask_kv_seq_stride; + } + if (float_mask) { + fmask += blocks * mask_kv_seq_stride; + } + } + + // Write the sum and max and outputs + if (simd_lid == 0) { + sums[0] = sum_exp_score; + maxs[0] = max_score; + } + + for (int i = 0; i < v_per_thread; i++) { + out[i] = static_cast(o[i]); + } +} + +template +[[kernel]] void sdpa_vector_2pass_2( + const device T* partials [[buffer(0)]], + const device float* sums [[buffer(1)]], + const device float* maxs [[buffer(2)]], + device T* out [[buffer(3)]], + const constant int& blocks [[buffer(4)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 tpg [[threadgroups_per_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + constexpr int BN = 32; + constexpr int BD = 32; + constexpr int elem_per_thread = D / BD; + + typedef float U; + + thread U o[elem_per_thread] = {0}; + threadgroup U outputs[BN * BD]; + + // Adjust positions + const int head_idx = tid.x; + const int q_seq_idx = tid.y; + const int q_offset = head_idx * tpg.y + q_seq_idx; + partials += q_offset * blocks * D + simd_gid * D + simd_lid * elem_per_thread; + sums += q_offset * blocks; + maxs += q_offset * blocks; + out += q_offset * D + simd_gid * elem_per_thread; + + // Set defaults + U sum_exp_score = 0.0; + U max_score = Limits::finite_min; + + // Reduce the max + for (int b = 0; b < blocks / BN; ++b) { + max_score = max(max_score, maxs[simd_lid + BN * b]); + } + max_score = simd_max(max_score); + + // Reduce the d + for (int b = 0; b < blocks / BN; ++b) { + U factor = fast::exp(maxs[simd_lid + BN * b] - max_score); + sum_exp_score += factor * sums[simd_lid + BN * b]; + } + sum_exp_score = simd_sum(sum_exp_score); + + // Reduce the sum exp and partials + for (int b = 0; b < blocks / BN; ++b) { + U factor = fast::exp(maxs[simd_gid] - max_score); + + // Update the output accumulator + for (int i = 0; i < elem_per_thread; i++) { + o[i] += factor * static_cast(partials[i]); + } + maxs += BN; + sums += BN; + partials += BN * D; + } + + // Use shared memory to transpose and reduce the final block + for (int i = 0; i < elem_per_thread; i++) { + outputs[simd_lid * BD + simd_gid] = o[i]; + threadgroup_barrier(mem_flags::mem_threadgroup); + o[i] = simd_sum(outputs[simd_gid * BD + simd_lid]); + o[i] = sum_exp_score == 0 ? o[i] : (o[i] / sum_exp_score); + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + // And write the output + if (simd_lid == 0) { + for (int i = 0; i < elem_per_thread; i++) { + out[i] = static_cast(o[i]); + } + } +} diff --git a/Source/Cxxmlx/mlx-generated/metal/softmax.h b/Source/Cxxmlx/mlx-generated/metal/softmax.h new file mode 100644 index 00000000..6ea4ac73 --- /dev/null +++ b/Source/Cxxmlx/mlx-generated/metal/softmax.h @@ -0,0 +1,190 @@ +// Copyright © 2023-2024 Apple Inc. + +template +inline T softmax_exp(T x) { + // Softmax doesn't need high precision exponential cause x is gonna be in + // (-oo, 0] anyway and subsequently it will be divided by sum(exp(x_i)). + return fast::exp(x); +} + +template +[[kernel]] void softmax_single_row( + const device T* in, + device T* out, + constant int& axis_size, + uint gid [[threadgroup_position_in_grid]], + uint _lid [[thread_position_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + int lid = _lid; + + constexpr int SIMD_SIZE = 32; + + threadgroup AccT local_max[SIMD_SIZE]; + threadgroup AccT local_normalizer[SIMD_SIZE]; + + AccT ld[N_READS]; + + in += gid * size_t(axis_size) + lid * N_READS; + if (lid * N_READS + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + ld[i] = AccT(in[i]); + } + } else { + for (int i = 0; i < N_READS; i++) { + ld[i] = + ((lid * N_READS + i) < axis_size) ? AccT(in[i]) : Limits::min; + } + } + if (simd_group_id == 0) { + local_max[simd_lane_id] = Limits::min; + local_normalizer[simd_lane_id] = 0; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Get the max + AccT maxval = Limits::finite_min; + for (int i = 0; i < N_READS; i++) { + maxval = (maxval < ld[i]) ? ld[i] : maxval; + } + maxval = simd_max(maxval); + if (simd_lane_id == 0) { + local_max[simd_group_id] = maxval; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (simd_group_id == 0) { + maxval = simd_max(local_max[simd_lane_id]); + if (simd_lane_id == 0) { + local_max[0] = maxval; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + maxval = local_max[0]; + + // Compute exp(x_i - maxval) and store the partial sums in local_normalizer + AccT normalizer = 0; + for (int i = 0; i < N_READS; i++) { + AccT exp_x = softmax_exp(ld[i] - maxval); + ld[i] = exp_x; + normalizer += exp_x; + } + normalizer = simd_sum(normalizer); + if (simd_lane_id == 0) { + local_normalizer[simd_group_id] = normalizer; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (simd_group_id == 0) { + normalizer = simd_sum(local_normalizer[simd_lane_id]); + if (simd_lane_id == 0) { + local_normalizer[0] = normalizer; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + normalizer = 1 / local_normalizer[0]; + + // Normalize and write to the output + out += gid * size_t(axis_size) + lid * N_READS; + if (lid * N_READS + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + out[i] = T(ld[i] * normalizer); + } + } else { + for (int i = 0; i < N_READS; i++) { + if ((lid * N_READS + i) < axis_size) { + out[i] = T(ld[i] * normalizer); + } + } + } +} + +template +[[kernel]] void softmax_looped( + const device T* in, + device T* out, + constant int& axis_size, + uint gid [[threadgroup_position_in_grid]], + uint lid [[thread_position_in_threadgroup]], + uint lsize [[threads_per_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + in += gid * size_t(axis_size); + + constexpr int SIMD_SIZE = 32; + + threadgroup AccT local_max[SIMD_SIZE]; + threadgroup AccT local_normalizer[SIMD_SIZE]; + + // Get the max and the normalizer in one go + AccT prevmax; + AccT maxval = Limits::finite_min; + AccT normalizer = 0; + for (int r = 0; r < static_cast(ceildiv(axis_size, N_READS * lsize)); + r++) { + int offset = r * lsize * N_READS + lid * N_READS; + AccT vals[N_READS]; + if (offset + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + vals[i] = AccT(in[offset + i]); + } + } else { + for (int i = 0; i < N_READS; i++) { + vals[i] = + (offset + i < axis_size) ? AccT(in[offset + i]) : Limits::min; + } + } + prevmax = maxval; + for (int i = 0; i < N_READS; i++) { + maxval = (maxval < vals[i]) ? vals[i] : maxval; + } + normalizer *= softmax_exp(prevmax - maxval); + for (int i = 0; i < N_READS; i++) { + normalizer += softmax_exp(vals[i] - maxval); + } + } + // Now we got partial normalizer of N_READS * ceildiv(axis_size, N_READS * + // lsize) parts. We need to combine them. + // 1. We start by finding the max across simd groups + // 2. We then change the partial normalizers to account for a possible + // change in max + // 3. We sum all normalizers + prevmax = maxval; + maxval = simd_max(maxval); + normalizer *= softmax_exp(prevmax - maxval); + normalizer = simd_sum(normalizer); + + // Now the normalizer and max value is correct for each simdgroup. We write + // them shared memory and combine them. + prevmax = maxval; + if (simd_lane_id == 0) { + local_max[simd_group_id] = maxval; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + maxval = simd_max(local_max[simd_lane_id]); + normalizer *= softmax_exp(prevmax - maxval); + if (simd_lane_id == 0) { + local_normalizer[simd_group_id] = normalizer; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + normalizer = simd_sum(local_normalizer[simd_lane_id]); + normalizer = 1 / normalizer; + + // Finally given the normalizer and max value we can directly write the + // softmax output + out += gid * size_t(axis_size); + for (int r = 0; r < static_cast(ceildiv(axis_size, N_READS * lsize)); + r++) { + int offset = r * lsize * N_READS + lid * N_READS; + if (offset + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + out[offset + i] = T(softmax_exp(in[offset + i] - maxval) * normalizer); + } + } else { + for (int i = 0; i < N_READS; i++) { + if (offset + i < axis_size) { + out[offset + i] = + T(softmax_exp(in[offset + i] - maxval) * normalizer); + } + } + } + } +} diff --git a/Source/Cxxmlx/mlx-generated/metal/sort.h b/Source/Cxxmlx/mlx-generated/metal/sort.h new file mode 100644 index 00000000..0d357333 --- /dev/null +++ b/Source/Cxxmlx/mlx-generated/metal/sort.h @@ -0,0 +1,719 @@ +// Copyright © 2023-2024 Apple Inc. + +#define MLX_MTL_CONST static constant constexpr const +#define MLX_MTL_LOOP_UNROLL _Pragma("clang loop unroll(full)") + +using namespace metal; + +// Based on GPU merge sort algorithm at +// https://github.com/NVIDIA/cccl/tree/main/cub/cub + +/////////////////////////////////////////////////////////////////////////////// +// Thread-level sort +/////////////////////////////////////////////////////////////////////////////// + +template +METAL_FUNC void thread_swap(thread T& a, thread T& b) { + T w = a; + a = b; + b = w; +} + +template +struct Init { + static constexpr constant T v = Limits::max; +}; + +template +struct Init>> { + static constexpr constant T v = metal::numeric_limits::quiet_NaN(); +}; + +template +struct LessThan { + static constexpr constant T init = Init::v; + METAL_FUNC bool operator()(T a, T b) const { + if constexpr ( + metal::is_floating_point_v || metal::is_same_v) { + bool an = isnan(a); + bool bn = isnan(b); + if (an | bn) { + return (!an) & bn; + } + } + return a < b; + } +}; + +template < + typename ValT, + typename IdxT, + bool ARG_SORT, + short N_PER_THREAD, + typename CompareOp> +struct ThreadSort { + static METAL_FUNC void sort( + thread ValT (&vals)[N_PER_THREAD], + thread IdxT (&idxs)[N_PER_THREAD]) { + CompareOp op; + MLX_MTL_LOOP_UNROLL + for (short i = 0; i < N_PER_THREAD; ++i) { + MLX_MTL_LOOP_UNROLL + for (short j = i & 1; j < N_PER_THREAD - 1; j += 2) { + if (op(vals[j + 1], vals[j])) { + thread_swap(vals[j + 1], vals[j]); + if (ARG_SORT) { + thread_swap(idxs[j + 1], idxs[j]); + } + } + } + } + } +}; + +/////////////////////////////////////////////////////////////////////////////// +// Threadgroup-level sort +/////////////////////////////////////////////////////////////////////////////// + +template < + typename ValT, + typename IdxT, + bool ARG_SORT, + short BLOCK_THREADS, + short N_PER_THREAD, + typename CompareOp> +struct BlockMergeSort { + using thread_sort_t = + ThreadSort; + static METAL_FUNC int merge_partition( + const threadgroup ValT* As, + const threadgroup ValT* Bs, + short A_sz, + short B_sz, + short sort_md) { + CompareOp op; + + short A_st = max(0, sort_md - B_sz); + short A_ed = min(sort_md, A_sz); + + while (A_st < A_ed) { + short md = A_st + (A_ed - A_st) / 2; + auto a = As[md]; + auto b = Bs[sort_md - 1 - md]; + + if (op(b, a)) { + A_ed = md; + } else { + A_st = md + 1; + } + } + + return A_ed; + } + + static METAL_FUNC void merge_step( + const threadgroup ValT* As, + const threadgroup ValT* Bs, + const threadgroup IdxT* As_idx, + const threadgroup IdxT* Bs_idx, + short A_sz, + short B_sz, + thread ValT (&vals)[N_PER_THREAD], + thread IdxT (&idxs)[N_PER_THREAD]) { + CompareOp op; + short a_idx = 0; + short b_idx = 0; + + for (int i = 0; i < N_PER_THREAD; ++i) { + auto a = (a_idx < A_sz) ? As[a_idx] : ValT(CompareOp::init); + auto b = (b_idx < B_sz) ? Bs[b_idx] : ValT(CompareOp::init); + bool pred = (b_idx < B_sz) && (a_idx >= A_sz || op(b, a)); + + vals[i] = pred ? b : a; + if (ARG_SORT) { + if (pred) { + idxs[i] = Bs_idx[b_idx]; + } else { + idxs[i] = (a_idx < A_sz) ? As_idx[a_idx] : IdxT(0); + } + } + + b_idx += short(pred); + a_idx += short(!pred); + } + } + + static METAL_FUNC void sort( + threadgroup ValT* tgp_vals [[threadgroup(0)]], + threadgroup IdxT* tgp_idxs [[threadgroup(1)]], + int size_sorted_axis, + uint3 lid [[thread_position_in_threadgroup]]) { + // Get thread location + int idx = lid.x * N_PER_THREAD; + + // Load from shared memory + thread ValT thread_vals[N_PER_THREAD]; + thread IdxT thread_idxs[N_PER_THREAD]; + for (int i = 0; i < N_PER_THREAD; ++i) { + thread_vals[i] = tgp_vals[idx + i]; + if (ARG_SORT) { + thread_idxs[i] = tgp_idxs[idx + i]; + } + } + + // Per thread sort + if (idx < size_sorted_axis) { + thread_sort_t::sort(thread_vals, thread_idxs); + } + + // Do merges using threadgroup memory + for (int merge_threads = 2; merge_threads <= BLOCK_THREADS; + merge_threads *= 2) { + // Update threadgroup memory + threadgroup_barrier(mem_flags::mem_threadgroup); + for (int i = 0; i < N_PER_THREAD; ++i) { + tgp_vals[idx + i] = thread_vals[i]; + if (ARG_SORT) { + tgp_idxs[idx + i] = thread_idxs[i]; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Find location in merge step + int merge_group = lid.x / merge_threads; + int merge_lane = lid.x % merge_threads; + + int sort_sz = N_PER_THREAD * merge_threads; + int sort_st = N_PER_THREAD * merge_threads * merge_group; + + // As = tgp_vals[A_st:A_ed] is sorted + // Bs = tgp_vals[B_st:B_ed] is sorted + int A_st = sort_st; + int A_ed = sort_st + sort_sz / 2; + int B_st = sort_st + sort_sz / 2; + int B_ed = sort_st + sort_sz; + + const threadgroup ValT* As = tgp_vals + A_st; + const threadgroup ValT* Bs = tgp_vals + B_st; + int A_sz = A_ed - A_st; + int B_sz = B_ed - B_st; + + // Find a partition of merge elements + // Ci = merge(As[partition:], Bs[sort_md - partition:]) + // of size N_PER_THREAD for each merge lane i + // C = [Ci] is sorted + int sort_md = N_PER_THREAD * merge_lane; + int partition = merge_partition(As, Bs, A_sz, B_sz, sort_md); + + As += partition; + Bs += sort_md - partition; + + A_sz -= partition; + B_sz -= sort_md - partition; + + const threadgroup IdxT* As_idx = + ARG_SORT ? tgp_idxs + A_st + partition : nullptr; + const threadgroup IdxT* Bs_idx = + ARG_SORT ? tgp_idxs + B_st + sort_md - partition : nullptr; + + // Merge starting at the partition and store results in thread registers + merge_step(As, Bs, As_idx, Bs_idx, A_sz, B_sz, thread_vals, thread_idxs); + } + + // Write out to shared memory + threadgroup_barrier(mem_flags::mem_threadgroup); + for (int i = 0; i < N_PER_THREAD; ++i) { + tgp_vals[idx + i] = thread_vals[i]; + if (ARG_SORT) { + tgp_idxs[idx + i] = thread_idxs[i]; + } + } + } +}; + +/////////////////////////////////////////////////////////////////////////////// +// Kernel sort +/////////////////////////////////////////////////////////////////////////////// + +template < + typename T, + typename U, + bool ARG_SORT, + short BLOCK_THREADS, + short N_PER_THREAD, + typename CompareOp = LessThan> +struct KernelMergeSort { + using ValT = T; + using IdxT = uint; + using block_merge_sort_t = BlockMergeSort< + ValT, + IdxT, + ARG_SORT, + BLOCK_THREADS, + N_PER_THREAD, + CompareOp>; + + MLX_MTL_CONST short N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD; + + static METAL_FUNC void block_sort( + const device T* inp, + device U* out, + const constant int& size_sorted_axis, + const constant int& in_stride_sorted_axis, + const constant int& out_stride_sorted_axis, + const constant int& in_stride_segment_axis, + const constant int& out_stride_segment_axis, + threadgroup ValT* tgp_vals, + threadgroup IdxT* tgp_idxs, + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + // tid.y tells us the segment index + inp += tid.y * in_stride_segment_axis; + out += tid.y * out_stride_segment_axis; + + // Copy into threadgroup memory + for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) { + tgp_vals[i] = i < size_sorted_axis ? inp[i * in_stride_sorted_axis] + : ValT(CompareOp::init); + if (ARG_SORT) { + tgp_idxs[i] = i; + } + } + + // Sort elements within the block + threadgroup_barrier(mem_flags::mem_threadgroup); + + block_merge_sort_t::sort(tgp_vals, tgp_idxs, size_sorted_axis, lid); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Write output + for (int i = lid.x; i < size_sorted_axis; i += BLOCK_THREADS) { + if (ARG_SORT) { + out[i * out_stride_sorted_axis] = tgp_idxs[i]; + } else { + out[i * out_stride_sorted_axis] = tgp_vals[i]; + } + } + } +}; + +template < + typename T, + typename U, + bool ARG_SORT, + short BLOCK_THREADS, + short N_PER_THREAD> +[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void block_sort( + const device T* inp [[buffer(0)]], + device U* out [[buffer(1)]], + const constant int& size_sorted_axis [[buffer(2)]], + const constant int& in_stride_sorted_axis [[buffer(3)]], + const constant int& out_stride_sorted_axis [[buffer(4)]], + const constant int& in_stride_segment_axis [[buffer(5)]], + const constant int& out_stride_segment_axis [[buffer(6)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + using sort_kernel = + KernelMergeSort; + using ValT = typename sort_kernel::ValT; + using IdxT = typename sort_kernel::IdxT; + + if (ARG_SORT) { + threadgroup ValT tgp_vals[sort_kernel::N_PER_BLOCK]; + threadgroup IdxT tgp_idxs[sort_kernel::N_PER_BLOCK]; + sort_kernel::block_sort( + inp, + out, + size_sorted_axis, + in_stride_sorted_axis, + out_stride_sorted_axis, + in_stride_segment_axis, + out_stride_segment_axis, + tgp_vals, + tgp_idxs, + tid, + lid); + } else { + threadgroup ValT tgp_vals[sort_kernel::N_PER_BLOCK]; + sort_kernel::block_sort( + inp, + out, + size_sorted_axis, + in_stride_sorted_axis, + out_stride_sorted_axis, + in_stride_segment_axis, + out_stride_segment_axis, + tgp_vals, + nullptr, + tid, + lid); + } +} + +constant constexpr const int zero_helper = 0; + +template < + typename T, + typename U, + bool ARG_SORT, + short BLOCK_THREADS, + short N_PER_THREAD> +[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void block_sort_nc( + const device T* inp [[buffer(0)]], + device U* out [[buffer(1)]], + const constant int& size_sorted_axis [[buffer(2)]], + const constant int& in_stride_sorted_axis [[buffer(3)]], + const constant int& out_stride_sorted_axis [[buffer(4)]], + const constant int& nc_dim [[buffer(5)]], + const constant int* nc_shape [[buffer(6)]], + const constant int64_t* in_nc_strides [[buffer(7)]], + const constant int64_t* out_nc_strides [[buffer(8)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + using sort_kernel = + KernelMergeSort; + using ValT = typename sort_kernel::ValT; + using IdxT = typename sort_kernel::IdxT; + + auto in_block_idx = elem_to_loc(tid.y, nc_shape, in_nc_strides, nc_dim); + auto out_block_idx = elem_to_loc(tid.y, nc_shape, out_nc_strides, nc_dim); + inp += in_block_idx; + out += out_block_idx; + + if (ARG_SORT) { + threadgroup ValT tgp_vals[sort_kernel::N_PER_BLOCK]; + threadgroup IdxT tgp_idxs[sort_kernel::N_PER_BLOCK]; + sort_kernel::block_sort( + inp, + out, + size_sorted_axis, + in_stride_sorted_axis, + out_stride_sorted_axis, + zero_helper, + zero_helper, + tgp_vals, + tgp_idxs, + tid, + lid); + } else { + threadgroup ValT tgp_vals[sort_kernel::N_PER_BLOCK]; + sort_kernel::block_sort( + inp, + out, + size_sorted_axis, + in_stride_sorted_axis, + out_stride_sorted_axis, + zero_helper, + zero_helper, + tgp_vals, + nullptr, + tid, + lid); + } +} + +template < + typename ValT, + typename IdxT, + bool ARG_SORT, + short BLOCK_THREADS, + short N_PER_THREAD, + typename CompareOp = LessThan> +struct KernelMultiBlockMergeSort { + using block_merge_sort_t = BlockMergeSort< + ValT, + IdxT, + ARG_SORT, + BLOCK_THREADS, + N_PER_THREAD, + CompareOp>; + + MLX_MTL_CONST short N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD; + + static METAL_FUNC void block_sort( + const device ValT* inp, + device ValT* out_vals, + device IdxT* out_idxs, + const constant int& size_sorted_axis, + const constant int& stride_sorted_axis, + threadgroup ValT* tgp_vals, + threadgroup IdxT* tgp_idxs, + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + // tid.y tells us the segment index + int base_idx = tid.x * N_PER_BLOCK; + + // Copy into threadgroup memory + for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) { + int idx = base_idx + i; + tgp_vals[i] = idx < size_sorted_axis ? inp[idx * stride_sorted_axis] + : ValT(CompareOp::init); + tgp_idxs[i] = idx; + } + + // Sort elements within the block + threadgroup_barrier(mem_flags::mem_threadgroup); + + block_merge_sort_t::sort(tgp_vals, tgp_idxs, size_sorted_axis, lid); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Write output + for (int i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) { + int idx = base_idx + i; + if (idx < size_sorted_axis) { + out_vals[idx] = tgp_vals[i]; + out_idxs[idx] = tgp_idxs[i]; + } + } + } + + static METAL_FUNC int merge_partition( + const device ValT* As, + const device ValT* Bs, + int A_sz, + int B_sz, + int sort_md) { + CompareOp op; + + int A_st = max(0, sort_md - B_sz); + int A_ed = min(sort_md, A_sz); + + while (A_st < A_ed) { + int md = A_st + (A_ed - A_st) / 2; + auto a = As[md]; + auto b = Bs[sort_md - 1 - md]; + + if (op(b, a)) { + A_ed = md; + } else { + A_st = md + 1; + } + } + + return A_ed; + } +}; + +template < + typename ValT, + typename IdxT, + bool ARG_SORT, + short BLOCK_THREADS, + short N_PER_THREAD> +[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_sort( + const device ValT* inp [[buffer(0)]], + device ValT* out_vals [[buffer(1)]], + device IdxT* out_idxs [[buffer(2)]], + const constant int& size_sorted_axis [[buffer(3)]], + const constant int& stride_sorted_axis [[buffer(4)]], + const constant int& nc_dim [[buffer(5)]], + const constant int* nc_shape [[buffer(6)]], + const constant int64_t* nc_strides [[buffer(7)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + using sort_kernel = KernelMultiBlockMergeSort< + ValT, + IdxT, + ARG_SORT, + BLOCK_THREADS, + N_PER_THREAD>; + + auto block_idx = elem_to_loc(tid.y, nc_shape, nc_strides, nc_dim); + inp += block_idx; + out_vals += tid.y * size_sorted_axis; + out_idxs += tid.y * size_sorted_axis; + + threadgroup ValT tgp_vals[sort_kernel::N_PER_BLOCK]; + threadgroup IdxT tgp_idxs[sort_kernel::N_PER_BLOCK]; + + sort_kernel::block_sort( + inp, + out_vals, + out_idxs, + size_sorted_axis, + stride_sorted_axis, + tgp_vals, + tgp_idxs, + tid, + lid); +} + +template < + typename ValT, + typename IdxT, + bool ARG_SORT, + short BLOCK_THREADS, + short N_PER_THREAD> +[[kernel]] void mb_block_partition( + device IdxT* block_partitions [[buffer(0)]], + const device ValT* dev_vals [[buffer(1)]], + const device IdxT* dev_idxs [[buffer(2)]], + const constant int& size_sorted_axis [[buffer(3)]], + const constant int& merge_tiles [[buffer(4)]], + const constant int& n_blocks [[buffer(5)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 tgp_dims [[threads_per_threadgroup]]) { + using sort_kernel = KernelMultiBlockMergeSort< + ValT, + IdxT, + ARG_SORT, + BLOCK_THREADS, + N_PER_THREAD>; + + block_partitions += tid.y * tgp_dims.x; + dev_vals += tid.y * size_sorted_axis; + dev_idxs += tid.y * size_sorted_axis; + + for (int i = lid.x; i <= n_blocks; i += tgp_dims.x) { + // Find location in merge step + int merge_group = i / merge_tiles; + int merge_lane = i % merge_tiles; + + int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles; + int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group; + + int A_st = min(size_sorted_axis, sort_st); + int A_ed = min(size_sorted_axis, sort_st + sort_sz / 2); + int B_st = A_ed; + int B_ed = min(size_sorted_axis, B_st + sort_sz / 2); + + int partition_at = min(B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane); + int partition = sort_kernel::merge_partition( + dev_vals + A_st, + dev_vals + B_st, + A_ed - A_st, + B_ed - B_st, + partition_at); + + block_partitions[i] = A_st + partition; + } +} + +template < + typename ValT, + typename IdxT, + bool ARG_SORT, + short BLOCK_THREADS, + short N_PER_THREAD, + typename CompareOp = LessThan> +[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void +mb_block_merge( + const device IdxT* block_partitions [[buffer(0)]], + const device ValT* dev_vals_in [[buffer(1)]], + const device IdxT* dev_idxs_in [[buffer(2)]], + device ValT* dev_vals_out [[buffer(3)]], + device IdxT* dev_idxs_out [[buffer(4)]], + const constant int& size_sorted_axis [[buffer(5)]], + const constant int& merge_tiles [[buffer(6)]], + const constant int& num_tiles [[buffer(7)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + using sort_kernel = KernelMultiBlockMergeSort< + ValT, + IdxT, + ARG_SORT, + BLOCK_THREADS, + N_PER_THREAD, + CompareOp>; + + using block_sort_t = typename sort_kernel::block_merge_sort_t; + + block_partitions += tid.y * (num_tiles + 1); + dev_vals_in += tid.y * size_sorted_axis; + dev_idxs_in += tid.y * size_sorted_axis; + dev_vals_out += tid.y * size_sorted_axis; + dev_idxs_out += tid.y * size_sorted_axis; + + int block_idx = tid.x; + int merge_group = block_idx / merge_tiles; + int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group; + int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles; + int sort_md = sort_kernel::N_PER_BLOCK * block_idx - sort_st; + + int A_st = block_partitions[block_idx + 0]; + int A_ed = block_partitions[block_idx + 1]; + int B_st = min(size_sorted_axis, 2 * sort_st + sort_sz / 2 + sort_md - A_st); + int B_ed = min( + size_sorted_axis, + 2 * sort_st + sort_sz / 2 + sort_md + sort_kernel::N_PER_BLOCK - A_ed); + + if ((block_idx % merge_tiles) == merge_tiles - 1) { + A_ed = min(size_sorted_axis, sort_st + sort_sz / 2); + B_ed = min(size_sorted_axis, sort_st + sort_sz); + } + + int A_sz = A_ed - A_st; + int B_sz = B_ed - B_st; + + // Load from global memory + thread ValT thread_vals[N_PER_THREAD]; + thread IdxT thread_idxs[N_PER_THREAD]; + for (int i = 0; i < N_PER_THREAD; i++) { + int idx = BLOCK_THREADS * i + lid.x; + if (idx < (A_sz + B_sz)) { + thread_vals[i] = (idx < A_sz) ? dev_vals_in[A_st + idx] + : dev_vals_in[B_st + idx - A_sz]; + thread_idxs[i] = (idx < A_sz) ? dev_idxs_in[A_st + idx] + : dev_idxs_in[B_st + idx - A_sz]; + } else { + thread_vals[i] = CompareOp::init; + thread_idxs[i] = 0; + } + } + + // Write to shared memory + threadgroup ValT tgp_vals[sort_kernel::N_PER_BLOCK]; + threadgroup IdxT tgp_idxs[sort_kernel::N_PER_BLOCK]; + threadgroup_barrier(mem_flags::mem_threadgroup); + for (int i = 0; i < N_PER_THREAD; i++) { + int idx = BLOCK_THREADS * i + lid.x; + tgp_vals[idx] = thread_vals[i]; + tgp_idxs[idx] = thread_idxs[i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Merge + int sort_md_local = min(A_sz + B_sz, N_PER_THREAD * int(lid.x)); + + int A_st_local = block_sort_t::merge_partition( + tgp_vals, tgp_vals + A_sz, A_sz, B_sz, sort_md_local); + int A_ed_local = A_sz; + + int B_st_local = sort_md_local - A_st_local; + int B_ed_local = B_sz; + + int A_sz_local = A_ed_local - A_st_local; + int B_sz_local = B_ed_local - B_st_local; + + // Do merge + block_sort_t::merge_step( + tgp_vals + A_st_local, + tgp_vals + A_ed_local + B_st_local, + tgp_idxs + A_st_local, + tgp_idxs + A_ed_local + B_st_local, + A_sz_local, + B_sz_local, + thread_vals, + thread_idxs); + + threadgroup_barrier(mem_flags::mem_threadgroup); + for (int i = 0; i < N_PER_THREAD; ++i) { + int idx = lid.x * N_PER_THREAD; + tgp_vals[idx + i] = thread_vals[i]; + tgp_idxs[idx + i] = thread_idxs[i]; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + // Write output + int base_idx = tid.x * sort_kernel::N_PER_BLOCK; + for (int i = lid.x; i < sort_kernel::N_PER_BLOCK; i += BLOCK_THREADS) { + int idx = base_idx + i; + if (idx < size_sorted_axis) { + dev_vals_out[idx] = tgp_vals[i]; + dev_idxs_out[idx] = tgp_idxs[i]; + } + } +} diff --git a/Source/Cmlx/mlx-generated/metal/steel/attn/attn.h b/Source/Cxxmlx/mlx-generated/metal/steel/attn/attn.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/steel/attn/attn.h rename to Source/Cxxmlx/mlx-generated/metal/steel/attn/attn.h diff --git a/Source/Cmlx/mlx-generated/metal/steel/attn/kernels/steel_attention.h b/Source/Cxxmlx/mlx-generated/metal/steel/attn/kernels/steel_attention.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/steel/attn/kernels/steel_attention.h rename to Source/Cxxmlx/mlx-generated/metal/steel/attn/kernels/steel_attention.h diff --git a/Source/Cmlx/mlx-generated/metal/steel/attn/kernels/steel_attention.metal b/Source/Cxxmlx/mlx-generated/metal/steel/attn/kernels/steel_attention.metal similarity index 100% rename from Source/Cmlx/mlx-generated/metal/steel/attn/kernels/steel_attention.metal rename to Source/Cxxmlx/mlx-generated/metal/steel/attn/kernels/steel_attention.metal diff --git a/Source/Cmlx/mlx-generated/metal/steel/attn/kernels/steel_attention_nax.h b/Source/Cxxmlx/mlx-generated/metal/steel/attn/kernels/steel_attention_nax.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/steel/attn/kernels/steel_attention_nax.h rename to Source/Cxxmlx/mlx-generated/metal/steel/attn/kernels/steel_attention_nax.h diff --git a/Source/Cmlx/mlx-generated/metal/steel/attn/loader.h b/Source/Cxxmlx/mlx-generated/metal/steel/attn/loader.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/steel/attn/loader.h rename to Source/Cxxmlx/mlx-generated/metal/steel/attn/loader.h diff --git a/Source/Cmlx/mlx-generated/metal/steel/attn/mma.h b/Source/Cxxmlx/mlx-generated/metal/steel/attn/mma.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/steel/attn/mma.h rename to Source/Cxxmlx/mlx-generated/metal/steel/attn/mma.h diff --git a/Source/Cmlx/mlx-generated/metal/steel/attn/nax.h b/Source/Cxxmlx/mlx-generated/metal/steel/attn/nax.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/steel/attn/nax.h rename to Source/Cxxmlx/mlx-generated/metal/steel/attn/nax.h diff --git a/Source/Cxxmlx/mlx-generated/metal/steel/attn/params.h b/Source/Cxxmlx/mlx-generated/metal/steel/attn/params.h new file mode 100644 index 00000000..f1cf09fa --- /dev/null +++ b/Source/Cxxmlx/mlx-generated/metal/steel/attn/params.h @@ -0,0 +1,44 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +/////////////////////////////////////////////////////////////////////////////// +// Attn param classes +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +struct AttnParams { + int B; ///< Batch Size + int H; ///< Heads + int D; ///< Head Dim + + int qL; ///< Query Sequence Length + int kL; ///< Key Sequence Length + + int gqa_factor; ///< Group Query factor + float scale; ///< Attention scale + + int NQ; ///< Number of query blocks + int NK; ///< Number of key/value blocks + + int NQ_aligned; ///< Number of full query blocks + int NK_aligned; ///< Number of full key/value blocks + + int qL_rem; ///< Remainder in last query block + int kL_rem; ///< Remainder in last key/value block + int qL_off; ///< Offset in query sequence start + + int64_t Q_strides[3]; ///< Query strides (B, H, L, D = 1) + int64_t K_strides[3]; ///< Key strides (B, H, L, D = 1) + int64_t V_strides[3]; ///< Value strides (B, H, L, D = 1) + int64_t O_strides[3]; ///< Output strides (B, H, L, D = 1) +}; + +struct AttnMaskParams { + int64_t M_strides[3]; ///< Mask strides (B, H, qL, kL = 1) +}; + +} // namespace steel +} // namespace mlx diff --git a/Source/Cmlx/mlx-generated/metal/steel/attn/transforms.h b/Source/Cxxmlx/mlx-generated/metal/steel/attn/transforms.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/steel/attn/transforms.h rename to Source/Cxxmlx/mlx-generated/metal/steel/attn/transforms.h diff --git a/Source/Cmlx/mlx-generated/metal/steel/conv/conv.h b/Source/Cxxmlx/mlx-generated/metal/steel/conv/conv.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/steel/conv/conv.h rename to Source/Cxxmlx/mlx-generated/metal/steel/conv/conv.h diff --git a/Source/Cxxmlx/mlx-generated/metal/steel/conv/kernels/steel_conv.h b/Source/Cxxmlx/mlx-generated/metal/steel/conv/kernels/steel_conv.h new file mode 100644 index 00000000..850ec15b --- /dev/null +++ b/Source/Cxxmlx/mlx-generated/metal/steel/conv/kernels/steel_conv.h @@ -0,0 +1,176 @@ +// Copyright © 2024 Apple Inc. + +#include + +using namespace metal; + +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + int N_CHANNELS = 0, + bool SMALL_FILTER = false> +[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void +implicit_gemm_conv_2d( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + device T* C [[buffer(2)]], + const constant MLXConvParams<2>* params [[buffer(3)]], + const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + using namespace mlx::steel; + + (void)lid; + + constexpr bool transpose_a = false; + constexpr bool transpose_b = true; + constexpr short tgp_padding_a = 16 / sizeof(T); + constexpr short tgp_padding_b = 16 / sizeof(T); + + constexpr short shape_a_cols = (transpose_a ? BM : BK) + tgp_padding_a; + constexpr short shape_b_cols = (transpose_b ? BK : BN) + tgp_padding_b; + constexpr short shape_a_rows = (transpose_a ? BK : BM); + constexpr short shape_b_rows = (transpose_b ? BN : BK); + constexpr short tgp_mem_size_a = shape_a_cols * shape_a_rows; + constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows; + + constexpr short tgp_size = WM * WN * 32; + + // Input loader + + using loader_a_t = typename metal::conditional_t< + // Check for small channel specialization + N_CHANNELS != 0 && N_CHANNELS <= 4, + + // Go to small channel specialization + Conv2DInputBlockLoaderSmallChannels< + T, + BM, + BN, + BK, + tgp_size, + N_CHANNELS, + tgp_padding_a>, + + // Else go to general loader + typename metal::conditional_t< + // Check if filter size is small enough + SMALL_FILTER, + + // Go to small filter specialization + Conv2DInputBlockLoaderSmallFilter< + T, + BM, + BN, + BK, + tgp_size, + tgp_padding_a>, + + // Else go to large filter generalization + Conv2DInputBlockLoaderLargeFilter< + T, + BM, + BN, + BK, + tgp_size, + tgp_padding_a>>>; + + // Weight loader + using loader_b_t = typename metal::conditional_t< + // Check for small channel specialization + N_CHANNELS != 0 && N_CHANNELS <= 4, + + // Go to small channel specialization + Conv2DWeightBlockLoaderSmallChannels< + T, + BM, + BN, + BK, + tgp_size, + N_CHANNELS, + tgp_padding_b>, + + // Else go to general loader + Conv2DWeightBlockLoader>; + + using mma_t = BlockMMA< + T, + T, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + shape_a_cols, + shape_b_cols>; + + threadgroup T As[tgp_mem_size_a]; + threadgroup T Bs[tgp_mem_size_b]; + + const int tid_y = ((tid.y) << gemm_params->swizzle_log) + + ((tid.x) & ((1 << gemm_params->swizzle_log) - 1)); + const int tid_x = (tid.x) >> gemm_params->swizzle_log; + + if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) { + return; + } + + const int c_row = tid_y * BM; + const int c_col = tid_x * BN; + const int K = gemm_params->K; + const int N = gemm_params->N; + const int C_per_group = params->C / params->groups; + + // Groups + A += tid.z * C_per_group; + B += tid.z * N * K; + C += tid.z * N; + + B += c_col * K; + C += c_row * (N * params->groups) + c_col; + + const int2 offsets_a(0, c_row); + const int2 offsets_b(0, c_col); + + // Prepare threadgroup loading operations + loader_a_t loader_a( + A, As, offsets_a, params, gemm_params, simd_gid, simd_lid); + loader_b_t loader_b( + B, Bs, offsets_b, params, gemm_params, simd_gid, simd_lid); + + // Prepare threadgroup mma operation + mma_t mma_op(simd_gid, simd_lid); + + int gemm_k_iterations = gemm_params->gemm_k_iterations; + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + threadgroup_barrier(mem_flags::mem_none); + + // Store results to device memory + short tgp_bm = min(BM, gemm_params->M - c_row); + short tgp_bn = min(BN, gemm_params->N - c_col); + const int ldc = N * params->groups; + mma_op.store_result_safe(C, ldc, short2(tgp_bn, tgp_bm)); +} diff --git a/Source/Cxxmlx/mlx-generated/metal/steel/conv/kernels/steel_conv_3d.h b/Source/Cxxmlx/mlx-generated/metal/steel/conv/kernels/steel_conv_3d.h new file mode 100644 index 00000000..d2fbac0f --- /dev/null +++ b/Source/Cxxmlx/mlx-generated/metal/steel/conv/kernels/steel_conv_3d.h @@ -0,0 +1,135 @@ +// Copyright © 2024 Apple Inc. + +#include + +using namespace metal; + +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + bool SMALL_FILTER = false> +[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void +implicit_gemm_conv_3d( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + device T* C [[buffer(2)]], + const constant MLXConvParams<3>* params [[buffer(3)]], + const constant ImplicitGemmConv3DParams* gemm_params [[buffer(4)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + using namespace mlx::steel; + + (void)lid; + + constexpr bool transpose_a = false; + constexpr bool transpose_b = true; + constexpr short tgp_padding_a = 16 / sizeof(T); + constexpr short tgp_padding_b = 16 / sizeof(T); + + constexpr short shape_a_cols = (transpose_a ? BM : BK) + tgp_padding_a; + constexpr short shape_b_cols = (transpose_b ? BK : BN) + tgp_padding_b; + constexpr short shape_a_rows = (transpose_a ? BK : BM); + constexpr short shape_b_rows = (transpose_b ? BN : BK); + constexpr short tgp_mem_size_a = shape_a_cols * shape_a_rows; + constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows; + + constexpr short tgp_size = WM * WN * 32; + + // Input loader + using loader_a_t = typename metal::conditional_t< + // If the filter is small we can precompute masks for bounds checking + SMALL_FILTER, + Conv3DInputBlockLoaderSmallFilter, + Conv3DInputBlockLoaderLargeFilter< + T, + BM, + BN, + BK, + tgp_size, + tgp_padding_a>>; + + // Weight loader + using loader_b_t = + Conv3DWeightBlockLoader; + + using mma_t = BlockMMA< + T, + T, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + shape_a_cols, + shape_b_cols>; + + threadgroup T As[tgp_mem_size_a]; + threadgroup T Bs[tgp_mem_size_b]; + + const int tid_y = ((tid.y) << gemm_params->swizzle_log) + + ((tid.x) & ((1 << gemm_params->swizzle_log) - 1)); + const int tid_x = (tid.x) >> gemm_params->swizzle_log; + + if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) { + return; + } + + const int c_row = tid_y * BM; + const int c_col = tid_x * BN; + const int K = gemm_params->K; + const int N = gemm_params->N; + const int C_per_group = params->C / params->groups; + + // Groups + A += tid.z * C_per_group; + B += tid.z * N * K; + C += tid.z * N; + + B += c_col * K; + C += c_row * (N * params->groups) + c_col; + + const int2 offsets_a(0, c_row); + const int2 offsets_b(0, c_col); + + // Prepare threadgroup loading operations + loader_a_t loader_a( + A, As, offsets_a, params, gemm_params, simd_gid, simd_lid); + loader_b_t loader_b( + B, Bs, offsets_b, params, gemm_params, simd_gid, simd_lid); + + // Prepare threadgroup mma operation + mma_t mma_op(simd_gid, simd_lid); + + int gemm_k_iterations = gemm_params->gemm_k_iterations; + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + threadgroup_barrier(mem_flags::mem_none); + + // Store results to device memory + short tgp_bm = min(BM, gemm_params->M - c_row); + short tgp_bn = min(BN, gemm_params->N - c_col); + const int ldc = N * params->groups; + mma_op.store_result_safe(C, ldc, short2(tgp_bn, tgp_bm)); +} diff --git a/Source/Cmlx/mlx-generated/metal/steel/conv/kernels/steel_conv_general.h b/Source/Cxxmlx/mlx-generated/metal/steel/conv/kernels/steel_conv_general.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/steel/conv/kernels/steel_conv_general.h rename to Source/Cxxmlx/mlx-generated/metal/steel/conv/kernels/steel_conv_general.h diff --git a/Source/Cmlx/mlx-generated/metal/steel/conv/loader.h b/Source/Cxxmlx/mlx-generated/metal/steel/conv/loader.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/steel/conv/loader.h rename to Source/Cxxmlx/mlx-generated/metal/steel/conv/loader.h diff --git a/Source/Cmlx/mlx-generated/metal/steel/conv/loaders/loader_channel_l.h b/Source/Cxxmlx/mlx-generated/metal/steel/conv/loaders/loader_channel_l.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/steel/conv/loaders/loader_channel_l.h rename to Source/Cxxmlx/mlx-generated/metal/steel/conv/loaders/loader_channel_l.h diff --git a/Source/Cmlx/mlx-generated/metal/steel/conv/loaders/loader_channel_n.h b/Source/Cxxmlx/mlx-generated/metal/steel/conv/loaders/loader_channel_n.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/steel/conv/loaders/loader_channel_n.h rename to Source/Cxxmlx/mlx-generated/metal/steel/conv/loaders/loader_channel_n.h diff --git a/Source/Cmlx/mlx-generated/metal/steel/conv/loaders/loader_general.h b/Source/Cxxmlx/mlx-generated/metal/steel/conv/loaders/loader_general.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/steel/conv/loaders/loader_general.h rename to Source/Cxxmlx/mlx-generated/metal/steel/conv/loaders/loader_general.h diff --git a/Source/Cxxmlx/mlx-generated/metal/steel/conv/params.h b/Source/Cxxmlx/mlx-generated/metal/steel/conv/params.h new file mode 100644 index 00000000..67d38274 --- /dev/null +++ b/Source/Cxxmlx/mlx-generated/metal/steel/conv/params.h @@ -0,0 +1,103 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +template +struct MLXConvParams { + int N; // Batch size + int C; // In channels + int O; // Out channels + int iS[NDIM]; // Input spatial dim + int wS[NDIM]; // Weight spatial dim + int oS[NDIM]; // Output spatial dim + int str[NDIM]; // Kernel strides + int pad[NDIM]; // Input padding + int kdil[NDIM]; // Kernel dilation + int idil[NDIM]; // Input dilation + int64_t in_strides[NDIM + 2]; // In strides + int64_t wt_strides[NDIM + 2]; // Wt strides + int64_t out_strides[NDIM + 2]; // Out strides + int groups; // Input channel groups + bool flip; + + static MLXConvParams + with_padded_channels(MLXConvParams other, int pad_out, int pad_in) { + MLXConvParams params = other; + + // Update strides + for (int i = 0; i < NDIM + 1; i++) { + params.in_strides[i] = + (params.in_strides[i] / params.C) * (params.C + pad_in); + params.wt_strides[i] = + (params.wt_strides[i] / params.C) * (params.C + pad_in); + params.out_strides[i] = + (params.out_strides[i] / params.O) * (params.O + pad_out); + } + params.in_strides[NDIM + 1] = 1; + params.wt_strides[NDIM + 1] = 1; + params.out_strides[NDIM + 1] = 1; + + // Update channels + params.C += pad_in; + params.O += pad_out; + + return params; + }; +}; + +namespace mlx { +namespace steel { + +struct ImplicitGemmConv2DParams { + const int M; + const int N; + const int K; + + const int gemm_k_iterations; + + const int inp_jump_w; + const int inp_jump_h; + const int inp_jump_c; + + const int tiles_n; + const int tiles_m; + const int swizzle_log; +}; + +struct ImplicitGemmConv3DParams { + const int M; + const int N; + const int K; + + const int gemm_k_iterations; + + const int inp_jump_w; + const int inp_jump_h; + const int inp_jump_d; + const int inp_jump_c; + + const int tiles_n; + const int tiles_m; + const int swizzle_log; +}; + +struct Conv2DGeneralJumpParams { + const int f_wgt_jump_h; + const int f_wgt_jump_w; + + const int f_out_jump_h; + const int f_out_jump_w; + + const int adj_out_h; + const int adj_out_w; + const int adj_out_hw; + const int adj_implicit_m; +}; + +struct Conv2DGeneralBaseInfo { + int weight_base; + int weight_size; +}; + +} // namespace steel +} // namespace mlx diff --git a/Source/Cxxmlx/mlx-generated/metal/steel/defines.h b/Source/Cxxmlx/mlx-generated/metal/steel/defines.h new file mode 100644 index 00000000..f5657ee3 --- /dev/null +++ b/Source/Cxxmlx/mlx-generated/metal/steel/defines.h @@ -0,0 +1,7 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#define STEEL_CONST static constant constexpr const +#define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)") +#define STEEL_PRAGMA_NO_UNROLL _Pragma("clang loop unroll(disable)") diff --git a/Source/Cmlx/mlx-generated/metal/steel/gemm/gemm.h b/Source/Cxxmlx/mlx-generated/metal/steel/gemm/gemm.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/steel/gemm/gemm.h rename to Source/Cxxmlx/mlx-generated/metal/steel/gemm/gemm.h diff --git a/Source/Cmlx/mlx-generated/metal/steel/gemm/gemm_nax.h b/Source/Cxxmlx/mlx-generated/metal/steel/gemm/gemm_nax.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/steel/gemm/gemm_nax.h rename to Source/Cxxmlx/mlx-generated/metal/steel/gemm/gemm_nax.h diff --git a/Source/Cxxmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_fused.h b/Source/Cxxmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_fused.h new file mode 100644 index 00000000..85830872 --- /dev/null +++ b/Source/Cxxmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_fused.h @@ -0,0 +1,346 @@ +// Copyright © 2024 Apple Inc. + +using namespace mlx::steel; + +/////////////////////////////////////////////////////////////////////////////// +// GEMM kernels +/////////////////////////////////////////////////////////////////////////////// + +constant bool has_batch [[function_constant(10)]]; + +constant bool use_out_source [[function_constant(100)]]; +constant bool do_axpby [[function_constant(110)]]; + +constant bool align_M [[function_constant(200)]]; +constant bool align_N [[function_constant(201)]]; +constant bool align_K [[function_constant(202)]]; + +// clang-format off +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + typename AccumType = float> +[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + const device T* C [[buffer(2), function_constant(use_out_source)]], + device T* D [[buffer(3)]], + const constant GEMMParams* params [[buffer(4)]], + const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]], + const constant int* batch_shape [[buffer(6), function_constant(has_batch)]], + const constant int64_t* batch_strides [[buffer(7), function_constant(has_batch)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on + // Pacifying compiler + (void)lid; + + using gemm_kernel = GEMMKernel< + T, + T, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + true, + true, + AccumType>; + + using loader_a_t = typename gemm_kernel::loader_a_t; + using loader_b_t = typename gemm_kernel::loader_b_t; + using mma_t = typename gemm_kernel::mma_t; + + // Find block + const int tid_y = ((tid.y) << params->swizzle_log) + + ((tid.x) & ((1 << params->swizzle_log) - 1)); + const int tid_x = (tid.x) >> params->swizzle_log; + + // Exit early if out of bounds + if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { + return; + } + + // Adjust for batch + if (has_batch) { + const constant auto* A_bstrides = batch_strides; + const constant auto* B_bstrides = batch_strides + params->batch_ndim; + + ulong2 batch_offsets = elem_to_loc_broadcast( + tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim); + + A += batch_offsets.x; + B += batch_offsets.y; + + if (use_out_source) { + const constant auto* C_bstrides = B_bstrides + params->batch_ndim; + C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim); + } + } else { + A += params->batch_stride_a * tid.z; + B += params->batch_stride_b * tid.z; + + if (use_out_source) { + C += addmm_params->batch_stride_c * tid.z; + } + } + + D += params->batch_stride_d * tid.z; + + // Prepare threadgroup memory + threadgroup T As[gemm_kernel::tgp_mem_size_a]; + threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; + + threadgroup_barrier(mem_flags::mem_none); + + // Find block in A, B, C + const int c_row = tid_y * BM; + const int c_col = tid_x * BN; + const size_t c_row_long = size_t(c_row); + const size_t c_col_long = size_t(c_col); + + A += transpose_a ? c_row_long : c_row_long * params->lda; + B += transpose_b ? c_col_long * params->ldb : c_col_long; + D += c_row_long * params->ldd + c_col_long; + + if (use_out_source) { + C += c_row_long * addmm_params->ldc + c_col_long * addmm_params->fdc; + } + + // Prepare threadgroup mma operation + thread mma_t mma_op(simd_group_id, simd_lane_id); + + // Prepare threadgroup loading operations + thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); + thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); + + // Prepare threadgroup bounds + const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row)); + const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col)); + + // Prepare iterations + int gemm_k_iterations = params->gemm_k_iterations_aligned; + + // Do unaligned K iterations first + if (!align_K) { + const int k_last = params->gemm_k_iterations_aligned * BK; + const int k_remain = params->K - k_last; + const size_t k_jump_a = + transpose_a ? params->lda * size_t(k_last) : size_t(k_last); + const size_t k_jump_b = + transpose_b ? size_t(k_last) : params->ldb * size_t(k_last); + + // Move loader source ahead to end + loader_a.src += k_jump_a; + loader_b.src += k_jump_b; + + // Load tile + const short2 tile_dims_A = + transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); + const short2 tile_dims_B = + transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Do matmul + mma_op.mma(As, Bs); + + // Reset source back to start + loader_a.src -= k_jump_a; + loader_b.src -= k_jump_b; + } + + const TransformAdd epilogue_op_add( + addmm_params->alpha, addmm_params->beta); + const TransformAxpby epilogue_op_axpby( + addmm_params->alpha, addmm_params->beta); + + /////////////////////////////////////////////////////////////////////////////// + // MNK aligned loop + if (align_M && align_N) { + // Do gemm + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + threadgroup_barrier(mem_flags::mem_none); + + // Do epilogue + if (use_out_source) { + if (do_axpby) { + mma_op.apply_epilogue( + C, addmm_params->ldc, addmm_params->fdc, epilogue_op_axpby); + } else { + mma_op.apply_epilogue( + C, addmm_params->ldc, addmm_params->fdc, epilogue_op_add); + } + } + + // Store results to device memory + return mma_op.store_result(D, params->ldd); + + } + /////////////////////////////////////////////////////////////////////////////// + // MN unaligned loop + else { // Loop over K - unaligned case + const int leftover_bk = 0; + + if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) { + // Do gemm + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + + // Do epilogue + if (use_out_source) { + if (do_axpby) { + mma_op.apply_epilogue( + C, addmm_params->ldc, addmm_params->fdc, epilogue_op_axpby); + } else { + mma_op.apply_epilogue( + C, addmm_params->ldc, addmm_params->fdc, epilogue_op_add); + } + } + + // Store results to device memory + return mma_op.store_result(D, params->ldd); + + } else if (align_N || tgp_bn == BN) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + + // Do epilogue + if (use_out_source) { + if (do_axpby) { + mma_op.apply_epilogue_safe( + C, + addmm_params->ldc, + addmm_params->fdc, + short2(tgp_bn, tgp_bm), + epilogue_op_axpby); + } else { + mma_op.apply_epilogue_safe( + C, + addmm_params->ldc, + addmm_params->fdc, + short2(tgp_bn, tgp_bm), + epilogue_op_add); + } + } + + // Store results to device memory + return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); + + } else if (align_M || tgp_bm == BM) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + + // Do epilogue + if (use_out_source) { + if (do_axpby) { + mma_op.apply_epilogue_safe( + C, + addmm_params->ldc, + addmm_params->fdc, + short2(tgp_bn, tgp_bm), + epilogue_op_axpby); + } else { + mma_op.apply_epilogue_safe( + C, + addmm_params->ldc, + addmm_params->fdc, + short2(tgp_bn, tgp_bm), + epilogue_op_add); + } + } + + // Store results to device memory + return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); + + } else { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + + // Do epilogue + if (use_out_source) { + if (do_axpby) { + mma_op.apply_epilogue_safe( + C, + addmm_params->ldc, + addmm_params->fdc, + short2(tgp_bn, tgp_bm), + epilogue_op_axpby); + } else { + mma_op.apply_epilogue_safe( + C, + addmm_params->ldc, + addmm_params->fdc, + short2(tgp_bn, tgp_bm), + epilogue_op_add); + } + } + + // Store results to device memory + return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); + } + } +} diff --git a/Source/Cxxmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_fused_nax.h b/Source/Cxxmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_fused_nax.h new file mode 100644 index 00000000..f0789548 --- /dev/null +++ b/Source/Cxxmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_fused_nax.h @@ -0,0 +1,219 @@ +// Copyright © 2025 Apple Inc. + +using namespace mlx::steel; + +constant bool has_batch [[function_constant(10)]]; + +constant bool use_out_source [[function_constant(100)]]; +constant bool do_axpby [[function_constant(110)]]; + +constant bool align_M [[function_constant(200)]]; +constant bool align_N [[function_constant(201)]]; +constant bool align_K [[function_constant(202)]]; + +// clang-format off +template < + bool kAlignedM, + bool kAlignedN, + typename NAXTile_t, + typename T> +void gemm_epilogue( + thread NAXTile_t& Dtile, + const device T* C, + const constant GEMMParams* params, + const constant GEMMAddMMParams* addmm_params, + const short sgp_sm, + const short sgp_sn) { // clang-format on + + (void)params; + + constexpr short UM = NAXTile_t::kSubTileRows; + constexpr short UN = NAXTile_t::kSubTileCols; + using CSubTile = NAXSubTile; + + using V = typename NAXTile_t::elem_type; + + constexpr short TM = NAXTile_t::kTileRows; + constexpr short TN = NAXTile_t::kTileCols; + constexpr short kElemsPerSubTile = NAXTile_t::kElemsPerSubTile; + + STEEL_PRAGMA_UNROLL + for (short mm = 0; mm < TM; mm++) { + STEEL_PRAGMA_UNROLL + for (short nn = 0; nn < TN; nn++) { + const short m = mm * UM; + const short n = nn * UN; + + CSubTile CTile; + + if constexpr (kAlignedM && kAlignedN) { + CTile.load(C, addmm_params->ldc, addmm_params->fdc, m, n); + } else { + CTile.load_safe( + C, addmm_params->ldc, addmm_params->fdc, sgp_sm, sgp_sn, m, n); + } + + auto delems = Dtile.subtile_at(mm, nn).elems(); + auto celems = CTile.elems(); + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemsPerSubTile; i++) { + if (do_axpby) { + delems[i] = addmm_params->alpha * delems[i] + + addmm_params->beta * static_cast(celems[i]); + } else { + delems[i] += static_cast(celems[i]); + } + } + } + } +} + +// clang-format off +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + typename AccumType = float> +[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + const device T* C [[buffer(2), function_constant(use_out_source)]], + device T* D [[buffer(3)]], + const constant GEMMParams* params [[buffer(4)]], + const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]], + const constant int* batch_shape [[buffer(6), function_constant(has_batch)]], + const constant int64_t* batch_strides [[buffer(7), function_constant(has_batch)]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]]) { // clang-format on + // Find block + const int tid_y = ((tid.y) << params->swizzle_log) + + ((tid.x) & ((1 << params->swizzle_log) - 1)); + const int tid_x = (tid.x) >> params->swizzle_log; + + // Exit early if out of bounds + if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { + return; + } + + // Adjust for batch + if (has_batch) { + const constant auto* A_bstrides = batch_strides; + const constant auto* B_bstrides = batch_strides + params->batch_ndim; + + ulong2 batch_offsets = elem_to_loc_broadcast( + tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim); + + A += batch_offsets.x; + B += batch_offsets.y; + + if (use_out_source) { + const constant auto* C_bstrides = B_bstrides + params->batch_ndim; + C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim); + } + } else { + A += params->batch_stride_a * tid.z; + B += params->batch_stride_b * tid.z; + + if (use_out_source) { + C += addmm_params->batch_stride_c * tid.z; + } + } + + D += params->batch_stride_d * tid.z; + + // Prepare threadgroup memory + threadgroup_barrier(mem_flags::mem_none); + + // Find block in A, B, C + const int c_row = tid_y * BM; + const int c_col = tid_x * BN; + const size_t c_row_long = size_t(c_row); + const size_t c_col_long = size_t(c_col); + + A += transpose_a ? c_row_long : c_row_long * params->lda; + B += transpose_b ? c_col_long * params->ldb : c_col_long; + D += c_row_long * params->ldd + c_col_long; + + if (use_out_source) { + C += c_row_long * addmm_params->ldc + c_col_long * addmm_params->fdc; + } + + constexpr short UM = 16; + constexpr short UN = 32; + constexpr short UK = 16; + constexpr short SM = BM / WM; + constexpr short SN = BN / WN; + constexpr short SK = 32; + + constexpr short TM = SM / UM; + constexpr short TN = SN / UN; + + const short tm = SM * (simd_group_id / WN); + const short tn = SN * (simd_group_id % WN); + + const int sgp_sm_int = + align_M ? int(SM) : min(int(SM), params->M - (c_row + tm)); + const short sgp_sm = short(sgp_sm_int); + const bool is_unaligned_sm = align_M ? false : (sgp_sm != SM); + + const int sgp_sn_int = + align_N ? int(SN) : min(int(SN), params->N - (c_col + tn)); + const short sgp_sn = short(sgp_sn_int); + const bool is_unaligned_sn = align_N ? false : (sgp_sn != SN); + + A += transpose_a ? tm : (tm * params->lda); + B += transpose_b ? (tn * params->ldb) : tn; + D += tm * params->ldd + tn; + + if (use_out_source) { + C += tm * addmm_params->ldc + tn * addmm_params->fdc; + } + + using DSubTile = NAXSubTile; + NAXTile Dtile; + + dispatch_bool(align_K, [&](auto kAlignedK) { + dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) { + dispatch_bool(align_N || !is_unaligned_sn, [&](auto kAlignedN) { + Dtile = gemm_loop< + T, + SM, + SN, + SK, + BK, + transpose_a, + transpose_b, + kAlignedM.value, + kAlignedN.value, + kAlignedK.value, + UM, + UN, + UK, + AccumType>( + A, + B, + params->lda, + params->ldb, + params->K, + params->gemm_k_iterations_aligned, + sgp_sm, + sgp_sn); + if (use_out_source) { + gemm_epilogue( + Dtile, C, params, addmm_params, sgp_sm, sgp_sn); + } + if constexpr (kAlignedM && kAlignedN) { + Dtile.store(D, int(params->ldd)); + } else { + Dtile.store_safe(D, int(params->ldd), short2(sgp_sn, sgp_sm)); + } + }); + }); + }); +} diff --git a/Source/Cxxmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_gather.h b/Source/Cxxmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_gather.h new file mode 100644 index 00000000..4c055e69 --- /dev/null +++ b/Source/Cxxmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_gather.h @@ -0,0 +1,459 @@ +// Copyright © 2024 Apple Inc. + +using namespace mlx::steel; + +constant bool has_batch [[function_constant(10)]]; +constant bool align_M [[function_constant(200)]]; +constant bool align_N [[function_constant(201)]]; +constant bool align_K [[function_constant(202)]]; + +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + typename AccumType = float> +[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void gather_mm_rhs( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + const device uint32_t* rhs_indices [[buffer(2)]], + device T* C [[buffer(3)]], + const constant GEMMParams* params [[buffer(4)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]]) { + using gemm_kernel = GEMMKernel< + T, + T, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + true, + true, + AccumType>; + + using loader_a_t = typename gemm_kernel::loader_a_t; + using loader_b_t = typename gemm_kernel::loader_b_t; + using mma_t = typename gemm_kernel::mma_t; + + if (params->tiles_n <= static_cast(tid.x) || + params->tiles_m <= static_cast(tid.y)) { + return; + } + + // Prepare threadgroup memory + threadgroup T As[gemm_kernel::tgp_mem_size_a]; + threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; + + // Find the block in A, B, C + const int c_row = tid.y * BM; + const int c_col = tid.x * BN; + const size_t c_row_long = size_t(c_row); + const size_t c_col_long = size_t(c_col); + + // Prepare threadgroup bounds + const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row)); + const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col)); + + A += transpose_a ? c_row_long : c_row_long * params->lda; + B += transpose_b ? c_col_long * params->ldb : c_col_long; + C += c_row_long * params->ldd + c_col_long; + + // Do as many matmuls as necessary + uint32_t index; + short offset; + uint32_t index_next = rhs_indices[c_row]; + short offset_next = 0; + int n = 0; + while (n < tgp_bm) { + n++; + offset = offset_next; + index = index_next; + offset_next = tgp_bm; + for (; n < tgp_bm; n++) { + if (rhs_indices[c_row + n] != index) { + offset_next = n; + index_next = rhs_indices[c_row + n]; + break; + } + } + threadgroup_barrier(mem_flags::mem_none); + + // Prepare threadgroup mma operation + thread mma_t mma_op(simd_group_id, simd_lane_id); + + // Prepare threadgroup loading operations + thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); + thread loader_b_t loader_b( + B + index * params->batch_stride_b, + params->ldb, + Bs, + simd_group_id, + simd_lane_id); + + // Prepare iterations + const int gemm_k_iterations = params->gemm_k_iterations_aligned; + + // Do unaligned K iterations first + if (!align_K) { + const int k_last = params->gemm_k_iterations_aligned * BK; + const int k_remain = params->K - k_last; + const size_t k_jump_a = + transpose_a ? params->lda * size_t(k_last) : size_t(k_last); + const size_t k_jump_b = + transpose_b ? size_t(k_last) : params->ldb * size_t(k_last); + + // Move loader source ahead to end + loader_a.src += k_jump_a; + loader_b.src += k_jump_b; + + // Load tile + const short2 tile_dims_A = + transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); + const short2 tile_dims_B = + transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Do matmul + mma_op.mma(As, Bs); + + // Reset source back to start + loader_a.src -= k_jump_a; + loader_b.src -= k_jump_b; + } + + // Matrix level aligned never check + if (align_M && align_N) { + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + // Store results to device memory + if (offset_next - offset == BM) { + mma_op.store_result(C, params->ldd); + } else { + mma_op.store_result_slice( + C, params->ldd, short2(0, offset), short2(BN, offset_next)); + } + } else { + const short lbk = 0; + + // Tile aligned don't check + if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + lbk, + LoopAlignment{}); + if (offset_next - offset == BM) { + mma_op.store_result(C, params->ldd); + } else { + mma_op.store_result_slice( + C, params->ldd, short2(0, offset), short2(BN, offset_next)); + } + } + + // Tile partially aligned check rows + else if (align_N || tgp_bn == BN) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + lbk, + LoopAlignment{}); + mma_op.store_result_slice( + C, params->ldd, short2(0, offset), short2(BN, offset_next)); + } + + // Tile partially aligned check cols + else if (align_M || tgp_bm == BM) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + lbk, + LoopAlignment{}); + mma_op.store_result_slice( + C, params->ldd, short2(0, offset), short2(tgp_bn, offset_next)); + } + + // Nothing aligned so check both rows and cols + else { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + lbk, + LoopAlignment{}); + mma_op.store_result_slice( + C, params->ldd, short2(0, offset), short2(tgp_bn, offset_next)); + } + } + } +} + +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + typename AccumType = float> +[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void gather_mm( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + const device uint32_t* lhs_indices [[buffer(2)]], + const device uint32_t* rhs_indices [[buffer(3)]], + device T* C [[buffer(4)]], + const constant GEMMParams* params [[buffer(5)]], + const constant int* indices_shape [[buffer(6)]], + const constant int64_t* lhs_strides [[buffer(7)]], + const constant int64_t* rhs_strides [[buffer(8)]], + const constant int& batch_ndim_a [[buffer(9)]], + const constant int* batch_shape_a [[buffer(10)]], + const constant int64_t* batch_strides_a [[buffer(11)]], + const constant int& batch_ndim_b [[buffer(12)]], + const constant int* batch_shape_b [[buffer(13)]], + const constant int64_t* batch_strides_b [[buffer(14)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]]) { + using gemm_kernel = GEMMKernel< + T, + T, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + true, + true, + AccumType>; + + using loader_a_t = typename gemm_kernel::loader_a_t; + using loader_b_t = typename gemm_kernel::loader_b_t; + using mma_t = typename gemm_kernel::mma_t; + + if (params->tiles_n <= static_cast(tid.x) || + params->tiles_m <= static_cast(tid.y)) { + return; + } + + // Move A and B to the locations pointed by lhs_indices and rhs_indices. + uint32_t indx_A, indx_B; + if (has_batch) { + ulong2 indices_offsets = elem_to_loc_broadcast( + tid.z, indices_shape, lhs_strides, rhs_strides, params->batch_ndim); + indx_A = lhs_indices[indices_offsets.x]; + indx_B = rhs_indices[indices_offsets.y]; + } else { + indx_A = lhs_indices[params->batch_stride_a * tid.z]; + indx_B = rhs_indices[params->batch_stride_b * tid.z]; + } + A += elem_to_loc(indx_A, batch_shape_a, batch_strides_a, batch_ndim_a); + B += elem_to_loc(indx_B, batch_shape_b, batch_strides_b, batch_ndim_b); + C += params->batch_stride_d * tid.z; + + // Prepare threadgroup memory + threadgroup T As[gemm_kernel::tgp_mem_size_a]; + threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; + + // Just make sure everybody's finished with the indexing math above. + threadgroup_barrier(mem_flags::mem_none); + + // Find block in A, B, C + const int c_row = tid.y * BM; + const int c_col = tid.x * BN; + const size_t c_row_long = size_t(c_row); + const size_t c_col_long = size_t(c_col); + + A += transpose_a ? c_row_long : c_row_long * params->lda; + B += transpose_b ? c_col_long * params->ldb : c_col_long; + C += c_row_long * params->ldd + c_col_long; + + // Prepare threadgroup mma operation + thread mma_t mma_op(simd_group_id, simd_lane_id); + + // Prepare threadgroup loading operations + thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); + thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); + + // Prepare threadgroup bounds + const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row)); + const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col)); + + // Prepare iterations + int gemm_k_iterations = params->gemm_k_iterations_aligned; + + // Do unaligned K iterations first + if (!align_K) { + const int k_last = params->gemm_k_iterations_aligned * BK; + const int k_remain = params->K - k_last; + const size_t k_jump_a = + transpose_a ? params->lda * size_t(k_last) : size_t(k_last); + const size_t k_jump_b = + transpose_b ? size_t(k_last) : params->ldb * size_t(k_last); + + // Move loader source ahead to end + loader_a.src += k_jump_a; + loader_b.src += k_jump_b; + + // Load tile + const short2 tile_dims_A = + transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); + const short2 tile_dims_B = + transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Do matmul + mma_op.mma(As, Bs); + + // Reset source back to start + loader_a.src -= k_jump_a; + loader_b.src -= k_jump_b; + } + + // Matrix level aligned never check + if (align_M && align_N) { + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + // Store results to device memory + mma_op.store_result(C, params->ldd); + } else { + const short lbk = 0; + + // Tile aligned don't check + if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + lbk, + LoopAlignment{}); + mma_op.store_result(C, params->ldd); + } + + // Tile partially aligned check rows + else if (align_N || tgp_bn == BN) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + lbk, + LoopAlignment{}); + mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm)); + } + + // Tile partially aligned check cols + else if (align_M || tgp_bm == BM) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + lbk, + LoopAlignment{}); + mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm)); + } + + // Nothing aligned so check both rows and cols + else { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + lbk, + LoopAlignment{}); + mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm)); + } + } +} diff --git a/Source/Cxxmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_gather_nax.h b/Source/Cxxmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_gather_nax.h new file mode 100644 index 00000000..67cd7378 --- /dev/null +++ b/Source/Cxxmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_gather_nax.h @@ -0,0 +1,143 @@ +// Copyright © 2024 Apple Inc. + +using namespace mlx::steel; + +constant bool align_M [[function_constant(200)]]; +constant bool align_N [[function_constant(201)]]; +constant bool align_K [[function_constant(202)]]; + +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + typename AccumType = float> +[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void +gather_mm_rhs_nax( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + const device uint32_t* rhs_indices [[buffer(2)]], + device T* C [[buffer(3)]], + const constant GEMMParams* params [[buffer(4)]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]]) { + constexpr short UM = 16; + constexpr short UN = 32; + constexpr short UK = 16; + constexpr short SM = BM / WM; + constexpr short SN = BN / WN; + constexpr short SK = 32; + constexpr short TM = SM / UM; + constexpr short TN = SN / UN; + + if (params->tiles_n <= static_cast(tid.x) || + params->tiles_m <= static_cast(tid.y)) { + return; + } + + // Find the block in A, B, C + const int c_row = tid.y * BM; + const int c_col = tid.x * BN; + const size_t c_row_long = size_t(c_row); + const size_t c_col_long = size_t(c_col); + + A += transpose_a ? c_row_long : c_row_long * params->lda; + B += transpose_b ? c_col_long * params->ldb : c_col_long; + C += c_row_long * params->ldd + c_col_long; + rhs_indices += c_row; + + const short tm = SM * (simd_group_id / WN); + const short tn = SN * (simd_group_id % WN); + + const int sgp_sm_int = + align_M ? int(SM) : min(int(SM), params->M - (c_row + tm)); + const short sgp_sm = short(sgp_sm_int); + const bool is_unaligned_sm = align_M ? false : (sgp_sm != SM); + + const int sgp_sn_int = + align_N ? int(SN) : min(int(SN), params->N - (c_col + tn)); + const short sgp_sn = short(sgp_sn_int); + const bool is_unaligned_sn = align_N ? false : (sgp_sn != SN); + + A += transpose_a ? tm : (tm * params->lda); + B += transpose_b ? (tn * params->ldb) : tn; + C += tm * params->ldd + tn; + rhs_indices += tm; + + // Do as many matmuls as necessary + uint32_t index; + short offset; + uint32_t index_next = rhs_indices[0]; + short offset_next = 0; + int n = 0; + while (n < sgp_sm) { + n++; + offset = offset_next; + index = index_next; + offset_next = sgp_sm; + for (; n < sgp_sm; n++) { + if (rhs_indices[n] != index) { + offset_next = n; + index_next = rhs_indices[n]; + break; + } + } + threadgroup_barrier(mem_flags::mem_none); + + using DSubTile = NAXSubTile; + NAXTile Ctile; + + dispatch_bool(align_K, [&](auto kAlignedK) { + dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) { + dispatch_bool(align_N || !is_unaligned_sn, [&](auto kAlignedN) { + auto do_gemm = gemm_loop< + T, + SM, + SN, + SK, + BK, + transpose_a, + transpose_b, + kAlignedM.value, + kAlignedN.value, + kAlignedK.value, + UM, + UN, + UK, + AccumType>; + Ctile = do_gemm( + A, + B + index * params->batch_stride_b, + params->lda, + params->ldb, + params->K, + params->gemm_k_iterations_aligned, + sgp_sm, + sgp_sn); + + if constexpr (kAlignedN.value) { + if (offset_next - offset == SM) { + Ctile.store(C, int(params->ldd)); + } else { + Ctile.store_slice( + C, + int(params->ldd), + short2(0, offset), + short2(SN, offset_next)); + } + } else { + Ctile.store_slice( + C, + int(params->ldd), + short2(0, offset), + short2(sgp_sn, offset_next)); + } + }); + }); + }); + } +} diff --git a/Source/Cmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_masked.h b/Source/Cxxmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_masked.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_masked.h rename to Source/Cxxmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_masked.h diff --git a/Source/Cxxmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_segmented.h b/Source/Cxxmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_segmented.h new file mode 100644 index 00000000..5a43e223 --- /dev/null +++ b/Source/Cxxmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_segmented.h @@ -0,0 +1,266 @@ +// Copyright © 2025 Apple Inc. + +using namespace mlx::steel; + +constant bool segments_contiguous [[function_constant(199)]]; +constant bool align_M [[function_constant(200)]]; +constant bool align_N [[function_constant(201)]]; + +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + typename AccumType = float> +[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void segmented_mm( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + const device uint32_t* segments [[buffer(2)]], + device T* C [[buffer(3)]], + const constant GEMMParams* params [[buffer(4)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]]) { + using gemm_kernel = GEMMKernel< + T, + T, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + true, + true, + AccumType>; + + using loader_a_t = typename gemm_kernel::loader_a_t; + using loader_b_t = typename gemm_kernel::loader_b_t; + using mma_t = typename gemm_kernel::mma_t; + + if (params->tiles_n <= static_cast(tid.x) || + params->tiles_m <= static_cast(tid.y)) { + return; + } + + // Prepare threadgroup memory + threadgroup T As[gemm_kernel::tgp_mem_size_a]; + threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; + + // Find the block in A, B, C + const int c_row = tid.y * BM; + const int c_col = tid.x * BN; + const size_t c_row_long = size_t(c_row); + const size_t c_col_long = size_t(c_col); + + // Prepare threadgroup bounds + const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row)); + const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col)); + + // Move the pointers to the output tile + A += transpose_a ? c_row_long : c_row_long * params->lda; + B += transpose_b ? c_col_long * params->ldb : c_col_long; + C += c_row_long * params->ldd + c_col_long; + + // Move the pointers to the start of the segment + uint32_t k_start, k_end; + if (segments_contiguous) { + k_start = segments[2 * tid.z]; + k_end = segments[2 * tid.z + 1]; + } else { + // We accept either contiguous (above) or weird strides where the beginning + // of the next one is the previous one. Basically the last two strides are + // both 1! + k_start = segments[tid.z]; + k_end = segments[tid.z + 1]; + } + A += transpose_a ? k_start * params->lda : k_start; + B += transpose_b ? k_start : k_start * params->ldb; + C += tid.z * params->batch_stride_d; + + // Prepare threadgroup mma operation + thread mma_t mma_op(simd_group_id, simd_lane_id); + + // Prepare threadgroup loading operations + thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); + thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); + + // Matrix level alignment so only check K + if (align_M && align_N) { + uint32_t k = k_start + BK; + for (; k <= k_end; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + short k_remain = BK - short(k - k_end); + const short2 tile_dims_A = + transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); + const short2 tile_dims_B = + transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + if (k_remain > 0) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(As, Bs); + } + mma_op.store_result(C, params->ldd); + } else { + // Tile aligned do the same as above + if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) { + uint32_t k = k_start + BK; + for (; k <= k_end; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + short k_remain = BK - short(k - k_end); + const short2 tile_dims_A = + transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); + const short2 tile_dims_B = + transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + if (k_remain > 0) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(As, Bs); + } + mma_op.store_result(C, params->ldd); + } + + // Tile partially aligned check rows + else if (align_N || tgp_bn == BN) { + uint32_t k = k_start + BK; + for (; k <= k_end; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load elements into threadgroup + loader_a.load_safe( + transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm)); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + short k_remain = BK - short(k - k_end); + const short2 tile_dims_A = + transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); + const short2 tile_dims_B = + transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + if (k_remain > 0) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(As, Bs); + } + mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm)); + } + + // Tile partially aligned check cols + else if (align_M || tgp_bm == BM) { + uint32_t k = k_start + BK; + for (; k <= k_end; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_safe( + transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK)); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + short k_remain = BK - short(k - k_end); + const short2 tile_dims_A = + transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); + const short2 tile_dims_B = + transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + if (k_remain > 0) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(As, Bs); + } + mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm)); + } + + // Nothing aligned so check both rows and cols + else { + uint32_t k = k_start + BK; + for (; k <= k_end; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load elements into threadgroup + loader_a.load_safe( + transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm)); + loader_b.load_safe( + transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK)); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + short k_remain = BK - short(k - k_end); + const short2 tile_dims_A = + transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); + const short2 tile_dims_B = + transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + if (k_remain > 0) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(As, Bs); + } + mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm)); + } + } +} diff --git a/Source/Cxxmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_splitk.h b/Source/Cxxmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_splitk.h new file mode 100644 index 00000000..a372e939 --- /dev/null +++ b/Source/Cxxmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_splitk.h @@ -0,0 +1,227 @@ +// Copyright © 2024 Apple Inc. + +using namespace mlx::steel; + +/////////////////////////////////////////////////////////////////////////////// +// GEMM kernels +/////////////////////////////////////////////////////////////////////////////// + +template < + typename T, + typename U, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + bool MN_aligned, + bool K_aligned> +[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void gemm_splitk( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + device U* C [[buffer(2)]], + const constant GEMMSpiltKParams* params [[buffer(3)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + (void)lid; + + using gemm_kernel = GEMMKernel< + T, + U, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + MN_aligned, + K_aligned>; + using loader_a_t = typename gemm_kernel::loader_a_t; + using loader_b_t = typename gemm_kernel::loader_b_t; + using mma_t = typename gemm_kernel::mma_t; + + threadgroup T As[gemm_kernel::tgp_mem_size_a]; + threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; + + const int tid_x = tid.x; + const int tid_y = tid.y; + const int tid_z = tid.z; + + if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { + return; + } + + // Find block in A, B, C + const int c_row = tid_y * BM; + const int c_col = tid_x * BN; + const int k_start = params->split_k_partition_size * tid_z; + + const size_t c_row_long = size_t(c_row); + const size_t c_col_long = size_t(c_col); + const size_t k_start_long = size_t(k_start); + + A += transpose_a ? (c_row_long + k_start_long * params->lda) + : (k_start_long + c_row_long * params->lda); + B += transpose_b ? (k_start_long + c_col_long * params->ldb) + : (c_col_long + k_start_long * params->ldb); + C += (size_t(params->split_k_partition_stride) * tid_z) + + (c_row_long * params->ldc + c_col_long); + + // Prepare threadgroup loading operations + thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); + thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); + + // Prepare threadgroup mma operation + thread mma_t mma_op(simd_group_id, simd_lane_id); + + int gemm_k_iterations = params->gemm_k_iterations_aligned; + + short tgp_bm = min(BM, params->M - c_row); + short tgp_bn = min(BN, params->N - c_col); + short leftover_bk = params->K % BK; + + if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + } else if (tgp_bn == BN) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + } else if (tgp_bm == BM) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + } else { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if ((tid_z + 1) == (params->split_k_partitions)) { + int gemm_k_iter_remaining = + (params->K - (k_start + params->split_k_partition_size)) / BK; + if (!K_aligned || gemm_k_iter_remaining > 0) + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iter_remaining, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + } + + if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) { + mma_op.store_result(C, params->ldc); + } else { + mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm)); + } +} + +/////////////////////////////////////////////////////////////////////////////// +// Split k accumulation kernel +/////////////////////////////////////////////////////////////////////////////// + +template < + typename AccT, + typename OutT, + typename Epilogue = TransformNone> +[[kernel]] void gemm_splitk_accum( + const device AccT* C_split [[buffer(0)]], + device OutT* D [[buffer(1)]], + const constant int& k_partitions [[buffer(2)]], + const constant int& partition_stride [[buffer(3)]], + const constant int& ldd [[buffer(4)]], + uint2 gid [[thread_position_in_grid]]) { + // Ajust D and C + D += gid.x + gid.y * size_t(ldd); + C_split += gid.x + gid.y * size_t(ldd); + + size_t offset = 0; + AccT out = 0; + + for (int i = 0; i < k_partitions; i++) { + out += C_split[offset]; + offset += partition_stride; + } + + // Write output + D[0] = Epilogue::apply(out); +} + +template < + typename AccT, + typename OutT, + typename Epilogue = TransformAxpby> +[[kernel]] void gemm_splitk_accum_axpby( + const device AccT* C_split [[buffer(0)]], + device OutT* D [[buffer(1)]], + const constant int& k_partitions [[buffer(2)]], + const constant int& partition_stride [[buffer(3)]], + const constant int& ldd [[buffer(4)]], + const device OutT* C [[buffer(5)]], + const constant int& ldc [[buffer(6)]], + const constant int& fdc [[buffer(7)]], + const constant float& alpha [[buffer(8)]], + const constant float& beta [[buffer(9)]], + uint2 gid [[thread_position_in_grid]]) { + // Ajust D and C + C += gid.x * size_t(fdc) + gid.y * size_t(ldc); + D += gid.x + gid.y * size_t(ldd); + C_split += gid.x + gid.y * size_t(ldd); + + size_t offset = 0; + AccT out = 0; + + for (int i = 0; i < k_partitions; i++) { + out += C_split[offset]; + offset += partition_stride; + } + + // Write output + Epilogue op(alpha, beta); + D[0] = op.apply(out, *C); +} diff --git a/Source/Cxxmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_splitk_nax.h b/Source/Cxxmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_splitk_nax.h new file mode 100644 index 00000000..1b6b8280 --- /dev/null +++ b/Source/Cxxmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_splitk_nax.h @@ -0,0 +1,152 @@ +// Copyright © 2026 Apple Inc. + +using namespace mlx::steel; + +constant bool align_M [[function_constant(200)]]; +constant bool align_N [[function_constant(201)]]; + +/////////////////////////////////////////////////////////////////////////////// +// NAX Split-K GEMM kernel +/////////////////////////////////////////////////////////////////////////////// + +// clang-format off +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + typename AccumType = float> +[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm_splitk_nax( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + device AccumType* C [[buffer(2)]], + const constant GEMMSpiltKParams* params [[buffer(3)]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]]) { // clang-format on + + const int linear_tid = tid.x; + + // Compute swizzled tile dimensions + const int tn_swizzled = params->tiles_n << params->swizzle_log; + const int tm_swizzled = + (params->tiles_m + (1 << params->swizzle_log) - 1) >> params->swizzle_log; + const int tiles_per_partition = tn_swizzled * tm_swizzled; + + const int tid_z = linear_tid / tiles_per_partition; + const int xy_flat = linear_tid % tiles_per_partition; + + // Decode 2D grid coordinates in swizzled space + const int grid_x = xy_flat % tn_swizzled; + const int grid_y = xy_flat / tn_swizzled; + + // Apply X-Y swizzle + const int tid_y = (grid_y << params->swizzle_log) + + (grid_x & ((1 << params->swizzle_log) - 1)); + const int tid_x = grid_x >> params->swizzle_log; + + // Exit early + if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { + return; + } + + // Calculate partition bounds + const int c_row = tid_y * BM; + const int c_col = tid_x * BN; + const int k_start = params->split_k_partition_size * tid_z; + const int k_end = min(k_start + params->split_k_partition_size, params->K); + + const size_t c_row_long = size_t(c_row); + const size_t c_col_long = size_t(c_col); + const size_t k_start_long = size_t(k_start); + + // Adjust pointers for split-K partition + A += transpose_a ? (c_row_long + k_start_long * params->lda) + : (k_start_long + c_row_long * params->lda); + B += transpose_b ? (k_start_long + c_col_long * params->ldb) + : (c_col_long + k_start_long * params->ldb); + C += (size_t(params->split_k_partition_stride) * tid_z) + + (c_row_long * params->ldc + c_col_long); + + // NAX tile configuration + constexpr short UM = 16; + constexpr short UN = 32; + constexpr short UK = 16; + constexpr short SM = BM / WM; + constexpr short SN = BN / WN; + constexpr short SK = 32; + + constexpr short TM = SM / UM; + constexpr short TN = SN / UN; + + // Calculate simdgroup offsets and alignment + const short tm = SM * (simd_group_id / WN); + const short tn = SN * (simd_group_id % WN); + + const int sgp_sm_int = + align_M ? int(SM) : min(int(SM), params->M - (c_row + tm)); + const short sgp_sm = short(sgp_sm_int); + const bool is_unaligned_sm = align_M ? false : (sgp_sm != SM); + + const int sgp_sn_int = + align_N ? int(SN) : min(int(SN), params->N - (c_col + tn)); + const short sgp_sn = short(sgp_sn_int); + const bool is_unaligned_sn = align_N ? false : (sgp_sn != SN); + + A += transpose_a ? tm : (tm * params->lda); + B += transpose_b ? (tn * params->ldb) : tn; + C += tm * params->ldc + tn; + + using DSubTile = NAXSubTile; + NAXTile Dtile; + + // gemm_loop through the partition + // Check K-alignment at runtime (partition-specific) + const int partition_k_size = k_end - k_start; + const int partition_k_iters = partition_k_size / BK; + const bool partition_k_aligned = (partition_k_size % BK) == 0; + + dispatch_bool(partition_k_aligned, [&](auto kAlignedK) { + dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) { + dispatch_bool(align_N || !is_unaligned_sn, [&](auto kAlignedN) { + Dtile = gemm_loop< + T, + SM, + SN, + SK, + BK, + transpose_a, + transpose_b, + kAlignedM.value, + kAlignedN.value, + kAlignedK.value, + UM, + UN, + UK, + AccumType>( + A, + B, + params->lda, + params->ldb, + partition_k_size, + partition_k_iters, + sgp_sm, + sgp_sn); + }); + }); + }); + + // Store result + dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) { + dispatch_bool(align_N || !is_unaligned_sn, [&](auto kAlignedN) { + if constexpr (kAlignedM && kAlignedN) { + Dtile.store(C, int(params->ldc)); + } else { + Dtile.store_safe(C, int(params->ldc), short2(sgp_sn, sgp_sm)); + } + }); + }); +} diff --git a/Source/Cmlx/mlx-generated/metal/steel/gemm/loader.h b/Source/Cxxmlx/mlx-generated/metal/steel/gemm/loader.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/steel/gemm/loader.h rename to Source/Cxxmlx/mlx-generated/metal/steel/gemm/loader.h diff --git a/Source/Cmlx/mlx-generated/metal/steel/gemm/mma.h b/Source/Cxxmlx/mlx-generated/metal/steel/gemm/mma.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/steel/gemm/mma.h rename to Source/Cxxmlx/mlx-generated/metal/steel/gemm/mma.h diff --git a/Source/Cmlx/mlx-generated/metal/steel/gemm/nax.h b/Source/Cxxmlx/mlx-generated/metal/steel/gemm/nax.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/steel/gemm/nax.h rename to Source/Cxxmlx/mlx-generated/metal/steel/gemm/nax.h diff --git a/Source/Cxxmlx/mlx-generated/metal/steel/gemm/params.h b/Source/Cxxmlx/mlx-generated/metal/steel/gemm/params.h new file mode 100644 index 00000000..b0ba07dd --- /dev/null +++ b/Source/Cxxmlx/mlx-generated/metal/steel/gemm/params.h @@ -0,0 +1,65 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +/////////////////////////////////////////////////////////////////////////////// +// GEMM param classes +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +struct GEMMParams { + const int M; + const int N; + const int K; + + const int lda; + const int ldb; + const int ldd; + + const int tiles_n; + const int tiles_m; + + const int64_t batch_stride_a; + const int64_t batch_stride_b; + const int64_t batch_stride_d; + + const int swizzle_log; + const int gemm_k_iterations_aligned; + + const int batch_ndim; +}; + +struct GEMMSpiltKParams { + const int M; + const int N; + const int K; + + const int lda; + const int ldb; + const int ldc; + + const int tiles_n; + const int tiles_m; + + const int split_k_partitions; + const int split_k_partition_stride; + const int split_k_partition_size; + + const int swizzle_log; + const int gemm_k_iterations_aligned; +}; + +struct GEMMAddMMParams { + const int ldc; + const int fdc; + + const int64_t batch_stride_c; + + const float alpha; + const float beta; +}; + +} // namespace steel +} // namespace mlx diff --git a/Source/Cmlx/mlx-generated/metal/steel/gemm/transforms.h b/Source/Cxxmlx/mlx-generated/metal/steel/gemm/transforms.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/steel/gemm/transforms.h rename to Source/Cxxmlx/mlx-generated/metal/steel/gemm/transforms.h diff --git a/Source/Cxxmlx/mlx-generated/metal/steel/utils.h b/Source/Cxxmlx/mlx-generated/metal/steel/utils.h new file mode 100644 index 00000000..55720a28 --- /dev/null +++ b/Source/Cxxmlx/mlx-generated/metal/steel/utils.h @@ -0,0 +1,42 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include + +METAL_FUNC ulong2 elem_to_loc_broadcast( + uint elem, + constant const int* shape, + constant const int64_t* a_strides, + constant const int64_t* b_strides, + int ndim) { + ulong loc_a{0}; + ulong loc_b{0}; + for (int i = ndim - 1; i >= 0 && elem > 0; --i) { + int pos_in_dim = (elem % shape[i]); + elem /= shape[i]; + loc_a += pos_in_dim * a_strides[i]; + loc_b += pos_in_dim * b_strides[i]; + } + return ulong2(loc_a, loc_b); +} + +METAL_FUNC ulong3 elem_to_loc_broadcast( + uint elem, + constant const int* shape, + constant const int64_t* a_strides, + constant const int64_t* b_strides, + constant const int64_t* c_strides, + int ndim) { + ulong loc_a{0}; + ulong loc_b{0}; + ulong loc_c{0}; + for (int i = ndim - 1; i >= 0 && elem > 0; --i) { + int pos_in_dim = (elem % shape[i]); + elem /= shape[i]; + loc_a += pos_in_dim * a_strides[i]; + loc_b += pos_in_dim * b_strides[i]; + loc_c += pos_in_dim * c_strides[i]; + } + return ulong3(loc_a, loc_b, loc_c); +} diff --git a/Source/Cmlx/mlx-generated/metal/steel/utils/integral_constant.h b/Source/Cxxmlx/mlx-generated/metal/steel/utils/integral_constant.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/steel/utils/integral_constant.h rename to Source/Cxxmlx/mlx-generated/metal/steel/utils/integral_constant.h diff --git a/Source/Cxxmlx/mlx-generated/metal/steel/utils/type_traits.h b/Source/Cxxmlx/mlx-generated/metal/steel/utils/type_traits.h new file mode 100644 index 00000000..f004dc83 --- /dev/null +++ b/Source/Cxxmlx/mlx-generated/metal/steel/utils/type_traits.h @@ -0,0 +1,55 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include + +#pragma METAL internals : enable + +namespace metal { + +template +struct is_empty : metal::bool_constant<__is_empty(T)> {}; + +#ifdef __cpp_variable_templates +template +constexpr constant bool is_empty_v = is_empty::value; +#endif + +template +struct make_void { + typedef void type; +}; + +template +using void_t = typename make_void::type; + +template +struct is_static : metal::bool_constant>::value> {}; + +template +struct pointer_element {}; + +template +struct pointer_element { + using type = remove_cv_t; +}; +template +struct pointer_element { + using type = remove_cv_t; +}; +template +struct pointer_element { + using type = remove_cv_t; +}; +template +struct pointer_element { + using type = remove_cv_t; +}; + +template +using pointer_element_t = typename pointer_element>::type; + +} // namespace metal + +#pragma METAL internals : disable \ No newline at end of file diff --git a/Source/Cxxmlx/mlx-generated/metal/ternary.h b/Source/Cxxmlx/mlx-generated/metal/ternary.h new file mode 100644 index 00000000..705b73e2 --- /dev/null +++ b/Source/Cxxmlx/mlx-generated/metal/ternary.h @@ -0,0 +1,145 @@ +// Copyright © 2024 Apple Inc. + +template < + typename T, + typename Op, + bool BSCALAR, + bool CSCALAR, + int N = WorkPerThread::n> +[[kernel]] void ternary_v( + device const bool* a, + device const T* b, + device const T* c, + device T* d, + constant uint& size, + uint index [[thread_position_in_grid]]) { + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + auto bidx = BSCALAR ? 0 : index + i; + auto cidx = CSCALAR ? 0 : index + i; + d[index + i] = Op()(a[index + i], b[bidx], c[cidx]); + } + } else { + for (int i = 0; i < N; ++i) { + auto bidx = BSCALAR ? 0 : index + i; + auto cidx = CSCALAR ? 0 : index + i; + d[index + i] = Op()(a[index + i], b[bidx], c[cidx]); + } + } +} + +template < + typename T, + typename Op, + bool BSCALAR, + bool CSCALAR, + int N = WorkPerThread::n> +[[kernel]] void ternary_v2( + device const bool* a, + device const T* b, + device const T* c, + device T* d, + constant int64_t& size, + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + auto bidx = BSCALAR ? 0 : offset + i; + auto cidx = CSCALAR ? 0 : offset + i; + d[offset + i] = Op()(a[offset + i], b[bidx], c[cidx]); + } + } else { + for (int i = 0; i < N; ++i) { + auto bidx = BSCALAR ? 0 : offset + i; + auto cidx = CSCALAR ? 0 : offset + i; + d[offset + i] = Op()(a[offset + i], b[bidx], c[cidx]); + } + } +} + +template +[[kernel]] void ternary_g_nd1( + device const bool* a, + device const T* b, + device const T* c, + device T* d, + constant const int64_t& a_strides, + constant const int64_t& b_strides, + constant const int64_t& c_strides, + uint index [[thread_position_in_grid]]) { + auto a_idx = elem_to_loc_1(index, a_strides); + auto b_idx = elem_to_loc_1(index, b_strides); + auto c_idx = elem_to_loc_1(index, c_strides); + d[index] = Op()(a[a_idx], b[b_idx], c[c_idx]); +} + +template +[[kernel]] void ternary_g_nd2( + device const bool* a, + device const T* b, + device const T* c, + device T* d, + constant const int64_t a_strides[2], + constant const int64_t b_strides[2], + constant const int64_t c_strides[2], + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + auto a_idx = elem_to_loc_2(index, a_strides); + auto b_idx = elem_to_loc_2(index, b_strides); + auto c_idx = elem_to_loc_2(index, c_strides); + IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y; + d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]); +} + +template +[[kernel]] void ternary_g_nd3( + device const bool* a, + device const T* b, + device const T* c, + device T* d, + constant const int64_t a_strides[3], + constant const int64_t b_strides[3], + constant const int64_t c_strides[3], + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + auto a_idx = elem_to_loc_3(index, a_strides); + auto b_idx = elem_to_loc_3(index, b_strides); + auto c_idx = elem_to_loc_3(index, c_strides); + IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z); + d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]); +} + +template +[[kernel]] void ternary_g( + device const bool* a, + device const T* b, + device const T* c, + device T* d, + constant const int* shape, + constant const int64_t* a_strides, + constant const int64_t* b_strides, + constant const int64_t* c_strides, + constant const int& ndim, + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + auto idx = elem_to_loc_3_nd( + {N * index.x, index.y, index.z}, + shape, + a_strides, + b_strides, + c_strides, + ndim); + auto xshape = shape[ndim - 1]; + IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z); + IdxT a_xstride = a_strides[ndim - 1]; + IdxT b_xstride = b_strides[ndim - 1]; + IdxT c_xstride = c_strides[ndim - 1]; + for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { + d[out_idx++] = Op()(a[idx.x], b[idx.y], c[idx.z]); + idx.x += a_xstride; + idx.y += b_xstride; + idx.z += c_xstride; + } +} diff --git a/Source/Cxxmlx/mlx-generated/metal/ternary_ops.h b/Source/Cxxmlx/mlx-generated/metal/ternary_ops.h new file mode 100644 index 00000000..e0235d9d --- /dev/null +++ b/Source/Cxxmlx/mlx-generated/metal/ternary_ops.h @@ -0,0 +1,10 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +struct Select { + template + T operator()(bool condition, T x, T y) { + return condition ? x : y; + } +}; diff --git a/Source/Cxxmlx/mlx-generated/metal/unary.h b/Source/Cxxmlx/mlx-generated/metal/unary.h new file mode 100644 index 00000000..db7be3d4 --- /dev/null +++ b/Source/Cxxmlx/mlx-generated/metal/unary.h @@ -0,0 +1,63 @@ +// Copyright © 2024 Apple Inc. + +template ::n> +[[kernel]] void unary_v( + device const T* in, + device U* out, + constant uint& size, + uint index [[thread_position_in_grid]]) { + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + out[index + i] = static_cast(Op()(in[index + i])); + } + } else { + for (int i = 0; i < N; ++i) { + out[index + i] = static_cast(Op()(in[index + i])); + } + } +} + +template ::n> +[[kernel]] void unary_v2( + device const T* in, + device U* out, + constant int64_t& size, + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + out[offset + i] = static_cast(Op()(in[offset + i])); + } + } else { + for (int i = 0; i < N; ++i) { + out[offset + i] = static_cast(Op()(in[offset + i])); + } + } +} + +template < + typename T, + typename U, + typename Op, + int N = 1, + typename IdxT = int64_t> +[[kernel]] void unary_g( + device const T* in, + device U* out, + constant const int* in_shape, + constant const int64_t* in_strides, + device const int& ndim, + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + auto idx = elem_to_loc( + {N * index.x, index.y, index.z}, in_shape, in_strides, ndim); + auto xshape = in_shape[ndim - 1]; + IdxT xstride = in_strides[ndim - 1]; + IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z); + for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { + out[out_idx++] = static_cast(Op()(in[idx])); + idx += xstride; + } +} diff --git a/Source/Cmlx/mlx-generated/metal/unary_ops.h b/Source/Cxxmlx/mlx-generated/metal/unary_ops.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/unary_ops.h rename to Source/Cxxmlx/mlx-generated/metal/unary_ops.h diff --git a/Source/Cmlx/mlx-generated/metal/utils.h b/Source/Cxxmlx/mlx-generated/metal/utils.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/utils.h rename to Source/Cxxmlx/mlx-generated/metal/utils.h diff --git a/Source/Cmlx/mlx-generated/quantized.cpp b/Source/Cxxmlx/mlx-generated/quantized.cpp similarity index 100% rename from Source/Cmlx/mlx-generated/quantized.cpp rename to Source/Cxxmlx/mlx-generated/quantized.cpp diff --git a/Source/Cmlx/mlx-generated/quantized_nax.cpp b/Source/Cxxmlx/mlx-generated/quantized_nax.cpp similarity index 100% rename from Source/Cmlx/mlx-generated/quantized_nax.cpp rename to Source/Cxxmlx/mlx-generated/quantized_nax.cpp diff --git a/Source/Cmlx/mlx-generated/quantized_utils.cpp b/Source/Cxxmlx/mlx-generated/quantized_utils.cpp similarity index 100% rename from Source/Cmlx/mlx-generated/quantized_utils.cpp rename to Source/Cxxmlx/mlx-generated/quantized_utils.cpp diff --git a/Source/Cmlx/mlx-generated/reduce.cpp b/Source/Cxxmlx/mlx-generated/reduce.cpp similarity index 100% rename from Source/Cmlx/mlx-generated/reduce.cpp rename to Source/Cxxmlx/mlx-generated/reduce.cpp diff --git a/Source/Cmlx/mlx-generated/reduce_utils.cpp b/Source/Cxxmlx/mlx-generated/reduce_utils.cpp similarity index 100% rename from Source/Cmlx/mlx-generated/reduce_utils.cpp rename to Source/Cxxmlx/mlx-generated/reduce_utils.cpp diff --git a/Source/Cmlx/mlx-generated/scan.cpp b/Source/Cxxmlx/mlx-generated/scan.cpp similarity index 100% rename from Source/Cmlx/mlx-generated/scan.cpp rename to Source/Cxxmlx/mlx-generated/scan.cpp diff --git a/Source/Cmlx/mlx-generated/scatter.cpp b/Source/Cxxmlx/mlx-generated/scatter.cpp similarity index 100% rename from Source/Cmlx/mlx-generated/scatter.cpp rename to Source/Cxxmlx/mlx-generated/scatter.cpp diff --git a/Source/Cmlx/mlx-generated/scatter_axis.cpp b/Source/Cxxmlx/mlx-generated/scatter_axis.cpp similarity index 100% rename from Source/Cmlx/mlx-generated/scatter_axis.cpp rename to Source/Cxxmlx/mlx-generated/scatter_axis.cpp diff --git a/Source/Cmlx/mlx-generated/softmax.cpp b/Source/Cxxmlx/mlx-generated/softmax.cpp similarity index 100% rename from Source/Cmlx/mlx-generated/softmax.cpp rename to Source/Cxxmlx/mlx-generated/softmax.cpp diff --git a/Source/Cmlx/mlx-generated/sort.cpp b/Source/Cxxmlx/mlx-generated/sort.cpp similarity index 100% rename from Source/Cmlx/mlx-generated/sort.cpp rename to Source/Cxxmlx/mlx-generated/sort.cpp diff --git a/Source/Cmlx/mlx-generated/steel_attention.cpp b/Source/Cxxmlx/mlx-generated/steel_attention.cpp similarity index 100% rename from Source/Cmlx/mlx-generated/steel_attention.cpp rename to Source/Cxxmlx/mlx-generated/steel_attention.cpp diff --git a/Source/Cmlx/mlx-generated/steel_attention_nax.cpp b/Source/Cxxmlx/mlx-generated/steel_attention_nax.cpp similarity index 100% rename from Source/Cmlx/mlx-generated/steel_attention_nax.cpp rename to Source/Cxxmlx/mlx-generated/steel_attention_nax.cpp diff --git a/Source/Cmlx/mlx-generated/steel_conv.cpp b/Source/Cxxmlx/mlx-generated/steel_conv.cpp similarity index 100% rename from Source/Cmlx/mlx-generated/steel_conv.cpp rename to Source/Cxxmlx/mlx-generated/steel_conv.cpp diff --git a/Source/Cmlx/mlx-generated/steel_conv_3d.cpp b/Source/Cxxmlx/mlx-generated/steel_conv_3d.cpp similarity index 100% rename from Source/Cmlx/mlx-generated/steel_conv_3d.cpp rename to Source/Cxxmlx/mlx-generated/steel_conv_3d.cpp diff --git a/Source/Cmlx/mlx-generated/steel_conv_general.cpp b/Source/Cxxmlx/mlx-generated/steel_conv_general.cpp similarity index 100% rename from Source/Cmlx/mlx-generated/steel_conv_general.cpp rename to Source/Cxxmlx/mlx-generated/steel_conv_general.cpp diff --git a/Source/Cmlx/mlx-generated/steel_gemm_fused.cpp b/Source/Cxxmlx/mlx-generated/steel_gemm_fused.cpp similarity index 100% rename from Source/Cmlx/mlx-generated/steel_gemm_fused.cpp rename to Source/Cxxmlx/mlx-generated/steel_gemm_fused.cpp diff --git a/Source/Cmlx/mlx-generated/steel_gemm_fused_nax.cpp b/Source/Cxxmlx/mlx-generated/steel_gemm_fused_nax.cpp similarity index 100% rename from Source/Cmlx/mlx-generated/steel_gemm_fused_nax.cpp rename to Source/Cxxmlx/mlx-generated/steel_gemm_fused_nax.cpp diff --git a/Source/Cmlx/mlx-generated/steel_gemm_gather.cpp b/Source/Cxxmlx/mlx-generated/steel_gemm_gather.cpp similarity index 100% rename from Source/Cmlx/mlx-generated/steel_gemm_gather.cpp rename to Source/Cxxmlx/mlx-generated/steel_gemm_gather.cpp diff --git a/Source/Cmlx/mlx-generated/steel_gemm_gather_nax.cpp b/Source/Cxxmlx/mlx-generated/steel_gemm_gather_nax.cpp similarity index 100% rename from Source/Cmlx/mlx-generated/steel_gemm_gather_nax.cpp rename to Source/Cxxmlx/mlx-generated/steel_gemm_gather_nax.cpp diff --git a/Source/Cmlx/mlx-generated/steel_gemm_masked.cpp b/Source/Cxxmlx/mlx-generated/steel_gemm_masked.cpp similarity index 100% rename from Source/Cmlx/mlx-generated/steel_gemm_masked.cpp rename to Source/Cxxmlx/mlx-generated/steel_gemm_masked.cpp diff --git a/Source/Cmlx/mlx-generated/steel_gemm_segmented.cpp b/Source/Cxxmlx/mlx-generated/steel_gemm_segmented.cpp similarity index 100% rename from Source/Cmlx/mlx-generated/steel_gemm_segmented.cpp rename to Source/Cxxmlx/mlx-generated/steel_gemm_segmented.cpp diff --git a/Source/Cmlx/mlx-generated/steel_gemm_splitk.cpp b/Source/Cxxmlx/mlx-generated/steel_gemm_splitk.cpp similarity index 100% rename from Source/Cmlx/mlx-generated/steel_gemm_splitk.cpp rename to Source/Cxxmlx/mlx-generated/steel_gemm_splitk.cpp diff --git a/Source/Cmlx/mlx-generated/steel_gemm_splitk_nax.cpp b/Source/Cxxmlx/mlx-generated/steel_gemm_splitk_nax.cpp similarity index 100% rename from Source/Cmlx/mlx-generated/steel_gemm_splitk_nax.cpp rename to Source/Cxxmlx/mlx-generated/steel_gemm_splitk_nax.cpp diff --git a/Source/Cmlx/mlx-generated/ternary.cpp b/Source/Cxxmlx/mlx-generated/ternary.cpp similarity index 100% rename from Source/Cmlx/mlx-generated/ternary.cpp rename to Source/Cxxmlx/mlx-generated/ternary.cpp diff --git a/Source/Cmlx/mlx-generated/ternary_ops.cpp b/Source/Cxxmlx/mlx-generated/ternary_ops.cpp similarity index 100% rename from Source/Cmlx/mlx-generated/ternary_ops.cpp rename to Source/Cxxmlx/mlx-generated/ternary_ops.cpp diff --git a/Source/Cmlx/mlx-generated/unary.cpp b/Source/Cxxmlx/mlx-generated/unary.cpp similarity index 100% rename from Source/Cmlx/mlx-generated/unary.cpp rename to Source/Cxxmlx/mlx-generated/unary.cpp diff --git a/Source/Cmlx/mlx-generated/unary_ops.cpp b/Source/Cxxmlx/mlx-generated/unary_ops.cpp similarity index 100% rename from Source/Cmlx/mlx-generated/unary_ops.cpp rename to Source/Cxxmlx/mlx-generated/unary_ops.cpp diff --git a/Source/Cmlx/mlx-generated/utils.cpp b/Source/Cxxmlx/mlx-generated/utils.cpp similarity index 100% rename from Source/Cmlx/mlx-generated/utils.cpp rename to Source/Cxxmlx/mlx-generated/utils.cpp diff --git a/Source/Cmlx/vendor-README.md b/Source/Cxxmlx/vendor-README.md similarity index 100% rename from Source/Cmlx/vendor-README.md rename to Source/Cxxmlx/vendor-README.md diff --git a/tools/fix-metal-includes.sh b/tools/fix-metal-includes.sh index 622d4311..7d6fbe5f 100755 --- a/tools/fix-metal-includes.sh +++ b/tools/fix-metal-includes.sh @@ -7,18 +7,18 @@ SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) ROOT_DIR=$(realpath "${SCRIPT_DIR}/..") # Where the files end up -OUTPUT_DIR="${ROOT_DIR}/Source/Cmlx/mlx-generated/metal" +OUTPUT_DIR="${ROOT_DIR}/Source/Cxxmlx/mlx-generated/metal" -# The Cmlx source dir -CMLX_MLX_DIR="${ROOT_DIR}/Source/Cmlx/mlx" +# The Cxxmlx source dir +CXXMLX_MLX_DIR="${ROOT_DIR}/Source/Cxxmlx/mlx" # sub-directory of Cmlx source containing the kernels KERNELS_INCLUDE_PATH="mlx/backend/metal/kernels" -KERNELS_DIR="${CMLX_MLX_DIR}/${KERNELS_INCLUDE_PATH}" +KERNELS_DIR="${CXXMLX_MLX_DIR}/${KERNELS_INCLUDE_PATH}" # list of kernels files to process -# see Source/Cmlx/mlx/mlx/backend/metal/kernels/CMakeLists.txt +# see Source/Cxxmlx/mlx/mlx/backend/metal/kernels/CMakeLists.txt KERNEL_LIST=" \ arg_reduce.metal \ conv.metal \ diff --git a/tools/update-mlx-xcodeproj.sh b/tools/update-mlx-xcodeproj.sh index a37d04d9..911d49f5 100755 --- a/tools/update-mlx-xcodeproj.sh +++ b/tools/update-mlx-xcodeproj.sh @@ -79,7 +79,7 @@ do h=mlx-`echo $x | tr / -` d=Source/Cmlx/include-framework/$h echo "#ifdef __cplusplus" > $d - cat Source/Cmlx/mlx/mlx/$x | sed -e 's:backend/:backend-:g' -e 's:cuda/:cuda-:g' -e 's:gpu/:gpu-:g' -e 's:metal/:metal-:g' -e 's:distributed/:distributed-:g' -e 's:types/:types-:' -e 's:io/:io-:' -e 's:common/:common-:' -e 's:cpu/:cpu-:' -e 's:#include "mlx/:#include :g' -e 's:Metal/Metal.hpp:Cmlx/Metal.hpp:g' >> $d + cat Source/Cxxmlx/mlx/mlx/$x | sed -e 's:backend/:backend-:g' -e 's:cuda/:cuda-:g' -e 's:gpu/:gpu-:g' -e 's:metal/:metal-:g' -e 's:distributed/:distributed-:g' -e 's:types/:types-:' -e 's:io/:io-:' -e 's:common/:common-:' -e 's:cpu/:cpu-:' -e 's:#include "mlx/:#include :g' -e 's:Metal/Metal.hpp:Cmlx/Metal.hpp:g' >> $d echo "#endif" >> $d # add to Cmlx @@ -87,7 +87,7 @@ do done # build & copy in the Metal.hpp header -(cd Source/Cmlx/metal-cpp; ./SingleHeader/MakeSingleHeader.py -o ../include-framework/Metal.hpp.in Foundation/Foundation.hpp QuartzCore/QuartzCore.hpp Metal/Metal.hpp MetalFX/MetalFX.hpp) +(cd Source/Cxxmlx/metal-cpp; ./SingleHeader/MakeSingleHeader.py -o ../../Cmlx/include-framework/Metal.hpp.in Foundation/Foundation.hpp QuartzCore/QuartzCore.hpp Metal/Metal.hpp MetalFX/MetalFX.hpp) echo "#ifdef __cplusplus" > Source/Cmlx/include-framework/Metal.hpp cat Source/Cmlx/include-framework/Metal.hpp.in >> Source/Cmlx/include-framework/Metal.hpp diff --git a/tools/update-mlx.sh b/tools/update-mlx.sh index fc1ab961..ed600b61 100755 --- a/tools/update-mlx.sh +++ b/tools/update-mlx.sh @@ -14,11 +14,31 @@ fi rm -f Source/Cmlx/include/mlx/c/* cp Source/Cmlx/mlx-c/mlx/c/*.h Source/Cmlx/include/mlx/c -# run the command to do the build-time code generation for Metal +# copy mlx C++ public headers to build area +rm -rf Source/Cxxmlx/include/mlx +mkdir -p Source/Cxxmlx/include/mlx +rsync -a \ + --include='*/' \ + --include='*.h' \ + --include='*.hpp' \ + --exclude='*' \ + Source/Cxxmlx/mlx/mlx/ \ + Source/Cxxmlx/include/mlx/ + +# copy metal-cpp headers used by public MLX Metal backend headers +for header_dir in Foundation Metal MetalFX QuartzCore +do + rm -rf "Source/Cxxmlx/include/${header_dir}" + rsync -a \ + "Source/Cxxmlx/metal-cpp/${header_dir}" \ + Source/Cxxmlx/include/ +done + +# run the command to do the build-time code generation mkdir build cd build -cmake ../Source/Cmlx/mlx -DMLX_METAL_JIT=ON -DMACOS_VERSION=14.0 +cmake ../Source/Cxxmlx/mlx -DMLX_METAL_JIT=ON -DMACOS_VERSION=14.0 # run the cmake build to generate the source files cd mlx/backend/metal @@ -75,28 +95,28 @@ make cpu_compiled_preamble # run the command to do the build-time code generation for CUDA cmake \ - -DMLX_SOURCE_ROOT="../Source/Cmlx/mlx/mlx/backend/cuda" \ + -DMLX_SOURCE_ROOT="../Source/Cxxmlx/mlx/mlx/backend/cuda" \ -DMLX_JIT_SOURCES="device/atomic_ops.cuh:device/binary_ops.cuh:device/cast_op.cuh:device/complex.cuh:device/config.h:device/fp16_math.cuh:device/gather.cuh:device/gather_axis.cuh:device/hadamard.cuh:device/indexing.cuh:device/scatter.cuh:device/scatter_axis.cuh:device/scatter_ops.cuh:device/ternary_ops.cuh:device/unary_ops.cuh:device/utils.cuh" \ - -P "../Source/Cmlx/mlx/mlx/backend/cuda/bin2h.cmake" + -P "../Source/Cxxmlx/mlx/mlx/backend/cuda/bin2h.cmake" cd .. -rm -rf Source/Cmlx/mlx-generated/metal -rm -rf Source/Cmlx/mlx-generated/cuda -rm -f Source/Cmlx/mlx-generated/* -mkdir -p Source/Cmlx/mlx-generated/cuda -cp build/mlx/backend/metal/jit/* Source/Cmlx/mlx-generated -cp build/mlx/backend/cpu/compiled_preamble.cpp Source/Cmlx/mlx-generated -cp build/gen/cuda_jit_sources.h Source/Cmlx/mlx-generated/cuda +rm -rf Source/Cxxmlx/mlx-generated/metal +rm -rf Source/Cxxmlx/mlx-generated/cuda +rm -f Source/Cxxmlx/mlx-generated/* +mkdir -p Source/Cxxmlx/mlx-generated/cuda +cp build/mlx/backend/metal/jit/* Source/Cxxmlx/mlx-generated +cp build/mlx/backend/cpu/compiled_preamble.cpp Source/Cxxmlx/mlx-generated +cp build/gen/cuda_jit_sources.h Source/Cxxmlx/mlx-generated/cuda # we don't need the cmake build directory any more rm -rf build # remove any absolute paths and make them relative to the package root -for x in Source/Cmlx/mlx-generated/*.cpp ; do \ +for x in Source/Cxxmlx/mlx-generated/*.cpp ; do \ sed -i .tmp -e "s:`pwd`/::g" $x done; -rm Source/Cmlx/mlx-generated/*.tmp +rm Source/Cxxmlx/mlx-generated/*.tmp # Update the headers ./tools/fix-metal-includes.sh diff --git a/xcode/MLX.xcodeproj/project.pbxproj b/xcode/MLX.xcodeproj/project.pbxproj index 87b93ba1..9f2f6112 100644 --- a/xcode/MLX.xcodeproj/project.pbxproj +++ b/xcode/MLX.xcodeproj/project.pbxproj @@ -1441,37 +1441,42 @@ }; C3AE90232EAAA47E000BD280 /* json */ = { isa = PBXFileSystemSynchronizedRootGroup; - path = json; - sourceTree = ""; + name = json; + path = ../Source/Cxxmlx/json; + sourceTree = SOURCE_ROOT; }; C3AE908B2EAAA47E000BD280 /* metal-cpp */ = { isa = PBXFileSystemSynchronizedRootGroup; exceptions = ( C3AE939F2EAAAA5A000BD280 /* Exceptions for "metal-cpp" folder in "Cmlx" target */, ); - path = "metal-cpp"; - sourceTree = ""; + name = "metal-cpp"; + path = "../Source/Cxxmlx/metal-cpp"; + sourceTree = SOURCE_ROOT; }; C3AE97002EAAAAAD000BD280 /* mlx */ = { isa = PBXFileSystemSynchronizedRootGroup; exceptions = ( C3AE9EA62EAAABFC000BD280 /* Exceptions for "mlx" folder in "Cmlx" target */, ); - path = mlx; - sourceTree = ""; + name = mlx; + path = ../Source/Cxxmlx/mlx; + sourceTree = SOURCE_ROOT; }; C3AE97872EAAAAAD000BD280 /* mlx-conditional */ = { isa = PBXFileSystemSynchronizedRootGroup; - path = "mlx-conditional"; - sourceTree = ""; + name = "mlx-conditional"; + path = "../Source/Cxxmlx/mlx-conditional"; + sourceTree = SOURCE_ROOT; }; C3AE98112EAAAAAD000BD280 /* mlx-generated */ = { isa = PBXFileSystemSynchronizedRootGroup; exceptions = ( C3CB32A92EB168CD0029A645 /* Exceptions for "mlx-generated" folder in "Cmlx" target */, ); - path = "mlx-generated"; - sourceTree = ""; + name = "mlx-generated"; + path = "../Source/Cxxmlx/mlx-generated"; + sourceTree = SOURCE_ROOT; }; C3AE9D992EAAAB29000BD280 /* mlx-c */ = { isa = PBXFileSystemSynchronizedRootGroup; @@ -1523,8 +1528,9 @@ exceptions = ( C3CBF3322EAC23D80029A645 /* Exceptions for "fmt" folder in "Cmlx" target */, ); - path = fmt; - sourceTree = ""; + name = fmt; + path = ../Source/Cxxmlx/fmt; + sourceTree = SOURCE_ROOT; }; C3CBF3382EAC243B0029A645 /* tools */ = { isa = PBXFileSystemSynchronizedRootGroup; diff --git a/xcode/xcconfig/Cmlx.xcconfig b/xcode/xcconfig/Cmlx.xcconfig index 37d66860..5fab0572 100644 --- a/xcode/xcconfig/Cmlx.xcconfig +++ b/xcode/xcconfig/Cmlx.xcconfig @@ -5,13 +5,12 @@ PRODUCT_BUNDLE_IDENTIFIER = com.apple.mlx.Cmlx DEFINES_MODULE = YES BUILD_LIBRARY_FOR_DISTRIBUTION = YES -HEADER_SEARCH_PATHS = $(SDKROOT)/usr/include/c++/v1 $(SDKROOT)/usr/include $(inherited) $(SRCROOT)/../Source/Cmlx/metal-cpp $(SRCROOT)/../Source/Cmlx/fmt/include $(SRCROOT)/../Source/Cmlx/json/single_include/nlohmann +HEADER_SEARCH_PATHS = $(SDKROOT)/usr/include/c++/v1 $(SDKROOT)/usr/include $(inherited) $(SRCROOT)/../Source/Cxxmlx/metal-cpp $(SRCROOT)/../Source/Cxxmlx/fmt/include $(SRCROOT)/../Source/Cxxmlx/json/single_include/nlohmann OTHER_CFLAGS = -isysroot $(SDKROOT) -USER_HEADER_SEARCH_PATHS = $(SRCROOT)/../Source/Cmlx/mlx-c $(SRCROOT)/../Source/Cmlx/mlx +USER_HEADER_SEARCH_PATHS = $(SRCROOT)/../Source/Cmlx/mlx-c $(SRCROOT)/../Source/Cxxmlx/mlx MTL_COMPILER_FLAGS = -Wall -Wextra -fno-fast-math -Wno-c++17-extensions -Wno-c++20-extensions GCC_PREPROCESSOR_DEFINITIONS = _METAL_=1 SWIFTPM_BUNDLE=\"com.apple.mlx.Cmlx\" MLX_VERSION=\"0.31.1\" MLX_USE_ACCELERATE=1 METAL_PATH=\"default.metallib\" ACCELERATE_NEW_LAPACK=1 -