diff --git a/test/test_ao_models.py b/test/test_ao_models.py index a658216a7e..e159ed9ea5 100644 --- a/test/test_ao_models.py +++ b/test/test_ao_models.py @@ -9,6 +9,9 @@ from torch.testing._internal import common_utils from torchao._models.llama.model import Transformer +from torchao.utils import get_current_accelerator_device + +_DEVICE = get_current_accelerator_device() def init_model(name="stories15M", device="cpu", precision=torch.bfloat16): @@ -22,7 +25,7 @@ class TorchAOBasicTestCase(unittest.TestCase): """Test suite for basic Transformer inference functionality.""" @common_utils.parametrize( - "device", ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"] + "device", ["cpu", _DEVICE] if torch.accelerator.is_available() else ["cpu"] ) @common_utils.parametrize("batch_size", [1, 4]) @common_utils.parametrize("is_training", [True, False])