Skip to content

Conversation

@Perseus14
Copy link
Collaborator

@Perseus14 Perseus14 commented Jan 15, 2026

Summary

This PR introduces full Low-Rank Adaptation (LoRA) inference support for the WAN family of models in MaxDiffusion.

Unlike previous implementations in this codebase that rely on flax.linen, this implementation leverages flax.nnx. This allows for a more Pythonic, object-oriented approach to weight injection, enabling us to modify the transformer model in-place.

Key Features

1. Transition to flax.nnx

WAN models in MaxDiffusion are implemented using flax.nnx. To support LoRA, we implemented a native NNX loader rather than wrapping linen modules.

  • In-Place Merging: We iterate through the NNX graph (nnx.iter_graph) to identify target layers (nnx.Linear, nnx.Conv, nnx.Embed, nnx.LayerNorm) and merge LoRA weights directly into the kernel values.
  • Graph Traversal: This approach avoids complex module replacement logic common in functional frameworks, allowing us to simply "visit" nodes and apply updates.

2. Robust Weight Merging Strategy

This implementation solves several critical distributed training/inference challenges:

  • Device-Side Merging (jax.jit): To avoid ShardingMismatch and DeviceArray errors that occur when mixing sharded TPU weights with CPU-based LoRA weights, all merge computations (kernel + delta) are performed within JIT-compiled functions (_compute_and_add_*_jit). This ensures weight updates occur efficiently on-device across the TPU mesh.
  • Zero-Copy Transfer: Utilizes jax.dlpack where possible to efficiently move PyTorch tensors to JAX arrays without unnecessary memory overhead.

3. Advanced LoRA Support

Beyond standard Linear rank reduction, this PR supports:

  • LoCON / LoRA for Convolutions: Supports LoRA for both 1x1 and kxk convolutions. 1x1 convolutions are merged efficiently inside the JIT like linear layers, while kxk convolution deltas (LoCON) are pre-calculated on the host and added to any existing diff weights before device-side merging.
  • Full Weight & Bias Diffs (diff, diff_b): Supports checkpoints that include full-parameter fine-tuning offsets (difference injections) and bias tuning, which are common in high-fidelity WAN fine-tunes.
  • Embeddings & Norms: Includes support for patching text_embedding, time_embedding, and LayerNorm/RMSNorm scales and biases.

4. Scanned vs. Unscanned Layers

MaxDiffusion supports enabling jax.scan for transformer layers via the scan_layers: True configuration flag. This improves training memory efficiency by stacking weights of repeated layers (e.g., Attention, FFN) along a new leading dimension. Since users may run inference with or without this flag enabled, this LoRA implementation is designed to transparently support both modes.

The loader distinguishes between:

  • scan_layers: False: The model graph is "unrolled." The merge_lora() function is used, which iterates through each layer and merges weights individually via efficient, on-device JIT calls (_compute_and_add_single_jit).
  • scan_layers: True: The merge_lora_for_scanned() function is used. It detects which parameters are stacked (e.g., kernel.ndim > 2) and which are not.
    • For stacked parameters: It gathers all corresponding LoRA weights on the host CPU into stacked NumPy arrays and dispatches a single, batched call to _compute_and_add_scanned_jit. This updates all layers in the stack at once on-device, which is significantly more efficient than merging layer-by-layer.
    • For non-stacked parameters (e.g., embeddings, proj_out): It merges them individually using the single-layer JIT logic.

This dual approach ensures correct weight injection whether or not layers are scanned, while maximizing performance in scanned mode through batching.

Files Added / Modified

  • src/maxdiffusion/models/lora_nnx.py: [NEW] Core logic. Contains the JIT merge functions, parse_lora_dict, and the graph traversal logic (merge_lora, merge_lora_for_scanned) to inject weights into NNX modules.
  • src/maxdiffusion/loaders/wan_lora_nnx_loader.py: [NEW] Orchestrates the loading process. Handles the download of safetensors, conversion of keys, and delegation to the merge functions.
  • src/maxdiffusion/generate_wan.py: Updated the generation pipeline to identify if lora is enabled and trigger the loading sequence before inference.
  • src/maxdiffusion/lora_conversion_utils.py: Updated translate_wan_nnx_path_to_diffusers_lora to accurately map NNX paths (including embeddings and time projections) to Diffusers-style keys.
  • base_wan_lora_14b.yml & base_wan_lora_27b.yml: Added lora_config section to specify LoRA checkpoints and parameters during inference.

Testing

  • Scenario: Validation of LoRA weights for WAN2.1 and WAN2.2, designed to enable high-quality video generation in a reduced number of inference steps.
Model LoRA Type Video Link Inference Steps Generation Time
WAN 2.1 T2V Link to Video 4 steps ~13s
WAN 2.2 T2V Link to Video 8 steps ~20s

@github-actions
Copy link

@Perseus14 Perseus14 marked this pull request as ready for review January 16, 2026 05:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant