Skip to content

Commit 820a245

Browse files
unamedkrclaude
andcommitted
Metal MoE: memoryBarrier between phases (MLX pattern)
Applied MLX's Metal pattern: - Single command buffer for all 3 phases (was 3 separate) - memoryBarrierWithScope:MTLBarrierScopeBuffers between phases - GPU-side sync instead of CPU waitUntilCompleted per phase - Removed all NSLog debug overhead 35B no longer hangs with single cmdBuf. But per-layer dispatch still slower than CPU fused IQ2 dot for MoE experts. 0.8B: 7.5 tok/s (Metal build). 35B: ~1 tok/s (loading overhead). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 372afb8 commit 820a245

4 files changed

Lines changed: 110 additions & 102 deletions

File tree

src/backend/metal/tq_matmul.metal

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,3 +516,67 @@ kernel void matmul_q4_k(
516516
output[row] = total;
517517
}
518518
}
519+
520+
521+
/* ============================================================
522+
* TurboQuant self Q4 matmul: block_size=32, 16 packed bytes + 1 float scale
523+
* dequant: (nibble - 8) * scale
524+
* Optimized: 4-byte unroll, SIMD reduce
525+
* ============================================================ */
526+
kernel void matmul_tq_q4(
527+
device const float* input [[buffer(0)]],
528+
device float* output [[buffer(1)]],
529+
device const uint8_t* weight_qs [[buffer(2)]],
530+
device const float* weight_sc [[buffer(3)]],
531+
constant uint& in_dim_u [[buffer(4)]],
532+
constant uint& out_dim_u [[buffer(5)]],
533+
uint row [[threadgroup_position_in_grid]],
534+
uint tid [[thread_index_in_threadgroup]],
535+
uint tg_size [[threads_per_threadgroup]])
536+
{
537+
if (row >= out_dim_u) return;
538+
539+
const uint in_dim = in_dim_u;
540+
const uint n_blocks = in_dim / 32;
541+
const uint blocks_per_thread = (n_blocks + tg_size - 1) / tg_size;
542+
const uint block_start = tid * blocks_per_thread;
543+
const uint block_end = min(block_start + blocks_per_thread, n_blocks);
544+
545+
const uint qs_row = row * n_blocks * 16;
546+
const uint sc_row = row * n_blocks;
547+
float sum = 0.0f;
548+
549+
for (uint b = block_start; b < block_end; b++) {
550+
const float sc = weight_sc[sc_row + b];
551+
device const uint8_t* qs = weight_qs + qs_row + b * 16;
552+
const uint base = b * 32;
553+
for (uint k = 0; k < 16; k += 4) {
554+
uint8_t p0 = qs[k], p1 = qs[k+1], p2 = qs[k+2], p3 = qs[k+3];
555+
sum += (float(int(p0 & 0xF) - 8) * input[base + k]
556+
+ float(int(p0 >> 4) - 8) * input[base + k + 16]
557+
+ float(int(p1 & 0xF) - 8) * input[base + k + 1]
558+
+ float(int(p1 >> 4) - 8) * input[base + k + 17]
559+
+ float(int(p2 & 0xF) - 8) * input[base + k + 2]
560+
+ float(int(p2 >> 4) - 8) * input[base + k + 18]
561+
+ float(int(p3 & 0xF) - 8) * input[base + k + 3]
562+
+ float(int(p3 >> 4) - 8) * input[base + k + 19]) * sc;
563+
}
564+
}
565+
566+
sum += simd_shuffle_down(sum, 16);
567+
sum += simd_shuffle_down(sum, 8);
568+
sum += simd_shuffle_down(sum, 4);
569+
sum += simd_shuffle_down(sum, 2);
570+
sum += simd_shuffle_down(sum, 1);
571+
572+
threadgroup float simd_sums[8];
573+
if (tid % 32 == 0) simd_sums[tid / 32] = sum;
574+
threadgroup_barrier(mem_flags::mem_threadgroup);
575+
576+
if (tid == 0) {
577+
uint n_simd = (tg_size + 31) / 32;
578+
float total = 0.0f;
579+
for (uint s = 0; s < n_simd; s++) total += simd_sums[s];
580+
output[row] = total;
581+
}
582+
}

src/backend/metal/tq_metal_dispatch.m

Lines changed: 37 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -903,7 +903,20 @@ int tq_metal_moe_forward(
903903
options:MTLResourceStorageModeShared];
904904
if (!params_buf) return -1;
905905

906-
/* --- Create command buffer and encoder --- */
906+
/* --- Create output buffer for Phase 3 (allocated once with other buffers) --- */
907+
size_t output_bytes = (size_t)hidden_dim * sizeof(float);
908+
id<MTLBuffer> output_buf = [tq_mtl_device newBufferWithLength:output_bytes
909+
options:MTLResourceStorageModeShared];
910+
if (!output_buf) {
911+
/* Fallback to hybrid if buffer creation fails */
912+
memcpy(hb_output, [gate_buf contents], inter_bytes);
913+
return 1;
914+
}
915+
916+
/* --- Single command buffer for all 3 phases (MLX pattern) ---
917+
* Metal guarantees sequential execution of compute encoders within
918+
* one command buffer. memoryBarrierWithScope ensures buffer writes
919+
* from one encoder are visible to the next. */
907920
id<MTLCommandBuffer> cmdBuf = [tq_mtl_queue commandBuffer];
908921
if (!cmdBuf) return -1;
909922

@@ -933,56 +946,13 @@ int tq_metal_moe_forward(
933946
MTLSize gridSize = MTLSizeMake(n_tgs, 1, 1);
934947
MTLSize tgSize = MTLSizeMake(TQ_MATMUL_TG_SIZE, 1, 1);
935948
[enc dispatchThreadgroups:gridSize threadsPerThreadgroup:tgSize];
936-
[enc endEncoding];
937-
}
938-
939-
/* --- Phase 1: commit and wait to isolate hang --- */
940-
[cmdBuf commit];
941-
[cmdBuf waitUntilCompleted];
942-
943-
if (cmdBuf.status == MTLCommandBufferStatusError) {
944-
NSLog(@"TurboQuant MoE: Phase 1 (gate+up) FAILED: %@", cmdBuf.error);
945-
return -1;
946-
}
947-
NSLog(@"TurboQuant MoE: Phase 1 (gate+up) completed OK");
948949

949-
#ifdef TQ_MOE_DEBUG_VALIDATE
950-
/* === Debug: compare GPU gate output for expert 0 vs CPU tq_matmul_gguf === */
951-
{
952-
/* tq_matmul_gguf declared in tq_gguf.h (already included) */
953-
float* gpu_gate = (float*)[gate_buf contents];
954-
float* cpu_gate = (float*)malloc((size_t)expert_dim * sizeof(float));
955-
if (cpu_gate) {
956-
/* CPU matmul for expert 0's gate weights */
957-
const uint8_t* gate_w = (const uint8_t*)weight_base + gate_offsets[0];
958-
tq_ggml_dtype gt0 = gate_types_in ? (tq_ggml_dtype)gate_types_in[0]
959-
: (tq_ggml_dtype)weight_type;
960-
tq_matmul_gguf(cpu_gate, input, gate_w, gt0, expert_dim, hidden_dim);
961-
962-
/* Compare first 8 and last 8 values */
963-
NSLog(@"TurboQuant MoE DEBUG: gate expert 0 comparison (first 8):");
964-
float max_err = 0.0f;
965-
for (int i = 0; i < expert_dim; i++) {
966-
float err = fabsf(gpu_gate[i] - cpu_gate[i]);
967-
if (err > max_err) max_err = err;
968-
if (i < 8 || i >= expert_dim - 4) {
969-
NSLog(@" [%d] GPU=%.6f CPU=%.6f err=%.6f", i, gpu_gate[i], cpu_gate[i], err);
970-
}
971-
}
972-
NSLog(@"TurboQuant MoE DEBUG: gate max_err=%.6f across %d elements", max_err, expert_dim);
973-
if (max_err > 0.01f) {
974-
NSLog(@"TurboQuant MoE DEBUG: *** MISMATCH DETECTED *** — weight offset or decoding bug");
975-
}
976-
free(cpu_gate);
977-
}
950+
/* Memory barrier: ensure gate_buf/up_buf writes visible to Phase 2 */
951+
[enc memoryBarrierWithScope:MTLBarrierScopeBuffers];
952+
[enc endEncoding];
978953
}
979-
#endif /* TQ_MOE_DEBUG_VALIDATE */
980954

