Skip to content

arusli1/EfficientQwen

Repository files navigation

EfficientQwen

Latency + throughput optimization of Qwen3.5-4B inference on a single NVIDIA A10G (24 GB) via stacked quantization, speculative decoding, and vLLM serving tweaks.

Headline: ~3.4× throughput vs. the BF16 baseline at parity quality on MMLU-Pro / IFEval / GPQA-Diamond, with a one-line make serve.


What's inside

  Qwen3.5-4B (BF16, 4B params, ~7.5 GB on disk)
      │
      ▼
  ┌──────────────────────────────────────────────┐
  │  AWQ-4bit MLP weights                        │  llmcompressor calibration
  │  ↳ compressed-tensors / Marlin kernel         │  scripts/llmcompressor_calibrate.py
  └──────────────────────────────────────────────┘
      │
      ▼
  ┌──────────────────────────────────────────────┐
  │  Vocab pruning (optional)  248k → 64k        │  scripts/prune_vocab_v2.py
  │  ↳ tied embed slice + token-ID remap at serve │  scripts/analyze_vocab.py
  └──────────────────────────────────────────────┘
      │
      ▼
  ┌──────────────────────────────────────────────┐
  │  MTP head fine-tune  (K=4 speculative)       │  scripts/redistill_mtp.py
  │  ↳ frozen quant projections, BF16 norms+FC    │  ~13M trainable, ~30 min
  └──────────────────────────────────────────────┘
      │
      ▼
  ┌──────────────────────────────────────────────┐
  │  vLLM serving runtime                         │  scripts/serve.py
  │   • FP8 KV-cache                              │
  │   • Chunked prefill (2k tokens)               │
  │   • Prefix caching                            │
  │   • max_num_seqs=8, batched cudagraphs        │
  │   • qk-norm + RoPE fusion                     │
  │   • cuDNN prefill kernel                      │
  └──────────────────────────────────────────────┘
      │
      ▼
  ┌──────────────────────────────────────────────┐
  │  Pre-baked torch._inductor cache              │  scripts/bake_cache.py
  │  ↳ A40-build → A10G-serve device-name shim    │  scripts/_cache_patch.py
  │  ↳ cold-start 697s → 156s (~78% off)          │
  └──────────────────────────────────────────────┘

Optimizations, summarized

Lever Technique Source
Weight quant AWQ-4bit MLP, compressed-tensors → Marlin kernel scripts/llmcompressor_calibrate.py
Speculative decoding Multi-Token Predictor head (MTP), K=4 candidates scripts/redistill_mtp.py
KV cache FP8 KV, block size 16, chunked prefill, prefix caching scripts/serve.py (vLLM flags)
Vocab Prune embedding 248k → 64k, sidecar token-ID remap scripts/prune_vocab_v2.py
Batching max_num_seqs=8, multi-size cudagraphs {8,16,…,64} experiments/cyankiwi-seq8/config.env
Kernel fusion qk-norm + RoPE fusion (VLLM_COMPILATION_CONFIG), cuDNN prefill scripts/serve.py
Cold start Pre-bake torch.compile cache + device-name shim scripts/bake_cache.py, scripts/_cache_patch.py
Chat template Truncate thinking-mode reasoning budget for short-form tasks experiments/cyankiwi-seq8-shortcot/

Measured results

Tested on AWS g5.xlarge (1× A10G, 24 GB). Baseline = BF16 Qwen3.5-4B under stock vLLM. Speedup measured on a mixed prompt set (short / medium / long-form generation).

Variant Avg latency speedup MMLU-Pro IFEval GPQA-D
cyankiwi (AWQ-4bit baseline) 1.0× ref 0.65 0.83 0.59
cyankiwi-seq8 (+ batched cudagraphs) 3.45× 0.65 0.83 0.59
cyankiwi-seq8-best (+ MTP K=4 tuning + rep-penalty fix) 3.73× 0.64 0.86 0.59*

* GPQA-D measured at 44/99 in fast-eval; full eval still pending.


Quick start

make install                 # creates .venv, installs host deps
make install-eval            # adds lm-eval + datasets (eval box)
make download                # baseline weights (~3.8 GB)
make datasets                # MMLU-Pro / IFEval / GPQA-Diamond
make test                    # pytest (~10s, no GPU needed)

Run + eval (GPU box)

make serve                              # vLLM, default variant
make eval-quality                       # 10% lm-eval sample (~20 min)
make eval-quality-full                  # full lm-eval sample (~60 min)
make eval-latency                       # latency probe on diverse prompts

# Try a different variant:
make serve         VARIANT=cyankiwi-seq8-best
make eval-quality  VARIANT=cyankiwi-seq8-best
make eval-latency  VARIANT=cyankiwi-seq8-best

Outputs land in experiments/<variant>/{quality,latency}_<date>.json.

Container build (optional)

make build                              # docker build with native cache bake (GPU host)
make build-import                       # use a pre-built cache_import.tar.gz (no GPU needed)
make verify-image VARIANT=cyankiwi      # inspect resulting image

Repo layout

experiments/                  research artifacts — one subdir per variant
  cyankiwi/                     AWQ-4bit reference (hf: cyankiwi/Qwen3.5-4B-AWQ-4bit)
  cyankiwi-seq8/                + max_num_seqs=8 + multi-size cudagraphs
  cyankiwi-seq8-best/           + MTP K=4 + matched cudagraph capture + rep_penalty fix
  cyankiwi-seq8-shortcot/       + chat-template thinking-budget cap
  cyankiwi-seq8-v64k/           + 248k → 64k vocab prune
  cyankiwi-seq8-w4lm/           + INT4 lm_head via AWQ-Marlin
  cyankiwi-seq8-shortcot-w4lm-v64k-mtp4/    stacked combo
  README.md                     variant catalog + naming convention
weights/                      raw weight files (large; gitignored beyond baseline)
  cyankiwi/                     baseline upstream
scripts/
  serve.py, _cache_patch.py, bake_cache.py     # runtime
  sitecustomize.py                              # worker-subprocess cache shim
  bench_latency.py, bench_chat_latency.py,      # latency probes
  profile_model.py, profile_matrix.sh
  eval_smoke.py, eval_fast.py, eval_full.py,
  eval_common.py
  prune_vocab_v2.py, analyze_vocab.py           # vocab work
  redistill_mtp.py                              # MTP head training
  llmcompressor_calibrate.py,                   # quantization
  build_calibration_corpus.py
  analyze_lmhead_svd.py                         # SVD on lm_head (exploration)
  download_weights.py, verify_checkpoint.py
  build_image.sh, verify_image.sh
  check_schemas.py                              # validate measurements JSON
eval/                         lm-eval-harness driver scripts
tests/                        pytest, 1:1 mirror of critical scripts
results/                      one-off measurement records (SVD, K-sweep, profiles)
Dockerfile                    container target
Makefile                      VARIANT-aware: make <target> VARIANT=name

Each variant directory in experiments/ is a self-contained record: hypothesis (README.md), serve flags (config.env), measurements JSON, dated eval outputs.


Approach notes

  • Why MTP over EAGLE / Medusa: lower training overhead — the MTP head initializes from the base model and only ~13M params train. K=4 dominated K=7 on this model + GPU because acceptance rate dropped past 4 candidates on long-form generation.
  • Why FP8 KV over INT8: vLLM auto-selects FP8 with compressed-tensors AWQ on A10G (SM86) and the throughput delta is real; INT8 KV in vLLM required a custom kernel path that wasn't a net win here.
  • Cold-start trick: torch._inductor keys its cache by GPU SM string. If you bake the cache on an A40 and serve on an A10G, the keys mismatch and cold-start re-compiles everything (~697s). _cache_patch.py spoofs the device name at Python startup so the A10G hits the warm cache (~156s). This applies to vLLM's worker subprocesses too, so sitecustomize.py duplicates it via Python's site machinery.

License

See LICENSE.

About

Minimizing Inference Latency for Qwen3.5-4B on A10G/A40

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors