diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index e53c9c8ae..3f6f68ece 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -1496,6 +1496,18 @@ def __call__( grammar=grammar, ) + def create_generate( + self, + prompt: str, + max_tokens: int = 256 + ) -> GenerateResponse: + response = self.create_completion(prompt=prompt, max_tokens=max_tokens) + + return { + "generated_text": response["choices"][0]["text"], + "status": 200 + } + def _convert_text_completion_to_chat( self, completion: Completion ) -> ChatCompletion: diff --git a/llama_cpp/llama_types.py b/llama_cpp/llama_types.py index 6ee7ef914..e6d78f7d7 100644 --- a/llama_cpp/llama_types.py +++ b/llama_cpp/llama_types.py @@ -170,3 +170,9 @@ class ChatCompletionRequestMessage(TypedDict): content: Optional[str] name: NotRequired[str] funcion_call: NotRequired[ChatCompletionFunctionCall] + +class CreateGenerateResponse(TypedDict): + generated_text: str + status: int + +GenerateResponse = CreateGenerateResponse diff --git a/llama_cpp/server/app.py b/llama_cpp/server/app.py index 3dd0a38fe..ddf826c22 100644 --- a/llama_cpp/server/app.py +++ b/llama_cpp/server/app.py @@ -800,6 +800,61 @@ def iterator() -> Iterator[llama_cpp.ChatCompletionChunk]: return iterator_or_completion +class GenerateRequestParameters(BaseModel): + max_new_tokens: int = max_tokens_field + + +class CreateGenerateRequest(BaseModel): + inputs: str = Field(default="", description="Input string for the model") + parameters: Optional[GenerateRequestParameters] + + +@router.post("/v1/generate") +async def generate( + request: Request, + body: CreateGenerateRequest, + llama: llama_cpp.Llama = Depends(get_llama), +) -> llama_cpp.CreateGenerateResponse: + input_dict = body.model_dump() + print(input_dict) + kwargs = { + "prompt": input_dict["inputs"] + } + if "parameters" in input_dict: + for k, v in input_dict["parameters"].items(): + print("k, v", k, v) + if k == "max_new_tokens": + kwargs["max_tokens"] = v + + iterator_or_completion: Union[ + llama_cpp.GenerateResponse, Iterator[llama_cpp.GenerateResponse] + ] = await run_in_threadpool(llama.create_generate, **kwargs) + + if isinstance(iterator_or_completion, Iterator): + # EAFP: It's easier to ask for forgiveness than permission + first_response = await run_in_threadpool(next, iterator_or_completion) + + # If no exception was raised from first_response, we can assume that + # the iterator is valid and we can use it to stream the response. + def iterator() -> Iterator[llama_cpp.ChatCompletionChunk]: + yield first_response + yield from iterator_or_completion + + print("responses", first_response) + send_chan, recv_chan = anyio.create_memory_object_stream(10) + return EventSourceResponse( + recv_chan, + data_sender_callable=partial( # type: ignore + get_event_publisher, + request=request, + inner_send_chan=send_chan, + iterator=iterator(), + ), + ) + else: + return iterator_or_completion + + class ModelData(TypedDict): id: str object: Literal["model"]