981-
/* --- New command buffer for Phase 2 --- */
982-
cmdBuf = [tq_mtl_queue commandBuffer];
983-
if (!cmdBuf) return -1;
984-
985-
/* ======== Phase 2: SwiGLU ======== */
955+
/* ======== Phase 2: SwiGLU (reads gate_buf/up_buf from Phase 1) ======== */
986956
{
987957
id<MTLComputeCommandEncoder> enc = [cmdBuf computeCommandEncoder];
988958
if (!enc) return -1;
@@ -998,40 +968,15 @@ int tq_metal_moe_forward(
998968
MTLSize gridSize = MTLSizeMake(n_tgs, 1, 1);
999969
MTLSize tgSize = MTLSizeMake(tg, 1, 1);
1000970
[enc dispatchThreadgroups:gridSize threadsPerThreadgroup:tgSize];
1001-
[enc endEncoding];
1002-
}
1003-
1004-
/* --- Phase 2: commit and wait to isolate hang --- */
1005-
[cmdBuf commit];
1006-
[cmdBuf waitUntilCompleted];
1007971

1008-
if (cmdBuf.status == MTLCommandBufferStatusError) {
1009-
NSLog(@"TurboQuant MoE: Phase 2 (SwiGLU) FAILED: %@", cmdBuf.error);
1010-
return -1;
972+
/* Memory barrier: ensure gate_buf writes visible to Phase 3 */
973+
[enc memoryBarrierWithScope:MTLBarrierScopeBuffers];
974+
[enc endEncoding];
1011975
}
1012-
NSLog(@"TurboQuant MoE: Phase 2 (SwiGLU) completed OK");
1013976

1014977
/* ======== Phase 3: down projection + weighted accumulate (GPU) ========
1015-
* Previously skipped due to IQ2_S shader hanging with constant array.
1016-
* Now fixed: IQ2_S codebook passed as device buffer (buffer 4). */
978+
* IQ2_S codebook passed as device buffer (buffer 4). */
1017979
{
1018-
/* Create output buffer for hidden_dim results */
1019-
size_t output_bytes = (size_t)hidden_dim * sizeof(float);
1020-
id<MTLBuffer> output_buf = [tq_mtl_device newBufferWithLength:output_bytes
1021-
options:MTLResourceStorageModeShared];
1022-
if (!output_buf) {
1023-
/* Fallback to hybrid if buffer creation fails */
1024-
memcpy(hb_output, [gate_buf contents], inter_bytes);
1025-
return 1;
1026-
}
1027-
1028-
/* New command buffer for Phase 3 */
1029-
cmdBuf = [tq_mtl_queue commandBuffer];
1030-
if (!cmdBuf) {
1031-
memcpy(hb_output, [gate_buf contents], inter_bytes);
1032-
return 1;
1033-
}
1034-
1035980
id<MTLComputeCommandEncoder> enc = [cmdBuf computeCommandEncoder];
1036981
if (!enc) {
1037982
memcpy(hb_output, [gate_buf contents], inter_bytes);
@@ -1057,26 +1002,26 @@ int tq_metal_moe_forward(
10571002
MTLSize tgSize3 = MTLSizeMake(TQ_MATMUL_TG_SIZE, 1, 1);
10581003
[enc dispatchThreadgroups:gridSize3 threadsPerThreadgroup:tgSize3];
10591004
[enc endEncoding];
1005+
}
10601006

1061-
[cmdBuf commit];
1062-
[cmdBuf waitUntilCompleted];
1007+
/* ONE commit + wait for all 3 phases */
1008+
[cmdBuf commit];
1009+
[cmdBuf waitUntilCompleted];
10631010

1064-
if (cmdBuf.status == MTLCommandBufferStatusError) {
1065-
NSLog(@"TurboQuant MoE: Phase 3 (down+accum) FAILED: %@", cmdBuf.error);
1066-
/* Fallback to hybrid on failure */
1067-
memcpy(hb_output, [gate_buf contents], inter_bytes);
1068-
return 1;
1069-
}
1070-
NSLog(@"TurboQuant MoE: Phase 3 (down+accum) completed OK");
1011+
if (cmdBuf.status == MTLCommandBufferStatusError) {
1012+
NSLog(@"TurboQuant MoE: GPU dispatch FAILED: %@", cmdBuf.error);
1013+
/* Fallback to hybrid on failure */
1014+
memcpy(hb_output, [gate_buf contents], inter_bytes);
1015+
return 1;
1016+
}
10711017

1072-
/* Copy result to output */
1073-
memcpy(output, [output_buf contents], output_bytes);
1018+
/* Copy result to output */
1019+
memcpy(output, [output_buf contents], output_bytes);
10741020

1075-
/* Also copy hb for potential caller use */
1076-
memcpy(hb_output, [gate_buf contents], inter_bytes);
1021+
/* Also copy hb for potential caller use */
1022+
memcpy(hb_output, [gate_buf contents], inter_bytes);
10771023

1078-
return 0; /* Full GPU success */
1079-
}
1024+
return 0; /* Full GPU success */
10801025
}
10811026
}
10821027

