Skip to content

Add JAX FFI Host support#1446

Open
loney7 wants to merge 1 commit into
NVIDIA:mainfrom
loney7:loney7/ffi-host-support
Open

Add JAX FFI Host support#1446
loney7 wants to merge 1 commit into
NVIDIA:mainfrom
loney7:loney7/ffi-host-support

Conversation

@loney7

@loney7 loney7 commented May 8, 2026

Copy link
Copy Markdown

Description

This PR adds support for running JAX FFI callbacks on the CPU (Host) in addition to CUDA.

Changes:

  • Registered FFI targets for both "CUDA" and "Host" platforms in register_ffi_callback.
  • Handled the "Host" platform case in ffi_callback by using the CPU device and bypassing CUDA-specific features like streams and graphs.
  • Updated FfiCallable to reconstruct arguments and execute the function on the CPU when running on the Host platform.

Checklist

  • I am familiar with the Contributing Guidelines.
  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.

Test plan

You can verify these changes by running the JAX interop tests which include FFI tests. Ensure they pass on both CPU and GPU (if available).

uv run warp/tests/interop/test_jax.py


<!-- This is an auto-generated comment: release notes by coderabbit.ai -->
## Summary by CodeRabbit

* **New Features**
  * Separate CUDA and CPU execution paths for FFI callbacks with device-aware execution scoping.
  * CUDA graph compatibility enabled only for CUDA execution.
  * CPU/host path can take a direct Python execution route for faster host calls.
  * Separate CUDA and Host callback registrations for JAX interop.

* **Tests**
  * Added CPU/host FFI tests covering kernels and callables (add, sincos, in/out args, scale).
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

@copy-pr-bot

copy-pr-bot Bot commented May 8, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai

coderabbitai Bot commented May 8, 2026

Copy link
Copy Markdown

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

This pull request extends JAX experimental FFI callbacks to support both CUDA and Host (CPU) execution paths. Callbacks now accept a platform parameter, register separate FFI targets for each platform, and conditionally dispatch to platform-appropriate device and kernel launch logic. CUDA-specific traits and graph capture modes are guarded by platform checks.

Changes

CUDA and Host FFI Execution

Layer / File(s) Summary
Callback Protocol & Platform Parameter
warp/_src/jax_experimental/ffi.py
FfiKernel.ffi_callback(), FfiCallable.ffi_callback(), and register_ffi_callback() callbacks gain platform="CUDA" parameter. CUDA graph compatibility traits are now enabled only when platform=="CUDA".
Dual-Platform FFI Registration
warp/_src/jax_experimental/ffi.py
FfiKernel and FfiCallable register separate FFI capsules for both CUDA and Host platforms, each passing the appropriate platform argument. register_ffi_callback() stores capsules under distinct registry keys (_cuda and _host suffixes).
FfiKernel Platform-Conditional Launch
warp/_src/jax_experimental/ffi.py
FfiKernel execution branches by platform: CUDA selects CUDA device and stream, calls wp_cuda_launch_kernel; Host uses CPU device with no stream, calls wp_cpu_launch_kernel.
FfiCallable Host Execution Short-Circuit
warp/_src/jax_experimental/ffi.py
When platform=="Host", FfiCallable reconstructs Warp arrays on the CPU device from the call frame and directly invokes the wrapped Python function, bypassing all CUDA graph logic.
FfiCallable CUDA Execution & Graph Capture
warp/_src/jax_experimental/ffi.py
CUDA graph modes are guarded by device.is_cuda. Execution scopes conditionally use ScopedStream (when stream present) or ScopedDevice, restricting graph capture and replay to CUDA devices only.
CPU Host Tests
warp/tests/interop/test_jax.py
New CPU-only jax_kernel and jax_callable tests added and registered in TestJax with devices=None.

🎯 4 (Complex) | ⏱️ ~45 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 5.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'Add JAX FFI Host support' directly and clearly summarizes the main change: adding Host/CPU platform support to JAX FFI callbacks alongside existing CUDA support.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands.

@greptile-apps

greptile-apps Bot commented May 8, 2026

Copy link
Copy Markdown

Greptile Summary

This PR extends the JAX FFI integration to support the CPU ("Host") platform in addition to CUDA, enabling Warp kernels and user-defined Python callables to be invoked through the JAX FFI on CPU devices.

  • Registers a separate Host-platform FFI target (via a lambda-wrapped callback) alongside the existing CUDA target in FfiKernel, FfiCallable, and register_ffi_callback; guards CUDA-graph traits and stream access on platform == "CUDA".
  • Adds a CPU launch path in FfiKernel.ffi_callback that builds a ctypes.Structure from kernel args and calls wp_cpu_launch_kernel, and a Host fast path in FfiCallable.ffi_callback via a new _reconstruct_args helper that creates Warp arrays from raw FFI buffer pointers.
  • Updates ExecutionContext and FfiBuffer in xla_ffi.py to be platform-aware: stream assignment is skipped for Host, and __array_interface__ / __cuda_array_interface__ properties guard their respective platforms.

Confidence Score: 5/5

Safe to merge; the Host execution paths correctly mirror the existing CUDA paths, and the previously identified critical bugs have all been addressed in this version.

The changes are a clean extension of existing, well-tested patterns: the Host callback registration, stream guarding, and argument reconstruction all follow the same structure as the working CUDA path. The only findings are non-blocking quality suggestions.

warp/_src/jax/ffi.py — specifically the CPU launch path in FfiKernel.ffi_callback where the ArgsStruct is created fresh on every invocation.

Important Files Changed

Filename Overview
warp/_src/jax/ffi.py Core change: registers Host FFI targets, adds CPU execution paths for FfiKernel (ArgsStruct-based wp_cpu_launch_kernel call) and FfiCallable (_reconstruct_args + ScopedDevice), and guards CUDA-only logic; the ArgsStruct is rebuilt on every invocation without caching.
warp/_src/jax/xla_ffi.py ExecutionContext now conditionally calls get_stream_from_callframe only for CUDA; FfiBuffer gains a platform field with array_interface for Host and guards cuda_array_interface to CUDA only.
warp/tests/interop/test_jax.py Adds seven Host-platform FFI tests covering add, sincos, in_out, scale_vec (FfiKernel), scale_constant, in_out, and callback (FfiCallable/register_ffi_callback); all gated on jax >= 0.5.0 and CPU availability.
CHANGELOG.md Changelog entry added for Host FFI platform support.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["JAX ffi_call(name, ...)"] --> B{Platform?}
    B -- CUDA --> C["FfiKernel / FfiCallable\nffi_callback(call_frame, platform='CUDA')"]
    B -- Host --> D["FfiKernel / FfiCallable\nffi_callback(call_frame, platform='Host')"]

    C --> E["get_cuda_device(ordinal)\nget_stream_from_callframe()"]
    E --> F["wp_cuda_launch_kernel(...)"]

    D --> G["wp.get_device('cpu')\nstream = None"]
    G --> H{Class?}
    H -- FfiKernel --> I["Build ArgsStruct\nwp_cpu_launch_kernel(hooks.forward, bounds, args)"]
    H -- FfiCallable --> J["_reconstruct_args(inputs, outputs, call_desc, cpu)\nwp.ScopedDevice(cpu)\nself.func(*arg_list)"]
    H -- register_ffi_callback --> K["FfiBuffer(platform='Host')\n__array_interface__\nExecutionContext(stream=None)\nfunc(inputs, outputs, attrs, ctx)"]
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
    A["JAX ffi_call(name, ...)"] --> B{Platform?}
    B -- CUDA --> C["FfiKernel / FfiCallable\nffi_callback(call_frame, platform='CUDA')"]
    B -- Host --> D["FfiKernel / FfiCallable\nffi_callback(call_frame, platform='Host')"]

    C --> E["get_cuda_device(ordinal)\nget_stream_from_callframe()"]
    E --> F["wp_cuda_launch_kernel(...)"]

    D --> G["wp.get_device('cpu')\nstream = None"]
    G --> H{Class?}
    H -- FfiKernel --> I["Build ArgsStruct\nwp_cpu_launch_kernel(hooks.forward, bounds, args)"]
    H -- FfiCallable --> J["_reconstruct_args(inputs, outputs, call_desc, cpu)\nwp.ScopedDevice(cpu)\nself.func(*arg_list)"]
    H -- register_ffi_callback --> K["FfiBuffer(platform='Host')\n__array_interface__\nExecutionContext(stream=None)\nfunc(inputs, outputs, attrs, ctx)"]
Loading

Reviews (15): Last reviewed commit: "Add JAX FFI Host (CPU) support" | Re-trigger Greptile

