Skip to content
56 changes: 49 additions & 7 deletions src/transformers/conversion_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,56 @@ def _build_checkpoint_conversion_mapping():
operations=[MergeModulelist(dim=0), Concatenate(dim=1)],
),
WeightConverter(
source_patterns=["mlp.experts.*.down_proj.weight"],
source_patterns="mlp.experts.*.down_proj.weight",
target_patterns="mlp.experts.down_proj",
operations=[MergeModulelist(dim=0)],
),
],
"phimoe": [
WeightConverter(
source_patterns=[
"mlp.experts.*.w1.weight",
"mlp.experts.*.w3.weight",
],
target_patterns="mlp.experts.gate_up_proj",
operations=[MergeModulelist(dim=0), Concatenate(dim=1)],
),
WeightConverter(
source_patterns="mlp.experts.*.w2.weight",
target_patterns="mlp.experts.down_proj",
operations=[MergeModulelist(dim=0)],
),
],
"lfm2_moe": [
WeightConverter(
source_patterns=[
"feed_forward.experts.*.w1.weight",
"feed_forward.experts.*.w3.weight",
],
target_patterns="feed_forward.experts.gate_up_proj",
operations=[MergeModulelist(dim=0), Concatenate(dim=1)],
),
WeightConverter(
source_patterns="feed_forward.experts.*.w2.weight",
target_patterns="feed_forward.experts.down_proj",
operations=[MergeModulelist(dim=0)],
),
],
"jamba": [
WeightConverter(
source_patterns=[
"feed_forward.experts.*.gate_proj.weight",
"feed_forward.experts.*.up_proj.weight",
],
target_patterns="feed_forward.experts.gate_up_proj",
operations=[MergeModulelist(dim=0), Concatenate(dim=1)],
),
WeightConverter(
source_patterns="feed_forward.experts.*.down_proj.weight",
target_patterns="feed_forward.experts.down_proj",
operations=[MergeModulelist(dim=0)],
),
],
"timm_wrapper": [
# Simply add the prefix `timm_model`
# TODO: Would be probably much cleaner with a `add_prefix` argument in WeightRenaming
Expand Down Expand Up @@ -117,16 +162,13 @@ def _build_checkpoint_conversion_mapping():
),
]

mapping["phimoe"] = mapping["mixtral"].copy()
mapping["deepseek_v2"] = mapping["qwen2_moe"].copy()
mapping["deepseek_v3"] = mapping["qwen2_moe"].copy()
mapping["dot1"] = mapping["qwen2_moe"].copy()
mapping["ernie_4_5_moe"] = mapping["qwen2_moe"].copy()
mapping["dots1"] = mapping["qwen2_moe"].copy()
mapping["ernie4_5_moe"] = mapping["qwen2_moe"].copy()
mapping["glm4_moe"] = mapping["qwen2_moe"].copy()
mapping["glm4v_moe"] = mapping["qwen2_moe"].copy()
mapping["jamba"] = mapping["qwen2_moe"].copy()
mapping["lfm2_moe"] = mapping["mixtral"].copy()
mapping["long_cat_flash"] = mapping["qwen2_moe"].copy()
mapping["longcat_flash"] = mapping["qwen2_moe"].copy()
mapping["qwen3_moe"] = mapping["qwen2_moe"].copy()
mapping["qwen3_omni_moe"] = mapping["qwen2_moe"].copy()
mapping["qwen3_next"] = mapping["qwen2_moe"].copy()
Expand Down
168 changes: 69 additions & 99 deletions src/transformers/core_model_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,27 +48,6 @@
logger = logging.get_logger(__name__)


def compile_glob_rule(source_glob: str, target_glob: str) -> tuple[re.Pattern, str]:
"""
Convert a glob-style source + target into a full regex + replacement.

Rules:
- '*' in source_glob → (.*) capture group
- '*' in target_glob → \\1, \\2, ... backrefs
"""
regex = re.compile(source_glob)

counter = 0

def _star_to_backref(_: re.Match) -> str:
nonlocal counter
counter += 1
return rf"\{counter}"

replacement = re.sub(r"\*", _star_to_backref, target_glob)
return regex, replacement


