Skip to content

Commit 6076ec8

Browse files
committed
add deepseek_v3 model
1 parent 89580df commit 6076ec8

File tree

10 files changed

+1298
-1
lines changed

10 files changed

+1298
-1
lines changed

mindone/transformers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -691,6 +691,7 @@
691691
from .utils import logging
692692

693693
if version.parse(transformers.__version__) >= version.parse("4.51.0"):
694+
from .models.deepseek_v3 import DeepseekV3ForCausalLM, DeepseekV3Model, DeepseekV3PreTrainedModel
694695
from .models.qwen3 import Qwen3ForCausalLM, Qwen3Model, Qwen3PreTrainedModel
695696

696697
if version.parse(transformers.__version__) >= version.parse("4.51.3"):
Lines changed: 282 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,282 @@
1+
# Copyright 2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from abc import ABC
15+
from functools import partial
16+
from typing import Optional
17+
18+
from transformers.utils import logging
19+
20+
import mindspore as ms
21+
import mindspore.nn as nn
22+
from mindspore import mint
23+
24+
from .cache_utils import Cache
25+
from .modeling_outputs import (
26+
BaseModelOutputWithPast,
27+
QuestionAnsweringModelOutput,
28+
SequenceClassifierOutputWithPast,
29+
TokenClassifierOutput,
30+
)
31+
from .models.auto import AutoModel
32+
from .processing_utils import Unpack
33+
from .utils import TransformersKwargs
34+
35+
logger = logging.get_logger(__name__)
36+
37+
38+
class GradientCheckpointingLayer(nn.Cell):
39+
"""Base class for layers with gradient checkpointing.
40+
41+
This class enables gradient checkpointing functionality for a layer. By default, gradient checkpointing is disabled
42+
(`gradient_checkpointing = False`). When `model.set_gradient_checkpointing()` is called, gradient checkpointing is
43+
enabled by setting `gradient_checkpointing = True` and assigning a checkpointing function to `_gradient_checkpointing_func`.
44+
45+
Important:
46+
47+
When using gradient checkpointing with `use_reentrant=True`, inputs that require gradients (e.g. hidden states)
48+
must be passed as positional arguments (`*args`) rather than keyword arguments to properly propagate gradients.
49+
50+
Example:
51+
52+
```python
53+
>>> # Correct - hidden_states passed as positional arg
54+
>>> out = self.layer(hidden_states, attention_mask=attention_mask)
55+
56+
>>> # Incorrect - hidden_states passed as keyword arg
57+
>>> out = self.layer(hidden_states=hidden_states, attention_mask=attention_mask)
58+
```
59+
"""
60+
61+
gradient_checkpointing = False
62+
63+
def __call__(self, *args, **kwargs):
64+
if self.gradient_checkpointing and self.training:
65+
do_warn = False
66+
layer_name = self.__class__.__name__
67+
message = f"Caching is incompatible with gradient checkpointing in {layer_name}. Setting"
68+
69+
if "use_cache" in kwargs and kwargs["use_cache"]:
70+
kwargs["use_cache"] = False
71+
message += " `use_cache=False`,"
72+
do_warn = True
73+
74+
# different names for the same thing in different layers
75+
if "past_key_value" in kwargs and kwargs["past_key_value"] is not None:
76+
kwargs["past_key_value"] = None
77+
message += " `past_key_value=None`,"
78+
do_warn = True
79+
80+
if "past_key_values" in kwargs and kwargs["past_key_values"] is not None:
81+
kwargs["past_key_values"] = None
82+
message += " `past_key_values=None`,"
83+
do_warn = True
84+
85+
if "layer_past" in kwargs and kwargs["layer_past"] is not None:
86+
kwargs["layer_past"] = None
87+
message += " `layer_past=None`,"
88+
do_warn = True
89+
90+
# warn if anything was changed
91+
if do_warn:
92+
message = message.rstrip(",") + "."
93+
logger.warning(message)
94+
95+
return self._gradient_checkpointing_func(partial(super().__call__, **kwargs), *args)
96+
return super().__call__(*args, **kwargs)
97+
98+
99+
class GenericForSequenceClassification(ABC):
100+
base_model_prefix = "model"
101+
102+
def __init__(self, config):
103+
super().__init__(config)
104+
self.num_labels = config.num_labels
105+
# Similar to `self.model = AutoModel.from_config(config)` but allows to change the base model name if needed in the child class
106+
setattr(self, self.base_model_prefix, AutoModel.from_config(config))
107+
self.score = mint.nn.Linear(config.hidden_size, self.num_labels, bias=False)
108+
109+
# Initialize weights and apply final processing
110+
self.post_init()
111+
112+
def construct(
113+
self,
114+
input_ids: Optional[ms.Tensor] = None,
115+
attention_mask: Optional[ms.Tensor] = None,
116+
position_ids: Optional[ms.Tensor] = None,
117+
past_key_values: Optional[Cache] = None,
118+
inputs_embeds: Optional[ms.Tensor] = None,
119+
labels: Optional[ms.Tensor] = None,
120+
use_cache: Optional[bool] = None,
121+
**kwargs: Unpack[TransformersKwargs],
122+
) -> SequenceClassifierOutputWithPast:
123+
transformer_outputs: BaseModelOutputWithPast = getattr(self, self.base_model_prefix)(
124+
input_ids,
125+
attention_mask=attention_mask,
126+
position_ids=position_ids,
127+
past_key_values=past_key_values,
128+
inputs_embeds=inputs_embeds,
129+
use_cache=use_cache,
130+
**kwargs,
131+
)
132+
hidden_states = transformer_outputs.last_hidden_state
133+
logits = self.score(hidden_states)
134+
135+
if input_ids is not None:
136+
batch_size = input_ids.shape[0]
137+
else:
138+
batch_size = inputs_embeds.shape[0]
139+
140+
if self.config.pad_token_id is None and batch_size != 1:
141+
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
142+
if self.config.pad_token_id is None:
143+
last_non_pad_token = -1
144+
elif input_ids is not None:
145+
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
146+
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, ms.int32)
147+
token_indices = mint.arange(input_ids.shape[-1], dtype=ms.int32)
148+
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
149+
else:
150+
last_non_pad_token = -1
151+
logger.warning_once(
152+
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
153+
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
154+
)
155+
156+
pooled_logits = logits[mint.arange(batch_size), last_non_pad_token]
157+
158+
loss = None
159+
if labels is not None:
160+
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
161+
162+
return SequenceClassifierOutputWithPast(
163+
loss=loss,
164+
logits=pooled_logits,
165+
past_key_values=transformer_outputs.past_key_values,
166+
hidden_states=transformer_outputs.hidden_states,
167+
attentions=transformer_outputs.attentions,
168+
)
169+
170+
171+
class GenericForQuestionAnswering(ABC):
172+
base_model_prefix = "model"
173+
174+
def __init__(self, config):
175+
super().__init__(config)
176+
# Similar to `self.model = AutoModel.from_config(config)` but allows to change the base model name if needed in the child class
177+
setattr(self, self.base_model_prefix, AutoModel.from_config(config))
178+
self.qa_outputs = mint.nn.Linear(config.hidden_size, 2)
179+
180+
# Initialize weights and apply final processing
181+
self.post_init()
182+
183+
def get_input_embeddings(self):
184+
return getattr(self, self.base_model_prefix).embed_tokens
185+
186+
def set_input_embeddings(self, value):
187+
getattr(self, self.base_model_prefix).embed_tokens = value
188+
189+
def construct(
190+
self,
191+
input_ids: Optional[ms.Tensor] = None,
192+
attention_mask: Optional[ms.Tensor] = None,
193+
position_ids: Optional[ms.Tensor] = None,
194+
past_key_values: Optional[Cache] = None,
195+
inputs_embeds: Optional[ms.Tensor] = None,
196+
start_positions: Optional[ms.Tensor] = None,
197+
end_positions: Optional[ms.Tensor] = None,
198+
**kwargs: Unpack[TransformersKwargs],
199+
) -> QuestionAnsweringModelOutput:
200+
outputs: BaseModelOutputWithPast = getattr(self, self.base_model_prefix)(
201+
input_ids,
202+
attention_mask=attention_mask,
203+
position_ids=position_ids,
204+
past_key_values=past_key_values,
205+
inputs_embeds=inputs_embeds,
206+
**kwargs,
207+
)
208+
209+
sequence_output = outputs.last_hidden_state
210+
211+
logits = self.qa_outputs(sequence_output)
212+
start_logits, end_logits = logits.split(1, dim=-1)
213+
start_logits = start_logits.squeeze(-1).contiguous()
214+
end_logits = end_logits.squeeze(-1).contiguous()
215+
216+
loss = None
217+
if start_positions is not None and end_positions is not None:
218+
loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)
219+
220+
return QuestionAnsweringModelOutput(
221+
loss=loss,
222+
start_logits=start_logits,
223+
end_logits=end_logits,
224+
hidden_states=outputs.hidden_states,
225+
attentions=outputs.attentions,
226+
)
227+
228+
229+
class GenericForTokenClassification(ABC):
230+
base_model_prefix = "model"
231+
232+
def __init__(self, config):
233+
super().__init__(config)
234+
self.num_labels = config.num_labels
235+
# Similar to `self.model = AutoModel.from_config(config)` but allows to change the base model name if needed in the child class
236+
setattr(self, self.base_model_prefix, AutoModel.from_config(config))
237+
if getattr(config, "classifier_dropout", None) is not None:
238+
classifier_dropout = config.classifier_dropout
239+
elif getattr(config, "hidden_dropout", None) is not None:
240+
classifier_dropout = config.hidden_dropout
241+
else:
242+
classifier_dropout = 0.1
243+
self.dropout = mint.nn.Dropout(classifier_dropout)
244+
self.score = mint.nn.Linear(config.hidden_size, config.num_labels)
245+
246+
# Initialize weights and apply final processing
247+
self.post_init()
248+
249+
def construct(
250+
self,
251+
input_ids: Optional[ms.Tensor] = None,
252+
attention_mask: Optional[ms.Tensor] = None,
253+
position_ids: Optional[ms.Tensor] = None,
254+
past_key_values: Optional[Cache] = None,
255+
inputs_embeds: Optional[ms.Tensor] = None,
256+
labels: Optional[ms.Tensor] = None,
257+
use_cache: Optional[bool] = None,
258+
**kwargs,
259+
) -> TokenClassifierOutput:
260+
outputs: BaseModelOutputWithPast = getattr(self, self.base_model_prefix)(
261+
input_ids,
262+
attention_mask=attention_mask,
263+
position_ids=position_ids,
264+
past_key_values=past_key_values,
265+
inputs_embeds=inputs_embeds,
266+
use_cache=use_cache,
267+
**kwargs,
268+
)
269+
sequence_output = outputs.last_hidden_state
270+
sequence_output = self.dropout(sequence_output)
271+
logits = self.score(sequence_output)
272+
273+
loss = None
274+
if labels is not None:
275+
loss = self.loss_function(logits, labels, self.config)
276+
277+
return TokenClassifierOutput(
278+
loss=loss,
279+
logits=logits,
280+
hidden_states=outputs.hidden_states,
281+
attentions=outputs.attentions,
282+
)