Comment thread warp/_src/jax/ffi.py
Comment thread warp/_src/jax/ffi.py

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@warp/_src/jax_experimental/ffi.py`:
- Around line 1772-1777: In register_ffi_callback, the Host ffi capsule is
constructed from the wrong variable (ffi_capsule_host) causing a NameError;
change the construction to use the host ccall address value by calling
jax.ffi.pycapsule(ffi_ccall_address_host.value) and then register that capsule
with jax.ffi.register_ffi_target(name, ffi_capsule_host, platform="Host") so the
Host path mirrors the CUDA path (refer to ffi_ccall_address_host,
ffi_capsule_host, register_ffi_target).
- Around line 629-632: The code assigns ffi_ccall_address_host then creates
ffi_capsule_host but mistakenly uses ffi_capsule_host.value (self-referential
NameError); change the capsule creation to use the previously computed address
value (ffi_ccall_address_host.value) so the lines around callback_func_host,
ffi_ccall_address_host, ffi_capsule_host and the
jax.ffi.register_ffi_target(self.name, ffi_capsule_host, platform="Host") call
use ffi_ccall_address_host.value when constructing the pycapsule for the Host
callback that wraps FFI_CCALLFUNC and calls self.ffi_callback.
- Around line 226-229: The Host FFI registration references an undefined
variable: replace the erroneous creation of ffi_capsule_host (currently using
ffi_capsule_host.value) with a capsule built from the c_void_p address you just
made; specifically, after creating callback_func_host and
ffi_ccall_address_host, set ffi_capsule_host =
jax.ffi.pycapsule(ffi_ccall_address_host.value) and then call
jax.ffi.register_ffi_target(self.name, ffi_capsule_host, platform="Host") so the
capsule is created from the correct address (symbols: callback_func_host,
FFI_CCALLFUNC, ffi_ccall_address_host, ffi_capsule_host,
jax.ffi.register_ffi_target, self.name, self.ffi_callback).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yml

Review profile: CHILL

Plan: Enterprise

Run ID: d2c93d0f-cbef-4891-b361-773ad8140f6c

📥 Commits

Reviewing files that changed from the base of the PR and between 54327e3 and 09569ba.

📒 Files selected for processing (1)
  • warp/_src/jax_experimental/ffi.py

Comment thread warp/_src/jax/ffi.py
Comment thread warp/_src/jax/ffi.py
Comment thread warp/_src/jax/ffi.py
Comment thread warp/_src/jax/ffi.py
@loney7

loney7 commented May 8, 2026

Copy link
Copy Markdown
Author

pre-commit.ci autofix

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@warp/tests/interop/test_jax.py`:
- Around line 2422-2450: The host-only test registrations for TestJax (the
add_function_test calls registering test_ffi_jax_kernel_host_add,
test_ffi_jax_kernel_host_sincos, test_ffi_jax_kernel_host_in_out,
test_ffi_jax_kernel_host_scale_vec_constant,
test_ffi_jax_callable_host_scale_constant, and
test_ffi_jax_callable_host_in_out) are currently inside the
jax_compatible_cuda_devices conditional; move these specific add_function_test
calls out of that CUDA-only if block so they are always registered on CPU-only
setups, keeping the existing device=None argument and leaving CUDA/GPU-specific
registrations inside the original conditional.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yml

Review profile: CHILL

Plan: Enterprise

Run ID: fcfb2e10-e509-41b9-b675-c8960056b6be

📥 Commits

Reviewing files that changed from the base of the PR and between d39169c and f6048f4.

📒 Files selected for processing (1)
  • warp/tests/interop/test_jax.py

Comment thread warp/tests/interop/test_jax.py Outdated
@loney7

loney7 commented May 9, 2026

Copy link
Copy Markdown
Author

pre-commit.ci autofix

@shi-eric shi-eric requested a review from nvlukasz May 12, 2026 16:45
Comment thread warp/tests/interop/test_jax.py Outdated
Comment thread warp/_src/jax/ffi.py
@shi-eric

Copy link
Copy Markdown
Contributor

Please add a CHANGELOG.md entry under Unreleased for this JAX Host FFI behavior change.

@shi-eric

Copy link
Copy Markdown
Contributor

Please squash this PR down to a single coherent commit before merge.

@shi-eric

Copy link
Copy Markdown
Contributor

Please rebase onto current main and move the fix/tests to the promoted JAX code paths. Commit 604a8961df6d40ea64ff1e740b23581e4c72c96f promoted the JAX code from jax_experimental to jax after this PR was opened, so the final diff should target the current locations.

