99from ml_dtypes import bfloat16
1010
1111import mindspore as ms
12- from mindspore import mint , nn , ops
12+ from mindspore import nn , ops
1313
1414logger = logging .getLogger ("ModelingsUnitTest" )
1515
@@ -131,10 +131,6 @@ def get_pt2ms_mappings(m):
131131 mappings [f"{ name } .running_mean" ] = f"{ name } .moving_mean" , lambda x : x
132132 mappings [f"{ name } .running_var" ] = f"{ name } .moving_variance" , lambda x : x
133133 mappings [f"{ name } .num_batches_tracked" ] = None , lambda x : x
134- elif isinstance (cell , (mint .nn .BatchNorm1d , mint .nn .BatchNorm2d , mint .nn .BatchNorm3d )):
135- # TODO: for mint.nn, the dtype for each param should expected to be same among torch and mindspore
136- # this is a temporary fix, delete this branch in future.
137- mappings [f"{ name } .num_batches_tracked" ] = f"{ name } .num_batches_tracked" , lambda x : x .to (ms .float32 )
138134 return mappings
139135
140136
@@ -150,6 +146,11 @@ def convert_state_dict(m, state_dict_pt):
150146 state_dict_ms = {}
151147 for name_pt , data_pt in state_dict_pt .items ():
152148 name_ms , data_mapping = mappings .get (name_pt , (name_pt , lambda x : x ))
149+ # for torch back compatibility
150+ # for torch <2.0, dtype of num_batches_tracked is int32, for torch>=2.0, dtype of num_batches_tracked is int64,
151+ # mindspore.mint is aligned with torch>=2.0
152+ if "num_batches_tracked" in name_pt and data_pt .dtype == torch .int32 :
153+ data_pt = data_pt .to (torch .int64 )
153154 data_ms = ms .Parameter (
154155 data_mapping (ms .Tensor .from_numpy (data_pt .float ().numpy ()).to (dtype_mappings [data_pt .dtype ])), name = name_ms
155156 )
0 commit comments