Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions scripts/checkpoint_conversion/convert_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,18 @@
import torchtitan.protocols.train_spec as train_spec_module
from torch.distributed.checkpoint import HuggingFaceStorageWriter
from torchtitan.components.checkpoint import ModelWrapper
from torchtitan.config import TORCH_DTYPE_MAP


@torch.inference_mode()
def convert_to_hf(input_dir, output_dir, model_name, model_flavor, hf_assets_path):
if model_name == "flux":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we are removing this line here? Could you provide some tests results (can be simple as a screenshot or terminal outputs) to show it still works for FLUX model?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, I removed it because the import didn't make sense anymore as the flux folder was moved from experiments to the main models folders. So now it does not require "special" treatment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding when to use low precision, it is more space efficient, especially when uploading checkpoints to cloud storage. I kept the default export fp32 but IMO making bf16 the default export type makes more sense as most models that are uploaded today are in bf16.

import torchtitan.experiments.flux # noqa: F401
def convert_to_hf(
input_dir,
output_dir,
model_name,
model_flavor,
hf_assets_path,
export_dtype,
):
# load model and model args so that we can get the state dict shape
train_spec = train_spec_module.get_train_spec(model_name)
model_args = train_spec.model_args[model_flavor]
Expand Down Expand Up @@ -49,6 +55,11 @@ def convert_to_hf(input_dir, output_dir, model_name, model_flavor, hf_assets_pat
thread_count_consolidation=5,
)

# map and apply export dtype if needed
target_dtype = TORCH_DTYPE_MAP[export_dtype]
if target_dtype != torch.float32:
hf_state_dict = {k: v.to(target_dtype) for k, v in hf_state_dict.items()}

dcp.save(
hf_state_dict,
storage_writer=storage_writer,
Expand All @@ -71,6 +82,14 @@ def convert_to_hf(input_dir, output_dir, model_name, model_flavor, hf_assets_pat
)
parser.add_argument("--model_name", type=str, nargs="?", default="llama3")
parser.add_argument("--model_flavor", type=str, nargs="?", default="8B")
parser.add_argument(
"--export_dtype",
type=str,
nargs="?",
choices=["float16", "bfloat16", "float32"],
default="float32",
help="Export dtype for HF checkpoint (default: float32)",
)
args = parser.parse_args()

convert_to_hf(
Expand All @@ -79,4 +98,5 @@ def convert_to_hf(input_dir, output_dir, model_name, model_flavor, hf_assets_pat
args.model_name,
args.model_flavor,
args.hf_assets_path,
args.export_dtype,
)
Loading