1111import torch .distributed .checkpoint as dcp
1212import torchtitan .protocols .train_spec as train_spec_module
1313from torch .distributed .checkpoint import HuggingFaceStorageWriter
14+ from torchtitan .config import TORCH_DTYPE_MAP
1415from torchtitan .components .checkpoint import ModelWrapper
1516
1617
1718@torch .inference_mode ()
18- def convert_to_hf (input_dir , output_dir , model_name , model_flavor , hf_assets_path ):
19- if model_name == "flux" :
20- import torchtitan .experiments .flux # noqa: F401
19+ def convert_to_hf (
20+ input_dir ,
21+ output_dir ,
22+ model_name ,
23+ model_flavor ,
24+ hf_assets_path ,
25+ export_dtype ,
26+ ):
2127 # load model and model args so that we can get the state dict shape
2228 train_spec = train_spec_module .get_train_spec (model_name )
2329 model_args = train_spec .model_args [model_flavor ]
@@ -49,6 +55,11 @@ def convert_to_hf(input_dir, output_dir, model_name, model_flavor, hf_assets_pat
4955 thread_count_consolidation = 5 ,
5056 )
5157
58+ # map and apply export dtype if needed
59+ target_dtype = TORCH_DTYPE_MAP [export_dtype ]
60+ if target_dtype != torch .float32 :
61+ hf_state_dict = {k : v .to (target_dtype ) for k , v in hf_state_dict .items ()}
62+
5263 dcp .save (
5364 hf_state_dict ,
5465 storage_writer = storage_writer ,
@@ -71,6 +82,14 @@ def convert_to_hf(input_dir, output_dir, model_name, model_flavor, hf_assets_pat
7182 )
7283 parser .add_argument ("--model_name" , type = str , nargs = "?" , default = "llama3" )
7384 parser .add_argument ("--model_flavor" , type = str , nargs = "?" , default = "8B" )
85+ parser .add_argument (
86+ "--export_dtype" ,
87+ type = str ,
88+ nargs = "?" ,
89+ choices = ["float16" , "bfloat16" , "float32" ],
90+ default = "float32" ,
91+ help = "Export dtype for HF checkpoint (default: float32)" ,
92+ )
7493 args = parser .parse_args ()
7594
7695 convert_to_hf (
@@ -79,4 +98,5 @@ def convert_to_hf(input_dir, output_dir, model_name, model_flavor, hf_assets_pat
7998 args .model_name ,
8099 args .model_flavor ,
81100 args .hf_assets_path ,
101+ args .export_dtype ,
82102 )
0 commit comments