Skip to content

[Bug] Cannot build --quantization q0f32 model #987

@LeshengJin

Description

@LeshengJin

🐛 Bug

Meet an error while building fp32 llama-2.

To Reproduce

python3 build.py --model /opt/models/llama-2/llama-2-7b-chat-hf/ --quantization q0f32 --use-cache 0 --no-cutlass-norm --max-seq-len 16000 --build-model-only

Additional context

Traceback (most recent call last):
  File "/opt/scratch/leshengjin/mlc-llm/build.py", line 4, in <module>
    main()
  File "/opt/scratch/leshengjin/mlc-llm/mlc_llm/build.py", line 10, in main
    core.build_model_from_args(parsed_args)
  File "/opt/scratch/leshengjin/mlc-llm/mlc_llm/core.py", line 640, in build_model_from_args
    mod = mod_transform_before_build(mod, param_manager, args, model_config)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/scratch/leshengjin/mlc-llm/mlc_llm/core.py", line 392, in mod_transform_before_build
    mod = fuse_split_rotary_embedding(
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/scratch/leshengjin/mlc-llm/mlc_llm/transform/fuse_split_rotary_embedding.py", line 242, in fuse_split_rotary_embedding
    mod["decode"] = rewrite_bindings(ctx, rewriter, mod["decode"])
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/scratch/leshengjin/tvm/python/tvm/relax/dpl/rewrite.py", line 118, in rewrite_bindings
    return ffi.rewrite_bindings(ctx, rewriter, func)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/scratch/leshengjin/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
    raise_last_ffi_error()
  File "/opt/scratch/leshengjin/tvm/python/tvm/_ffi/base.py", line 476, in raise_last_ffi_error
    raise py_err
  File "/opt/scratch/leshengjin/tvm/src/relax/ir/dataflow_matcher.cc", line 1083, in tvm::relax::RewriteBindings(tvm::relax::PatternContext const&, tvm::runtime::PackedFunc, tvm::relax::Function)
    return PatternRewriter::Run(ctx, rewriter, f);
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/scratch/leshengjin/tvm/src/relax/ir/dataflow_matcher.cc", line 937, in tvm::relax::Function tvm::relax::PatternRewriter::Run<tvm::relax::PatternContext>(tvm::relax::PatternContext, tvm::runtime::PackedFunc, tvm::relax::Function)
    return Downcast<Function>(RemoveAllUnused(rewriter.VisitExpr(f)));
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/scratch/leshengjin/tvm/src/relax/ir/dataflow_matcher.cc", line 951, in tvm::relax::PatternRewriter::VisitExpr(tvm::RelayExpr const&)
    auto node = ExprMutator::VisitExpr(expr);
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/scratch/leshengjin/tvm/src/relax/ir/dataflow_matcher.cc", line 951, in tvm::relax::PatternRewriter::VisitExpr(tvm::RelayExpr const&)
    auto node = ExprMutator::VisitExpr(expr);
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/scratch/leshengjin/tvm/src/relax/ir/dataflow_matcher.cc", line 979, in tvm::relax::PatternRewriter::VisitBindingBlock_(tvm::relax::DataflowBlockNode const*)
    return RewriteDataflowBlockFixedPoint(GetRef<DataflowBlock>(block_node));
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/scratch/leshengjin/tvm/src/relax/ir/dataflow_matcher.cc", line 1046, in tvm::relax::PatternRewriter::RewriteDataflowBlockFixedPoint(tvm::relax::BindingBlock)
    this->VisitBinding(binding);
          ^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/scratch/leshengjin/tvm/src/relax/ir/dataflow_matcher.cc", line 942, in tvm::relax::PatternRewriter::VisitBinding_(tvm::relax::VarBindingNode const*)
    ExprMutator::VisitBinding_(binding);
                  ^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/scratch/leshengjin/tvm/src/relax/ir/dataflow_matcher.cc", line 951, in tvm::relax::PatternRewriter::VisitExpr(tvm::RelayExpr const&)
    auto node = ExprMutator::VisitExpr(expr);
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/scratch/leshengjin/tvm/src/relax/ir/block_builder.cc", line 468, in tvm::relax::Normalizer::Normalize(tvm::RelayExpr const&)
    Expr normalized = this->VisitExpr(expr);
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/scratch/leshengjin/tvm/src/relax/ir/block_builder.cc", line 544, in tvm::relax::Normalizer::VisitExpr(tvm::RelayExpr const&)
    return ExprFunctor::VisitExpr(expr);
                  ^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/scratch/leshengjin/tvm/src/relax/ir/block_builder.cc", line 599, in tvm::relax::Normalizer::VisitExpr_(tvm::relax::CallNode const*)
    auto inferred_sinfo = InferStructInfo(call);
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/scratch/leshengjin/tvm/src/relax/ir/block_builder.cc", line 757, in tvm::relax::Normalizer::InferStructInfo(tvm::relax::Call const&)
    return op_map_infer_struct_info_[op](call, GetRef<BlockBuilder>(this));
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/scratch/leshengjin/tvm/src/relax/op/tensor/linear_algebra.cc", line 67, in tvm::relax::InferStructInfoMatmul(tvm::relax::Call const&, tvm::relax::BlockBuilder const&)
    ? InferBinaryArithOpOutDtype(call, ctx, x1_sinfo, x2_sinfo)
^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/scratch/leshengjin/tvm/src/relax/op/distributed/../op_common.h", line 197, in tvm::relax::InferBinaryArithOpOutDtype(tvm::relax::Call const&, tvm::relax::BlockBuilder const&, tvm::relax::TensorStructInfo const&, tvm::relax::TensorStructInfo const&)
    ctx->ReportFatal(Diagnostic::Error(call)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/scratch/leshengjin/tvm/src/relax/ir/block_builder.cc", line 138, in tvm::relax::BlockBuilderImpl::ReportFatal(tvm::Diagnostic const&)
    LOG(FATAL) << diagnostic->message;
                  ^^^^^^^^^^^^^^^^^^^^^
tvm._ffi.base.TVMError: Traceback (most recent call last):
  14: tvm::relax::RewriteBindings(tvm::relax::PatternContext const&, tvm::runtime::PackedFunc, tvm::relax::Function)
        at /opt/scratch/leshengjin/tvm/src/relax/ir/dataflow_matcher.cc:1083
  13: tvm::relax::Function tvm::relax::PatternRewriter::Run<tvm::relax::PatternContext>(tvm::relax::PatternContext, tvm::runtime::PackedFunc, tvm::relax::Function)
        at /opt/scratch/leshengjin/tvm/src/relax/ir/dataflow_matcher.cc:937
  12: tvm::relax::PatternRewriter::VisitExpr(tvm::RelayExpr const&)
        at /opt/scratch/leshengjin/tvm/src/relax/ir/dataflow_matcher.cc:951
  11: tvm::relax::PatternRewriter::VisitExpr(tvm::RelayExpr const&)
        at /opt/scratch/leshengjin/tvm/src/relax/ir/dataflow_matcher.cc:951
  10: tvm::relax::PatternRewriter::VisitBindingBlock_(tvm::relax::DataflowBlockNode const*)
        at /opt/scratch/leshengjin/tvm/src/relax/ir/dataflow_matcher.cc:979
  9: tvm::relax::PatternRewriter::RewriteDataflowBlockFixedPoint(tvm::relax::BindingBlock)
        at /opt/scratch/leshengjin/tvm/src/relax/ir/dataflow_matcher.cc:1046
  8: tvm::relax::PatternRewriter::VisitBinding_(tvm::relax::VarBindingNode const*)
        at /opt/scratch/leshengjin/tvm/src/relax/ir/dataflow_matcher.cc:942
  7: tvm::relax::PatternRewriter::VisitExpr(tvm::RelayExpr const&)
        at /opt/scratch/leshengjin/tvm/src/relax/ir/dataflow_matcher.cc:951
  6: tvm::relax::Normalizer::Normalize(tvm::RelayExpr const&)
        at /opt/scratch/leshengjin/tvm/src/relax/ir/block_builder.cc:468
  5: tvm::relax::Normalizer::VisitExpr(tvm::RelayExpr const&)
        at /opt/scratch/leshengjin/tvm/src/relax/ir/block_builder.cc:544
  4: tvm::relax::Normalizer::VisitExpr_(tvm::relax::CallNode const*)
        at /opt/scratch/leshengjin/tvm/src/relax/ir/block_builder.cc:599
  3: tvm::relax::Normalizer::InferStructInfo(tvm::relax::Call const&)
        at /opt/scratch/leshengjin/tvm/src/relax/ir/block_builder.cc:757
  2: tvm::relax::InferStructInfoMatmul(tvm::relax::Call const&, tvm::relax::BlockBuilder const&)
        at /opt/scratch/leshengjin/tvm/src/relax/op/tensor/linear_algebra.cc:67
  1: tvm::relax::InferBinaryArithOpOutDtype(tvm::relax::Call const&, tvm::relax::BlockBuilder const&, tvm::relax::TensorStructInfo const&, tvm::relax::TensorStructInfo const&)
        at /opt/scratch/leshengjin/tvm/src/relax/op/distributed/../op_common.h:197
  0: tvm::relax::BlockBuilderImpl::ReportFatal(tvm::Diagnostic const&)
        at /opt/scratch/leshengjin/tvm/src/relax/ir/block_builder.cc:138
  File "/opt/scratch/leshengjin/tvm/src/relax/ir/block_builder.cc", line 138
TVMError: Data types float16 and float32 must be equal for binary operators

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugConfirmed bugs

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions