Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 33 additions & 9 deletions mlc_llm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,17 @@ def validate_config(model_path: str):
), f"Model type {config['model_type']} not supported."


def get_cuda_sm_version():
major, minor = parse_compute_version(tvm.cuda(0).compute_version)

if major == 8:
sm = 80
else:
sm = 10 * major + minor

return sm


def mod_transform_before_build(
mod: tvm.IRModule,
param_manager: param_manager.ParamManager,
Expand Down Expand Up @@ -550,13 +561,7 @@ def mod_transform_before_build(
if len(patterns) > 0:
os.makedirs("./tmp", exist_ok=True)

major, minor = parse_compute_version(tvm.cuda(0).compute_version)

if major == 8:
sm = 80
else:
sm = 10 * major + minor

sm = get_cuda_sm_version()
options = {"cutlass": {"sm": sm, "find_first_valid": False}}

if hasattr(config, "rms_norm_eps"):
Expand Down Expand Up @@ -698,15 +703,19 @@ def build_model_from_args(args: argparse.Namespace):
"WARNING: q4f16_1 is preferred to q4f16_0, "
"and it is highly recommended to use q4f16_1 instead"
)

use_ft_quant = args.quantization.name in ["q4f16_ft", "q8f16_ft"]

if args.num_shards > 1:
if (not args.build_model_only) and (not args.convert_weight_only):
raise ValueError(
"`num_shards` should be used together with "
"`--build-model-only` and `--convert-weight-only`"
)
use_ft_quant = args.quantization.name in ["q4f16_ft", "q8f16_ft"]

if use_ft_quant:
raise ValueError("Multi-GPU deployments are not available for ft quantization.")
args.use_presharded_weights = True

os.makedirs(args.artifact_path, exist_ok=True)
if args.debug_dump:
os.makedirs(os.path.join(args.artifact_path, "debug"), exist_ok=True)
Expand Down Expand Up @@ -794,6 +803,21 @@ def build_model_from_args(args: argparse.Namespace):
mod_transform = seq(mod_transform)

params = utils.convert_weights(mod_transform, param_manager, params, args)

if args.num_shards > 1 and use_ft_quant:
preprocessed = []
weight_preprocess_func = tvm.get_global_func("cutlass.ft_preprocess_weight")
is_int4 = args.quantization.name == "q4f16_ft"
sm = get_cuda_sm_version()

for p in params:
if p.dtype == "int8":
preprocessed.append(weight_preprocess_func(p, sm, is_int4))
else:
preprocessed.append(p)

params = preprocessed

utils.save_params(params, args.artifact_path, args.num_shards if args.use_presharded_weights else 1)

if args.model_category != "minigpt":
Expand Down
23 changes: 15 additions & 8 deletions mlc_llm/quantization/ft_rowwise_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def __init__(self, dtype, nbit):
else:
self.sm = None

self.do_preprocess = True

def get_quantize_func(self, param_info: relax.TensorStructInfo) -> Optional[FQuantize]:
assert self.sm is not None

Expand All @@ -44,15 +46,20 @@ def f_quantize(bb: relax.BlockBuilder, inputs: List[relax.Expr]):
primfunc_name_hint="encode",
)
packed_weight = bb.normalize(encoded_data[0])
encoded_weight = bb.emit(
relax.call_pure_packed(
"cutlass.ft_preprocess_weight",
packed_weight,
self.sm,
self.nbit == 4,
sinfo_args=packed_weight.struct_info,

if self.do_preprocess:
encoded_weight = bb.emit(
relax.call_pure_packed(
"cutlass.ft_preprocess_weight",
packed_weight,
self.sm,
self.nbit == 4,
sinfo_args=packed_weight.struct_info,
)
)
)
else:
encoded_weight = packed_weight

return bb.emit(relax.Tuple([encoded_weight, encoded_data[1]]))

return f_quantize
Expand Down
131 changes: 125 additions & 6 deletions mlc_llm/relax_model/commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,112 @@ def shard_gate_up_weight_scale(weight: relax.TensorStructInfo):
}


def _get_shard_strategies_ft(
model_config, num_shards: int, param_shape_is_already_sharded: bool
) -> Dict[str, tvm.tir.PrimFunc]:
q_heads = model_config.num_attention_heads
kv_heads = model_config.get_num_key_value_heads()

def shard_qkv_weight(weight: relax.TensorStructInfo):
(red, spatial), dtype = weight.shape, weight.dtype
red, spatial = int(red), int(spatial)
if param_shape_is_already_sharded:
spatial *= num_shards
head_dim = spatial // (q_heads + 2 * kv_heads)
a = te.placeholder((red, spatial), dtype=dtype)
w = topi.reshape(a, (red, spatial // head_dim, head_dim))
q = te.compute((red, q_heads, head_dim), lambda i, j, k: w[i, j, k])
k = te.compute((red, kv_heads, head_dim), lambda i, j, k: w[i, q_heads + j, k])
v = te.compute((red, kv_heads, head_dim), lambda i, j, k: w[i, q_heads + kv_heads + j, k])
q = topi.reshape(q, (red, num_shards, q_heads // num_shards, head_dim))
k = topi.reshape(k, (red, num_shards, kv_heads // num_shards, head_dim))
v = topi.reshape(v, (red, num_shards, kv_heads // num_shards, head_dim))
w = topi.concatenate((q, k, v), axis=2)
w = topi.reshape(w, (red, num_shards, (q_heads + kv_heads * 2) // num_shards * head_dim))
w = topi.transpose(w, (1, 0, 2))
func = te.create_prim_func([a, w])
return func

def shard_qkv_scale(scale: relax.TensorStructInfo):
(spatial,), dtype = scale.shape, scale.dtype
spatial = int(spatial)
if param_shape_is_already_sharded:
spatial *= num_shards
head_dim = spatial // (q_heads + 2 * kv_heads)
a = te.placeholder((spatial,), dtype=dtype)
w = topi.reshape(a, (spatial // head_dim, head_dim))
q = te.compute((q_heads, head_dim), lambda i, j: w[i, j])
k = te.compute((kv_heads, head_dim), lambda i, j: w[q_heads + i, j])
v = te.compute((kv_heads, head_dim), lambda i, j: w[q_heads + kv_heads + i, j])
q = topi.reshape(q, (num_shards, q_heads // num_shards, head_dim))
k = topi.reshape(k, (num_shards, kv_heads // num_shards, head_dim))
v = topi.reshape(v, (num_shards, kv_heads // num_shards, head_dim))
w = topi.concatenate((q, k, v), axis=1)
w = topi.reshape(w, (num_shards, (q_heads + kv_heads * 2) // num_shards * head_dim))
func = te.create_prim_func([a, w])
return func

def shard_qkv_weight_scale(x: relax.TensorStructInfo):
if x.ndim == 2:
return shard_qkv_weight(x)
else:
return shard_qkv_scale(x)

def shard_k_weight(weight: relax.TensorStructInfo):
(red, spatial), dtype = weight.shape, weight.dtype
red, spatial = int(red), int(spatial)
if param_shape_is_already_sharded:
red *= num_shards
a = te.placeholder((red, spatial), dtype=dtype)
w = topi.reshape(a, (num_shards, red // num_shards, spatial))
func = te.create_prim_func([a, w])
return func

def shard_gate_up_weight(weight: relax.TensorStructInfo):
(red, spatial), dtype = weight.shape, weight.dtype
red, spatial = int(red), int(spatial)
if param_shape_is_already_sharded:
spatial *= num_shards
a = te.placeholder((red, spatial), dtype=dtype)
g = te.compute((red, spatial // 2), lambda i, j: a[i, j])
u = te.compute((red, spatial // 2), lambda i, j: a[i, spatial // 2 + j])
g = topi.reshape(g, (red, num_shards, spatial // 2 // num_shards))
u = topi.reshape(u, (red, num_shards, spatial // 2 // num_shards))
w = topi.concatenate((g, u), axis=2)
w = topi.reshape(w, (red, num_shards, spatial // num_shards))
w = topi.transpose(w, (1, 0, 2))
func = te.create_prim_func([a, w])
return func

def shard_gate_up_scale(weight: relax.TensorStructInfo):
(spatial,), dtype = weight.shape, weight.dtype
spatial = int(spatial)
if param_shape_is_already_sharded:
spatial *= num_shards
a = te.placeholder((spatial,), dtype=dtype)
g = te.compute((spatial // 2,), lambda i: a[i])
u = te.compute((spatial // 2,), lambda i: a[spatial // 2 + i])
g = topi.reshape(g, (num_shards, spatial // 2 // num_shards))
u = topi.reshape(u, (num_shards, spatial // 2 // num_shards))
w = topi.concatenate((g, u), axis=1)
w = topi.reshape(w, (num_shards, spatial // num_shards))
func = te.create_prim_func([a, w])
return func

def shard_gate_up_weight_scale(x: relax.TensorStructInfo):
if x.ndim == 2:
return shard_gate_up_weight(x)
else:
return shard_gate_up_scale(x)

return {
"shard_qkv": shard_qkv_weight_scale,
"shard_mlp_k": shard_k_weight,
"shard_o_proj_k": shard_k_weight,
"shard_gate_up": shard_gate_up_weight_scale,
}


def create_shard_info_func(param_manager, args, model_config) -> tvm.IRModule:
shard_strategy_to_func = _get_shard_strategies(
model_config,
Expand Down Expand Up @@ -140,11 +246,20 @@ def add_to_shard_info(param_name: str, func_name: Optional[str]):


def create_shard_transformation_func(param_manager, args, model_config) -> tvm.IRModule:
shard_strategy_to_func = _get_shard_strategies(
model_config,
num_shards=args.num_shards,
param_shape_is_already_sharded=args.build_model_only,
)
use_ft_quant = args.quantization.name in ["q4f16_ft", "q8f16_ft"]

if use_ft_quant:
shard_strategy_to_func = _get_shard_strategies_ft(
model_config,
num_shards=args.num_shards,
param_shape_is_already_sharded=args.build_model_only,
)
else:
shard_strategy_to_func = _get_shard_strategies(
model_config,
num_shards=args.num_shards,
param_shape_is_already_sharded=args.build_model_only,
)

q_params = param_manager.get_quantized_param_info("prefill").fields

Expand Down Expand Up @@ -189,7 +304,11 @@ def create_shard_transformation_func(param_manager, args, model_config) -> tvm.I

arg = relax.Var(arg_name, qparam_sinfo)

if param.shard_strategy is None:
if param.shard_strategy is None or (
use_ft_quant
and param.shard_strategy in ["shard_mlp_k", "shard_o_proj_k"]
and len(qparam_sinfo.shape) == 1
):
sharded = arg
else:
strategy_func = shard_strategy_to_func[param.shard_strategy](
Expand Down
7 changes: 7 additions & 0 deletions mlc_llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,15 @@ def argparse_postproc_common(args: argparse.Namespace) -> None:

if args.quantization not in quantization_schemes:
raise ValueError(f'Quantization "{args.quantization}" is not supported.')

use_ft_quant = args.quantization in ["q4f16_ft", "q8f16_ft"]
args.quantization = quantization_schemes[args.quantization]

if use_ft_quant and args.num_shards > 1:
# Preprocess is done after sharding for this case.
args.quantization.linear_weight.do_preprocess = False
args.quantization.final_fc_weight.do_preprocess = False


def debug_dump_script(mod, name, args: argparse.Namespace, show_meta=True):
"""Debug dump mode"""
Expand Down
16 changes: 12 additions & 4 deletions serve/mlc_serve/model/paged_cache_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def sample(logits, sampling_params, vocab_size):


def load_disco_module(artifact_path, lib_path, num_shards):
sess = di.ProcessSession(num_workers=num_shards)
sess = di.ThreadedSession(num_workers=num_shards)
devices = range(num_shards)
sess.init_ccl("nccl", *devices)
module = sess.load_vm_module(lib_path)
Expand All @@ -314,7 +314,7 @@ def load_disco_module(artifact_path, lib_path, num_shards):
ndarray_cache_metadata = f.read()

loader = loader_create(metadata_path, ndarray_cache_metadata, "", module)
loader_load = sess.get_global_func("runtime.disco.ShardLoaderLoadAll")
loader_load = sess.get_global_func("runtime.disco.ShardLoaderLoadAllPresharded")
params = loader_load(loader)

return module, params, sess
Expand All @@ -327,7 +327,13 @@ def copy_to_worker_0(sess: di.Session, host_array):


def get_tvm_model(artifact_path, model, quantization, num_shards, dev):
model_artifact_path = os.path.join(artifact_path, f"{model}-{quantization}")
if num_shards > 1:
model_artifact_path = os.path.join(
artifact_path, f"{model}-{quantization}-presharded-{num_shards}gpu"
)
else:
model_artifact_path = os.path.join(artifact_path, f"{model}-{quantization}")

lib_path = os.path.join(model_artifact_path, f"{model}-{quantization}-cuda.so")

if num_shards == 1:
Expand Down Expand Up @@ -470,7 +476,9 @@ def get_used_memory(self):
)
peak_memory = get_used_memory_func(self.dev)

param_bytes = sum(math.prod(param.shape) * np.dtype(param.dtype).itemsize for param in params)
param_bytes = sum(
math.prod(param.shape) * np.dtype(param.dtype).itemsize for param in params
)

return peak_memory + param_bytes

Expand Down