Skip to content

Commit fbfba6c

Browse files
wasertechita.zaporozhets@huggingface.co
authored andcommitted
fix: Initialize ApertusMLP's xielu activation using torch_dtype (#42864)
* Fix Apertus model crash on float16 hardware Initialize XIELU activation with correct dtype from config (using config.dtype instead of default bfloat16) to prevent promotion to float32 and subsequent crashes on Turing/float16 GPUs. * refactor: Move `ACT2CLS` import to top-level in Apertus models.
1 parent 2f81d58 commit fbfba6c

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

src/transformers/models/apertus/modeling_apertus.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import torch
2626
from torch import nn
2727

28-
from ...activations import ACT2FN
28+
from ...activations import ACT2CLS, ACT2FN
2929
from ...cache_utils import Cache, DynamicCache
3030
from ...generation import GenerationMixin
3131
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
@@ -49,6 +49,8 @@ def __init__(self, config):
4949
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
5050
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
5151
self.act_fn = ACT2FN[config.hidden_act]
52+
if config.hidden_act == "xielu":
53+
self.act_fn = ACT2CLS["xielu"](dtype=config.dtype)
5254

5355
def forward(self, x):
5456
return self.down_proj(self.act_fn(self.up_proj(x)))

src/transformers/models/apertus/modular_apertus.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import torch
2020
from torch import nn
2121

22+
from ...activations import ACT2CLS
2223
from ...cache_utils import Cache
2324
from ...configuration_utils import PreTrainedConfig
2425
from ...modeling_rope_utils import RopeParameters
@@ -192,9 +193,11 @@ def __init__(
192193

193194
class ApertusMLP(NemotronMLP):
194195
def __init__(self, config):
195-
super().__init__()
196+
super().__init__(config)
196197
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
197198
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
199+
if config.hidden_act == "xielu":
200+
self.act_fn = ACT2CLS["xielu"](dtype=config.dtype)
198201

199202

200203
class ApertusRMSNorm(LlamaRMSNorm):

0 commit comments

Comments
 (0)