From 09533b0362627d91d3b047f44b9e2f5e80e35c07 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 7 Nov 2023 00:49:36 +0000 Subject: [PATCH 1/7] wip --- mlc_llm/core.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mlc_llm/core.py b/mlc_llm/core.py index 5920a37eb1..f5e9e32bb3 100644 --- a/mlc_llm/core.py +++ b/mlc_llm/core.py @@ -705,8 +705,9 @@ def build_model_from_args(args: argparse.Namespace): "`--build-model-only` and `--convert-weight-only`" ) use_ft_quant = args.quantization.name in ["q4f16_ft", "q8f16_ft"] - if use_ft_quant: - raise ValueError("Multi-GPU deployments are not available for ft quantization.") + if use_ft_quant and not args.use_presharded_weights: + raise ValueError("Multi-GPU deployments with FT quantization requires --use-presharded-weights.") + os.makedirs(args.artifact_path, exist_ok=True) if args.debug_dump: os.makedirs(os.path.join(args.artifact_path, "debug"), exist_ok=True) From 6285f1a9be384e68a3f97b6d480b31683937b20c Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 7 Nov 2023 05:40:00 +0000 Subject: [PATCH 2/7] works with ThreadedSession --- serve/mlc_serve/model/paged_cache_model.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/serve/mlc_serve/model/paged_cache_model.py b/serve/mlc_serve/model/paged_cache_model.py index 4807429bf0..c5b5b81006 100644 --- a/serve/mlc_serve/model/paged_cache_model.py +++ b/serve/mlc_serve/model/paged_cache_model.py @@ -303,7 +303,7 @@ def sample(logits, sampling_params, vocab_size): def load_disco_module(artifact_path, lib_path, num_shards): - sess = di.ProcessSession(num_workers=num_shards) + sess = di.ThreadedSession(num_workers=num_shards) devices = range(num_shards) sess.init_ccl("nccl", *devices) module = sess.load_vm_module(lib_path) @@ -314,7 +314,7 @@ def load_disco_module(artifact_path, lib_path, num_shards): ndarray_cache_metadata = f.read() loader = loader_create(metadata_path, ndarray_cache_metadata, "", module) - loader_load = sess.get_global_func("runtime.disco.ShardLoaderLoadAll") + loader_load = sess.get_global_func("runtime.disco.ShardLoaderLoadAllPresharded") params = loader_load(loader) return module, params, sess @@ -327,7 +327,13 @@ def copy_to_worker_0(sess: di.Session, host_array): def get_tvm_model(artifact_path, model, quantization, num_shards, dev): - model_artifact_path = os.path.join(artifact_path, f"{model}-{quantization}") + if num_shards > 1: + model_artifact_path = os.path.join( + artifact_path, f"{model}-{quantization}-presharded-{num_shards}gpu" + ) + else: + model_artifact_path = os.path.join(artifact_path, f"{model}-{quantization}") + lib_path = os.path.join(model_artifact_path, f"{model}-{quantization}-cuda.so") if num_shards == 1: @@ -470,7 +476,9 @@ def get_used_memory(self): ) peak_memory = get_used_memory_func(self.dev) - param_bytes = sum(math.prod(param.shape) * np.dtype(param.dtype).itemsize for param in params) + param_bytes = sum( + math.prod(param.shape) * np.dtype(param.dtype).itemsize for param in params + ) return peak_memory + param_bytes From 89dfe4db7c8b597e3e6d74bfad5f89117a834433 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 7 Nov 2023 09:28:45 +0000 Subject: [PATCH 3/7] wip --- mlc_llm/core.py | 9 +- .../quantization/ft_rowwise_quantization.py | 19 +-- mlc_llm/relax_model/commons.py | 131 +++++++++++++++++- 3 files changed, 143 insertions(+), 16 deletions(-) diff --git a/mlc_llm/core.py b/mlc_llm/core.py index f5e9e32bb3..56d6b03d86 100644 --- a/mlc_llm/core.py +++ b/mlc_llm/core.py @@ -698,13 +698,15 @@ def build_model_from_args(args: argparse.Namespace): "WARNING: q4f16_1 is preferred to q4f16_0, " "and it is highly recommended to use q4f16_1 instead" ) + + use_ft_quant = args.quantization.name in ["q4f16_ft", "q8f16_ft"] + if args.num_shards > 1: if (not args.build_model_only) and (not args.convert_weight_only): raise ValueError( "`num_shards` should be used together with " "`--build-model-only` and `--convert-weight-only`" ) - use_ft_quant = args.quantization.name in ["q4f16_ft", "q8f16_ft"] if use_ft_quant and not args.use_presharded_weights: raise ValueError("Multi-GPU deployments with FT quantization requires --use-presharded-weights.") @@ -795,6 +797,11 @@ def build_model_from_args(args: argparse.Namespace): mod_transform = seq(mod_transform) params = utils.convert_weights(mod_transform, param_manager, params, args) + + if args.num_shards > 1 and use_ft_quant: + for p in params: + print(p.shape, p.dtype) + utils.save_params(params, args.artifact_path, args.num_shards if args.use_presharded_weights else 1) if args.model_category != "minigpt": diff --git a/mlc_llm/quantization/ft_rowwise_quantization.py b/mlc_llm/quantization/ft_rowwise_quantization.py index a34b52bb82..d9d3d6b7d8 100644 --- a/mlc_llm/quantization/ft_rowwise_quantization.py +++ b/mlc_llm/quantization/ft_rowwise_quantization.py @@ -44,15 +44,16 @@ def f_quantize(bb: relax.BlockBuilder, inputs: List[relax.Expr]): primfunc_name_hint="encode", ) packed_weight = bb.normalize(encoded_data[0]) - encoded_weight = bb.emit( - relax.call_pure_packed( - "cutlass.ft_preprocess_weight", - packed_weight, - self.sm, - self.nbit == 4, - sinfo_args=packed_weight.struct_info, - ) - ) + # encoded_weight = bb.emit( + # relax.call_pure_packed( + # "cutlass.ft_preprocess_weight", + # packed_weight, + # self.sm, + # self.nbit == 4, + # sinfo_args=packed_weight.struct_info, + # ) + # ) + encoded_weight = packed_weight return bb.emit(relax.Tuple([encoded_weight, encoded_data[1]])) return f_quantize diff --git a/mlc_llm/relax_model/commons.py b/mlc_llm/relax_model/commons.py index b00dcb0446..0f2f41d93b 100644 --- a/mlc_llm/relax_model/commons.py +++ b/mlc_llm/relax_model/commons.py @@ -88,6 +88,112 @@ def shard_gate_up_weight_scale(weight: relax.TensorStructInfo): } +def _get_shard_strategies_ft( + model_config, num_shards: int, param_shape_is_already_sharded: bool +) -> Dict[str, tvm.tir.PrimFunc]: + q_heads = model_config.num_attention_heads + kv_heads = model_config.get_num_key_value_heads() + + def shard_qkv_weight(weight: relax.TensorStructInfo): + (red, spatial), dtype = weight.shape, weight.dtype + red, spatial = int(red), int(spatial) + if param_shape_is_already_sharded: + spatial *= num_shards + head_dim = spatial // (q_heads + 2 * kv_heads) + a = te.placeholder((red, spatial), dtype=dtype) + w = topi.reshape(a, (red, spatial // head_dim, head_dim)) + q = te.compute((red, q_heads, head_dim), lambda i, j, k: w[i, j, k]) + k = te.compute((red, kv_heads, head_dim), lambda i, j, k: w[i, q_heads + j, k]) + v = te.compute((red, kv_heads, head_dim), lambda i, j, k: w[i, q_heads + kv_heads + j, k]) + q = topi.reshape(q, (red, num_shards, q_heads // num_shards, head_dim)) + k = topi.reshape(k, (red, num_shards, kv_heads // num_shards, head_dim)) + v = topi.reshape(v, (red, num_shards, kv_heads // num_shards, head_dim)) + w = topi.concatenate((q, k, v), axis=2) + w = topi.reshape(w, (red, num_shards, (q_heads + kv_heads * 2) // num_shards * head_dim)) + w = topi.transpose(w, (1, 0, 2)) + func = te.create_prim_func([a, w]) + return func + + def shard_qkv_scale(scale: relax.TensorStructInfo): + (spatial,), dtype = scale.shape, scale.dtype + spatial = int(spatial) + if param_shape_is_already_sharded: + spatial *= num_shards + head_dim = spatial // (q_heads + 2 * kv_heads) + a = te.placeholder((spatial,), dtype=dtype) + w = topi.reshape(a, (spatial // head_dim, head_dim)) + q = te.compute((q_heads, head_dim), lambda i, j: w[i, j]) + k = te.compute((kv_heads, head_dim), lambda i, j: w[q_heads + i, j]) + v = te.compute((kv_heads, head_dim), lambda i, j: w[q_heads + kv_heads + i, j]) + q = topi.reshape(q, (num_shards, q_heads // num_shards, head_dim)) + k = topi.reshape(k, (num_shards, kv_heads // num_shards, head_dim)) + v = topi.reshape(v, (num_shards, kv_heads // num_shards, head_dim)) + w = topi.concatenate((q, k, v), axis=1) + w = topi.reshape(w, (num_shards, (q_heads + kv_heads * 2) // num_shards * head_dim)) + func = te.create_prim_func([a, w]) + return func + + def shard_qkv_weight_scale(x: relax.TensorStructInfo): + if x.ndim == 2: + return shard_qkv_weight(x) + else: + return shard_qkv_scale(x) + + def shard_k_weight(weight: relax.TensorStructInfo): + (red, spatial), dtype = weight.shape, weight.dtype + red, spatial = int(red), int(spatial) + if param_shape_is_already_sharded: + red *= num_shards + a = te.placeholder((red, spatial), dtype=dtype) + w = topi.reshape(a, (num_shards, red // num_shards, spatial)) + func = te.create_prim_func([a, w]) + return func + + def shard_gate_up_weight(weight: relax.TensorStructInfo): + (red, spatial), dtype = weight.shape, weight.dtype + red, spatial = int(red), int(spatial) + if param_shape_is_already_sharded: + spatial *= num_shards + a = te.placeholder((red, spatial), dtype=dtype) + g = te.compute((red, spatial // 2), lambda i, j: a[i, j]) + u = te.compute((red, spatial // 2), lambda i, j: a[i, spatial // 2 + j]) + g = topi.reshape(g, (red, num_shards, spatial // 2 // num_shards)) + u = topi.reshape(u, (red, num_shards, spatial // 2 // num_shards)) + w = topi.concatenate((g, u), axis=2) + w = topi.reshape(w, (red, num_shards, spatial // num_shards)) + w = topi.transpose(w, (1, 0, 2)) + func = te.create_prim_func([a, w]) + return func + + def shard_gate_up_scale(weight: relax.TensorStructInfo): + (spatial,), dtype = weight.shape, weight.dtype + spatial = int(spatial) + if param_shape_is_already_sharded: + spatial *= num_shards + a = te.placeholder((spatial,), dtype=dtype) + g = te.compute((spatial // 2,), lambda i: a[i]) + u = te.compute((spatial // 2,), lambda i: a[spatial // 2 + i]) + g = topi.reshape(g, (num_shards, spatial // 2 // num_shards)) + u = topi.reshape(u, (num_shards, spatial // 2 // num_shards)) + w = topi.concatenate((g, u), axis=1) + w = topi.reshape(w, (num_shards, spatial // num_shards)) + func = te.create_prim_func([a, w]) + return func + + def shard_gate_up_weight_scale(x: relax.TensorStructInfo): + if x.ndim == 2: + return shard_gate_up_weight(x) + else: + return shard_gate_up_scale(x) + + return { + "shard_qkv": shard_qkv_weight_scale, + "shard_mlp_k": shard_k_weight, + "shard_o_proj_k": shard_k_weight, + "shard_gate_up": shard_gate_up_weight_scale, + } + + def create_shard_info_func(param_manager, args, model_config) -> tvm.IRModule: shard_strategy_to_func = _get_shard_strategies( model_config, @@ -140,11 +246,20 @@ def add_to_shard_info(param_name: str, func_name: Optional[str]): def create_shard_transformation_func(param_manager, args, model_config) -> tvm.IRModule: - shard_strategy_to_func = _get_shard_strategies( - model_config, - num_shards=args.num_shards, - param_shape_is_already_sharded=args.build_model_only, - ) + use_ft_quant = args.quantization.name in ["q4f16_ft", "q8f16_ft"] + + if use_ft_quant: + shard_strategy_to_func = _get_shard_strategies_ft( + model_config, + num_shards=args.num_shards, + param_shape_is_already_sharded=args.build_model_only, + ) + else: + shard_strategy_to_func = _get_shard_strategies( + model_config, + num_shards=args.num_shards, + param_shape_is_already_sharded=args.build_model_only, + ) q_params = param_manager.get_quantized_param_info("prefill").fields @@ -189,7 +304,11 @@ def create_shard_transformation_func(param_manager, args, model_config) -> tvm.I arg = relax.Var(arg_name, qparam_sinfo) - if param.shard_strategy is None: + if param.shard_strategy is None or ( + use_ft_quant + and param.shard_strategy in ["shard_mlp_k", "shard_o_proj_k"] + and len(qparam_sinfo.shape) == 1 + ): sharded = arg else: strategy_func = shard_strategy_to_func[param.shard_strategy]( From 8149ea2f115aba5cb3f0f18cb20e71f3a1fb5c68 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 7 Nov 2023 09:36:33 +0000 Subject: [PATCH 4/7] wip --- mlc_llm/core.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/mlc_llm/core.py b/mlc_llm/core.py index 56d6b03d86..83c668494e 100644 --- a/mlc_llm/core.py +++ b/mlc_llm/core.py @@ -799,8 +799,15 @@ def build_model_from_args(args: argparse.Namespace): params = utils.convert_weights(mod_transform, param_manager, params, args) if args.num_shards > 1 and use_ft_quant: + preprocessed = [] + weight_preprocess_func = tvm.get_global_func("cutlass.ft_preprocess_weight") for p in params: - print(p.shape, p.dtype) + if p.dtype == "int8": + preprocessed.append(weight_preprocess_func(p, 80, True)) + else: + preprocessed.append(p) + + params = preprocessed utils.save_params(params, args.artifact_path, args.num_shards if args.use_presharded_weights else 1) From 1330ac3f01e8b87e5eb7df4e29badd6d5bc46eb0 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 7 Nov 2023 19:55:40 +0000 Subject: [PATCH 5/7] automatically set use-presharded-weights for FT --- mlc_llm/core.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mlc_llm/core.py b/mlc_llm/core.py index 83c668494e..32d707e131 100644 --- a/mlc_llm/core.py +++ b/mlc_llm/core.py @@ -707,8 +707,9 @@ def build_model_from_args(args: argparse.Namespace): "`num_shards` should be used together with " "`--build-model-only` and `--convert-weight-only`" ) - if use_ft_quant and not args.use_presharded_weights: - raise ValueError("Multi-GPU deployments with FT quantization requires --use-presharded-weights.") + + if use_ft_quant: + args.use_presharded_weights = True os.makedirs(args.artifact_path, exist_ok=True) if args.debug_dump: From 7ed660e4eaec43f190ba231fc792264b56c3483c Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 7 Nov 2023 20:17:34 +0000 Subject: [PATCH 6/7] conditionally apply preprocessing --- .../quantization/ft_rowwise_quantization.py | 26 ++++++++++++------- mlc_llm/utils.py | 7 +++++ 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/mlc_llm/quantization/ft_rowwise_quantization.py b/mlc_llm/quantization/ft_rowwise_quantization.py index d9d3d6b7d8..e5ba8a0c5d 100644 --- a/mlc_llm/quantization/ft_rowwise_quantization.py +++ b/mlc_llm/quantization/ft_rowwise_quantization.py @@ -30,6 +30,8 @@ def __init__(self, dtype, nbit): else: self.sm = None + self.do_preprocess = True + def get_quantize_func(self, param_info: relax.TensorStructInfo) -> Optional[FQuantize]: assert self.sm is not None @@ -44,16 +46,20 @@ def f_quantize(bb: relax.BlockBuilder, inputs: List[relax.Expr]): primfunc_name_hint="encode", ) packed_weight = bb.normalize(encoded_data[0]) - # encoded_weight = bb.emit( - # relax.call_pure_packed( - # "cutlass.ft_preprocess_weight", - # packed_weight, - # self.sm, - # self.nbit == 4, - # sinfo_args=packed_weight.struct_info, - # ) - # ) - encoded_weight = packed_weight + + if self.do_preprocess: + encoded_weight = bb.emit( + relax.call_pure_packed( + "cutlass.ft_preprocess_weight", + packed_weight, + self.sm, + self.nbit == 4, + sinfo_args=packed_weight.struct_info, + ) + ) + else: + encoded_weight = packed_weight + return bb.emit(relax.Tuple([encoded_weight, encoded_data[1]])) return f_quantize diff --git a/mlc_llm/utils.py b/mlc_llm/utils.py index 5246b95c3d..f6858922cd 100644 --- a/mlc_llm/utils.py +++ b/mlc_llm/utils.py @@ -119,8 +119,15 @@ def argparse_postproc_common(args: argparse.Namespace) -> None: if args.quantization not in quantization_schemes: raise ValueError(f'Quantization "{args.quantization}" is not supported.') + + use_ft_quant = args.quantization in ["q4f16_ft", "q8f16_ft"] args.quantization = quantization_schemes[args.quantization] + if use_ft_quant and args.num_shards > 1: + # Preprocess is done after sharding for this case. + args.quantization.linear_weight.do_preprocess = False + args.quantization.final_fc_weight.do_preprocess = False + def debug_dump_script(mod, name, args: argparse.Namespace, show_meta=True): """Debug dump mode""" From 1c65fa6553b0ad33f336afe537742dccd805623e Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 7 Nov 2023 20:24:21 +0000 Subject: [PATCH 7/7] fix --- mlc_llm/core.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/mlc_llm/core.py b/mlc_llm/core.py index 32d707e131..3a7dd9a1ca 100644 --- a/mlc_llm/core.py +++ b/mlc_llm/core.py @@ -461,6 +461,17 @@ def validate_config(model_path: str): ), f"Model type {config['model_type']} not supported." +def get_cuda_sm_version(): + major, minor = parse_compute_version(tvm.cuda(0).compute_version) + + if major == 8: + sm = 80 + else: + sm = 10 * major + minor + + return sm + + def mod_transform_before_build( mod: tvm.IRModule, param_manager: param_manager.ParamManager, @@ -550,13 +561,7 @@ def mod_transform_before_build( if len(patterns) > 0: os.makedirs("./tmp", exist_ok=True) - major, minor = parse_compute_version(tvm.cuda(0).compute_version) - - if major == 8: - sm = 80 - else: - sm = 10 * major + minor - + sm = get_cuda_sm_version() options = {"cutlass": {"sm": sm, "find_first_valid": False}} if hasattr(config, "rms_norm_eps"): @@ -802,9 +807,12 @@ def build_model_from_args(args: argparse.Namespace): if args.num_shards > 1 and use_ft_quant: preprocessed = [] weight_preprocess_func = tvm.get_global_func("cutlass.ft_preprocess_weight") + is_int4 = args.quantization.name == "q4f16_ft" + sm = get_cuda_sm_version() + for p in params: if p.dtype == "int8": - preprocessed.append(weight_preprocess_func(p, 80, True)) + preprocessed.append(weight_preprocess_func(p, sm, is_int4)) else: preprocessed.append(p)