mindone/transformers/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@
103103
)
104104

105105
if version.parse(transformers.__version__) >= version.parse("4.51.0"):
106-
from . import qwen3
106+
from . import deepseek_v3, qwen3
107107

108108
if version.parse(transformers.__version__) >= version.parse("4.51.3"):
109109
from . import glm4

mindone/transformers/models/auto/configuration_auto.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,9 @@
296296

297297
if version.parse(transformers.__version__) >= version.parse("4.51.0"):
298298
CONFIG_MAPPING_NAMES.update({"qwen3": "Qwen3Config"})
299+
CONFIG_MAPPING_NAMES.update({"deepseek_v3": "DeepseekV3Config"})
299300
MODEL_NAMES_MAPPING.update({"qwen3": "Qwen3Model"})
301+
MODEL_NAMES_MAPPING.update({"deepseek_v3": "DeepSeek-V3"})
300302

301303
if version.parse(transformers.__version__) >= version.parse("4.51.3"):
302304
CONFIG_MAPPING_NAMES.update({"glm4": "Glm4Config"})

mindone/transformers/models/auto/modeling_auto.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -674,6 +674,8 @@
674674

675675

676676
if version.parse(transformers.__version__) >= version.parse("4.51.0"):
677+
MODEL_MAPPING_NAMES.update({"deepseek_v3": "DeepseekV3Model"}),
678+
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.update({"deepseek_v3": "DeepseekV3ForCausalLM"}),
677679
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.update({"qwen3": "Qwen3Model"})
678680
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.update({"qwen3": "Qwen3ForCausalLM"})
679681
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.update({"qwen3": "Qwen3ForSequenceClassification"})
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
2+
#
3+
# This code is adapted from https://github.com/huggingface/transformers
4+
# with modifications to run transformers on mindspore.
5+
#
6+
# This code is adapted from https://github.com/huggingface/transformers
7+
# with modifications to run transformers on mindspore.
8+
#
9+
# Licensed under the Apache License, Version 2.0 (the "License");
10+
# you may not use this file except in compliance with the License.
11+
# You may obtain a copy of the License at
12+
#
13+
# http://www.apache.org/licenses/LICENSE-2.0
14+
#
15+
# Unless required by applicable law or agreed to in writing, software
16+
# distributed under the License is distributed on an "AS IS" BASIS,
17+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18+
# See the License for the specific language governing permissions and
19+
# limitations under the License.
20+
21+
from .modeling_deepseek_v3 import *

0 commit comments

Comments
 (0)