Skip to content

Commit 03708cc

Browse files
bzantiumjson.bourne
andauthored
add DeepseekV3ForTokenClassification (#40641)
* add DeepseekV3ForTokenClassification * fix typo --------- Co-authored-by: json.bourne <json.bourne@kakaocorp.com>
1 parent c485c52 commit 03708cc

File tree

5 files changed

+25
-2
lines changed

5 files changed

+25
-2
lines changed

docs/source/en/model_doc/deepseek_v3.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,3 +188,8 @@ error, it means NCCL was probably not loaded.
188188

189189
[[autodoc]] DeepseekV3ForSequenceClassification
190190
- forward
191+
192+
## DeepseekV3ForTokenClassification
193+
194+
[[autodoc]] DeepseekV3ForTokenClassification
195+
- forward

src/transformers/models/auto/modeling_auto.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1412,6 +1412,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
14121412
("data2vec-text", "Data2VecTextForTokenClassification"),
14131413
("deberta", "DebertaForTokenClassification"),
14141414
("deberta-v2", "DebertaV2ForTokenClassification"),
1415+
("deepseek_v3", "DeepseekV3ForTokenClassification"),
14151416
("diffllama", "DiffLlamaForTokenClassification"),
14161417
("distilbert", "DistilBertForTokenClassification"),
14171418
("electra", "ElectraForTokenClassification"),

src/transformers/models/deepseek_v3/modeling_deepseek_v3.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,11 @@
1717
from ...integrations import use_kernel_forward_from_hub
1818
from ...masking_utils import create_causal_mask
1919
from ...modeling_flash_attention_utils import FlashAttentionKwargs
20-
from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer
20+
from ...modeling_layers import (
21+
GenericForSequenceClassification,
22+
GenericForTokenClassification,
23+
GradientCheckpointingLayer,
24+
)
2125
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
2226
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
2327
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
@@ -676,9 +680,14 @@ class DeepseekV3ForSequenceClassification(GenericForSequenceClassification, Deep
676680
pass
677681

678682

683+
class DeepseekV3ForTokenClassification(GenericForTokenClassification, DeepseekV3PreTrainedModel):
684+
pass
685+
686+
679687
__all__ = [
680688
"DeepseekV3PreTrainedModel",
681689
"DeepseekV3Model",
682690
"DeepseekV3ForCausalLM",
683691
"DeepseekV3ForSequenceClassification",
692+
"DeepseekV3ForTokenClassification",
684693
]

src/transformers/models/deepseek_v3/modular_deepseek_v3.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from ...activations import ACT2FN
1010
from ...cache_utils import Cache
1111
from ...modeling_flash_attention_utils import FlashAttentionKwargs
12-
from ...modeling_layers import GenericForSequenceClassification
12+
from ...modeling_layers import GenericForSequenceClassification, GenericForTokenClassification
1313
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
1414
from ...processing_utils import Unpack
1515
from ...utils import logging
@@ -361,9 +361,14 @@ class DeepseekV3ForSequenceClassification(GenericForSequenceClassification, Deep
361361
pass
362362

363363

364+
class DeepseekV3ForTokenClassification(GenericForTokenClassification, DeepseekV3PreTrainedModel):
365+
pass
366+
367+
364368
__all__ = [
365369
"DeepseekV3PreTrainedModel",
366370
"DeepseekV3Model",
367371
"DeepseekV3ForCausalLM",
368372
"DeepseekV3ForSequenceClassification",
373+
"DeepseekV3ForTokenClassification",
369374
]

tests/models/deepseek_v3/test_modeling_deepseek_v3.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from transformers import (
4444
DeepseekV3ForCausalLM,
4545
DeepseekV3ForSequenceClassification,
46+
DeepseekV3ForTokenClassification,
4647
DeepseekV3Model,
4748
)
4849
from transformers.models.deepseek_v3.modeling_deepseek_v3 import (
@@ -217,6 +218,7 @@ class DeepseekV3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
217218
DeepseekV3Model,
218219
DeepseekV3ForCausalLM,
219220
DeepseekV3ForSequenceClassification,
221+
DeepseekV3ForTokenClassification,
220222
)
221223
if is_torch_available()
222224
else ()
@@ -226,6 +228,7 @@ class DeepseekV3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
226228
{
227229
"feature-extraction": DeepseekV3Model,
228230
"text-classification": DeepseekV3ForSequenceClassification,
231+
"token-classification": DeepseekV3ForTokenClassification,
229232
"text-generation": DeepseekV3ForCausalLM,
230233
"zero-shot": DeepseekV3ForSequenceClassification,
231234
}

0 commit comments

Comments
 (0)