diff --git a/mlc_llm/utils.py b/mlc_llm/utils.py index 9d8751e5d6..616a74a9c4 100644 --- a/mlc_llm/utils.py +++ b/mlc_llm/utils.py @@ -1,10 +1,13 @@ # pylint: disable=missing-docstring,invalid-name import argparse import json +import math import os import shutil from typing import Any, Dict, List, Optional, Set +import numpy as np + import tvm from tvm import relax @@ -269,11 +272,12 @@ def save_params(params: List[tvm.nd.NDArray], artifact_path: str) -> None: meta_data["ParamSize"] = len(params) total_size = 0.0 for i, nd in enumerate(params): + assert nd is not None, f"Missing parameter at index {i}" param_dict[f"param_{i}"] = nd - np_nd = nd.numpy() - total_size += np_nd.size * np_nd.dtype.itemsize - total_size = total_size / 1024.0 / 1024.0 / 1024.0 - print(f"Total param size: {total_size} GB") + + total_size_bytes = sum(math.prod(param.shape) * np.dtype(param.dtype).itemsize for param in params) + total_size_gb = total_size_bytes / (1024 ** 3) + print(f"Total param size: {total_size_gb} GB") tvmjs.dump_ndarray_cache( param_dict, f"{artifact_path}/params", meta_data=meta_data, encode_format="raw" )