diff --git a/.env b/.env index fbeafeba..1a6a7e78 100644 --- a/.env +++ b/.env @@ -9,6 +9,12 @@ ANTHROPIC_API_KEY= # For using Voyage VOYAGE_API_KEY= +# For using Replicate +REPLICATE_API_KEY= + +# For using Ollama +OLLAMA_HOST_URL= + # For using GPT on Azure AZURE_OPENAI_BASEURL= AZURE_OPENAI_DEPLOYMENT= diff --git a/examples/chat-claude-anthropic.php b/examples/chat-claude-anthropic.php index 84a449a3..7aa2516c 100755 --- a/examples/chat-claude-anthropic.php +++ b/examples/chat-claude-anthropic.php @@ -1,10 +1,10 @@ 0.5, // default options for the model ]); diff --git a/examples/chat-llama-ollama.php b/examples/chat-llama-ollama.php new file mode 100755 index 00000000..551b8f2d --- /dev/null +++ b/examples/chat-llama-ollama.php @@ -0,0 +1,29 @@ +loadEnv(dirname(__DIR__).'/.env'); + +if (empty($_ENV['OLLAMA_HOST_URL'])) { + echo 'Please set the OLLAMA_HOST_URL environment variable.'.PHP_EOL; + exit(1); +} + +$platform = new Ollama(HttpClient::create(), $_ENV['OLLAMA_HOST_URL']); +$llm = new Llama($platform); + +$chain = new Chain($llm); +$messages = new MessageBag( + Message::forSystem('You are a helpful assistant.'), + Message::ofUser('Tina has one brother and one sister. How many sisters do Tina\'s siblings have?'), +); +$response = $chain->call($messages); + +echo $response->getContent().PHP_EOL; diff --git a/examples/chat-llama-replicate.php b/examples/chat-llama-replicate.php new file mode 100755 index 00000000..491129af --- /dev/null +++ b/examples/chat-llama-replicate.php @@ -0,0 +1,29 @@ +loadEnv(dirname(__DIR__).'/.env'); + +if (empty($_ENV['REPLICATE_API_KEY'])) { + echo 'Please set the REPLICATE_API_KEY environment variable.'.PHP_EOL; + exit(1); +} + +$platform = new Replicate(HttpClient::create(), $_ENV['REPLICATE_API_KEY']); +$llm = new Llama($platform); + +$chain = new Chain($llm); +$messages = new MessageBag( + Message::forSystem('You are a helpful assistant.'), + Message::ofUser('Tina has one brother and one sister. How many sisters do Tina\'s siblings have?'), +); +$response = $chain->call($messages); + +echo $response->getContent().PHP_EOL; diff --git a/examples/chat-o1-openai.php b/examples/chat-o1-openai.php index ef86ee1f..a6b62181 100755 --- a/examples/chat-o1-openai.php +++ b/examples/chat-o1-openai.php @@ -3,9 +3,8 @@ use PhpLlm\LlmChain\Chain; use PhpLlm\LlmChain\Message\Message; use PhpLlm\LlmChain\Message\MessageBag; -use PhpLlm\LlmChain\OpenAI\Model\Gpt; -use PhpLlm\LlmChain\OpenAI\Model\Gpt\Version; -use PhpLlm\LlmChain\OpenAI\Platform\OpenAI; +use PhpLlm\LlmChain\Model\Language\Gpt; +use PhpLlm\LlmChain\Platform\OpenAI\OpenAI; use Symfony\Component\Dotenv\Dotenv; use Symfony\Component\HttpClient\HttpClient; @@ -23,7 +22,7 @@ } $platform = new OpenAI(HttpClient::create(), $_ENV['OPENAI_API_KEY']); -$llm = new Gpt($platform, Version::o1Preview()); +$llm = new Gpt($platform, Gpt::O1_PREVIEW); $prompt = <<create(<<create(<<embed($documents); // initialize the index $store->initialize(); -$llm = new Gpt($platform, Version::gpt4oMini()); +$llm = new Gpt($platform, Gpt::GPT_4O_MINI); $similaritySearch = new SimilaritySearch($embeddings, $store); $toolBox = new ToolBox(new ToolAnalyzer(), [$similaritySearch]); diff --git a/examples/store-pinecone-similarity-search.php b/examples/store-pinecone-similarity-search.php index 547beb7b..5aa09144 100755 --- a/examples/store-pinecone-similarity-search.php +++ b/examples/store-pinecone-similarity-search.php @@ -6,10 +6,9 @@ use PhpLlm\LlmChain\DocumentEmbedder; use PhpLlm\LlmChain\Message\Message; use PhpLlm\LlmChain\Message\MessageBag; -use PhpLlm\LlmChain\OpenAI\Model\Embeddings; -use PhpLlm\LlmChain\OpenAI\Model\Gpt; -use PhpLlm\LlmChain\OpenAI\Model\Gpt\Version; -use PhpLlm\LlmChain\OpenAI\Platform\OpenAI; +use PhpLlm\LlmChain\Model\Embeddings\OpenAI as Embeddings; +use PhpLlm\LlmChain\Model\Language\Gpt; +use PhpLlm\LlmChain\Platform\OpenAI\OpenAI as Platform; use PhpLlm\LlmChain\Store\Pinecone\Store; use PhpLlm\LlmChain\ToolBox\ChainProcessor; use PhpLlm\LlmChain\ToolBox\Tool\SimilaritySearch; @@ -48,11 +47,11 @@ } // create embeddings for documents -$platform = new OpenAI(HttpClient::create(), $_ENV['OPENAI_API_KEY']); +$platform = new Platform(HttpClient::create(), $_ENV['OPENAI_API_KEY']); $embedder = new DocumentEmbedder($embeddings = new Embeddings($platform), $store); $embedder->embed($documents); -$llm = new Gpt($platform, Version::gpt4oMini()); +$llm = new Gpt($platform, Gpt::GPT_4O_MINI); $similaritySearch = new SimilaritySearch($embeddings, $store); $toolBox = new ToolBox(new ToolAnalyzer(), [$similaritySearch]); diff --git a/examples/stream-claude-anthropic.php b/examples/stream-claude-anthropic.php index 08f12fcf..9c6836ea 100644 --- a/examples/stream-claude-anthropic.php +++ b/examples/stream-claude-anthropic.php @@ -1,10 +1,10 @@ $body - * - * @return array - */ - public function request(array $body): iterable; -} diff --git a/src/OpenAI/Model/Embeddings.php b/src/Model/Embeddings/OpenAI.php similarity index 71% rename from src/OpenAI/Model/Embeddings.php rename to src/Model/Embeddings/OpenAI.php index 249c83c8..ad9b37fd 100644 --- a/src/OpenAI/Model/Embeddings.php +++ b/src/Model/Embeddings/OpenAI.php @@ -2,20 +2,22 @@ declare(strict_types=1); -namespace PhpLlm\LlmChain\OpenAI\Model; +namespace PhpLlm\LlmChain\Model\Embeddings; use PhpLlm\LlmChain\Document\Vector; use PhpLlm\LlmChain\EmbeddingsModel; -use PhpLlm\LlmChain\OpenAI\Model\Embeddings\Version; -use PhpLlm\LlmChain\OpenAI\Platform; +use PhpLlm\LlmChain\Platform\OpenAI\Platform; -final class Embeddings implements EmbeddingsModel +final readonly class OpenAI implements EmbeddingsModel { + public const TEXT_ADA_002 = 'text-embedding-ada-002'; + public const TEXT_3_LARGE = 'text-embedding-3-large'; + public const TEXT_3_SMALL = 'text-embedding-3-small'; + public function __construct( - private readonly Platform $platform, - private ?Version $version = null, + private Platform $platform, + private string $version = self::TEXT_3_SMALL, ) { - $this->version ??= Version::textEmbedding3Small(); } public function create(string $text, array $options = []): Vector @@ -43,7 +45,7 @@ public function multiCreate(array $texts, array $options = []): array private function createBody(string $text): array { return [ - 'model' => $this->version->name, + 'model' => $this->version, 'input' => $text, ]; } diff --git a/src/Voyage/Model/Voyage.php b/src/Model/Embeddings/Voyage.php similarity index 51% rename from src/Voyage/Model/Voyage.php rename to src/Model/Embeddings/Voyage.php index 98f694fa..2cae5262 100644 --- a/src/Voyage/Model/Voyage.php +++ b/src/Model/Embeddings/Voyage.php @@ -2,20 +2,25 @@ declare(strict_types=1); -namespace PhpLlm\LlmChain\Voyage\Model; +namespace PhpLlm\LlmChain\Model\Embeddings; use PhpLlm\LlmChain\Document\Vector; use PhpLlm\LlmChain\EmbeddingsModel; -use PhpLlm\LlmChain\Voyage\Model\Voyage\Version; -use PhpLlm\LlmChain\Voyage\Platform; +use PhpLlm\LlmChain\Platform\Voyage as Platform; -final class Voyage implements EmbeddingsModel +final readonly class Voyage implements EmbeddingsModel { + public const VERSION_V3 = 'voyage-3'; + public const VERSION_V3_LITE = 'voyage-3-lite'; + public const VERSION_FINANCE_2 = 'voyage-finance-2'; + public const VERSION_MULTILINGUAL_2 = 'voyage-multilingual-2'; + public const VERSION_LAW_2 = 'voyage-law-2'; + public const VERSION_CODE_2 = 'voyage-code-2'; + public function __construct( - private readonly Platform $platform, - private ?Version $version = null, + private Platform $platform, + private string $version = self::VERSION_V3, ) { - $this->version ??= Version::v3(); } public function create(string $text, array $options = []): Vector @@ -28,7 +33,7 @@ public function create(string $text, array $options = []): Vector public function multiCreate(array $texts, array $options = []): array { $response = $this->platform->request(array_merge($options, [ - 'model' => $this->version->name, + 'model' => $this->version, 'input' => $texts, ])); diff --git a/src/Anthropic/Model/Claude.php b/src/Model/Language/Claude.php similarity index 75% rename from src/Anthropic/Model/Claude.php rename to src/Model/Language/Claude.php index fda96696..baa0e733 100644 --- a/src/Anthropic/Model/Claude.php +++ b/src/Model/Language/Claude.php @@ -2,27 +2,30 @@ declare(strict_types=1); -namespace PhpLlm\LlmChain\Anthropic\Model; +namespace PhpLlm\LlmChain\Model\Language; -use PhpLlm\LlmChain\Anthropic\Model\Claude\Version; -use PhpLlm\LlmChain\Anthropic\Platform; use PhpLlm\LlmChain\LanguageModel; use PhpLlm\LlmChain\Message\MessageBag; +use PhpLlm\LlmChain\Platform\Anthropic; use PhpLlm\LlmChain\Response\ResponseInterface; use PhpLlm\LlmChain\Response\StreamResponse; use PhpLlm\LlmChain\Response\TextResponse; -final class Claude implements LanguageModel +final readonly class Claude implements LanguageModel { + public const VERSION_3_HAIKU = 'claude-3-haiku-20240307'; + public const VERSION_3_SONNET = 'claude-3-sonnet-20240229'; + public const VERSION_35_SONNET = 'claude-3-5-sonnet-20240620'; + public const VERSION_3_OPUS = 'claude-3-opus-20240229'; + /** * @param array $options The default options for the model usage */ public function __construct( - private readonly Platform $platform, - private ?Version $version = null, - private readonly array $options = ['temperature' => 1.0, 'max_tokens' => 1000], + private Anthropic $platform, + private string $version = self::VERSION_35_SONNET, + private array $options = ['temperature' => 1.0, 'max_tokens' => 1000], ) { - $this->version ??= Version::sonnet35(); } /** @@ -33,7 +36,7 @@ public function call(MessageBag $messages, array $options = []): ResponseInterfa { $system = $messages->getSystemMessage(); $body = array_merge($this->options, $options, [ - 'model' => $this->version->name, + 'model' => $this->version, 'system' => $system->content, 'messages' => $messages->withoutSystemMessage(), ]); diff --git a/src/OpenAI/Model/Gpt.php b/src/Model/Language/Gpt.php similarity index 79% rename from src/OpenAI/Model/Gpt.php rename to src/Model/Language/Gpt.php index 2389a386..425b318a 100644 --- a/src/OpenAI/Model/Gpt.php +++ b/src/Model/Language/Gpt.php @@ -2,13 +2,12 @@ declare(strict_types=1); -namespace PhpLlm\LlmChain\OpenAI\Model; +namespace PhpLlm\LlmChain\Model\Language; use PhpLlm\LlmChain\Exception\RuntimeException; use PhpLlm\LlmChain\LanguageModel; use PhpLlm\LlmChain\Message\MessageBag; -use PhpLlm\LlmChain\OpenAI\Model\Gpt\Version; -use PhpLlm\LlmChain\OpenAI\Platform; +use PhpLlm\LlmChain\Platform\OpenAI\Platform; use PhpLlm\LlmChain\Response\Choice; use PhpLlm\LlmChain\Response\ChoiceResponse; use PhpLlm\LlmChain\Response\ResponseInterface; @@ -19,25 +18,42 @@ final class Gpt implements LanguageModel { + public const GPT_35_TURBO = 'gpt-3.5-turbo'; + public const GPT_35_TURBO_INSTRUCT = 'gpt-3.5-turbo-instruct'; + public const GPT_4 = 'gpt-4'; + public const GPT_4_TURBO = 'gpt-4-turbo'; + public const GPT_4O = 'gpt-4o'; + public const GPT_4O_MINI = 'gpt-4o-mini'; + public const O1_MINI = 'o1-mini'; + public const O1_PREVIEW = 'o1-preview'; + /** * @param array $options The default options for the model usage */ public function __construct( private readonly Platform $platform, - private ?Version $version = null, + private readonly string $version = self::GPT_4O, private readonly array $options = ['temperature' => 1.0], + private bool $supportsImageInput = false, + private bool $supportsStructuredOutput = false, ) { - $this->version ??= Version::gpt4o(); + if (false === $this->supportsImageInput) { + $this->supportsImageInput = in_array($this->version, [self::GPT_4_TURBO, self::GPT_4O, self::GPT_4O_MINI, self::O1_MINI, self::O1_PREVIEW], true); + } + + if (false === $this->supportsStructuredOutput) { + $this->supportsStructuredOutput = in_array($this->version, [self::GPT_4O, self::GPT_4O_MINI], true); + } } /** - * @param array $options The options to be used for this specific call. - * Can overwrite default options. + * @param array $options The options to be used for this specific call. + * Can overwrite default options. */ public function call(MessageBag $messages, array $options = []): ResponseInterface { $body = array_merge($this->options, $options, [ - 'model' => $this->version->name, + 'model' => $this->version, 'messages' => $messages, ]); @@ -76,12 +92,12 @@ public function supportsToolCalling(): bool public function supportsImageInput(): bool { - return $this->version->supportImageInput; + return $this->supportsImageInput; } public function supportsStructuredOutput(): bool { - return $this->version->supportStructuredOutput; + return $this->supportsStructuredOutput; } private function streamIsToolCall(\Generator $response): bool diff --git a/src/Model/Language/Llama.php b/src/Model/Language/Llama.php new file mode 100644 index 00000000..9d87660d --- /dev/null +++ b/src/Model/Language/Llama.php @@ -0,0 +1,137 @@ +platform instanceof Replicate) { + $response = $this->platform->request('meta/meta-llama-3.1-405b-instruct', 'predictions', [ + 'system' => self::convertMessage($messages->getSystemMessage() ?? new SystemMessage('')), + 'prompt' => self::convertToPrompt($messages->withoutSystemMessage()), + ]); + + return new TextResponse(implode('', $response['output'])); + } + + $response = $this->platform->request('llama3.2', 'chat', ['messages' => $messages, 'stream' => false]); + + return new TextResponse($response['message']['content']); + } + + /** + * @todo make method private, just for testing, or create a MessageBag to LLama convert class :thinking: + */ + public static function convertToPrompt(MessageBag $messageBag): string + { + $messages = []; + + /** @var UserMessage|SystemMessage|AssistantMessage $message */ + foreach ($messageBag->getIterator() as $message) { + $messages[] = self::convertMessage($message); + } + + $messages = array_filter($messages, fn ($message) => '' !== $message); + + return trim(implode(PHP_EOL.PHP_EOL, $messages)).PHP_EOL.PHP_EOL.'<|start_header_id|>assistant<|end_header_id|>'; + } + + /** + * @todo make method private, just for testing + */ + public static function convertMessage(UserMessage|SystemMessage|AssistantMessage $message): string + { + if ($message instanceof SystemMessage) { + return trim(<<<|start_header_id|>system<|end_header_id|> + +{$message->content}<|eot_id|> +SYSTEM); + } + + if ($message instanceof AssistantMessage) { + if ('' === $message->content || null === $message->content) { + return ''; + } + + return trim(<<{$message->getRole()->value}<|end_header_id|> + +{$message->content}<|eot_id|> +ASSISTANT); + } + + if ($message instanceof UserMessage) { + $count = count($message->content); + + $contentParts = []; + if ($count > 1) { + foreach ($message->content as $value) { + if ($value instanceof Text) { + $contentParts[] = $value->text; + } + + if ($value instanceof Image) { + $contentParts[] = $value->url; + } + } + } elseif (1 === $count) { + $value = $message->content[0]; + if ($value instanceof Text) { + $contentParts[] = $value->text; + } + + if ($value instanceof Image) { + $contentParts[] = $value->url; + } + } else { + throw new RuntimeException('Unsupported message type.'); + } + + $content = implode(PHP_EOL, $contentParts); + + return trim(<<{$message->getRole()->value}<|end_header_id|> + +{$content}<|eot_id|> +USER); + } + + throw new RuntimeException('Unsupported message type.'); // @phpstan-ignore-line + } + + public function supportsToolCalling(): bool + { + return false; // it does, but implementation here is still open. + } + + public function supportsImageInput(): bool + { + return false; // it does, but implementation here is still open. + } + + public function supportsStructuredOutput(): bool + { + return false; + } +} diff --git a/src/OpenAI/Model/Embeddings/Version.php b/src/OpenAI/Model/Embeddings/Version.php deleted file mode 100644 index f56c0ff8..00000000 --- a/src/OpenAI/Model/Embeddings/Version.php +++ /dev/null @@ -1,34 +0,0 @@ - $body + * + * @return array + */ + public function request(string $model, string $endpoint, array $body): array + { + $url = sprintf('%s/api/%s', $this->hostUrl, $endpoint); + + $response = $this->httpClient->request('POST', $url, [ + 'headers' => ['Content-Type' => 'application/json'], + 'json' => array_merge($body, [ + 'model' => $model, + ]), + ]); + + return $response->toArray(); + } +} diff --git a/src/OpenAI/Platform/AbstractPlatform.php b/src/Platform/OpenAI/AbstractPlatform.php similarity index 95% rename from src/OpenAI/Platform/AbstractPlatform.php rename to src/Platform/OpenAI/AbstractPlatform.php index 7949442c..496ee8b9 100644 --- a/src/OpenAI/Platform/AbstractPlatform.php +++ b/src/Platform/OpenAI/AbstractPlatform.php @@ -2,10 +2,9 @@ declare(strict_types=1); -namespace PhpLlm\LlmChain\OpenAI\Platform; +namespace PhpLlm\LlmChain\Platform\OpenAI; use PhpLlm\LlmChain\Exception\RuntimeException; -use PhpLlm\LlmChain\OpenAI\Platform; use Symfony\Component\HttpClient\Chunk\ServerSentEvent; use Symfony\Component\HttpClient\EventSourceHttpClient; use Symfony\Component\HttpClient\Exception\ClientException; diff --git a/src/OpenAI/Platform/Azure.php b/src/Platform/OpenAI/Azure.php similarity index 95% rename from src/OpenAI/Platform/Azure.php rename to src/Platform/OpenAI/Azure.php index 5d6e247e..b00409e8 100644 --- a/src/OpenAI/Platform/Azure.php +++ b/src/Platform/OpenAI/Azure.php @@ -2,9 +2,8 @@ declare(strict_types=1); -namespace PhpLlm\LlmChain\OpenAI\Platform; +namespace PhpLlm\LlmChain\Platform\OpenAI; -use PhpLlm\LlmChain\OpenAI\Platform; use Symfony\Component\HttpClient\EventSourceHttpClient; use Symfony\Contracts\HttpClient\HttpClientInterface; use Symfony\Contracts\HttpClient\ResponseInterface; diff --git a/src/OpenAI/Platform/OpenAI.php b/src/Platform/OpenAI/OpenAI.php similarity index 93% rename from src/OpenAI/Platform/OpenAI.php rename to src/Platform/OpenAI/OpenAI.php index 74833bac..790fad75 100644 --- a/src/OpenAI/Platform/OpenAI.php +++ b/src/Platform/OpenAI/OpenAI.php @@ -2,9 +2,8 @@ declare(strict_types=1); -namespace PhpLlm\LlmChain\OpenAI\Platform; +namespace PhpLlm\LlmChain\Platform\OpenAI; -use PhpLlm\LlmChain\OpenAI\Platform; use Symfony\Component\HttpClient\EventSourceHttpClient; use Symfony\Contracts\HttpClient\HttpClientInterface; use Symfony\Contracts\HttpClient\ResponseInterface; diff --git a/src/OpenAI/Platform.php b/src/Platform/OpenAI/Platform.php similarity index 90% rename from src/OpenAI/Platform.php rename to src/Platform/OpenAI/Platform.php index 28759bc4..dd12e656 100644 --- a/src/OpenAI/Platform.php +++ b/src/Platform/OpenAI/Platform.php @@ -2,7 +2,7 @@ declare(strict_types=1); -namespace PhpLlm\LlmChain\OpenAI; +namespace PhpLlm\LlmChain\Platform\OpenAI; interface Platform { diff --git a/src/Platform/Replicate.php b/src/Platform/Replicate.php new file mode 100644 index 00000000..37968fca --- /dev/null +++ b/src/Platform/Replicate.php @@ -0,0 +1,56 @@ + $body + * + * @return array + */ + public function request(string $model, string $endpoint, array $body): array + { + $url = sprintf('https://api.replicate.com/v1/models/%s/%s', $model, $endpoint); + + $response = $this->httpClient->request('POST', $url, [ + 'headers' => ['Content-Type' => 'application/json'], + 'auth_bearer' => $this->apiKey, + 'json' => ['input' => $body], + ])->toArray(); + + while (!in_array($response['status'], ['succeeded', 'failed', 'canceled'], true)) { + sleep(1); + + $response = $this->getResponse($response['id']); + } + + return $response; + } + + /** + * @return array + */ + private function getResponse(string $id): array + { + $url = sprintf('https://api.replicate.com/v1/predictions/%s', $id); + + $response = $this->httpClient->request('GET', $url, [ + 'headers' => ['Content-Type' => 'application/json'], + 'auth_bearer' => $this->apiKey, + ]); + + return $response->toArray(); + } +} diff --git a/src/Voyage/Platform/Voyage.php b/src/Platform/Voyage.php similarity index 76% rename from src/Voyage/Platform/Voyage.php rename to src/Platform/Voyage.php index 330e73e3..0671596e 100644 --- a/src/Voyage/Platform/Voyage.php +++ b/src/Platform/Voyage.php @@ -2,12 +2,11 @@ declare(strict_types=1); -namespace PhpLlm\LlmChain\Voyage\Platform; +namespace PhpLlm\LlmChain\Platform; -use PhpLlm\LlmChain\Voyage\Platform; use Symfony\Contracts\HttpClient\HttpClientInterface; -final readonly class Voyage implements Platform +final readonly class Voyage { public function __construct( private HttpClientInterface $httpClient, @@ -15,6 +14,11 @@ public function __construct( ) { } + /** + * @param array $body + * + * @return array + */ public function request(array $body): array { $response = $this->httpClient->request('POST', 'https://api.voyageai.com/v1/embeddings', [ diff --git a/src/Voyage/Model/Voyage/Version.php b/src/Voyage/Model/Voyage/Version.php deleted file mode 100644 index 77ffd2f3..00000000 --- a/src/Voyage/Model/Voyage/Version.php +++ /dev/null @@ -1,46 +0,0 @@ - $body - * - * @return array - */ - public function request(array $body): array; -} diff --git a/tests/Model/Language/LlamaTest.php b/tests/Model/Language/LlamaTest.php new file mode 100644 index 00000000..94ff87fa --- /dev/null +++ b/tests/Model/Language/LlamaTest.php @@ -0,0 +1,133 @@ +append($message[1]); + } + + self::assertSame(<<<|start_header_id|>system<|end_header_id|> + +You are a helpful chatbot.<|eot_id|> + +<|start_header_id|>user<|end_header_id|> + +Hello, how are you?<|eot_id|> + +<|start_header_id|>user<|end_header_id|> + +Hello, how are you? +What is your name?<|eot_id|> + +<|start_header_id|>user<|end_header_id|> + +Hello, how are you? +What is your name? +https://example.com/image.jpg<|eot_id|> + +<|start_header_id|>assistant<|end_header_id|> + +I am an assistant.<|eot_id|> + +<|start_header_id|>assistant<|end_header_id|> +EXPECTED, + (new Llama(new Ollama(new MockHttpClient(), 'http://example.com')))->convertToPrompt($messageBag) + ); + } + + #[Test] + #[DataProvider('provideMessages')] + public function convertMessage(string $expected, UserMessage|SystemMessage|AssistantMessage $message): void + { + self::assertSame( + $expected, + (new Llama(new Ollama(new MockHttpClient(), 'http://example.com')))->convertMessage($message) + ); + } + + /** + * @return iterable + */ + public static function provideMessages(): iterable + { + yield 'System message' => [ + <<<|start_header_id|>system<|end_header_id|> + +You are a helpful chatbot.<|eot_id|> +SYSTEM, + Message::forSystem('You are a helpful chatbot.'), + ]; + + yield 'UserMessage' => [ + <<user<|end_header_id|> + +Hello, how are you?<|eot_id|> +USER, + Message::ofUser('Hello, how are you?'), + ]; + + yield 'UserMessage with two texts' => [ + <<user<|end_header_id|> + +Hello, how are you? +What is your name?<|eot_id|> +USER, + Message::ofUser('Hello, how are you?', 'What is your name?'), + ]; + + yield 'UserMessage with two texts and one image' => [ + <<user<|end_header_id|> + +Hello, how are you? +What is your name? +https://example.com/image.jpg<|eot_id|> +USER, + Message::ofUser('Hello, how are you?', 'What is your name?', new Image('https://example.com/image.jpg')), + ]; + + yield 'AssistantMessage' => [ + <<assistant<|end_header_id|> + +I am an assistant.<|eot_id|> +ASSISTANT, + new AssistantMessage('I am an assistant.'), + ]; + + yield 'AssistantMessage with null content' => [ + '', + new AssistantMessage(), + ]; + } +}