Add JAX FFI Host support#1446
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughThis pull request extends JAX experimental FFI callbacks to support both CUDA and Host (CPU) execution paths. Callbacks now accept a ChangesCUDA and Host FFI Execution
🎯 4 (Complex) | ⏱️ ~45 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
Greptile SummaryThis PR extends the JAX FFI integration to support the CPU ("Host") platform in addition to CUDA, enabling Warp kernels and user-defined Python callables to be invoked through the JAX FFI on CPU devices.
Confidence Score: 5/5Safe to merge; the Host execution paths correctly mirror the existing CUDA paths, and the previously identified critical bugs have all been addressed in this version. The changes are a clean extension of existing, well-tested patterns: the Host callback registration, stream guarding, and argument reconstruction all follow the same structure as the working CUDA path. The only findings are non-blocking quality suggestions. warp/_src/jax/ffi.py — specifically the CPU launch path in FfiKernel.ffi_callback where the ArgsStruct is created fresh on every invocation. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["JAX ffi_call(name, ...)"] --> B{Platform?}
B -- CUDA --> C["FfiKernel / FfiCallable\nffi_callback(call_frame, platform='CUDA')"]
B -- Host --> D["FfiKernel / FfiCallable\nffi_callback(call_frame, platform='Host')"]
C --> E["get_cuda_device(ordinal)\nget_stream_from_callframe()"]
E --> F["wp_cuda_launch_kernel(...)"]
D --> G["wp.get_device('cpu')\nstream = None"]
G --> H{Class?}
H -- FfiKernel --> I["Build ArgsStruct\nwp_cpu_launch_kernel(hooks.forward, bounds, args)"]
H -- FfiCallable --> J["_reconstruct_args(inputs, outputs, call_desc, cpu)\nwp.ScopedDevice(cpu)\nself.func(*arg_list)"]
H -- register_ffi_callback --> K["FfiBuffer(platform='Host')\n__array_interface__\nExecutionContext(stream=None)\nfunc(inputs, outputs, attrs, ctx)"]
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
A["JAX ffi_call(name, ...)"] --> B{Platform?}
B -- CUDA --> C["FfiKernel / FfiCallable\nffi_callback(call_frame, platform='CUDA')"]
B -- Host --> D["FfiKernel / FfiCallable\nffi_callback(call_frame, platform='Host')"]
C --> E["get_cuda_device(ordinal)\nget_stream_from_callframe()"]
E --> F["wp_cuda_launch_kernel(...)"]
D --> G["wp.get_device('cpu')\nstream = None"]
G --> H{Class?}
H -- FfiKernel --> I["Build ArgsStruct\nwp_cpu_launch_kernel(hooks.forward, bounds, args)"]
H -- FfiCallable --> J["_reconstruct_args(inputs, outputs, call_desc, cpu)\nwp.ScopedDevice(cpu)\nself.func(*arg_list)"]
H -- register_ffi_callback --> K["FfiBuffer(platform='Host')\n__array_interface__\nExecutionContext(stream=None)\nfunc(inputs, outputs, attrs, ctx)"]
Reviews (15): Last reviewed commit: "Add JAX FFI Host (CPU) support" | Re-trigger Greptile |
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@warp/_src/jax_experimental/ffi.py`:
- Around line 1772-1777: In register_ffi_callback, the Host ffi capsule is
constructed from the wrong variable (ffi_capsule_host) causing a NameError;
change the construction to use the host ccall address value by calling
jax.ffi.pycapsule(ffi_ccall_address_host.value) and then register that capsule
with jax.ffi.register_ffi_target(name, ffi_capsule_host, platform="Host") so the
Host path mirrors the CUDA path (refer to ffi_ccall_address_host,
ffi_capsule_host, register_ffi_target).
- Around line 629-632: The code assigns ffi_ccall_address_host then creates
ffi_capsule_host but mistakenly uses ffi_capsule_host.value (self-referential
NameError); change the capsule creation to use the previously computed address
value (ffi_ccall_address_host.value) so the lines around callback_func_host,
ffi_ccall_address_host, ffi_capsule_host and the
jax.ffi.register_ffi_target(self.name, ffi_capsule_host, platform="Host") call
use ffi_ccall_address_host.value when constructing the pycapsule for the Host
callback that wraps FFI_CCALLFUNC and calls self.ffi_callback.
- Around line 226-229: The Host FFI registration references an undefined
variable: replace the erroneous creation of ffi_capsule_host (currently using
ffi_capsule_host.value) with a capsule built from the c_void_p address you just
made; specifically, after creating callback_func_host and
ffi_ccall_address_host, set ffi_capsule_host =
jax.ffi.pycapsule(ffi_ccall_address_host.value) and then call
jax.ffi.register_ffi_target(self.name, ffi_capsule_host, platform="Host") so the
capsule is created from the correct address (symbols: callback_func_host,
FFI_CCALLFUNC, ffi_ccall_address_host, ffi_capsule_host,
jax.ffi.register_ffi_target, self.name, self.ffi_callback).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yml
Review profile: CHILL
Plan: Enterprise
Run ID: d2c93d0f-cbef-4891-b361-773ad8140f6c
📒 Files selected for processing (1)
warp/_src/jax_experimental/ffi.py
|
pre-commit.ci autofix |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@warp/tests/interop/test_jax.py`:
- Around line 2422-2450: The host-only test registrations for TestJax (the
add_function_test calls registering test_ffi_jax_kernel_host_add,
test_ffi_jax_kernel_host_sincos, test_ffi_jax_kernel_host_in_out,
test_ffi_jax_kernel_host_scale_vec_constant,
test_ffi_jax_callable_host_scale_constant, and
test_ffi_jax_callable_host_in_out) are currently inside the
jax_compatible_cuda_devices conditional; move these specific add_function_test
calls out of that CUDA-only if block so they are always registered on CPU-only
setups, keeping the existing device=None argument and leaving CUDA/GPU-specific
registrations inside the original conditional.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yml
Review profile: CHILL
Plan: Enterprise
Run ID: fcfb2e10-e509-41b9-b675-c8960056b6be
📒 Files selected for processing (1)
warp/tests/interop/test_jax.py
|
pre-commit.ci autofix |
|
Please add a |
|
Please squash this PR down to a single coherent commit before merge. |
|
Please rebase onto current |
nvlukasz
left a comment
There was a problem hiding this comment.
Thank you for the contribution. Please address outstanding comments or let us know if you are unable to do so.
@nvlukasz , please allow till the end of this week to address the outstanding comments. I apologise for the delay. |
766a819 to
cc5aa99
Compare
|
pre-commit.ci autofix |
cc5aa99 to
c06e8d4
Compare
|
pre-commit.ci autofix |
420290b to
38f4e20
Compare
|
pre-commit.ci autofix |
|
Hi @nvlukasz wondering if you could take another pass :) |
df57595 to
f752ab9
Compare
|
pre-commit.ci autofix |
1 similar comment
|
pre-commit.ci autofix |
|
Want your agent to iterate on Greptile's feedback? Try greploops. |
|
/ok to test 12f792c |
|
Hey @btaba, did you check that the tests pass before asking for a re-review? The newly added Host FFI tests fail both locally and in GitHub CI. The RTX PRO 6000 JAX job ( The branch is also still conflicting with |
|
pre-commit.ci autofix |
0c182df to
0644b33
Compare
|
/ok to test 0644b33 |
5d52c5c to
865d5ef
Compare
|
/ok to test 865d5ef |
d6973f2 to
62590b8
Compare
|
Could you also fix this Host Minimal repro: import jax
import jax.numpy as jnp
import numpy as np
import warp as wp
@wp.kernel
def add_kernel(
a: wp.array[wp.float32],
b: wp.array[wp.float32],
out: wp.array[wp.float32],
):
tid = wp.tid()
out[tid] = a[tid] + b[tid]
jax_add = wp.jax_kernel(add_kernel)
@jax.jit
def f(a, b):
(out,) = jax_add(a, b)
return out
cpu = jax.devices("cpu")[0]
# Leave the effective JAX default device as CUDA, but place inputs on CPU.
a = jax.device_put(jnp.arange(16, dtype=jnp.float32), cpu)
b = jax.device_put(jnp.ones(16, dtype=jnp.float32), cpu)
out = f(a, b)
jax.block_until_ready(out)
np.testing.assert_allclose(np.asarray(out), np.arange(1, 17, dtype=np.float32))On this branch, The Host callback should ensure the kernel module is loaded for CPU before calling |
|
Could you also fix the generic Host FFI callback buffer path? The Host target is registered, but the buffers passed to the user callback are still Minimal repro: import warnings
import jax
import jax.numpy as jnp
import warp as wp
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
from warp.jax_experimental import register_ffi_callback
@wp.kernel
def double_kernel(inp: wp.array[wp.float32], out: wp.array[wp.float32]):
tid = wp.tid()
out[tid] = 2.0 * inp[tid]
def warp_callback(inputs, outputs, attrs, ctx):
inp = inputs[0]
out = outputs[0]
assert ctx.stream is None
assert hasattr(inp, "__cuda_array_interface__")
assert not hasattr(inp, "__array_interface__")
wp.launch(double_kernel, dim=inp.shape[0], inputs=[inp], outputs=[out], device="cpu")
register_ffi_callback("warp_generic_host_buffer_repro", warp_callback)
@jax.jit
def f(x):
out_type = jax.ShapeDtypeStruct(x.shape, x.dtype)
call = jax.ffi.ffi_call("warp_generic_host_buffer_repro", out_type)
return call(x)
with jax.default_device(jax.devices("cpu")[0]):
x = jnp.arange(16, dtype=jnp.float32)
out = f(x)
jax.block_until_ready(out)On this branch, the Host callback is reached, but the CPU launch fails while packing the Can the Host generic callback path expose CPU-compatible buffers, e.g. via |
Add support for running JAX FFI callbacks on the CPU (Host) platform in addition to CUDA. Changes: - Register FFI targets for both "CUDA" and "Host" platforms in FfiKernel, FfiCallable, and register_ffi_callback. - Handle the "Host" platform case in ffi_callback by using the CPU device and bypassing CUDA-specific features like streams and graphs. - Update FfiCallable to reconstruct arguments and execute the function on the CPU when running on the Host platform. - Make FfiBuffer platform-aware: expose __array_interface__ for Host and __cuda_array_interface__ for CUDA. - Ensure kernel module is loaded on the target device before getting hooks, fixing the case where JAX default device is CUDA but the callback runs on Host. - Guard COMMAND_BUFFER_COMPATIBLE traits and graph_mode setup behind platform == "CUDA" checks. - Add CPU/Host FFI tests covering kernels, callables, and callbacks. (NVIDIAGH-1446) Signed-off-by: Ankit Jain <kitsrish@google.com>
62590b8 to
c9d6224
Compare
Description
This PR adds support for running JAX FFI callbacks on the CPU (Host) in addition to CUDA.
Changes:
"CUDA"and"Host"platforms inregister_ffi_callback."Host"platform case inffi_callbackby using the CPU device and bypassing CUDA-specific features like streams and graphs.FfiCallableto reconstruct arguments and execute the function on the CPU when running on the Host platform.Checklist
Test plan
You can verify these changes by running the JAX interop tests which include FFI tests. Ensure they pass on both CPU and GPU (if available).