|
18 | 18 | if is_torch_available(): |
19 | 19 | import torch |
20 | 20 | from torch import nn |
| 21 | +from typing import Optional |
| 22 | + |
| 23 | +from ..core_model_loading import ConversionOps |
| 24 | + |
21 | 25 |
|
22 | 26 | if is_accelerate_available(): |
23 | 27 | from accelerate import init_empty_weights |
24 | 28 |
|
25 | 29 | import re |
26 | 30 | from contextlib import contextmanager |
27 | 31 |
|
| 32 | +from ..quantizers.quantizers_utils import get_module_from_name |
| 33 | + |
28 | 34 |
|
29 | 35 | logger = logging.get_logger(__name__) |
30 | 36 |
|
@@ -70,6 +76,126 @@ def on_device(dev): |
70 | 76 | yield |
71 | 77 |
|
72 | 78 |
|
| 79 | +class Mxfp4Quantize(ConversionOps): |
| 80 | + def __init__(self, hf_quantizer): |
| 81 | + self.hf_quantizer = hf_quantizer |
| 82 | + |
| 83 | + def convert( |
| 84 | + self, |
| 85 | + input_dict: dict[str, torch.Tensor], |
| 86 | + model: Optional[torch.nn.Module] = None, |
| 87 | + missing_keys: Optional[list[str]] = None, |
| 88 | + full_layer_name: str | None = None, |
| 89 | + **kwargs, |
| 90 | + ) -> dict[str, torch.Tensor]: |
| 91 | + _, value = tuple(input_dict.items())[0] |
| 92 | + value = value[0] if isinstance(value, list) else value |
| 93 | + |
| 94 | + module, _ = get_module_from_name(model, full_layer_name) |
| 95 | + |
| 96 | + with torch.device(value.device): |
| 97 | + if isinstance(module, Mxfp4GptOssExperts): |
| 98 | + triton_weight_tensor, weight_scale = quantize_to_mxfp4(value.transpose(-1, -2), triton_kernels_hub) |
| 99 | + PrecisionConfig, FlexCtx, InFlexData = ( |
| 100 | + triton_kernels_hub.matmul_ogs.PrecisionConfig, |
| 101 | + triton_kernels_hub.matmul_ogs.FlexCtx, |
| 102 | + triton_kernels_hub.matmul_ogs.InFlexData, |
| 103 | + ) |
| 104 | + triton_weight_tensor, weight_scale = swizzle_mxfp4( |
| 105 | + triton_weight_tensor, weight_scale, triton_kernels_hub |
| 106 | + ) |
| 107 | + |
| 108 | + proj = "gate_up_proj" if "gate_up_proj" in full_layer_name else "down_proj" |
| 109 | + |
| 110 | + if proj in module._parameters: |
| 111 | + # Remove the nn.Parameter registration so we can attach the Triton tensor |
| 112 | + del module._parameters[proj] |
| 113 | + |
| 114 | + setattr(module, proj, triton_weight_tensor) |
| 115 | + setattr( |
| 116 | + module, |
| 117 | + f"{proj}_precision_config", |
| 118 | + PrecisionConfig(weight_scale=weight_scale, flex_ctx=FlexCtx(rhs_data=InFlexData())), |
| 119 | + ) |
| 120 | + |
| 121 | + missing_keys.discard(f"{full_layer_name}") |
| 122 | + module._is_hf_initialized = True |
| 123 | + |
| 124 | + return {} |
| 125 | + |
| 126 | + |
| 127 | +class Mxfp4Dequantize(ConversionOps): |
| 128 | + def __init__(self, hf_quantizer): |
| 129 | + self.hf_quantizer = hf_quantizer |
| 130 | + |
| 131 | + def convert( |
| 132 | + self, |
| 133 | + input_dict: dict[str, torch.Tensor], |
| 134 | + model: Optional[torch.nn.Module] = None, |
| 135 | + full_layer_name: str | None = None, |
| 136 | + missing_keys=None, |
| 137 | + **kwargs, |
| 138 | + ) -> dict[str, torch.Tensor]: |
| 139 | + param_data = {} |
| 140 | + if "_blocks" in input_dict.keys(): |
| 141 | + if isinstance(input_dict["_blocks"], list): |
| 142 | + param_data["_blocks"] = input_dict["_blocks"][0] |
| 143 | + else: |
| 144 | + param_data["_blocks"] = input_dict["_blocks"] |
| 145 | + if "_scales" in input_dict.keys(): |
| 146 | + if isinstance(input_dict["_scales"], list): |
| 147 | + param_data["_scales"] = input_dict["_scales"][0] |
| 148 | + else: |
| 149 | + param_data["_scales"] = input_dict["_scales"] |
| 150 | + |
| 151 | + # Here we are dequantizing the weights |
| 152 | + dequantized = dequantize_convertops(param_data["_blocks"], param_data["_scales"], param_data["_blocks"].device) |
| 153 | + return {full_layer_name: dequantized} |
| 154 | + |
| 155 | + |
| 156 | +class Mxfp4Deserialize(ConversionOps): |
| 157 | + def __init__(self, hf_quantizer): |
| 158 | + self.hf_quantizer = hf_quantizer |
| 159 | + |
| 160 | + def convert( |
| 161 | + self, |
| 162 | + input_dict: dict[str, torch.Tensor], |
| 163 | + model: Optional[torch.nn.Module] = None, |
| 164 | + full_layer_name: str | None = None, |
| 165 | + missing_keys: Optional[list[str]] = None, |
| 166 | + **kwargs, |
| 167 | + ) -> dict[str, torch.Tensor]: |
| 168 | + param_data = {} |
| 169 | + if "_blocks" in input_dict.keys(): |
| 170 | + if isinstance(input_dict["_blocks"], list): |
| 171 | + param_data["_blocks"] = input_dict["_blocks"][0] |
| 172 | + else: |
| 173 | + param_data["_blocks"] = input_dict["_blocks"] |
| 174 | + if "_scales" in input_dict.keys(): |
| 175 | + if isinstance(input_dict["_scales"], list): |
| 176 | + param_data["_scales"] = input_dict["_scales"][0] |
| 177 | + else: |
| 178 | + param_data["_scales"] = input_dict["_scales"] |
| 179 | + |
| 180 | + # Eagerly set tensors on the module and perform swizzle |
| 181 | + module, _ = get_module_from_name(model, full_layer_name) |
| 182 | + proj = "gate_up_proj" if "gate_up_proj" in full_layer_name else "down_proj" |
| 183 | + swizzle_mxfp4_convertops( |
| 184 | + param_data["_blocks"], |
| 185 | + param_data["_scales"], |
| 186 | + module, |
| 187 | + proj, |
| 188 | + param_data["_blocks"].device, |
| 189 | + triton_kernels_hub, |
| 190 | + ) |
| 191 | + missing_keys.discard(f"{full_layer_name}") |
| 192 | + module._is_hf_initialized = True |
| 193 | + # We return an empty mapping since the module was updated in-place. This prevents |
| 194 | + # the loader from trying to materialize the original meta-parameter names again. |
| 195 | + # We don't use set_param_for_module since it expects mainly a torch.nn.Parameter or a safetensors pointer |
| 196 | + return {} |
| 197 | + |
| 198 | + |
73 | 199 | # Copied from GPT_OSS repo and vllm |
74 | 200 | def quantize_to_mxfp4(w, triton_kernels_hub): |
75 | 201 | downcast_to_mxfp_torch = triton_kernels_hub.numerics_details.mxfp.downcast_to_mxfp_torch |
@@ -110,6 +236,7 @@ def convert_moe_packed_tensors( |
110 | 236 | """ |
111 | 237 | import math |
112 | 238 |
|
| 239 | + blocks = blocks.to(torch.uint8) |
113 | 240 | # Check if blocks and scales are on CPU, and move to GPU if so |
114 | 241 | if not blocks.is_cuda and torch.cuda.is_available(): |
115 | 242 | blocks = blocks.cuda() |
@@ -162,26 +289,20 @@ def __init__(self, config): |
162 | 289 | self.intermediate_size = config.intermediate_size |
163 | 290 | self.hidden_size = config.hidden_size |
164 | 291 |
|
165 | | - self.gate_up_proj_blocks = nn.Parameter( |
| 292 | + self.gate_up_proj = nn.Parameter( |
166 | 293 | torch.zeros(self.num_experts, 2 * self.intermediate_size, self.hidden_size // 32, 16, dtype=torch.uint8), |
167 | 294 | requires_grad=False, |
168 | 295 | ) |
169 | | - self.gate_up_proj_scales = nn.Parameter( |
170 | | - torch.zeros(self.num_experts, 2 * self.intermediate_size, self.hidden_size // 32, dtype=torch.uint8), |
171 | | - requires_grad=False, |
172 | | - ) |
| 296 | + |
173 | 297 | self.gate_up_proj_bias = nn.Parameter( |
174 | 298 | torch.zeros(self.num_experts, 2 * self.intermediate_size, dtype=torch.float32), requires_grad=False |
175 | 299 | ) |
176 | 300 |
|
177 | | - self.down_proj_blocks = nn.Parameter( |
| 301 | + self.down_proj = nn.Parameter( |
178 | 302 | torch.zeros((self.num_experts, self.hidden_size, self.intermediate_size // 32, 16), dtype=torch.uint8), |
179 | 303 | requires_grad=False, |
180 | 304 | ) |
181 | | - self.down_proj_scales = nn.Parameter( |
182 | | - torch.zeros(self.num_experts, self.hidden_size, self.intermediate_size // 32, dtype=torch.uint8), |
183 | | - requires_grad=False, |
184 | | - ) |
| 305 | + |
185 | 306 | self.down_proj_bias = nn.Parameter( |
186 | 307 | torch.zeros(self.num_experts, self.hidden_size, dtype=torch.float32), requires_grad=False |
187 | 308 | ) |
@@ -361,6 +482,14 @@ def dequantize(module, param_name, param_value, target_device, dq_param_name, ** |
361 | 482 | delattr(module, scales_attr) |
362 | 483 |
|
363 | 484 |
|
| 485 | +def dequantize_convertops(blocks, scales, target_device): |
| 486 | + dequantized = convert_moe_packed_tensors(blocks, scales) |
| 487 | + if target_device == "cpu" and torch.cuda.is_available(): |
| 488 | + torch.cuda.empty_cache() |
| 489 | + dequantized = torch.nn.Parameter(dequantized.to(target_device)) |
| 490 | + return dequantized |
| 491 | + |
| 492 | + |
364 | 493 | def load_and_swizzle_mxfp4(module, param_name, param_value, target_device, triton_kernels_hub, **kwargs): |
365 | 494 | """ |
366 | 495 | This transforms the weights obtained using `convert_gpt_oss.py` to load them into `Mxfp4GptOssExperts`. |
@@ -428,6 +557,53 @@ def load_and_swizzle_mxfp4(module, param_name, param_value, target_device, trito |
428 | 557 | del blocks |
429 | 558 |
|
430 | 559 |
|
| 560 | +def swizzle_mxfp4_convertops(blocks, scales, module, proj, target_device, triton_kernels_hub): |
| 561 | + """ |
| 562 | + This transforms the weights obtained using `convert_gpt_oss.py` to load them into `Mxfp4GptOssExperts`. |
| 563 | + """ |
| 564 | + PrecisionConfig, FlexCtx, InFlexData = ( |
| 565 | + triton_kernels_hub.matmul_ogs.PrecisionConfig, |
| 566 | + triton_kernels_hub.matmul_ogs.FlexCtx, |
| 567 | + triton_kernels_hub.matmul_ogs.InFlexData, |
| 568 | + ) |
| 569 | + |
| 570 | + local_experts = blocks.size(0) |
| 571 | + if getattr(target_device, "type", target_device) == "cpu": |
| 572 | + target_device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda" |
| 573 | + |
| 574 | + blocks = blocks.to(target_device).contiguous() |
| 575 | + scales = scales.to(target_device).contiguous() |
| 576 | + |
| 577 | + if proj == "gate_up_proj": |
| 578 | + blocks = blocks.reshape(local_experts, module.intermediate_size * 2, -1) |
| 579 | + else: |
| 580 | + blocks = blocks.reshape(local_experts, -1, module.intermediate_size // 2) |
| 581 | + if getattr(target_device, "type", target_device) == "cpu": |
| 582 | + target_device = "cuda" |
| 583 | + |
| 584 | + with on_device(target_device): |
| 585 | + triton_weight_tensor, weight_scale = swizzle_mxfp4( |
| 586 | + blocks.transpose(-2, -1), scales.transpose(-2, -1), triton_kernels_hub |
| 587 | + ) |
| 588 | + # need to overwrite the shapes for the kernels |
| 589 | + if proj == "gate_up_proj": |
| 590 | + triton_weight_tensor.shape = torch.Size([local_experts, module.hidden_size, module.intermediate_size * 2]) |
| 591 | + else: |
| 592 | + triton_weight_tensor.shape = torch.Size([local_experts, module.intermediate_size, module.hidden_size]) |
| 593 | + |
| 594 | + # triton_weight_tensor is what needs to be passed in oai kernels. It stores the data, the shapes and any more objects. It's like a subtensor |
| 595 | + # Since the Experts module registers gate_up_proj and down_proj as nn.Parameters, we need to remove them so we can attach the Triton tensor |
| 596 | + if proj in module._parameters: |
| 597 | + # Remove the nn.Parameter registration so we can attach the Triton tensor |
| 598 | + del module._parameters[proj] |
| 599 | + setattr(module, proj, triton_weight_tensor) |
| 600 | + setattr( |
| 601 | + module, |
| 602 | + f"{proj}_precision_config", |
| 603 | + PrecisionConfig(weight_scale=weight_scale, flex_ctx=FlexCtx(rhs_data=InFlexData())), |
| 604 | + ) |
| 605 | + |
| 606 | + |
431 | 607 | def _replace_with_mxfp4_linear( |
432 | 608 | model, |
433 | 609 | modules_to_not_convert=None, |
|
0 commit comments