Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mindone/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,7 @@
from .utils import logging

if version.parse(transformers.__version__) >= version.parse("4.51.0"):
from .models.deepseek_v3 import DeepseekV3ForCausalLM, DeepseekV3Model, DeepseekV3PreTrainedModel
from .models.qwen3 import Qwen3ForCausalLM, Qwen3Model, Qwen3PreTrainedModel

if version.parse(transformers.__version__) >= version.parse("4.51.3"):
Expand Down
285 changes: 285 additions & 0 deletions mindone/transformers/modeling_layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,285 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# This code is adapted from https://github.com/huggingface/transformers
# with modifications to run transformers on mindspore.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC
from functools import partial
from typing import Optional

from transformers.utils import logging

import mindspore as ms
import mindspore.nn as nn
from mindspore import mint

from .cache_utils import Cache
from .modeling_outputs import (
BaseModelOutputWithPast,
QuestionAnsweringModelOutput,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
from .models.auto import AutoModel
from .processing_utils import Unpack
from .utils import TransformersKwargs

logger = logging.get_logger(__name__)


class GradientCheckpointingLayer(nn.Cell):
"""Base class for layers with gradient checkpointing.
This class enables gradient checkpointing functionality for a layer. By default, gradient checkpointing is disabled
(`gradient_checkpointing = False`). When `model.set_gradient_checkpointing()` is called, gradient checkpointing is
enabled by setting `gradient_checkpointing = True` and assigning a checkpointing function to `_gradient_checkpointing_func`.
Important:
When using gradient checkpointing with `use_reentrant=True`, inputs that require gradients (e.g. hidden states)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems use_reetrant is not implemented in this class?

must be passed as positional arguments (`*args`) rather than keyword arguments to properly propagate gradients.
Example:
```python
>>> # Correct - hidden_states passed as positional arg
>>> out = self.layer(hidden_states, attention_mask=attention_mask)
>>> # Incorrect - hidden_states passed as keyword arg
>>> out = self.layer(hidden_states=hidden_states, attention_mask=attention_mask)
```
"""

gradient_checkpointing = False

def __call__(self, *args, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why use __call__ instead of construct for this nn.Cell class

if self.gradient_checkpointing and self.training:
do_warn = False
layer_name = self.__class__.__name__
message = f"Caching is incompatible with gradient checkpointing in {layer_name}. Setting"

if "use_cache" in kwargs and kwargs["use_cache"]:
kwargs["use_cache"] = False
message += " `use_cache=False`,"
do_warn = True

# different names for the same thing in different layers
if "past_key_value" in kwargs and kwargs["past_key_value"] is not None:
kwargs["past_key_value"] = None
message += " `past_key_value=None`,"
do_warn = True

if "past_key_values" in kwargs and kwargs["past_key_values"] is not None:
kwargs["past_key_values"] = None
message += " `past_key_values=None`,"
do_warn = True

if "layer_past" in kwargs and kwargs["layer_past"] is not None:
kwargs["layer_past"] = None
message += " `layer_past=None`,"
do_warn = True

# warn if anything was changed
if do_warn:
message = message.rstrip(",") + "."
logger.warning(message)

return self._gradient_checkpointing_func(partial(super().__call__, **kwargs), *args)
return super().__call__(*args, **kwargs)


class GenericForSequenceClassification(ABC):
base_model_prefix = "model"

def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
# Similar to `self.model = AutoModel.from_config(config)` but allows to change the base model name if needed in the child class
setattr(self, self.base_model_prefix, AutoModel.from_config(config))
self.score = mint.nn.Linear(config.hidden_size, self.num_labels, bias=False)

# Initialize weights and apply final processing
self.post_init()

def construct(
self,
input_ids: Optional[ms.Tensor] = None,
attention_mask: Optional[ms.Tensor] = None,
position_ids: Optional[ms.Tensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[ms.Tensor] = None,
labels: Optional[ms.Tensor] = None,
use_cache: Optional[bool] = None,
**kwargs: Unpack[TransformersKwargs],
) -> SequenceClassifierOutputWithPast:
transformer_outputs: BaseModelOutputWithPast = getattr(self, self.base_model_prefix)(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
**kwargs,
)
hidden_states = transformer_outputs.last_hidden_state
logits = self.score(hidden_states)

if input_ids is not None:
batch_size = input_ids.shape[0]
else:
batch_size = inputs_embeds.shape[0]

if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
last_non_pad_token = -1
elif input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, ms.int32)
token_indices = mint.arange(input_ids.shape[-1], dtype=ms.int32)
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
else:
last_non_pad_token = -1
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)

pooled_logits = logits[mint.arange(batch_size), last_non_pad_token]

loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)

return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)


class GenericForQuestionAnswering(ABC):
base_model_prefix = "model"

