Torch-MIGraphX integrates AMD's graph inference engine with the PyTorch ecosystem.
It provides utilities and APIs for generating a mgx_module that is designed to be invoked in the same manner as any other torch module, but utilize the MIGraphX inference engine internally.
This library currently supports two paths for lowering:
- FX Tracing: Uses tracing API provided by the
torch.fxlibrary. - Dynamo Backend: Importing torch_migraphx automatically registers the "migraphx" backend that can be used with the
torch.compileAPI.
The simplest and recommended way to get started is using the provided Dockerfile. Build using:
./build_image.sh
Start container using:
sudo docker run -it --network=host --device=/dev/kfd --device=/dev/dri --group-add=video --ipc=host --cap-add=SYS_PTRACE --security-opt seccomp=unconfined torch_migraphx
The default Dockerfile builds on the nightly pytorch container and installs the latest source version of MIGraphX and torch_migraphx. For more builds refer to the docker directory.
Install Pre-reqs:
Build and install from source
git clone https://github.com/ROCmSoftwarePlatform/torch_migraphx.git
cd ./torch_migraphx/py
pip install .
# FX Tracing
torch_migraphx.fx.lower_to_mgx(torch_model, sample_inputs)
# Dynamo Backend
torch.compile(torch_model, backend="migraphx")
import torch
import torchvision
import torch_migraphx
resnet = torchvision.models.resnet50()
sample_input = torch.randn(2, 3, 64, 64)
resnet_mgx = torch_migraphx.fx.lower_to_mgx(resnet, [sample_input])
result = resnet_mgx(sample_input)
import torch
import torchvision
import torch_migraphx
densenet = torchvision.models.densenet161().cuda()
sample_input = torch.randn(2, 3, 512, 512).cuda()
densenet_mgx = torch.compile(densenet, backend="migraphx")
result = densenet_mgx(sample_input.cuda())
For more examples please refer to the examples directory.
We welcome contributions! Please read CONTRIBUTING.md for development setup, branch strategy, coding standards, and the pull request process. For bugs and feature requests, open a GitHub Issue.
To report a security vulnerability, do not open a public GitHub issue. See SECURITY.md for our responsible disclosure policy.
For questions, issues, or contributions, please reach out to the maintainers:
- Shivad Bhavsar — @shivadbhavsar · Shivad.Bhavsar@amd.com
Note: For internal or private AMD repositories, maintainers must list their AMD email address. See CODEOWNERS for the full ownership list.
This project is licensed under the BSD 3-Clause License.