Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions Matrix-Game-3/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@ In addition, the model trained on a combination of unreal and real-world data, a

## Requirements
It supports one gpu or multi-gpu inference. We tested this repo on the following setup:
* A/H series GPUs are tested.
* NVIDIA A/H series GPUs are tested.
* AMD Instinct MI300X GPUs are also supported (ROCm 7.x + AITER).
* Linux operating system.
* 64 GB RAM.

## ⚙️ Quick Start
### Installation
### Installation (NVIDIA)
Create a conda environment and install dependencies:
```
conda create -n matrix-game-3.0 python=3.12 -y
Expand All @@ -39,6 +40,20 @@ cd Matrix-Game-3.0
pip install -r requirements.txt
```

### Installation (AMD ROCm)
For AMD GPUs (e.g. MI300X) with ROCm 7.x:
```bash
conda create -n matrix-game-3.0 python=3.10 -y
conda activate matrix-game-3.0
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm7.2
git clone https://github.com/SkyworkAI/Matrix-Game-3.0.git
cd Matrix-Game-3.0
grep -v -E "^torch|flash.attn" requirements.txt | pip install -r /dev/stdin
pip install opencv-python-headless
# Install AITER (AMD flash attention CK backend)
pip install aiter # or: git clone https://github.com/ROCm/aiter && cd aiter && git submodule update --init 3rdparty/composable_kernel && pip install .
```

### Model Download
```
pip install "huggingface_hub[cli]"
Expand Down
32 changes: 28 additions & 4 deletions Matrix-Game-3/wan/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,18 @@
except ModuleNotFoundError:
FLASH_ATTN_3_AVAILABLE = False

try:
import importlib as _il
_aiter_mha = _il.import_module('aiter.ops.mha')
_aiter_flash_attn_varlen = _aiter_mha.flash_attn_varlen_func
AITER_AVAILABLE = True
except Exception:
AITER_AVAILABLE = False

try:
import flash_attn
FLASH_ATTN_2_AVAILABLE = True
except ModuleNotFoundError:
except Exception:
FLASH_ATTN_2_AVAILABLE = False

import warnings
Expand Down Expand Up @@ -110,8 +118,19 @@ def half(x):
softmax_scale=softmax_scale,
causal=causal,
deterministic=deterministic).unflatten(0, (b, lq))
else:
assert FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE
elif AITER_AVAILABLE:
cu_seqlens_q = torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
0, dtype=torch.int32).to(q.device, non_blocking=True)
cu_seqlens_k = torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
0, dtype=torch.int32).to(q.device, non_blocking=True)
x = _aiter_flash_attn_varlen(
q=q, k=k, v=v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=lq, max_seqlen_k=lk,
softmax_scale=softmax_scale, causal=causal,
).unflatten(0, (b, lq))
elif FLASH_ATTN_2_AVAILABLE:
x = flash_attn.flash_attn_varlen_func(
q=q,
k=k,
Expand All @@ -127,6 +146,11 @@ def half(x):
causal=causal,
window_size=window_size,
deterministic=deterministic).unflatten(0, (b, lq))
else:
raise RuntimeError(
'No flash attention backend available. Install one of: '
'flash-attn (NVIDIA), or aiter (AMD ROCm 7.x: pip install aiter)'
)

return x.type(out_dtype)

Expand All @@ -149,7 +173,7 @@ def attention(
version=None,
):
global _WARNED_FA_DISABLED
if version != '0' and (FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE):
if version != '0' and (FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE or AITER_AVAILABLE):
return flash_attention(
q=q,
k=k,
Expand Down