Skip to content

Commit 26667d3

Browse files
committed
Fix Apertus model crash on float16 hardware
Initialize XIELU activation with correct dtype from config (using config.dtype instead of default bfloat16) to prevent promotion to float32 and subsequent crashes on Turing/float16 GPUs.
1 parent 40dc11c commit 26667d3

File tree

3 files changed

+12
-1
lines changed

3 files changed

+12
-1
lines changed

src/transformers/models/apertus/modeling_apertus.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@ def __init__(self, config):
4949
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
5050
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
5151
self.act_fn = ACT2FN[config.hidden_act]
52+
if config.hidden_act == "xielu":
53+
from ...activations import ACT2CLS
54+
55+
self.act_fn = ACT2CLS["xielu"](dtype=config.dtype)
5256

5357
def forward(self, x):
5458
return self.down_proj(self.act_fn(self.up_proj(x)))

src/transformers/models/apertus/modular_apertus.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,9 +192,13 @@ def __init__(
192192

193193
class ApertusMLP(NemotronMLP):
194194
def __init__(self, config):
195-
super().__init__()
195+
super().__init__(config)
196196
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
197197
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
198+
if config.hidden_act == "xielu":
199+
from ...activations import ACT2CLS
200+
201+
self.act_fn = ACT2CLS["xielu"](dtype=config.dtype)
198202

199203

200204
class ApertusRMSNorm(LlamaRMSNorm):

tests/models/apertus/test_modeling_apertus.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,6 @@ class ApertusModelTest(CausalLMModelTest, unittest.TestCase):
6060
@slow
6161
class ApertusIntegrationTest(unittest.TestCase):
6262
pass
63+
64+
65+

0 commit comments

Comments
 (0)