@nvlukasz nvlukasz left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the contribution. Please address outstanding comments or let us know if you are unable to do so.

Comment thread warp/_src/jax/ffi.py
Comment thread warp/_src/jax_experimental/ffi.py Outdated
Comment thread warp/_src/jax_experimental/ffi.py Outdated
@loney7

loney7 commented May 19, 2026

Copy link
Copy Markdown
Author

Thank you for the contribution. Please address outstanding comments or let us know if you are unable to do so.

@nvlukasz , please allow till the end of this week to address the outstanding comments. I apologise for the delay.

@loney7 loney7 force-pushed the loney7/ffi-host-support branch 2 times, most recently from 766a819 to cc5aa99 Compare June 9, 2026 05:46
@loney7

loney7 commented Jun 9, 2026

Copy link
Copy Markdown
Author

pre-commit.ci autofix

@loney7 loney7 force-pushed the loney7/ffi-host-support branch from cc5aa99 to c06e8d4 Compare June 9, 2026 05:48
@loney7

loney7 commented Jun 9, 2026

Copy link
Copy Markdown
Author

pre-commit.ci autofix

@loney7 loney7 marked this pull request as ready for review June 9, 2026 07:27
@loney7 loney7 force-pushed the loney7/ffi-host-support branch from 420290b to 38f4e20 Compare June 9, 2026 07:37
@loney7

loney7 commented Jun 9, 2026

Copy link
Copy Markdown
Author

pre-commit.ci autofix

@btaba

btaba commented Jun 9, 2026

Copy link
Copy Markdown
Contributor

Hi @nvlukasz wondering if you could take another pass :)

@loney7 loney7 force-pushed the loney7/ffi-host-support branch from df57595 to f752ab9 Compare June 9, 2026 20:17
@loney7

loney7 commented Jun 9, 2026

Copy link
Copy Markdown
Author

pre-commit.ci autofix

1 similar comment
@loney7

loney7 commented Jun 9, 2026

Copy link
Copy Markdown
Author

pre-commit.ci autofix

@greptile-apps

greptile-apps Bot commented Jun 9, 2026

Copy link
Copy Markdown

Want your agent to iterate on Greptile's feedback? Try greploops.

@btaba

btaba commented Jun 12, 2026

Copy link
Copy Markdown
Contributor

friendly ping to @nvlukasz or @shi-eric (sorry for the trouble)

@btaba

btaba commented Jun 23, 2026

Copy link
Copy Markdown
Contributor

friendly ping again @nvlukasz or @shi-eric

@shi-eric

Copy link
Copy Markdown
Contributor

/ok to test 12f792c

@shi-eric

Copy link
Copy Markdown
Contributor

Hey @btaba, did you check that the tests pass before asking for a re-review?

The newly added Host FFI tests fail both locally and in GitHub CI. The RTX PRO 6000 JAX job (test-warp-gpu-linux, jax[cuda13]) fails with the same two Host FFI failures I see locally: an exact-equality failure in test_ffi_jax_kernel_host_sincos, and a Host callable vec2 output shape mismatch in test_ffi_jax_callable_host_scale_constant.

The branch is also still conflicting with main, and the requested squash is pending. Please fix the failing tests, rebase onto current main, squash the PR, and then ping us again for review.

@loney7

loney7 commented Jun 24, 2026

Copy link
Copy Markdown
Author

pre-commit.ci autofix

@loney7 loney7 force-pushed the loney7/ffi-host-support branch from 0c182df to 0644b33 Compare June 24, 2026 07:01
@loney7

loney7 commented Jun 24, 2026

Copy link
Copy Markdown
Author

/ok to test 0644b33

@loney7 loney7 force-pushed the loney7/ffi-host-support branch 2 times, most recently from 5d52c5c to 865d5ef Compare June 24, 2026 07:21
@loney7

loney7 commented Jun 24, 2026

Copy link
Copy Markdown
Author

/ok to test 865d5ef

@loney7 loney7 force-pushed the loney7/ffi-host-support branch from d6973f2 to 62590b8 Compare June 24, 2026 09:46

Copy link
Copy Markdown
Contributor

Could you also fix this Host jax_kernel placement case? It fails when jax_kernel is traced with CUDA as the effective JAX default device, but the actual call executes on CPU.

Minimal repro:

import jax
import jax.numpy as jnp
import numpy as np
import warp as wp


