Skip to content

Commit 0a9d6c7

Browse files
authored
[Utils] Remove conversion to numpy array in utils.save_params (mlc-ai#1083)
Prior to this commit, each parameter was converted to a numpy-owned array as part of a total size computation. This commit computes the size directly, removing the conversion.
1 parent 3cf5605 commit 0a9d6c7

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

mlc_llm/utils.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@
22
import argparse
33
import functools
44
import json
5+
import math
56
import os
67
import shutil
78
from typing import Any, Dict, List, Optional, Set
89

10+
import numpy as np
11+
912
import tvm
1013
from tvm import relax
1114

@@ -283,11 +286,12 @@ def save_params(params: List[tvm.nd.NDArray], artifact_path: str) -> None:
283286
meta_data["ParamSize"] = len(params)
284287
total_size = 0.0
285288
for i, nd in enumerate(params):
289+
assert nd is not None, f"Missing parameter at index {i}"
286290
param_dict[f"param_{i}"] = nd
287-
np_nd = nd.numpy()
288-
total_size += np_nd.size * np_nd.dtype.itemsize
289-
total_size = total_size / 1024.0 / 1024.0 / 1024.0
290-
print(f"Total param size: {total_size} GB")
291+
292+
total_size_bytes = sum(math.prod(param.shape) * np.dtype(param.dtype).itemsize for param in params)
293+
total_size_gb = total_size_bytes / (1024 ** 3)
294+
print(f"Total param size: {total_size_gb} GB")
291295
tvmjs.dump_ndarray_cache(
292296
param_dict, f"{artifact_path}/params", meta_data=meta_data, encode_format="raw"
293297
)

0 commit comments

Comments
 (0)