diff --git a/Matrix-Game-3/README.md b/Matrix-Game-3/README.md index 136affc..725caa8 100644 --- a/Matrix-Game-3/README.md +++ b/Matrix-Game-3/README.md @@ -22,12 +22,13 @@ In addition, the model trained on a combination of unreal and real-world data, a ## Requirements It supports one gpu or multi-gpu inference. We tested this repo on the following setup: -* A/H series GPUs are tested. +* NVIDIA A/H series GPUs are tested. +* AMD Instinct MI300X GPUs are also supported (ROCm 7.x + AITER). * Linux operating system. * 64 GB RAM. ## ⚙️ Quick Start -### Installation +### Installation (NVIDIA) Create a conda environment and install dependencies: ``` conda create -n matrix-game-3.0 python=3.12 -y @@ -39,6 +40,20 @@ cd Matrix-Game-3.0 pip install -r requirements.txt ``` +### Installation (AMD ROCm) +For AMD GPUs (e.g. MI300X) with ROCm 7.x: +```bash +conda create -n matrix-game-3.0 python=3.10 -y +conda activate matrix-game-3.0 +pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm7.2 +git clone https://github.com/SkyworkAI/Matrix-Game-3.0.git +cd Matrix-Game-3.0 +grep -v -E "^torch|flash.attn" requirements.txt | pip install -r /dev/stdin +pip install opencv-python-headless +# Install AITER (AMD flash attention CK backend) +pip install aiter # or: git clone https://github.com/ROCm/aiter && cd aiter && git submodule update --init 3rdparty/composable_kernel && pip install . +``` + ### Model Download ``` pip install "huggingface_hub[cli]" diff --git a/Matrix-Game-3/wan/modules/attention.py b/Matrix-Game-3/wan/modules/attention.py index 9c1f7e2..4648539 100644 --- a/Matrix-Game-3/wan/modules/attention.py +++ b/Matrix-Game-3/wan/modules/attention.py @@ -13,10 +13,18 @@ except ModuleNotFoundError: FLASH_ATTN_3_AVAILABLE = False +try: + import importlib as _il + _aiter_mha = _il.import_module('aiter.ops.mha') + _aiter_flash_attn_varlen = _aiter_mha.flash_attn_varlen_func + AITER_AVAILABLE = True +except Exception: + AITER_AVAILABLE = False + try: import flash_attn FLASH_ATTN_2_AVAILABLE = True -except ModuleNotFoundError: +except Exception: FLASH_ATTN_2_AVAILABLE = False import warnings @@ -110,8 +118,19 @@ def half(x): softmax_scale=softmax_scale, causal=causal, deterministic=deterministic).unflatten(0, (b, lq)) - else: - assert FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE + elif AITER_AVAILABLE: + cu_seqlens_q = torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum( + 0, dtype=torch.int32).to(q.device, non_blocking=True) + cu_seqlens_k = torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum( + 0, dtype=torch.int32).to(q.device, non_blocking=True) + x = _aiter_flash_attn_varlen( + q=q, k=k, v=v, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=lq, max_seqlen_k=lk, + softmax_scale=softmax_scale, causal=causal, + ).unflatten(0, (b, lq)) + elif FLASH_ATTN_2_AVAILABLE: x = flash_attn.flash_attn_varlen_func( q=q, k=k, @@ -127,6 +146,11 @@ def half(x): causal=causal, window_size=window_size, deterministic=deterministic).unflatten(0, (b, lq)) + else: + raise RuntimeError( + 'No flash attention backend available. Install one of: ' + 'flash-attn (NVIDIA), or aiter (AMD ROCm 7.x: pip install aiter)' + ) return x.type(out_dtype) @@ -149,7 +173,7 @@ def attention( version=None, ): global _WARNED_FA_DISABLED - if version != '0' and (FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE): + if version != '0' and (FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE or AITER_AVAILABLE): return flash_attention( q=q, k=k,