@wp.kernel
def add_kernel(
    a: wp.array[wp.float32],
    b: wp.array[wp.float32],
    out: wp.array[wp.float32],
):
    tid = wp.tid()
    out[tid] = a[tid] + b[tid]


jax_add = wp.jax_kernel(add_kernel)


@jax.jit
def f(a, b):
    (out,) = jax_add(a, b)
    return out


cpu = jax.devices("cpu")[0]

# Leave the effective JAX default device as CUDA, but place inputs on CPU.
a = jax.device_put(jnp.arange(16, dtype=jnp.float32), cpu)
b = jax.device_put(jnp.ones(16, dtype=jnp.float32), cpu)

out = f(a, b)
jax.block_until_ready(out)

np.testing.assert_allclose(np.asarray(out), np.arange(1, 17, dtype=np.float32))

On this branch, __call__ preloads the module for the default CUDA device. Execution then enters the Host callback, selects wp.get_device("cpu"), and fails when fetching CPU hooks:

RuntimeError: Module is not loaded on device cpu

The Host callback should ensure the kernel module is loaded for CPU before calling get_kernel_hooks(..., device). This would be good to cover with a regression test that leaves the effective JAX default device on CUDA while passing CPU-placed inputs.

Copy link
Copy Markdown
Contributor

Could you also fix the generic Host FFI callback buffer path? The Host target is registered, but the buffers passed to the user callback are still FfiBuffer objects that only expose __cuda_array_interface__. On CPU, passing those buffers directly to wp.launch(..., device="cpu") fails because Warp looks for __array_interface__.

Minimal repro:

import warnings

import jax
import jax.numpy as jnp
import warp as wp

with warnings.catch_warnings():
    warnings.simplefilter("ignore", DeprecationWarning)
    from warp.jax_experimental import register_ffi_callback


@wp.kernel
def double_kernel(inp: wp.array[wp.float32], out: wp.array[wp.float32]):
    tid = wp.tid()
    out[tid] = 2.0 * inp[tid]


def warp_callback(inputs, outputs, attrs, ctx):
    inp = inputs[0]
    out = outputs[0]

    assert ctx.stream is None
    assert hasattr(inp, "__cuda_array_interface__")
    assert not hasattr(inp, "__array_interface__")

    wp.launch(double_kernel, dim=inp.shape[0], inputs=[inp], outputs=[out], device="cpu")


register_ffi_callback("warp_generic_host_buffer_repro", warp_callback)


@jax.jit
def f(x):
    out_type = jax.ShapeDtypeStruct(x.shape, x.dtype)
    call = jax.ffi.ffi_call("warp_generic_host_buffer_repro", out_type)
    return call(x)


with jax.default_device(jax.devices("cpu")[0]):
    x = jnp.arange(16, dtype=jnp.float32)
    out = f(x)

jax.block_until_ready(out)

On this branch, the Host callback is reached, but the CPU launch fails while packing the FfiBuffer:

RuntimeError: Error launching kernel 'double_kernel', argument 'inp' expects an array of type array(ndim=1, dtype=float32), but passed value has type FfiBuffer.

Can the Host generic callback path expose CPU-compatible buffers, e.g. via __array_interface__, so the same callback pattern supported on CUDA also works with wp.launch(..., device="cpu")?

Add support for running JAX FFI callbacks on the CPU (Host) platform
in addition to CUDA.

Changes:
- Register FFI targets for both "CUDA" and "Host" platforms in
  FfiKernel, FfiCallable, and register_ffi_callback.
- Handle the "Host" platform case in ffi_callback by using the CPU
  device and bypassing CUDA-specific features like streams and graphs.
- Update FfiCallable to reconstruct arguments and execute the function
  on the CPU when running on the Host platform.
- Make FfiBuffer platform-aware: expose __array_interface__ for Host
  and __cuda_array_interface__ for CUDA.
- Ensure kernel module is loaded on the target device before getting
  hooks, fixing the case where JAX default device is CUDA but the
  callback runs on Host.
- Guard COMMAND_BUFFER_COMPATIBLE traits and graph_mode setup behind
  platform == "CUDA" checks.
- Add CPU/Host FFI tests covering kernels, callables, and callbacks.

(NVIDIAGH-1446)

Signed-off-by: Ankit Jain <kitsrish@google.com>
@loney7 loney7 force-pushed the loney7/ffi-host-support branch from 62590b8 to c9d6224 Compare June 25, 2026 14:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants