From 7ee6a90bf6d73128cfb77c95a52de2be3fb8f7d8 Mon Sep 17 00:00:00 2001 From: zengxian Date: Wed, 10 Dec 2025 15:09:37 +0800 Subject: [PATCH 1/3] port more case in test_nf4.py to xpu --- test/dtypes/test_nf4.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/dtypes/test_nf4.py b/test/dtypes/test_nf4.py index a42a209a38..356b548346 100644 --- a/test/dtypes/test_nf4.py +++ b/test/dtypes/test_nf4.py @@ -262,7 +262,7 @@ def test_smoketest_linear(self, dtype: torch.dtype): _ = torch.nn.functional.linear(inp, a) _ = torch.nn.functional.linear(inp, a_nf4) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available") @parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) def test_smoketest_linear_compile(self, dtype: torch.dtype): if ( @@ -634,7 +634,7 @@ def world_size(self) -> int: return 2 @pytest.mark.skipif( - version.parse(torch.__version__).base_version < "2.4.0", + version.parse(torch.__version__) < version.parse("2.4.0"), reason="torch >= 2.4 required", ) @skip_if_lt_x_gpu(2) From 3b1ecf94aed0b48b1a6911b9b629e47748a8bba0 Mon Sep 17 00:00:00 2001 From: zengxian Date: Wed, 10 Dec 2025 15:53:42 +0800 Subject: [PATCH 2/3] port more case in test_quant_api.py to xpu --- test/quantization/test_quant_api.py | 22 +++++++++---------- .../int4/int4_tile_packed_to_4d_tensor.py | 2 +- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 5cd81ece90..91a8e0b000 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -643,20 +643,20 @@ def test_module_fqn_to_config_module_name(self): assert isinstance(model.linear2.weight, AffineQuantizedTensor) assert isinstance(model.linear2.weight._layout, PlainLayout) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available") def test_module_fqn_to_config_regex_basic(self): config1 = Int4WeightOnlyConfig( group_size=32, int4_packing_format="tile_packed_to_4d" ) config = ModuleFqnToConfig({"re:linear.": config1}) - model = ToyLinearModel().cuda().to(dtype=torch.bfloat16) - example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16) + model = ToyLinearModel().to(_DEVICE).to(dtype=torch.bfloat16) + example_inputs = model.example_inputs(device=_DEVICE, dtype=torch.bfloat16) quantize_(model, config, filter_fn=None) model(*example_inputs) assert isinstance(model.linear1.weight, Int4TilePackedTo4dTensor) assert isinstance(model.linear2.weight, Int4TilePackedTo4dTensor) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available") def test_module_fqn_to_config_regex_precedence(self): """Testing that full path config takes precedence over regex config in ModuleFqnToConfig @@ -666,14 +666,14 @@ def test_module_fqn_to_config_regex_precedence(self): ) config2 = IntxWeightOnlyConfig() config = ModuleFqnToConfig({"linear1": config1, "re:linear.": config2}) - model = ToyLinearModel().cuda().to(dtype=torch.bfloat16) - example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16) + model = ToyLinearModel().to(_DEVICE).to(dtype=torch.bfloat16) + example_inputs = model.example_inputs(device=_DEVICE, dtype=torch.bfloat16) quantize_(model, config, filter_fn=None) model(*example_inputs) assert isinstance(model.linear1.weight, Int4TilePackedTo4dTensor) assert isinstance(model.linear2.weight, IntxUnpackedToInt8Tensor) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available") def test_module_fqn_to_config_regex_precedence2(self): """Testing that full path config takes precedence over regex config in ModuleFqnToConfig, swapping @@ -685,14 +685,14 @@ def test_module_fqn_to_config_regex_precedence2(self): ) config2 = IntxWeightOnlyConfig() config = ModuleFqnToConfig({"re:linear.": config2, "linear1": config1}) - model = ToyLinearModel().cuda().to(dtype=torch.bfloat16) - example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16) + model = ToyLinearModel().to(_DEVICE).to(dtype=torch.bfloat16) + example_inputs = model.example_inputs(device=_DEVICE, dtype=torch.bfloat16) quantize_(model, config, filter_fn=None) model(*example_inputs) assert isinstance(model.linear1.weight, Int4TilePackedTo4dTensor) assert isinstance(model.linear2.weight, IntxUnpackedToInt8Tensor) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available") def test_module_fqn_to_config_regex_fullmatch(self): """Testing that we will only match the fqns that fully matches the regex @@ -731,7 +731,7 @@ def example_inputs(self): "linear3_full_match.bias": None, } ) - model = M(dtype=torch.bfloat16, device="cuda") + model = M(dtype=torch.bfloat16, device=_DEVICE) example_inputs = model.example_inputs() quantize_(model, config, filter_fn=None) model(*example_inputs) diff --git a/torchao/quantization/quantize_/workflows/int4/int4_tile_packed_to_4d_tensor.py b/torchao/quantization/quantize_/workflows/int4/int4_tile_packed_to_4d_tensor.py index 8d7291edcb..41c302855c 100644 --- a/torchao/quantization/quantize_/workflows/int4/int4_tile_packed_to_4d_tensor.py +++ b/torchao/quantization/quantize_/workflows/int4/int4_tile_packed_to_4d_tensor.py @@ -120,7 +120,7 @@ def from_hp( # Validate kernel requirements orig_out_features, orig_in_features = hp_tensor.shape[-2:] # TODO: relax checks to enable quantizing in other platoforms and run in A100 - if not torch.cuda.get_device_capability()[0] >= 8: + if torch.cuda.is_available() and not torch.cuda.get_device_capability()[0] >= 8: raise ValueError( f"Cannot use tinygemm int4 kernel with a device of compute capability {torch.cuda.get_device_capability()}, the minimum compute capability is 8.0 for tensor core kernels." ) From d01032ded650276d16848826aa1ead6776348d4c Mon Sep 17 00:00:00 2001 From: zengxian Date: Fri, 12 Dec 2025 09:52:42 +0800 Subject: [PATCH 3/3] skip nf4 fsdp2 --- test/dtypes/test_nf4.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/dtypes/test_nf4.py b/test/dtypes/test_nf4.py index 356b548346..4f1d1d3038 100644 --- a/test/dtypes/test_nf4.py +++ b/test/dtypes/test_nf4.py @@ -634,7 +634,7 @@ def world_size(self) -> int: return 2 @pytest.mark.skipif( - version.parse(torch.__version__) < version.parse("2.4.0"), + version.parse(torch.__version__).base_version < "2.4.0", reason="torch >= 2.4 required", ) @skip_if_lt_x_gpu(2)