Question about metal gemm #3380
Replies: 2 comments
-
|
For refrence, I was testing mlx's sdpa like this: import mlx.core as mx, time, sys
seq = int(sys.argv[1]) if len(sys.argv) > 1 else 2048
d = int(sys.argv[2]) if len(sys.argv) > 2 else 128
iters = int(sys.argv[3]) if len(sys.argv) > 3 else 20
Q = mx.random.normal((1, 1, seq, d)).astype(mx.float16)
K = mx.random.normal((1, 1, seq, d)).astype(mx.float16)
V = mx.random.normal((1, 1, seq, d)).astype(mx.float16)
mx.eval(Q, K, V)
for _ in range(5):
O = mx.fast.scaled_dot_product_attention(Q, K, V, scale=1.0/d**0.5)
mx.eval(O)
times = []
for _ in range(iters):
t0 = time.perf_counter()
O = mx.fast.scaled_dot_product_attention(Q, K, V, scale=1.0/d**0.5)
mx.eval(O)
times.append((time.perf_counter() - t0) * 1e6)
times.sort()
t = times[len(times) // 2]
flops = 4 * seq * seq * d + 5 * seq * seq
print(f"{t:.0f},{flops / (t * 1e3):.1f}") |
Beta Was this translation helpful? Give feedback.
-
|
Few likely culprits worth investigating — based on a quick read of your kernel vs mlx's 1. Your 2×2 simdgroup split is probably the biggest single cost. Each row of scores is owned by 2 simdgroups, so the softmax row-reductions need a partner-SIMD exchange via 2. Q reload per tile. Because V trashes 3. fp16 accumulators for softmax stats. Your 4. Score writeback for PV. You write scores → KVs → read them back for PV. That's an extra barrier + threadgroup bandwidth per tile. mlx keeps the simdgroup matrix accumulators live through softmax into PV via the 5. Threadgroup size 128. Caps occupancy — mlx uses My guess is (1) + (3) together explain most of the 2×. Your Source for reference: |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
I've written a fused attention kernel targeting M2 and I'm benchmarking against MLX's
scaled_dot_product_attention. After several rounds of optimization I'm still ~2x slower and I'm trying to understand what architectural choices explain the gap.The Kernel
Benchmarking setup
threadgroups = (1, seq/16, 1),threads_per_tg = (128, 1, 1)cmd->GPUEndTime() - cmd->GPUStartTime(), median of 10 runs after 3 warmupTimes
Key design choices:
BlockMMA<BN>structThis was specifically a standard attention: tile Q into 16-row blocks, stream K/V in 128-column tiles, online softmax with running max/sum, accumulate O in registers.
Note: my kernel is specialized for seq=2048, d=128 with compile-time constants (loop bounds, strides, tile counts are all constexpr). MLX's SDPA is fully general. Despite this advantage, I'm still 2.1x slower — which suggests the gap is
architectural, not from runtime overhead.
Any thoughts or advice would be much appreciated!
Beta Was this translation helpful? Give feedback.
All reactions