def build_glob_alternation(
globs: list[Union[WeightRenaming, WeightConverter, str]],
) -> tuple[re.Pattern, dict[str, str], dict[str, str]]:
Expand Down Expand Up @@ -300,6 +279,7 @@ def convert(
class WeightTransform:
source_patterns: Union[str, list[str]] = field(init=True)
target_patterns: Union[str, list[str]] = field(init=True)
compiled_sources: re.Pattern = field(init=False)

distributed_operation: Optional[TensorParallelLayer] = None
quantization_operation: Optional[ConversionOps] = None
Expand All @@ -319,20 +299,27 @@ def __post_init__(self):
for i, pattern in enumerate(self.target_patterns):
# Some mapping contains `^` to notify start of string when matching -> remove it during reverse mapping
pattern = pattern.removeprefix("^")
# This is ugly but needed for reverse mapping of Qwen2.5!
if r"(?!\.(language_model|visual))" in pattern:
pattern = pattern.replace(r"(?!\.(language_model|visual))", "")
# Allow capturing groups in patterns, i.e. to add a prefix to all keys (e.g. timm_wrapper)
# Remove negative lookahead if any. This is ugly but needed for reverse mapping of Qwen2.5 and Sam3!
pattern = re.sub(r"\(\?!.+\)", "", pattern)
# Allow capturing groups in patterns, i.e. to add/remove a prefix to all keys (e.g. timm_wrapper, sam3)
if r"(.+)" in pattern:
pattern = pattern.replace(r"(.+)", "")
pattern = pattern.replace(r"(.+)", r"\1")
self.target_patterns[i] = pattern

# We also need to check capturing groups in the sources during reverse mapping (e.g. timm_wrapper)
# We also need to check capturing groups in the sources during reverse mapping (e.g. timm_wrapper, sam3)
for i, pattern in enumerate(self.source_patterns):
if r"\1" in pattern:
pattern = pattern.replace(r"\1", "")
pattern = pattern.replace(r"\1", r"(.+)")
self.source_patterns[i] = pattern

# Construct the regex we will use to rename keys from the sources to the targets
branches = []
for i, source_pattern in enumerate(self.source_patterns):
group_name = f"g{i}"
pattern = source_pattern.replace(".*.", r"\..*\.")
branches.append(f"(?P<{group_name}>{pattern})")
self.compiled_sources = re.compile("|".join(branches))

def add_tensor(self, target_key: str, source_key: str, source_pattern: str, future: Future):
self.collected_tensors[source_pattern].append(future)
self.layer_targets[target_key].add(source_key)
Expand All @@ -341,6 +328,32 @@ def reset(self) -> None:
"""Clean-up the collected tensors to make sure we don't keep references to past tensors in memory."""
self.collected_tensors = defaultdict(list)

def rename_source_key(self, source_key: str) -> tuple[str, str | None]:
"""
Return a tuple (renamed_key, source_pattern_producing_the_match).
Try renaming `source_key` according to the source and target patterns of the current WeightTransform.
In case of a one-to-many transform, i.e. we have several target patterns, the matching source pattern
will be replaced by the first of all the target patterns (they are then correctly expanded in the Operations).
"""
# Try matching one of the alternation branches
match_object = self.compiled_sources.search(source_key)
if match_object is None:
return source_key, None
# Find the source that produced the match (it's the first group that matched, as the search stops after first branch match)
matching_group_name = next(name for name, val in match_object.groupdict().items() if val is not None)
source_pattern_that_matched = self.source_patterns[int(matching_group_name[1:])]
# If we matched, we always replace with the first target pattern, in case we have several (one to many transform)
replacement = self.target_patterns[0]
# # Allow capturing groups in patterns, i.e. to add a prefix to all keys (e.g. timm_wrapper, sam3)
if r"\1" in replacement:
# The index of the internal group we need to replace is the index of the matched named group as it comes
# inside that matched named group
replaced_group_idx = self.compiled_sources.groupindex[matching_group_name] + 1
replacement = replacement.replace(r"\1", match_object.group(replaced_group_idx))
renamed_key = source_key.replace(match_object.group(0), replacement)

return renamed_key, source_pattern_that_matched

def reverse_transform(self) -> WeightTransform:
"""Reverse the current `WeightTransform` instance, to be able to save with the opposite weight transformations."""
# TODO: check this and relax when quantizer have `reverse_op`
Expand Down Expand Up @@ -610,54 +623,30 @@ class SkipLayer(Exception):
pass


def repl(m, repl_map: dict[str, str]) -> str:
# Collect all groups that matched
matched_groups = [name for name, val in m.groupdict().items() if val]

if len(matched_groups) == 0:
# Should never happen
return m.group(0)

if len(matched_groups) > 1:
raise ValueError(
"only a single match should happen, your regex patterns are tangled: "
f"groups matched = {matched_groups} for the patternsL {repl_map.keys()}"
)

# Exactly one match => return replacement
name = matched_groups[0]
replacement = repl_map[name]
# Allow capturing groups in patterns, i.e. to add a prefix to all keys (e.g. timm_wrapper)
if r"\1" in replacement and len(m.groups()) > 1:
replacement = replacement.replace(r"\1", m.group(1))

return replacement


def rename_source_key(
source_key: str,
rename_alternation: re.Pattern,
rename_by_group: dict,
weight_pattern_alternation: re.Pattern | None,
weight_pattern_by_group: dict | None,
weight_renamings: list[WeightRenaming],
weight_converters: list[WeightConverter],
prefix: str | None = None,
meta_state_dict: dict | None = None,
) -> tuple[str, re.Match | None]:
) -> tuple[str, str | None]:
"""
Rename a source key given all the renaming and weight conversion patterns we have. Also takes care of adding/removing
the base model prefix during loading if necesary.
"""
# 1. apply all renamings
renamed_key = rename_alternation.sub(lambda m: repl(m, rename_by_group), source_key).replace("\\", "")

# 2. apply renaming through weight conversions on the key if we have any WeightConverter
matched_converter_pattern = (
weight_pattern_alternation.search(renamed_key) if weight_pattern_alternation is not None else None
)
if matched_converter_pattern is not None:
renamed_key = weight_pattern_alternation.sub(lambda m: repl(m, weight_pattern_by_group), renamed_key).replace(
"\\", ""
)
renamed_key = source_key
# 1. apply all renamings in turns (if multiple match, it's the responsibility of the mappings to make sure they
# are coherent)
for renaming in weight_renamings:
renamed_key, _ = renaming.rename_source_key(renamed_key)

# 2. apply renaming through weight conversions on the key if we have any WeightConverter (here we stop after
# the first match, as we assume only 1 converter can match any source key)
source_pattern = None
for converter in weight_converters:
renamed_key, source_pattern = converter.rename_source_key(renamed_key)
if source_pattern is not None:
break

# 3. check if we need to add or remove prefix if necesary (only during loading, not saving)
if prefix is not None and meta_state_dict is not None:
Expand All @@ -669,7 +658,7 @@ def rename_source_key(
elif meta_state_dict.get(f"{prefix}.{renamed_key}") is not None:
renamed_key = f"{prefix}.{renamed_key}"

return renamed_key, matched_converter_pattern
return renamed_key, source_pattern


def convert_and_load_state_dict_in_model(
Expand Down Expand Up @@ -796,10 +785,6 @@ def convert_and_load_state_dict_in_model(

# build '(?P<g0>.*.*\\.block_sparse_moe\\..*)' and group to source {'g0': '*.block_sparse_moe.'}
# and target to source {'g0': '*.mlp.'}. This allows us to quickly find which pattern matched.
rename_alt, _, rename_by_group = build_glob_alternation(renamings)
weight_pattern_alt, src_group_to_glob, tgt_group_to_glob = None, None, None
if converters != []:
weight_pattern_alt, src_group_to_glob, tgt_group_to_glob = build_glob_alternation(converters)
if tp_plan != {}:
tp_plan_alt, tp_plan_by_group_name, _ = build_glob_alternation(list(tp_plan.keys()))
if dtype_plan != {}:
Expand All @@ -810,24 +795,19 @@ def convert_and_load_state_dict_in_model(
state_dict = sorted(state_dict.items(), key=lambda kv: dot_natural_key(kv[0]))
for original_key, tensor in state_dict:
# 1. Rename the key according to all renaming pattern and optional weight converter patterns
renamed_key, matched_pattern = rename_source_key(
original_key,
rename_alt,
rename_by_group,
weight_pattern_alt,
tgt_group_to_glob,
prefix,
meta_model_state_dict,
renamed_key, source_pattern = rename_source_key(
original_key, renamings, converters, prefix, meta_model_state_dict
)

# 2. finally, collect the tensor into the proper converter
if renamed_key in missing_keys:
empty_param = meta_model_state_dict.get(renamed_key)
if matched_pattern:
new_converter = deepcopy(pattern_to_converter[src_group_to_glob[matched_pattern.lastgroup]])
# If we enter here, we have a WeightConverter operation to perform
if source_pattern is not None:
new_converter = deepcopy(pattern_to_converter[source_pattern])
# each target key gets its own converter instance
mapping = param_name_to_load.setdefault(renamed_key, new_converter)
source_pattern = src_group_to_glob[matched_pattern.lastgroup]
# Otherwise, only potential renaming
else:
mapping = param_name_to_load.setdefault(renamed_key, WeightRenaming(original_key, renamed_key))
source_pattern = original_key
Expand Down Expand Up @@ -879,8 +859,8 @@ def convert_and_load_state_dict_in_model(
future = spawn_materialize(thread_pool, tensor, param_device, _dtype)

mapping.add_tensor(renamed_key, original_key, source_pattern, future)
elif matched_pattern: # add all target keys as unexpected
mapping = pattern_to_converter[src_group_to_glob[matched_pattern.lastgroup]]
elif source_pattern is not None: # add all target keys as unexpected
mapping = pattern_to_converter[source_pattern]
for k in mapping.target_patterns:
unexpected_keys.add(renamed_key.replace(mapping.target_patterns[0], k))
else:
Expand Down Expand Up @@ -961,24 +941,14 @@ def revert_weight_conversion(model: PreTrainedModel, state_dict: dict[str, torch
pattern_to_converter = {k: converter for converter in converters for k in converter.source_patterns}
conversion_mapping = {}

# build '(?P<g0>.*.*\\.block_sparse_moe\\..*)' and group to source {'g0': '*.block_sparse_moe.'}
# and target to source {'g0': '*.mlp.'}. This allows us to quickly find which pattern matched.
rename_alt, _, rename_by_group = build_glob_alternation(renamings)
weight_pattern_alt, src_group_to_glob, tgt_group_to_glob = None, None, None
if converters != []:
weight_pattern_alt, src_group_to_glob, tgt_group_to_glob = build_glob_alternation(converters)

state_dict = sorted(state_dict.items(), key=lambda kv: dot_natural_key(kv[0]))
for original_key, tensor in state_dict:
# Rename the key according to all renaming pattern and optional weight converter patterns
renamed_key, matched_pattern = rename_source_key(
original_key, rename_alt, rename_by_group, weight_pattern_alt, tgt_group_to_glob
)
if matched_pattern is not None:
new_converter = deepcopy(pattern_to_converter[src_group_to_glob[matched_pattern.lastgroup]])
renamed_key, source_pattern = rename_source_key(original_key, renamings, converters)
if source_pattern is not None:
new_converter = deepcopy(pattern_to_converter[source_pattern])
# each target key gets its own converter instance
mapping = conversion_mapping.setdefault(renamed_key, new_converter)
source_pattern = src_group_to_glob[matched_pattern.lastgroup]
else:
mapping = conversion_mapping.setdefault(renamed_key, WeightRenaming(original_key, renamed_key))
source_pattern = original_key
Expand Down
12 changes: 3 additions & 9 deletions src/transformers/integrations/accelerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,18 +480,15 @@ def accelerate_disk_offload(
renamed) will be mapped to where they already reside on disk. Otherwise, the parameters will be resaved inside
`disk_offload_folder` during loading.
"""
from ..core_model_loading import WeightRenaming, build_glob_alternation, repl
from ..core_model_loading import WeightRenaming, rename_source_key

if disk_offload_folder is not None:
os.makedirs(disk_offload_folder, exist_ok=True)
is_offloaded_safetensors = checkpoint_files is not None and checkpoint_files[0].endswith(".safetensors")

rename = False
renamings = []
if weight_mapping is not None:
renamings = [entry for entry in weight_mapping if isinstance(entry, WeightRenaming)]
if len(renamings) > 0:
rename = True
rename_alt, _, rename_by_group = build_glob_alternation(renamings)

# In this case, the offload index is simply the existing safetensors (except if using custom weight loading
# Operation, e.g. the MoE models, where we need to resave the weights that were changed at loading time)
Expand All @@ -505,10 +502,7 @@ def accelerate_disk_offload(
weight_map = {k: os.path.join(folder, v) for k, v in sharded_metadata["weight_map"].items()}

# Update the weight names according to the `weight_mapping`
weight_renaming_map = {
rename_alt.sub(lambda m: repl(m, rename_by_group), k).replace("\\", "") if rename else k: k
for k in weight_map
}
weight_renaming_map = {rename_source_key(k, renamings, [])[0]: k for k in weight_map}

# Prepare the index using existing safetensors files
disk_offload_index = {
Expand Down
Loading