src/engine/tq_moe.c

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -669,10 +669,9 @@ void tq_moe_forward(const tq_moe_layer_t* layer,
669669
int num_active, int expert_dim, int hidden_dim, int num_experts_total, int weight_type,
670670
const int* gate_types, const int* up_types, const int* down_types);
671671

672-
/* Metal MoE: IQ2_S hang fixed! But per-phase waitUntilCompleted
673-
* makes it slow. Need single command buffer (was 9.5 tok/s).
674-
* Re-enable after merging single-cmdBuf dispatch. */
675-
if (0 && tq_metal_moe_available() && num_active > 0) {
672+
/* Metal MoE: single command buffer with memoryBarrier between phases.
673+
* Eliminates per-phase waitUntilCompleted overhead. */
674+
if (tq_metal_moe_available() && num_active > 0) {
676675
/* Check that all active experts use IQ2_XXS and have valid weights */
677676
int can_fuse = 1;
678677
const void* base_ptr = NULL;

src/engine/tq_transformer.c

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1576,12 +1576,12 @@ float* tq_forward(tq_model_t* model, tq_state_t* s, int token, int pos) {
15761576
tq_matmul_q2_preq(s->hb2, layer->w_up_q2, layer->w_up_q2s,
15771577
s->xb_q8, s->xb_q8s, c->intermediate_dim, dim);
15781578
} else if (layer->w_gate_q4) {
1579-
tq_quantize_row_q8(s->xb, s->xb_q8, s->xb_q8s, dim);
1580-
1581-
tq_matmul_q4_preq(s->hb, layer->w_gate_q4, layer->w_gate_q4s,
1582-
s->xb_q8, s->xb_q8s, c->intermediate_dim, dim);
1583-
tq_matmul_q4_preq(s->hb2, layer->w_up_q4, layer->w_up_q4s,
1584-
s->xb_q8, s->xb_q8s, c->intermediate_dim, dim);
1579+
/* FFN gate+up: batch 2 matmuls on GPU if Metal available,
1580+
* otherwise use Q4×Q8 preq fast path on CPU */
1581+
tq_metal_batch_begin_if_available();
1582+
tq_matmul_q4(s->hb, s->xb, layer->w_gate_q4, layer->w_gate_q4s, c->intermediate_dim, dim);
1583+
tq_matmul_q4(s->hb2, s->xb, layer->w_up_q4, layer->w_up_q4s, c->intermediate_dim, dim);
1584+
tq_metal_batch_flush_if_available();
15851585
} else if (layer->gguf_w_gate) {
15861586
/* Batch gate+up into one GPU command buffer (2 matmuls, 1 dispatch) */
15871587
tq_metal_batch_begin_if_available();

0 commit comments

Comments
 (0)