Closed
Description
🐛 Bug
Getting error for Mistral-7B-Instruct-v0.1 using python API
To Reproduce
Steps to reproduce the behavior:
- git clone https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1
- python -m mlc_llm.build --model Mistral-7B-Instruct-v0.1 --quantization q4f16_1 --target metal
- python sample_mlc_chat.py
# sample_mlc_chat.py
from python.mlc_chat import ChatModule
from python.mlc_chat.callback import StreamToStdout
cm = ChatModule(
model="Mistral-7B-Instruct-v0.1-q4f16_1",
)
output = cm.generate(
prompt="How to make a cake?",
progress_callback=StreamToStdout(callback_interval=2),
)
I get the following error message:
File "/Users/kartik/mlc/mlc-llm/sample_mlc_chat.py", line 16, in <module>
output = cm.generate(
^^^^^^^^^^^^
File "/Users/kartik/mlc/mlc-llm/python/mlc_chat/chat_module.py", line 846, in generate
self._prefill(prompt, generation_config=generation_config)
File "/Users/kartik/mlc/mlc-llm/python/mlc_chat/chat_module.py", line 1059, in _prefill
self._prefill_func(
File "tvm/_ffi/_cython/./packed_func.pxi", line 332, in tvm._ffi._cy3.core.PackedFuncBase.__call__
File "tvm/_ffi/_cython/./packed_func.pxi", line 277, in tvm._ffi._cy3.core.FuncCall
File "tvm/_ffi/_cython/./base.pxi", line 182, in tvm._ffi._cy3.core.CHECK_CALL
File "/Users/kartik/miniconda3/envs/mlc_test/lib/python3.11/site-packages/tvm/_ffi/base.py", line 476, in raise_last_ffi_error
raise py_err
tvm.error.InternalError: Traceback (most recent call last):
File "/Users/catalyst/Workspace/mlc-ai-package-self-runner/_work/package/package/tvm/src/runtime/relax_vm/vm.cc", line 757
InternalError: Check failed: static_cast<size_t>(gfunc.num_args) == args.size() (6 vs. 4) : ValueError: Invoking function prefill requires 6 inputs but only 4 inputs are provided.
Environment
- Platform (e.g. WebGPU/Vulkan/IOS/Android/CUDA): metal
- Operating system (e.g. Ubuntu/Windows/MacOS/...): macos
- Device (e.g. iPhone 12 Pro, PC+RTX 3090, ...) metal
- How you installed MLC-LLM (
conda
, source): conda - How you installed TVM-Unity (
pip
, source): conda - Python version (e.g. 3.10): 3.11.5
- GPU driver version (if applicable):
- CUDA/cuDNN version (if applicable):
- TVM Unity Hash Tag (
python -c "import tvm; print('\n'.join(f'{k}: {v}' for k, v in tvm.support.libinfo().items()))"
, applicable if you compile models): - Any other relevant information: