diff --git a/src/Model/Language/Llama.php b/src/Model/Language/Llama.php index cd04186a..e3343a9c 100644 --- a/src/Model/Language/Llama.php +++ b/src/Model/Language/Llama.php @@ -5,7 +5,10 @@ namespace PhpLlm\LlmChain\Model\Language; use PhpLlm\LlmChain\LanguageModel; +use PhpLlm\LlmChain\Message\AssistantMessage; use PhpLlm\LlmChain\Message\MessageBag; +use PhpLlm\LlmChain\Message\MessageInterface; +use PhpLlm\LlmChain\Message\UserMessage; use PhpLlm\LlmChain\Platform\Ollama; use PhpLlm\LlmChain\Platform\Replicate; use PhpLlm\LlmChain\Response\TextResponse; @@ -24,12 +27,32 @@ public function call(MessageBag $messages, array $options = []): TextResponse $response = $this->platform->request('meta/meta-llama-3.1-405b-instruct', $endpoint, [ 'system' => $systemMessage?->content, - 'prompt' => $messages->withoutSystemMessage()->getIterator()->current()->content[0]->text, // @phpstan-ignore-line TODO: Multiple messages + 'prompt' => self::convertToPrompt($messages->withoutSystemMessage()), ]); return new TextResponse(implode('', $response['output'])); } + private static function convertToPrompt(MessageBag $messageBag): string + { + $messages = []; + + /** @var MessageInterface $message */ + foreach ($messageBag->getIterator() as $message) { + if ($message instanceof UserMessage) { + $content = $message->content[0]->text; + } elseif ($message instanceof AssistantMessage && null !== $message->content) { + $content = $message->content; + } else { + continue; + } + + $messages[] = sprintf('%s: %s', ucfirst($message->getRole()->value), $content); + } + + return implode(PHP_EOL, $messages); + } + public function supportsToolCalling(): bool { return false; // it does, but implementation here is still open. diff --git a/tests/Fixture/StructuredOutput/MathReasoning.php b/tests/Fixture/StructuredOutput/MathReasoning.php new file mode 100644 index 00000000..f0f6c87d --- /dev/null +++ b/tests/Fixture/StructuredOutput/MathReasoning.php @@ -0,0 +1,17 @@ +