Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/transformers/quantizers/quantizer_mxfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,16 @@ def validate_environment(self, *args, **kwargs):
"Using mxfp4 quantization requires torch"
"Please install the latest version of torch ( pip install --upgrade torch )"
)

if self.quantization_config.dequantize:
return

if not torch.cuda.is_available():
raise RuntimeError("Using MXFP4 quantized models requires a GPU")

if not is_accelerate_available():
raise ImportError("Using mxfp4 requires Accelerate: `pip install accelerate`")

if self.quantization_config.dequantize:
return

compute_capability = torch.cuda.get_device_capability()
major, minor = compute_capability

Expand Down
46 changes: 46 additions & 0 deletions tests/quantization/mxfp4/test_mxfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,52 @@ def test_quantizer_validation_low_compute_capability_with_dequantize(self):
if "compute capability" in str(e):
self.fail("Should not raise compute capability error when dequantize=True")

def test_quantizer_validation_dequantize_on_cpu(self):
"""Test quantizer validation with dequantize enabled on CPU-only environment"""
with patch("torch.cuda.is_available", return_value=False):
from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer

config = Mxfp4Config(dequantize=True)
quantizer = Mxfp4HfQuantizer(config)

# Should not raise error when dequantize=True even without CUDA
try:
quantizer.validate_environment()
except RuntimeError as e:
if "requires a GPU" in str(e):
self.fail("Should not raise GPU requirement error when dequantize=True on CPU")

def test_quantizer_validation_order_dequantize_before_cuda_check(self):
"""Test that dequantize check happens before CUDA availability check"""
# Mock both torch.cuda.is_available and is_accelerate_available to return False
with (
patch("torch.cuda.is_available", return_value=False),
patch(
"transformers.quantizers.quantizer_mxfp4.is_accelerate_available",
return_value=False,
),
):
from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer

# Test with dequantize=True - should pass even without CUDA and accelerate
config = Mxfp4Config(dequantize=True)
quantizer = Mxfp4HfQuantizer(config)

# This should not raise any error because dequantize check comes first
try:
quantizer.validate_environment()
except (RuntimeError, ImportError) as e:
if "requires a GPU" in str(e) or "requires Accelerate" in str(e):
self.fail(f"Should not raise error when dequantize=True: {e}")

# Test with dequantize=False - should still fail due to missing CUDA
config = Mxfp4Config(dequantize=False)
quantizer = Mxfp4HfQuantizer(config)

with self.assertRaises(RuntimeError) as context:
quantizer.validate_environment()
self.assertIn("requires a GPU", str(context.exception))

def test_quantizer_validation_missing_triton(self):
"""Test quantizer validation when triton is not available"""
with (
Expand Down