def __init__(self, config):
super().__init__(config)
# Similar to `self.model = AutoModel.from_config(config)` but allows to change the base model name if needed in the child class
setattr(self, self.base_model_prefix, AutoModel.from_config(config))
self.qa_outputs = mint.nn.Linear(config.hidden_size, 2)

# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return getattr(self, self.base_model_prefix).embed_tokens

def set_input_embeddings(self, value):
getattr(self, self.base_model_prefix).embed_tokens = value

def construct(
self,
input_ids: Optional[ms.Tensor] = None,
attention_mask: Optional[ms.Tensor] = None,
position_ids: Optional[ms.Tensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[ms.Tensor] = None,
start_positions: Optional[ms.Tensor] = None,
end_positions: Optional[ms.Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> QuestionAnsweringModelOutput:
outputs: BaseModelOutputWithPast = getattr(self, self.base_model_prefix)(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
**kwargs,
)

sequence_output = outputs.last_hidden_state

logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1).contiguous()

loss = None
if start_positions is not None and end_positions is not None:
loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)

return QuestionAnsweringModelOutput(
loss=loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)


class GenericForTokenClassification(ABC):
base_model_prefix = "model"

def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
# Similar to `self.model = AutoModel.from_config(config)` but allows to change the base model name if needed in the child class
setattr(self, self.base_model_prefix, AutoModel.from_config(config))
if getattr(config, "classifier_dropout", None) is not None:
classifier_dropout = config.classifier_dropout
elif getattr(config, "hidden_dropout", None) is not None:
classifier_dropout = config.hidden_dropout
else:
classifier_dropout = 0.1
self.dropout = mint.nn.Dropout(classifier_dropout)
self.score = mint.nn.Linear(config.hidden_size, config.num_labels)

# Initialize weights and apply final processing
self.post_init()

def construct(
self,
input_ids: Optional[ms.Tensor] = None,
attention_mask: Optional[ms.Tensor] = None,
position_ids: Optional[ms.Tensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[ms.Tensor] = None,
labels: Optional[ms.Tensor] = None,
use_cache: Optional[bool] = None,
**kwargs,
) -> TokenClassifierOutput:
outputs: BaseModelOutputWithPast = getattr(self, self.base_model_prefix)(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
**kwargs,
)
sequence_output = outputs.last_hidden_state
sequence_output = self.dropout(sequence_output)
logits = self.score(sequence_output)

loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.config)

return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
2 changes: 1 addition & 1 deletion mindone/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@
)

if version.parse(transformers.__version__) >= version.parse("4.51.0"):
from . import qwen3
from . import deepseek_v3, qwen3

if version.parse(transformers.__version__) >= version.parse("4.51.3"):
from . import glm4
Expand Down
2 changes: 2 additions & 0 deletions mindone/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,9 @@

if version.parse(transformers.__version__) >= version.parse("4.51.0"):
CONFIG_MAPPING_NAMES.update({"qwen3": "Qwen3Config"})
CONFIG_MAPPING_NAMES.update({"deepseek_v3": "DeepseekV3Config"})
MODEL_NAMES_MAPPING.update({"qwen3": "Qwen3Model"})
MODEL_NAMES_MAPPING.update({"deepseek_v3": "DeepSeek-V3"})

if version.parse(transformers.__version__) >= version.parse("4.51.3"):
CONFIG_MAPPING_NAMES.update({"glm4": "Glm4Config"})
Expand Down
2 changes: 2 additions & 0 deletions mindone/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,8 @@


if version.parse(transformers.__version__) >= version.parse("4.51.0"):
MODEL_MAPPING_NAMES.update({"deepseek_v3": "DeepseekV3Model"}),
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.update({"deepseek_v3": "DeepseekV3ForCausalLM"}),
Comment on lines +677 to +678
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

These lines have trailing commas, which create unnecessary tuples. While this is syntactically valid, it's cleaner to remove them for better code clarity.

Suggested change
MODEL_MAPPING_NAMES.update({"deepseek_v3": "DeepseekV3Model"}),
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.update({"deepseek_v3": "DeepseekV3ForCausalLM"}),
MODEL_MAPPING_NAMES.update({"deepseek_v3": "DeepseekV3Model"})
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.update({"deepseek_v3": "DeepseekV3ForCausalLM"})

MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.update({"qwen3": "Qwen3Model"})
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.update({"qwen3": "Qwen3ForCausalLM"})
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.update({"qwen3": "Qwen3ForSequenceClassification"})
Expand Down
21 changes: 21 additions & 0 deletions mindone/transformers/models/deepseek_v3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# This code is adapted from https://github.com/huggingface/transformers
# with modifications to run transformers on mindspore.
#
# This code is adapted from https://github.com/huggingface/transformers
# with modifications to run transformers on mindspore.
Comment on lines +5 to +7
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This comment block is duplicated. Please remove the redundant part for better readability.

#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .modeling_deepseek_v3 import *
Loading