Skip to content

Commit cb739f8

Browse files
authored
[core] fix mxfp4 (#42382)
* initial commit * fix import * fix * add ops * style * decouple dequantize & deserialize logic * up
1 parent a95d997 commit cb739f8

File tree

2 files changed

+217
-10
lines changed

2 files changed

+217
-10
lines changed

src/transformers/integrations/mxfp4.py

Lines changed: 186 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,19 @@
1818
if is_torch_available():
1919
import torch
2020
from torch import nn
21+
from typing import Optional
22+
23+
from ..core_model_loading import ConversionOps
24+
2125

2226
if is_accelerate_available():
2327
from accelerate import init_empty_weights
2428

2529
import re
2630
from contextlib import contextmanager
2731

32+
from ..quantizers.quantizers_utils import get_module_from_name
33+
2834

2935
logger = logging.get_logger(__name__)
3036

@@ -70,6 +76,126 @@ def on_device(dev):
7076
yield
7177

7278

79+
class Mxfp4Quantize(ConversionOps):
80+
def __init__(self, hf_quantizer):
81+
self.hf_quantizer = hf_quantizer
82+
83+
def convert(
84+
self,
85+
input_dict: dict[str, torch.Tensor],
86+
model: Optional[torch.nn.Module] = None,
87+
missing_keys: Optional[list[str]] = None,
88+
full_layer_name: str | None = None,
89+
**kwargs,
90+
) -> dict[str, torch.Tensor]:
91+
_, value = tuple(input_dict.items())[0]
92+
value = value[0] if isinstance(value, list) else value
93+
94+
module, _ = get_module_from_name(model, full_layer_name)
95+
96+
with torch.device(value.device):
97+
if isinstance(module, Mxfp4GptOssExperts):
98+
triton_weight_tensor, weight_scale = quantize_to_mxfp4(value.transpose(-1, -2), triton_kernels_hub)
99+
PrecisionConfig, FlexCtx, InFlexData = (
100+
triton_kernels_hub.matmul_ogs.PrecisionConfig,
101+
triton_kernels_hub.matmul_ogs.FlexCtx,
102+
triton_kernels_hub.matmul_ogs.InFlexData,
103+
)
104+
triton_weight_tensor, weight_scale = swizzle_mxfp4(
105+
triton_weight_tensor, weight_scale, triton_kernels_hub
106+
)
107+
108+
proj = "gate_up_proj" if "gate_up_proj" in full_layer_name else "down_proj"
109+
110+
if proj in module._parameters:
111+
# Remove the nn.Parameter registration so we can attach the Triton tensor
112+
del module._parameters[proj]
113+
114+
setattr(module, proj, triton_weight_tensor)
115+
setattr(
116+
module,
117+
f"{proj}_precision_config",
118+
PrecisionConfig(weight_scale=weight_scale, flex_ctx=FlexCtx(rhs_data=InFlexData())),
119+
)
120+
121+
missing_keys.discard(f"{full_layer_name}")
122+
module._is_hf_initialized = True
123+
124+
return {}
125+
126+
127+
class Mxfp4Dequantize(ConversionOps):
128+
def __init__(self, hf_quantizer):
129+
self.hf_quantizer = hf_quantizer
130+
131+
def convert(
132+
self,
133+
input_dict: dict[str, torch.Tensor],
134+
model: Optional[torch.nn.Module] = None,
135+
full_layer_name: str | None = None,
136+
missing_keys=None,
137+
**kwargs,
138+
) -> dict[str, torch.Tensor]:
139+
param_data = {}
140+
if "_blocks" in input_dict.keys():
141+
if isinstance(input_dict["_blocks"], list):
142+
param_data["_blocks"] = input_dict["_blocks"][0]
143+
else:
144+
param_data["_blocks"] = input_dict["_blocks"]
145+
if "_scales" in input_dict.keys():
146+
if isinstance(input_dict["_scales"], list):
147+
param_data["_scales"] = input_dict["_scales"][0]
148+
else:
149+
param_data["_scales"] = input_dict["_scales"]
150+
151+
# Here we are dequantizing the weights
152+
dequantized = dequantize_convertops(param_data["_blocks"], param_data["_scales"], param_data["_blocks"].device)
153+
return {full_layer_name: dequantized}
154+
155+
156+
class Mxfp4Deserialize(ConversionOps):
157+
def __init__(self, hf_quantizer):
158+
self.hf_quantizer = hf_quantizer
159+
160+
def convert(
161+
self,
162+
input_dict: dict[str, torch.Tensor],
163+
model: Optional[torch.nn.Module] = None,
164+
full_layer_name: str | None = None,
165+
missing_keys: Optional[list[str]] = None,
166+
**kwargs,
167+
) -> dict[str, torch.Tensor]:
168+
param_data = {}
169+
if "_blocks" in input_dict.keys():
170+
if isinstance(input_dict["_blocks"], list):
171+
param_data["_blocks"] = input_dict["_blocks"][0]
172+
else:
173+
param_data["_blocks"] = input_dict["_blocks"]
174+
if "_scales" in input_dict.keys():
175+
if isinstance(input_dict["_scales"], list):
176+
param_data["_scales"] = input_dict["_scales"][0]
177+
else:
178+
param_data["_scales"] = input_dict["_scales"]
179+
180+
# Eagerly set tensors on the module and perform swizzle
181+
module, _ = get_module_from_name(model, full_layer_name)
182+
proj = "gate_up_proj" if "gate_up_proj" in full_layer_name else "down_proj"
183+
swizzle_mxfp4_convertops(
184+
param_data["_blocks"],
185+
param_data["_scales"],
186+
module,
187+
proj,
188+
param_data["_blocks"].device,
189+
triton_kernels_hub,
190+
)
191+
missing_keys.discard(f"{full_layer_name}")
192+
module._is_hf_initialized = True
193+
# We return an empty mapping since the module was updated in-place. This prevents
194+
# the loader from trying to materialize the original meta-parameter names again.
195+
# We don't use set_param_for_module since it expects mainly a torch.nn.Parameter or a safetensors pointer
196+
return {}
197+
198+
73199
# Copied from GPT_OSS repo and vllm
74200
def quantize_to_mxfp4(w, triton_kernels_hub):
75201
downcast_to_mxfp_torch = triton_kernels_hub.numerics_details.mxfp.downcast_to_mxfp_torch
@@ -110,6 +236,7 @@ def convert_moe_packed_tensors(
110236
"""
111237
import math
112238

239+
blocks = blocks.to(torch.uint8)
113240
# Check if blocks and scales are on CPU, and move to GPU if so
114241
if not blocks.is_cuda and torch.cuda.is_available():
115242
blocks = blocks.cuda()
@@ -162,26 +289,20 @@ def __init__(self, config):
162289
self.intermediate_size = config.intermediate_size
163290
self.hidden_size = config.hidden_size
164291

165-
self.gate_up_proj_blocks = nn.Parameter(
292+
self.gate_up_proj = nn.Parameter(
166293
torch.zeros(self.num_experts, 2 * self.intermediate_size, self.hidden_size // 32, 16, dtype=torch.uint8),
167294
requires_grad=False,
168295
)
169-
self.gate_up_proj_scales = nn.Parameter(
170-
torch.zeros(self.num_experts, 2 * self.intermediate_size, self.hidden_size // 32, dtype=torch.uint8),
171-
requires_grad=False,
172-
)
296+
173297
self.gate_up_proj_bias = nn.Parameter(
174298
torch.zeros(self.num_experts, 2 * self.intermediate_size, dtype=torch.float32), requires_grad=False
175299
)
176300

177-
self.down_proj_blocks = nn.Parameter(
301+
self.down_proj = nn.Parameter(
178302
torch.zeros((self.num_experts, self.hidden_size, self.intermediate_size // 32, 16), dtype=torch.uint8),
179303
requires_grad=False,
180304
)
181-
self.down_proj_scales = nn.Parameter(
182-
torch.zeros(self.num_experts, self.hidden_size, self.intermediate_size // 32, dtype=torch.uint8),
183-
requires_grad=False,
184-
)
305+
185306
self.down_proj_bias = nn.Parameter(
186307
torch.zeros(self.num_experts, self.hidden_size, dtype=torch.float32), requires_grad=False
187308
)
@@ -361,6 +482,14 @@ def dequantize(module, param_name, param_value, target_device, dq_param_name, **
361482
delattr(module, scales_attr)
362483

363484

485+
def dequantize_convertops(blocks, scales, target_device):
486+
dequantized = convert_moe_packed_tensors(blocks, scales)
487+
if target_device == "cpu" and torch.cuda.is_available():
488+
torch.cuda.empty_cache()
489+
dequantized = torch.nn.Parameter(dequantized.to(target_device))
490+
return dequantized
491+
492+
364493
def load_and_swizzle_mxfp4(module, param_name, param_value, target_device, triton_kernels_hub, **kwargs):
365494
"""
366495
This transforms the weights obtained using `convert_gpt_oss.py` to load them into `Mxfp4GptOssExperts`.
@@ -428,6 +557,53 @@ def load_and_swizzle_mxfp4(module, param_name, param_value, target_device, trito
428557
del blocks
429558

430559

560+
def swizzle_mxfp4_convertops(blocks, scales, module, proj, target_device, triton_kernels_hub):
561+
"""
562+
This transforms the weights obtained using `convert_gpt_oss.py` to load them into `Mxfp4GptOssExperts`.
563+
"""
564+
PrecisionConfig, FlexCtx, InFlexData = (
565+
triton_kernels_hub.matmul_ogs.PrecisionConfig,
566+
triton_kernels_hub.matmul_ogs.FlexCtx,
567+
triton_kernels_hub.matmul_ogs.InFlexData,
568+
)
569+
570+
local_experts = blocks.size(0)
571+
if getattr(target_device, "type", target_device) == "cpu":
572+
target_device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
573+
574+
blocks = blocks.to(target_device).contiguous()
575+
scales = scales.to(target_device).contiguous()
576+
577+
if proj == "gate_up_proj":
578+
blocks = blocks.reshape(local_experts, module.intermediate_size * 2, -1)
579+
else:
580+
blocks = blocks.reshape(local_experts, -1, module.intermediate_size // 2)
581+
if getattr(target_device, "type", target_device) == "cpu":
582+
target_device = "cuda"
583+
584+
with on_device(target_device):
585+
triton_weight_tensor, weight_scale = swizzle_mxfp4(
586+
blocks.transpose(-2, -1), scales.transpose(-2, -1), triton_kernels_hub
587+
)
588+
# need to overwrite the shapes for the kernels
589+
if proj == "gate_up_proj":
590+
triton_weight_tensor.shape = torch.Size([local_experts, module.hidden_size, module.intermediate_size * 2])
591+
else:
592+
triton_weight_tensor.shape = torch.Size([local_experts, module.intermediate_size, module.hidden_size])
593+
594+
# triton_weight_tensor is what needs to be passed in oai kernels. It stores the data, the shapes and any more objects. It's like a subtensor
595+
# Since the Experts module registers gate_up_proj and down_proj as nn.Parameters, we need to remove them so we can attach the Triton tensor
596+
if proj in module._parameters:
597+
# Remove the nn.Parameter registration so we can attach the Triton tensor
598+
del module._parameters[proj]
599+
setattr(module, proj, triton_weight_tensor)
600+
setattr(
601+
module,
602+
f"{proj}_precision_config",
603+
PrecisionConfig(weight_scale=weight_scale, flex_ctx=FlexCtx(rhs_data=InFlexData())),
604+
)
605+
606+
431607
def _replace_with_mxfp4_linear(
432608
model,
433609
modules_to_not_convert=None,

src/transformers/quantizers/quantizer_mxfp4.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
if is_torch_available():
3333
import torch
3434

35+
from ..core_model_loading import WeightConverter
36+
3537
logger = logging.get_logger(__name__)
3638
triton_kernels_hub = None
3739

@@ -157,6 +159,8 @@ def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **
157159
from ..integrations import Mxfp4GptOssExperts
158160
from ..models.gpt_oss.modeling_gpt_oss import GptOssExperts
159161

162+
if self.pre_quantized:
163+
return False
160164
# if we are dequantizing, the model doesn't have scales, and blocks only params like gate_up_proj and down_proj so we need to handle this case differently
161165
if self.quantization_config.dequantize and ("blocks" in param_name or "scales" in param_name):
162166
module, tensor_name = get_module_from_name(model, param_name[: -len("_blocks")])
@@ -426,3 +430,30 @@ def is_trainable(self) -> bool:
426430
"MXFP4 quantization don't support training, please consider dequantizing the model first by passing quantization_config=Mxfp4Config(dequantize=True) to .from_pretrained()"
427431
)
428432
return False
433+
434+
def get_quantize_ops(self):
435+
from ..integrations.mxfp4 import Mxfp4Quantize
436+
437+
return Mxfp4Quantize(self)
438+
439+
def get_weight_conversions(self):
440+
from ..integrations.mxfp4 import Mxfp4Dequantize, Mxfp4Deserialize
441+
442+
if self.pre_quantized:
443+
if self.quantization_config.dequantize:
444+
return [
445+
WeightConverter(
446+
source_keys=["_blocks", "_scales"],
447+
target_keys="",
448+
operations=[Mxfp4Dequantize(self)],
449+
)
450+
]
451+
else:
452+
return [
453+
WeightConverter(
454+
source_keys=["_blocks", "_scales"],
455+
target_keys="",
456+
operations=[Mxfp4Deserialize(self)],
457+
)
458+
]
459+
return []

0 commit comments

Comments
 (0)