From fa5e017c48b6b196a6cdf923e57cc639936cc21a Mon Sep 17 00:00:00 2001 From: Pisarenko Grigoriy Date: Sun, 19 Oct 2025 17:27:05 +0300 Subject: [PATCH 1/7] Added ydb ai command --- ydb/public/lib/ydb_cli/commands/ya.make | 3 + ydb/public/lib/ydb_cli/commands/ydb_ai.cpp | 53 +++++++++++++++ ydb/public/lib/ydb_cli/commands/ydb_ai.h | 18 +++++ .../ydb_cli/commands/ydb_ai/line_reader.cpp | 66 +++++++++++++++++++ .../lib/ydb_cli/commands/ydb_ai/line_reader.h | 30 +++++++++ .../lib/ydb_cli/commands/ydb_ai/ya.make | 12 ++++ .../lib/ydb_cli/commands/ydb_root_common.cpp | 2 + 7 files changed, 184 insertions(+) create mode 100644 ydb/public/lib/ydb_cli/commands/ydb_ai.cpp create mode 100644 ydb/public/lib/ydb_cli/commands/ydb_ai.h create mode 100644 ydb/public/lib/ydb_cli/commands/ydb_ai/line_reader.cpp create mode 100644 ydb/public/lib/ydb_cli/commands/ydb_ai/line_reader.h create mode 100644 ydb/public/lib/ydb_cli/commands/ydb_ai/ya.make diff --git a/ydb/public/lib/ydb_cli/commands/ya.make b/ydb/public/lib/ydb_cli/commands/ya.make index 86e60e67e7f7..9b0a5dca1bb1 100644 --- a/ydb/public/lib/ydb_cli/commands/ya.make +++ b/ydb/public/lib/ydb_cli/commands/ya.make @@ -11,6 +11,7 @@ SRCS( topic_write_scenario.cpp topic_readwrite_scenario.cpp ydb_admin.cpp + ydb_ai.cpp ydb_benchmark.cpp ydb_bridge.cpp ydb_cluster.cpp @@ -57,6 +58,7 @@ PEERDIR( ydb/public/lib/ydb_cli/commands/sdk_core_access ydb/public/lib/ydb_cli/commands/topic_workload ydb/public/lib/ydb_cli/commands/transfer_workload + ydb/public/lib/ydb_cli/commands/ydb_ai ydb/public/lib/ydb_cli/commands/ydb_discovery ydb/public/lib/ydb_cli/common ydb/public/lib/ydb_cli/dump @@ -92,5 +94,6 @@ RECURSE( sdk_core_access topic_workload transfer_workload + ydb_ai ydb_discovery ) diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai.cpp b/ydb/public/lib/ydb_cli/commands/ydb_ai.cpp new file mode 100644 index 000000000000..63856aa6f88a --- /dev/null +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai.cpp @@ -0,0 +1,53 @@ +#include "ydb_ai.h" + +#include + +#include + +namespace NYdb::NConsoleClient { + +namespace { + +void PrintExitMessage() { + Cout << "\nBye" << Endl; +} + +} // anonymous namespace + +TCommandAi::TCommandAi() + : TBase("ai", {}, "AI-TODO: KIKIMR-24198 -- description") +{} + +void TCommandAi::Config(TConfig& config) { + TClientCommand::Config(config); + config.Opts->SetTitle("AI-TODO: KIKIMR-24198 -- title"); + config.Opts->SetFreeArgsNum(0); +} + +int TCommandAi::Run(TConfig& config) { + Y_UNUSED(config); + + Cout << "AI-TODO: KIKIMR-24198 -- welcome message" << Endl; + + // AI-TODO: KIKIMR-24202 - robust file creation + NAi::TLineReader lineReader("ydb-ai> ", (TFsPath(HomeDir) / ".ydb-ai/history").GetPath()); + + while (const auto& maybeLine = lineReader.ReadLine()) { + const auto& input = *maybeLine; + if (input.empty()) { + continue; + } + + if (IsIn({"quit", "exit"}, to_lower(input))) { + PrintExitMessage(); + return EXIT_SUCCESS; + } + + Cout << "Input value: " << input << Endl; + } + + PrintExitMessage(); + return EXIT_SUCCESS; +} + +} // namespace NYdb::NConsoleClient \ No newline at end of file diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai.h b/ydb/public/lib/ydb_cli/commands/ydb_ai.h new file mode 100644 index 000000000000..4ba27d2cd43e --- /dev/null +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai.h @@ -0,0 +1,18 @@ +#pragma once + +#include "ydb_command.h" + +namespace NYdb::NConsoleClient { + +class TCommandAi : public TYdbCommand { + using TBase = TYdbCommand; + +public: + TCommandAi(); + + void Config(TConfig& config) override; + + int Run(TConfig& config) override; +}; + +} // namespace NYdb::NConsoleClient diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/line_reader.cpp b/ydb/public/lib/ydb_cli/commands/ydb_ai/line_reader.cpp new file mode 100644 index 000000000000..67d5c6b6109b --- /dev/null +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/line_reader.cpp @@ -0,0 +1,66 @@ +#include "line_reader.h" + +#include + +namespace NYdb::NConsoleClient::NAi { + +TLineReader::TLineReader(TString prompt, TString historyFilePath) + : Prompt(std::move(prompt)) + , HistoryFilePath(historyFilePath) + , HistoryFileLock(historyFilePath) +{ + Rx.install_window_change_handler(); + + Rx.bind_key(replxx::Replxx::KEY::control('N'), [this](char32_t code) { + return Rx.invoke(replxx::Replxx::ACTION::HISTORY_NEXT, code); + }); + Rx.bind_key(replxx::Replxx::KEY::control('P'), [this](char32_t code) { + return Rx.invoke(replxx::Replxx::ACTION::HISTORY_PREVIOUS, code); + }); + Rx.bind_key(replxx::Replxx::KEY::control('D'), [](char32_t) { + return replxx::Replxx::ACTION_RESULT::BAIL; + }); + Rx.bind_key(replxx::Replxx::KEY::control('J'), [this](char32_t code) { + return Rx.invoke(replxx::Replxx::ACTION::COMMIT_LINE, code); + }); + + Rx.enable_bracketed_paste(); + Rx.set_unique_history(true); + + if (const auto guard = TryLockHistory(); guard && !Rx.history_load(HistoryFilePath)) { + Rx.print("Loading history failed: %s\n", strerror(errno)); + } +} + +std::optional TLineReader::ReadLine() { + do { + if (const char* input = Rx.input(Prompt.c_str())) { + auto result = Strip(input); + AddToHistory(result); + return std::move(result); + } + } while (errno == EAGAIN); + + return std::nullopt; +} + +TTryGuard TLineReader::TryLockHistory() { + // AI-TODO: KIKIMR-24202 - robust file creation and handling + TTryGuard guard(HistoryFileLock); + + if (!guard) { + Rx.print("Lock of history file failed: %s\n", strerror(errno)); + } + + return guard; +} + +void TLineReader::AddToHistory(const TString& line) { + Rx.history_add(line); + + if (const auto guard = TryLockHistory(); guard && !Rx.history_save(HistoryFilePath)) { + Rx.print("Save history failed: %s\n", strerror(errno)); + } +} + +} // namespace NYdb::NConsoleClient::NAi diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/line_reader.h b/ydb/public/lib/ydb_cli/commands/ydb_ai/line_reader.h new file mode 100644 index 000000000000..f156a0f70a14 --- /dev/null +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/line_reader.h @@ -0,0 +1,30 @@ +#pragma once + +#include + +#include +#include + +#include + +namespace NYdb::NConsoleClient::NAi { + +class TLineReader { +public: + TLineReader(TString prompt, TString historyFilePath); + + std::optional ReadLine(); + +private: + TTryGuard TryLockHistory(); + + void AddToHistory(const TString& line); + +private: + TString Prompt; + TString HistoryFilePath; + TFileLock HistoryFileLock; + replxx::Replxx Rx; +}; + +} // namespace NYdb::NConsoleClient::NAi diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/ya.make b/ydb/public/lib/ydb_cli/commands/ydb_ai/ya.make new file mode 100644 index 000000000000..2ce9b7aa0b07 --- /dev/null +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/ya.make @@ -0,0 +1,12 @@ +LIBRARY() + +SRCS( + line_reader.cpp +) + +PEERDIR( + contrib/restricted/patched/replxx + util +) + +END() diff --git a/ydb/public/lib/ydb_cli/commands/ydb_root_common.cpp b/ydb/public/lib/ydb_cli/commands/ydb_root_common.cpp index fdf4bf86a8f1..e48e854404cf 100644 --- a/ydb/public/lib/ydb_cli/commands/ydb_root_common.cpp +++ b/ydb/public/lib/ydb_cli/commands/ydb_root_common.cpp @@ -1,6 +1,7 @@ #include "ydb_root_common.h" #include "ydb_profile.h" #include "ydb_admin.h" +#include "ydb_ai.h" #include "ydb_debug.h" #include "ydb_service_auth.h" #include "ydb_service_discovery.h" @@ -38,6 +39,7 @@ TClientCommandRootCommon::TClientCommandRootCommon(const TString& name, const TC { ValidateSettings(); AddDangerousCommand(std::make_unique()); + AddCommand(std::make_unique()); AddCommand(std::make_unique()); AddCommand(std::make_unique()); AddCommand(std::make_unique()); From 595958244f294f432d30b4408247f579489c7945 Mon Sep 17 00:00:00 2001 From: Pisarenko Grigoriy Date: Fri, 24 Oct 2025 11:19:13 +0300 Subject: [PATCH 2/7] Supported model calls --- .../common/http_gateway/yql_http_gateway.h | 12 +- ydb/public/lib/ydb_cli/commands/ya.make | 1 + ydb/public/lib/ydb_cli/commands/ydb_ai.cpp | 9 +- ydb/public/lib/ydb_cli/commands/ydb_ai.h | 2 +- .../commands/ydb_ai/models/model_interface.h | 16 ++ .../commands/ydb_ai/models/model_openai.cpp | 175 ++++++++++++++++++ .../commands/ydb_ai/models/model_openai.h | 17 ++ .../ydb_cli/commands/ydb_ai/models/ya.make | 15 ++ .../lib/ydb_cli/commands/ydb_ai/ya.make | 4 + 9 files changed, 243 insertions(+), 8 deletions(-) create mode 100644 ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_interface.h create mode 100644 ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_openai.cpp create mode 100644 ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_openai.h create mode 100644 ydb/public/lib/ydb_cli/commands/ydb_ai/models/ya.make diff --git a/ydb/library/yql/providers/common/http_gateway/yql_http_gateway.h b/ydb/library/yql/providers/common/http_gateway/yql_http_gateway.h index 2095419a6270..9ad44b19ef81 100644 --- a/ydb/library/yql/providers/common/http_gateway/yql_http_gateway.h +++ b/ydb/library/yql/providers/common/http_gateway/yql_http_gateway.h @@ -2,21 +2,21 @@ #include "yql_http_header.h" -#include - #include + +#include + #include #include #include -#include - #include -#include #include namespace NYql { +class THttpGatewayConfig; + class IHTTPGateway { public: using TPtr = std::shared_ptr; @@ -138,4 +138,4 @@ class IHTTPGateway { const TString& awsSigV4 = {}); }; -} +} // namespace NYql diff --git a/ydb/public/lib/ydb_cli/commands/ya.make b/ydb/public/lib/ydb_cli/commands/ya.make index 9b0a5dca1bb1..256c0240d2d7 100644 --- a/ydb/public/lib/ydb_cli/commands/ya.make +++ b/ydb/public/lib/ydb_cli/commands/ya.make @@ -59,6 +59,7 @@ PEERDIR( ydb/public/lib/ydb_cli/commands/topic_workload ydb/public/lib/ydb_cli/commands/transfer_workload ydb/public/lib/ydb_cli/commands/ydb_ai + ydb/public/lib/ydb_cli/commands/ydb_ai/models ydb/public/lib/ydb_cli/commands/ydb_discovery ydb/public/lib/ydb_cli/common ydb/public/lib/ydb_cli/dump diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai.cpp b/ydb/public/lib/ydb_cli/commands/ydb_ai.cpp index 63856aa6f88a..94230e8f8744 100644 --- a/ydb/public/lib/ydb_cli/commands/ydb_ai.cpp +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai.cpp @@ -1,8 +1,10 @@ #include "ydb_ai.h" #include +#include #include +#include namespace NYdb::NConsoleClient { @@ -31,6 +33,11 @@ int TCommandAi::Run(TConfig& config) { // AI-TODO: KIKIMR-24202 - robust file creation NAi::TLineReader lineReader("ydb-ai> ", (TFsPath(HomeDir) / ".ydb-ai/history").GetPath()); + const auto model = NAi::CreateOpenAiModel({ + .BaseUrl = "https://api.eliza.yandex.net/raw/internal/deepseek/v1", // AI-TODO: KIKIMR-24214 -- configure it + .ModelId = "deepseek-0324", // AI-TODO: KIKIMR-24214 -- configure it + .ApiKey = GetEnv("MODEL_TOKEN"), // AI-TODO: KIKIMR-24214 -- configure it + }); while (const auto& maybeLine = lineReader.ReadLine()) { const auto& input = *maybeLine; @@ -43,7 +50,7 @@ int TCommandAi::Run(TConfig& config) { return EXIT_SUCCESS; } - Cout << "Input value: " << input << Endl; + Cout << "Model answer:\n" << model->Chat(input) << Endl; } PrintExitMessage(); diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai.h b/ydb/public/lib/ydb_cli/commands/ydb_ai.h index 4ba27d2cd43e..87d0d825a2d3 100644 --- a/ydb/public/lib/ydb_cli/commands/ydb_ai.h +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai.h @@ -4,7 +4,7 @@ namespace NYdb::NConsoleClient { -class TCommandAi : public TYdbCommand { +class TCommandAi final : public TYdbCommand { using TBase = TYdbCommand; public: diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_interface.h b/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_interface.h new file mode 100644 index 000000000000..03d4e853eb54 --- /dev/null +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_interface.h @@ -0,0 +1,16 @@ +#pragma once + +#include + +#include + +namespace NYdb::NConsoleClient::NAi { + +class IModel { +public: + using TPtr = std::shared_ptr; + + virtual TString Chat(const TString& input) = 0; +}; + +} // namespace NYdb::NConsoleClient::NAi diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_openai.cpp b/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_openai.cpp new file mode 100644 index 000000000000..c8875a1f4e4f --- /dev/null +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_openai.cpp @@ -0,0 +1,175 @@ +#include "model_openai.h" + +#include + +#include +#include +#include +#include +#include + +#include + +namespace NYdb::NConsoleClient::NAi { + +namespace { + +class TModelOpenAi final : public IModel { + class TConversationPart { + public: + enum class ERole { + User, + AI, + }; + + TConversationPart(const TString& content, ERole role) + : Content(content) + , Role(role) + {} + + NJson::TJsonValue ToJson() const { + NJson::TJsonValue result; + result["content"] = Content; + + auto& roleJson = result["role"]; + switch (Role) { + case ERole::User: + roleJson = "user"; + break; + case ERole::AI: + roleJson = "assistant"; + break; + } + + return result; + } + + private: + TString Content; + ERole Role; + }; + +public: + TModelOpenAi(NYql::IHTTPGateway::TPtr httpGateway, const TOpenAiModelSettings& settings) + : HttpGateway(httpGateway) + , Settings(settings) + { + TStringBuf sanitizedUrl; + TStringBuf query; + TStringBuf fragment; + SeparateUrlFromQueryAndFragment(Settings.BaseUrl, sanitizedUrl, query, fragment); + + if (query || fragment) { + throw yexception() << "BaseUrl must not contain query or fragment, got url: '" << Settings.BaseUrl << "' with query: '" << query << "' or fragment: '" << fragment << "'"; + } + + Settings.BaseUrl = RemoveFinalSlash(sanitizedUrl); + } + + TString Chat(const TString& input) final { + NJson::TJsonValue bodyJson; + bodyJson["model"] = Settings.ModelId; + + Conversation.emplace_back(input, TConversationPart::ERole::User); + auto& inputJson = bodyJson["input"]; + inputJson.SetType(NJson::JSON_ARRAY); + + auto& conversationJson = inputJson.GetArraySafe(); + for (const auto& part : Conversation) { + conversationJson.push_back(part.ToJson()); + } + + NJsonWriter::TBuf bodyWriter; + bodyWriter.WriteJsonValue(&bodyJson); + + NYql::THttpHeader headers = {.Fields = {"Content-Type: application/json"}}; + + if (Settings.ApiKey) { + headers.Fields.emplace_back(TStringBuilder() << "Authorization: Bearer " << Settings.ApiKey); + } + + auto answer = NThreading::NewPromise(); + HttpGateway->Upload( + TStringBuilder() << Settings.BaseUrl << "/responses", + std::move(headers), + bodyWriter.Str(), + [&answer](NYql::IHTTPGateway::TResult result) { + if (result.CurlResponseCode != CURLE_OK) { + answer.SetException(TStringBuilder() << "Request model failed: " << result.Issues.ToOneLineString() << ", internal code: " << curl_easy_strerror(result.CurlResponseCode) << ", response: " << result.Content.Extract()); + return; + } + + auto& content = result.Content; + if (content.HttpResponseCode < 200 || content.HttpResponseCode >= 300) { + answer.SetException(TStringBuilder() << "Request model failed, internal code: " << content.HttpResponseCode << ", response: " << result.Content.Extract()); + return; + } + + answer.SetValue(content.Extract()); + } + ); + + const auto result = answer.GetFuture().ExtractValueSync(); + NJson::TJsonValue resultJson; + if (!NJson::ReadJsonTree(result, &resultJson)) { + throw yexception() << "Response of model is not JSON, got response: " << result; + } + + ValidateJsonType(resultJson, NJson::JSON_MAP); + + const auto& output = ValidateJsonKey(resultJson, "output"); + ValidateJsonType(output, NJson::JSON_ARRAY, "output"); + ValidateJsonArraySize(output, 1, "output"); + + const auto& outputVal = output.GetArray()[0]; + ValidateJsonType(outputVal, NJson::JSON_MAP, "output[0]"); + + const auto& content = ValidateJsonKey(outputVal, "content", "output[0]"); + ValidateJsonType(content, NJson::JSON_ARRAY, "output[0].content"); + ValidateJsonArraySize(content, 1, "output[0].content"); + + const auto& contentVal = content.GetArray()[0]; + ValidateJsonType(contentVal, NJson::JSON_MAP, "output[0].content[0]"); + + const auto& text = ValidateJsonKey(contentVal, "text", "output[0].content[0]"); + ValidateJsonType(text, NJson::JSON_STRING, "output[0].content[0].text"); + + return text.GetString(); + } + +private: + void ValidateJsonType(const NJson::TJsonValue& value, NJson::EJsonValueType expectedType, const std::optional& fieldName = std::nullopt) const { + if (const auto valueType = value.GetType(); valueType != expectedType) { + throw yexception() << "Response" << (fieldName ? " field '" + *fieldName + "'" : "") << " of model has unexpected type: " << valueType << ", expected type: " << expectedType; + } + } + + void ValidateJsonArraySize(const NJson::TJsonValue& value, size_t expectedSize, const std::optional& fieldName = std::nullopt) const { + if (const auto valueSize = value.GetArray().size(); valueSize != expectedSize) { + throw yexception() << "Response" << (fieldName ? " field '" + *fieldName + "'" : "") << " of model has unexpected size: " << valueSize << ", expected size: " << expectedSize; + } + } + + const NJson::TJsonValue& ValidateJsonKey(const NJson::TJsonValue& value, const TString& key, const std::optional& fieldName = std::nullopt) const { + const auto* output = value.GetMap().FindPtr(key); + if (!output) { + throw yexception() << "Response of model does not contain '" << key << "' field" << (fieldName ? " in '" + *fieldName + "'" : ""); + } + + return *output; + } + +private: + const NYql::IHTTPGateway::TPtr HttpGateway; + TOpenAiModelSettings Settings; + + std::vector Conversation; +}; + +} // anonymous namespace + +IModel::TPtr CreateOpenAiModel(const TOpenAiModelSettings& settings) { + return std::make_shared(NYql::IHTTPGateway::Make(), settings); +} + +} // namespace NYdb::NConsoleClient::NAi diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_openai.h b/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_openai.h new file mode 100644 index 000000000000..9c09109e9b86 --- /dev/null +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_openai.h @@ -0,0 +1,17 @@ +#pragma once + +#include "model_interface.h" + +#include + +namespace NYdb::NConsoleClient::NAi { + +struct TOpenAiModelSettings { + TString BaseUrl; // AI-TODO KIKIMR-24211 add default value + TString ModelId; + TString ApiKey; +}; + +IModel::TPtr CreateOpenAiModel(const TOpenAiModelSettings& settings); + +} // namespace NYdb::NConsoleClient::NAi diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/models/ya.make b/ydb/public/lib/ydb_cli/commands/ydb_ai/models/ya.make new file mode 100644 index 000000000000..1420896a56df --- /dev/null +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/models/ya.make @@ -0,0 +1,15 @@ +LIBRARY() + +SRCS( + model_openai.cpp +) + +PEERDIR( + library/cpp/json + library/cpp/json/writer + library/cpp/string_utils/url + library/cpp/threading/future + ydb/library/yql/providers/common/http_gateway +) + +END() diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/ya.make b/ydb/public/lib/ydb_cli/commands/ydb_ai/ya.make index 2ce9b7aa0b07..b452545e9473 100644 --- a/ydb/public/lib/ydb_cli/commands/ydb_ai/ya.make +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/ya.make @@ -10,3 +10,7 @@ PEERDIR( ) END() + +RECURSE( + models +) From 0ad8cac6b5c876cfaef271096ba773365c4fc5a6 Mon Sep 17 00:00:00 2001 From: Pisarenko Grigoriy Date: Sun, 23 Nov 2025 13:57:31 +0300 Subject: [PATCH 3/7] Supported exec query tool calling --- ydb/public/lib/ydb_cli/commands/ya.make | 1 + ydb/public/lib/ydb_cli/commands/ydb_ai.cpp | 37 ++++- ydb/public/lib/ydb_cli/commands/ydb_ai.h | 4 +- .../commands/ydb_ai/models/model_interface.h | 21 ++- .../commands/ydb_ai/models/model_openai.cpp | 146 +++++++++++++++--- .../commands/ydb_ai/tools/exec_query_tool.cpp | 99 ++++++++++++ .../commands/ydb_ai/tools/exec_query_tool.h | 11 ++ .../commands/ydb_ai/tools/tool_interface.h | 26 ++++ .../lib/ydb_cli/commands/ydb_ai/tools/ya.make | 14 ++ .../lib/ydb_cli/commands/ydb_ai/ya.make | 1 + 10 files changed, 330 insertions(+), 30 deletions(-) create mode 100644 ydb/public/lib/ydb_cli/commands/ydb_ai/tools/exec_query_tool.cpp create mode 100644 ydb/public/lib/ydb_cli/commands/ydb_ai/tools/exec_query_tool.h create mode 100644 ydb/public/lib/ydb_cli/commands/ydb_ai/tools/tool_interface.h create mode 100644 ydb/public/lib/ydb_cli/commands/ydb_ai/tools/ya.make diff --git a/ydb/public/lib/ydb_cli/commands/ya.make b/ydb/public/lib/ydb_cli/commands/ya.make index 256c0240d2d7..c060b79fe6fa 100644 --- a/ydb/public/lib/ydb_cli/commands/ya.make +++ b/ydb/public/lib/ydb_cli/commands/ya.make @@ -60,6 +60,7 @@ PEERDIR( ydb/public/lib/ydb_cli/commands/transfer_workload ydb/public/lib/ydb_cli/commands/ydb_ai ydb/public/lib/ydb_cli/commands/ydb_ai/models + ydb/public/lib/ydb_cli/commands/ydb_ai/tools ydb/public/lib/ydb_cli/commands/ydb_discovery ydb/public/lib/ydb_cli/common ydb/public/lib/ydb_cli/dump diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai.cpp b/ydb/public/lib/ydb_cli/commands/ydb_ai.cpp index 94230e8f8744..7d07faf56efc 100644 --- a/ydb/public/lib/ydb_cli/commands/ydb_ai.cpp +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai.cpp @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -21,7 +22,7 @@ TCommandAi::TCommandAi() {} void TCommandAi::Config(TConfig& config) { - TClientCommand::Config(config); + TBase::Config(config); config.Opts->SetTitle("AI-TODO: KIKIMR-24198 -- title"); config.Opts->SetFreeArgsNum(0); } @@ -39,8 +40,12 @@ int TCommandAi::Run(TConfig& config) { .ApiKey = GetEnv("MODEL_TOKEN"), // AI-TODO: KIKIMR-24214 -- configure it }); + const auto sqlTool = NAi::CreateExecQueryTool(config); // AI-TODO: more generic tools registration + model->RegisterTool(sqlTool->GetName(), sqlTool->GetParametersSchema(), sqlTool->GetDescription()); + + // AI-TODO: there is strange highlighting of brackets while (const auto& maybeLine = lineReader.ReadLine()) { - const auto& input = *maybeLine; + TString input = *maybeLine; if (input.empty()) { continue; } @@ -50,7 +55,33 @@ int TCommandAi::Run(TConfig& config) { return EXIT_SUCCESS; } - Cout << "Model answer:\n" << model->Chat(input) << Endl; + // AI-TODO: limit interaction number + std::optional toolCallId; + while (input) { + // AI-TODO: progress visualization + auto output = model->HandleMessage(input); + Y_ENSURE(output.Text || output.ToolCall); + + if (output.Text) { + // AI-TODO: proper answer format + Cout << "Model answer:\n" << *output.Text << Endl; + } + + if (!output.ToolCall) { + break; + } + + const auto& toolCall = *output.ToolCall; + if (toolCall.Name != sqlTool->GetName()) { + // AI-TODO: proper wrong tool handling + Cout << "Unsupported tool: " << toolCall.Name << Endl; + return EXIT_FAILURE; + } + + // AI-TODO: ask permission before call and show progress + toolCallId = toolCall.Id; + input = sqlTool->Execute(toolCall.Parameters); + } } PrintExitMessage(); diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai.h b/ydb/public/lib/ydb_cli/commands/ydb_ai.h index 87d0d825a2d3..1385cea31908 100644 --- a/ydb/public/lib/ydb_cli/commands/ydb_ai.h +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai.h @@ -10,9 +10,9 @@ class TCommandAi final : public TYdbCommand { public: TCommandAi(); - void Config(TConfig& config) override; + void Config(TConfig& config) final; - int Run(TConfig& config) override; + int Run(TConfig& config) final; }; } // namespace NYdb::NConsoleClient diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_interface.h b/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_interface.h index 03d4e853eb54..cf7846c6fa82 100644 --- a/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_interface.h +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_interface.h @@ -1,6 +1,8 @@ #pragma once -#include +#include + +#include #include @@ -10,7 +12,22 @@ class IModel { public: using TPtr = std::shared_ptr; - virtual TString Chat(const TString& input) = 0; + virtual ~IModel() = default; + + struct TResponse { + struct TToolCall { + TString Id; + TString Name; + NJson::TJsonValue Parameters; + }; + + std::optional Text; + std::optional ToolCall; + }; + + virtual TResponse HandleMessage(const TString& input, std::optional toolCallId = std::nullopt) = 0; + + virtual void RegisterTool(const TString& name, const NJson::TJsonValue& parametersSchema, const TString& description) = 0; }; } // namespace NYdb::NConsoleClient::NAi diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_openai.cpp b/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_openai.cpp index c8875a1f4e4f..2f12c5a2578b 100644 --- a/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_openai.cpp +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_openai.cpp @@ -15,11 +15,39 @@ namespace NYdb::NConsoleClient::NAi { namespace { class TModelOpenAi final : public IModel { + class TToolInfo { + public: + TToolInfo(const TString& name, const NJson::TJsonValue& parametersSchema, const TString& description) + : Name(name) + , ParametersSchema(parametersSchema) + , Description(description) + {} + + NJson::TJsonValue ToJson() const { + NJson::TJsonValue result; + result["type"] = "function"; + + auto& toolJson = result["function"]; + toolJson["strict"] = false; // AI-TODO: enable after fixes + toolJson["name"] = Name; + toolJson["parameters"] = ParametersSchema; + toolJson["description"] = Description; + + return result; + } + + private: + TString Name; + NJson::TJsonValue ParametersSchema; + TString Description; + }; + class TConversationPart { public: enum class ERole { User, AI, + Tool, }; TConversationPart(const TString& content, ERole role) @@ -27,18 +55,33 @@ class TModelOpenAi final : public IModel { , Role(role) {} + TConversationPart(const TString& content, const TString& toolCallId) + : Content(content) + , Role(ERole::Tool) + , ToolCallId(toolCallId) + {} + NJson::TJsonValue ToJson() const { NJson::TJsonValue result; result["content"] = Content; auto& roleJson = result["role"]; switch (Role) { - case ERole::User: + case ERole::User: { roleJson = "user"; break; - case ERole::AI: + } + case ERole::AI: { roleJson = "assistant"; break; + } + case ERole::Tool: { + roleJson = "tool"; + + Y_ENSURE(ToolCallId); + result["tool_call_id"] = *ToolCallId; + break; + } } return result; @@ -46,7 +89,8 @@ class TModelOpenAi final : public IModel { private: TString Content; - ERole Role; + ERole Role = ERole::User; + std::optional ToolCallId; }; public: @@ -66,35 +110,45 @@ class TModelOpenAi final : public IModel { Settings.BaseUrl = RemoveFinalSlash(sanitizedUrl); } - TString Chat(const TString& input) final { + TResponse HandleMessage(const TString& input, std::optional toolCallId) final { + if (toolCallId) { + Conversation.emplace_back(input, *toolCallId); + } else { + Conversation.emplace_back(input, TConversationPart::ERole::User); + } + NJson::TJsonValue bodyJson; bodyJson["model"] = Settings.ModelId; - Conversation.emplace_back(input, TConversationPart::ERole::User); - auto& inputJson = bodyJson["input"]; - inputJson.SetType(NJson::JSON_ARRAY); - - auto& conversationJson = inputJson.GetArraySafe(); + auto& conversationJson = bodyJson["messages"].SetType(NJson::JSON_ARRAY).GetArraySafe(); for (const auto& part : Conversation) { conversationJson.push_back(part.ToJson()); } + auto& toolsArray = bodyJson["tools"].SetType(NJson::JSON_ARRAY).GetArraySafe(); + for (const auto& tool : Tools) { + toolsArray.push_back(tool.ToJson()); + } + NJsonWriter::TBuf bodyWriter; bodyWriter.WriteJsonValue(&bodyJson); NYql::THttpHeader headers = {.Fields = {"Content-Type: application/json"}}; + // Cerr << "-------------------------- Request: " << bodyWriter.Str(); + if (Settings.ApiKey) { headers.Fields.emplace_back(TStringBuilder() << "Authorization: Bearer " << Settings.ApiKey); } auto answer = NThreading::NewPromise(); HttpGateway->Upload( - TStringBuilder() << Settings.BaseUrl << "/responses", + TStringBuilder() << Settings.BaseUrl << "/chat/completions", std::move(headers), bodyWriter.Str(), [&answer](NYql::IHTTPGateway::TResult result) { if (result.CurlResponseCode != CURLE_OK) { + // AI-TODO: proper error handling answer.SetException(TStringBuilder() << "Request model failed: " << result.Issues.ToOneLineString() << ", internal code: " << curl_easy_strerror(result.CurlResponseCode) << ", response: " << result.Content.Extract()); return; } @@ -110,6 +164,7 @@ class TModelOpenAi final : public IModel { ); const auto result = answer.GetFuture().ExtractValueSync(); + // Cerr << "-------------------------- Result: " << result; NJson::TJsonValue resultJson; if (!NJson::ReadJsonTree(result, &resultJson)) { throw yexception() << "Response of model is not JSON, got response: " << result; @@ -117,24 +172,68 @@ class TModelOpenAi final : public IModel { ValidateJsonType(resultJson, NJson::JSON_MAP); - const auto& output = ValidateJsonKey(resultJson, "output"); - ValidateJsonType(output, NJson::JSON_ARRAY, "output"); - ValidateJsonArraySize(output, 1, "output"); + const auto& choices = ValidateJsonKey(resultJson, "choices"); + ValidateJsonType(choices, NJson::JSON_ARRAY, "choices"); + ValidateJsonArraySize(choices, 1, "choices"); + + // AI-TODO: proper error description + const auto& choiceVal = choices.GetArray()[0]; + ValidateJsonType(choiceVal, NJson::JSON_MAP, "choices[0]"); - const auto& outputVal = output.GetArray()[0]; - ValidateJsonType(outputVal, NJson::JSON_MAP, "output[0]"); + const auto& message = ValidateJsonKey(choiceVal, "message", "choices[0]"); + ValidateJsonType(message, NJson::JSON_MAP, "choices[0].message"); - const auto& content = ValidateJsonKey(outputVal, "content", "output[0]"); - ValidateJsonType(content, NJson::JSON_ARRAY, "output[0].content"); - ValidateJsonArraySize(content, 1, "output[0].content"); + const auto& content = ValidateJsonKey(message, "content", "choices[0].message"); + const auto& tollsCalls = ValidateJsonKey(message, "tool_calls", "choices[0].message"); + + if (content.GetType() == NJson::JSON_NULL && tollsCalls.GetType() == NJson::JSON_NULL) { + throw yexception() << "Response of model does not contain 'choices[0].message.content' or 'choices[0].message.tool_calls' fields, got response: " << result; + } + + TResponse response; + + if (content.GetType() != NJson::JSON_NULL) { + ValidateJsonType(content, NJson::JSON_STRING, "choices[0].message.content"); + response.Text = content.GetString(); + } - const auto& contentVal = content.GetArray()[0]; - ValidateJsonType(contentVal, NJson::JSON_MAP, "output[0].content[0]"); + if (tollsCalls.GetType() != NJson::JSON_NULL) { + ValidateJsonType(tollsCalls, NJson::JSON_ARRAY, "choices[0].message.tool_calls"); + ValidateJsonArraySize(tollsCalls, 1, "choices[0].message.tool_calls"); - const auto& text = ValidateJsonKey(contentVal, "text", "output[0].content[0]"); - ValidateJsonType(text, NJson::JSON_STRING, "output[0].content[0].text"); + const auto& toolCall = tollsCalls.GetArray()[0]; + ValidateJsonType(toolCall, NJson::JSON_MAP, "choices[0].message.tool_calls[0]"); + + const auto& function = ValidateJsonKey(toolCall, "function", "choices[0].message.tool_calls[0]"); + ValidateJsonType(function, NJson::JSON_MAP, "choices[0].message.tool_calls[0].function"); + + const auto& name = ValidateJsonKey(function, "name", "choices[0].message.tool_calls[0].function"); + ValidateJsonType(name, NJson::JSON_STRING, "choices[0].message.tool_calls[0].function.name"); + + const auto& arguments = ValidateJsonKey(function, "arguments", "choices[0].message.tool_calls[0].function"); + ValidateJsonType(arguments, NJson::JSON_STRING, "choices[0].message.tool_calls[0].function.arguments"); + + const auto& callId = ValidateJsonKey(toolCall, "id", "choices[0].message.tool_calls[0]"); + ValidateJsonType(callId, NJson::JSON_STRING, "choices[0].message.tool_calls[0].id"); + + NJson::TJsonValue argumentsJson; + if (!NJson::ReadJsonTree(arguments.GetString(), &argumentsJson)) { + throw yexception() << "Tool call arguments is not valid JSON, got response: " << arguments.GetString(); + } + + response.ToolCall = { + .Id = callId.GetString(), + .Name = name.GetString(), + .Parameters = std::move(argumentsJson) + }; + } + + Y_ENSURE(response.Text || response.ToolCall); + return response; + } - return text.GetString(); + void RegisterTool(const TString& name, const NJson::TJsonValue& parametersSchema, const TString& description) final { + Tools.emplace_back(name, parametersSchema, description); } private: @@ -163,6 +262,7 @@ class TModelOpenAi final : public IModel { const NYql::IHTTPGateway::TPtr HttpGateway; TOpenAiModelSettings Settings; + std::vector Tools; std::vector Conversation; }; diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/exec_query_tool.cpp b/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/exec_query_tool.cpp new file mode 100644 index 000000000000..afaea2c44257 --- /dev/null +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/exec_query_tool.cpp @@ -0,0 +1,99 @@ +#include "exec_query_tool.h" + +#include +#include + +#include + +namespace NYdb::NConsoleClient::NAi { + +namespace { + +class TExecQueryTool final : public ITool { +public: + explicit TExecQueryTool(TClientCommand::TConfig& config) + : Client(TDriver(config.CreateDriverConfig())) + { + NJson::TJsonValue sqlParam; + sqlParam["type"] = "string"; + + ParametersSchema["sql"] = sqlParam; + } + + TString GetName() const final { + return "execute_sql_query"; + } + + NJson::TJsonValue GetParametersSchema() const final { + return ParametersSchema; + } + + TString GetDescription() const final { + return "Execute SQL query"; // AI-TODO: proper description + } + + TString Execute(const NJson::TJsonValue& parameters) final { + ValidateJsonType(parameters, NJson::JSON_MAP); + + const auto& sql = ValidateJsonKey(parameters, "sql"); + ValidateJsonType(sql, NJson::JSON_STRING, "sql"); + + const auto& sqlString = sql.GetString(); + Cerr << "\n!! Execute SQL query: " << sqlString << Endl; // AI-TODO: proper query execution printing + + // AI-TODO: streaming execution + auto result = Client.ExecuteQuery(sqlString, NQuery::TTxControl::NoTx()).ExtractValueSync(); + + // AI-TODO: proper error printing + if (!result.IsSuccess()) { + return TStringBuilder() << "Error executing SQL query, status: " << result.GetStatus() << ", issues: " << result.GetIssues().ToString(); + } + + const auto& resultSets = result.GetResultSets(); + + // AI-TODO: proper result formating + TStringBuilder resultBuilder; + for (ui64 i = 0; i < resultSets.size(); ++i) { + resultBuilder << "Result set " << i << ":\n" << Endl; + + TResultSetParser parser(resultSets[i]); + while (parser.TryNextRow()) { + NJsonWriter::TBuf writer(NJsonWriter::HEM_UNSAFE, &resultBuilder.Out); + FormatResultRowJson(parser, resultSets[i].GetColumnsMeta(), writer, EBinaryStringEncoding::Unicode); + resultBuilder << "\n"; + } + } + + Cerr << "\n!! Execute query result: " << resultBuilder << Endl; // AI-TODO: proper query result printing + + return resultBuilder; + } + +private: + void ValidateJsonType(const NJson::TJsonValue& value, NJson::EJsonValueType expectedType, const std::optional& fieldName = std::nullopt) const { + if (const auto valueType = value.GetType(); valueType != expectedType) { + throw yexception() << "Tool request " << (fieldName ? " field '" + *fieldName + "'" : "") << " has unexpected type: " << valueType << ", expected type: " << expectedType; + } + } + + const NJson::TJsonValue& ValidateJsonKey(const NJson::TJsonValue& value, const TString& key, const std::optional& fieldName = std::nullopt) const { + const auto* output = value.GetMap().FindPtr(key); + if (!output) { + throw yexception() << "Tool request does not contain '" << key << "' field" << (fieldName ? " in '" + *fieldName + "'" : ""); + } + + return *output; + } + +private: + NQuery::TQueryClient Client; + NJson::TJsonValue ParametersSchema; +}; + +} // anonymous namespace + +ITool::TPtr CreateExecQueryTool(TClientCommand::TConfig& config) { + return std::make_shared(config); +} + +} // namespace NYdb::NConsoleClient::NAi diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/exec_query_tool.h b/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/exec_query_tool.h new file mode 100644 index 000000000000..a7e8407f5305 --- /dev/null +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/exec_query_tool.h @@ -0,0 +1,11 @@ +#pragma once + +#include "tool_interface.h" + +#include + +namespace NYdb::NConsoleClient::NAi { + +ITool::TPtr CreateExecQueryTool(TClientCommand::TConfig& config); + +} // namespace NYdb::NConsoleClient::NAi diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/tool_interface.h b/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/tool_interface.h new file mode 100644 index 000000000000..128fbd4b15ee --- /dev/null +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/tool_interface.h @@ -0,0 +1,26 @@ +#pragma once + +#include + +#include + +#include + +namespace NYdb::NConsoleClient::NAi { + +class ITool { +public: + using TPtr = std::shared_ptr; + + virtual ~ITool() = default; + + virtual TString GetName() const = 0; + + virtual NJson::TJsonValue GetParametersSchema() const = 0; + + virtual TString GetDescription() const = 0; + + virtual TString Execute(const NJson::TJsonValue& parameters) = 0; +}; + +} // namespace NYdb::NConsoleClient::NAi diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/ya.make b/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/ya.make new file mode 100644 index 000000000000..9e3583c729ab --- /dev/null +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/ya.make @@ -0,0 +1,14 @@ +LIBRARY() + +SRCS( + exec_query_tool.cpp +) + +PEERDIR( + library/cpp/json/writer + ydb/public/lib/json_value + ydb/public/lib/ydb_cli/common + ydb/public/sdk/cpp/src/client/query +) + +END() diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/ya.make b/ydb/public/lib/ydb_cli/commands/ydb_ai/ya.make index b452545e9473..78586dda5f78 100644 --- a/ydb/public/lib/ydb_cli/commands/ydb_ai/ya.make +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/ya.make @@ -13,4 +13,5 @@ END() RECURSE( models + tools ) From 5593ac69ad3d3b22f7478f475d0d9de54048bf91 Mon Sep 17 00:00:00 2001 From: Pisarenko Grigoriy Date: Sun, 23 Nov 2025 14:19:46 +0300 Subject: [PATCH 4/7] Supported yandex GPT calling --- ydb/public/lib/ydb_cli/commands/ydb_ai.cpp | 12 +++++- .../commands/ydb_ai/models/model_openai.cpp | 38 +++++++++++-------- .../commands/ydb_ai/models/model_openai.h | 2 +- .../commands/ydb_ai/tools/exec_query_tool.cpp | 6 ++- 4 files changed, 39 insertions(+), 19 deletions(-) diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai.cpp b/ydb/public/lib/ydb_cli/commands/ydb_ai.cpp index 7d07faf56efc..7170a851c628 100644 --- a/ydb/public/lib/ydb_cli/commands/ydb_ai.cpp +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai.cpp @@ -34,9 +34,17 @@ int TCommandAi::Run(TConfig& config) { // AI-TODO: KIKIMR-24202 - robust file creation NAi::TLineReader lineReader("ydb-ai> ", (TFsPath(HomeDir) / ".ydb-ai/history").GetPath()); + + // DeepSeek + // const auto model = NAi::CreateOpenAiModel({ + // .BaseUrl = "https://api.eliza.yandex.net/raw/internal/deepseek/v1", // AI-TODO: KIKIMR-24214 -- configure it + // .ModelId = "deepseek-0324", // AI-TODO: KIKIMR-24214 -- configure it + // .ApiKey = GetEnv("MODEL_TOKEN"), // AI-TODO: KIKIMR-24214 -- configure it + // }); + + // YandexGPT Pro const auto model = NAi::CreateOpenAiModel({ - .BaseUrl = "https://api.eliza.yandex.net/raw/internal/deepseek/v1", // AI-TODO: KIKIMR-24214 -- configure it - .ModelId = "deepseek-0324", // AI-TODO: KIKIMR-24214 -- configure it + .BaseUrl = "https://api.eliza.yandex.net/internal/zeliboba/32b_aligned_quantized_202506/generative/v1", // AI-TODO: KIKIMR-24214 -- configure it .ApiKey = GetEnv("MODEL_TOKEN"), // AI-TODO: KIKIMR-24214 -- configure it }); diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_openai.cpp b/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_openai.cpp index 2f12c5a2578b..b0b37fbb1013 100644 --- a/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_openai.cpp +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_openai.cpp @@ -118,7 +118,10 @@ class TModelOpenAi final : public IModel { } NJson::TJsonValue bodyJson; - bodyJson["model"] = Settings.ModelId; + + if (Settings.ModelId) { + bodyJson["model"] = *Settings.ModelId; + } auto& conversationJson = bodyJson["messages"].SetType(NJson::JSON_ARRAY).GetArraySafe(); for (const auto& part : Conversation) { @@ -172,6 +175,11 @@ class TModelOpenAi final : public IModel { ValidateJsonType(resultJson, NJson::JSON_MAP); + const auto& resultMap = resultJson.GetMap(); + if (const auto it = resultMap.find("response"); it != resultMap.end()) { + resultJson = it->second; + } + const auto& choices = ValidateJsonKey(resultJson, "choices"); ValidateJsonType(choices, NJson::JSON_ARRAY, "choices"); ValidateJsonArraySize(choices, 1, "choices"); @@ -183,25 +191,25 @@ class TModelOpenAi final : public IModel { const auto& message = ValidateJsonKey(choiceVal, "message", "choices[0]"); ValidateJsonType(message, NJson::JSON_MAP, "choices[0].message"); - const auto& content = ValidateJsonKey(message, "content", "choices[0].message"); - const auto& tollsCalls = ValidateJsonKey(message, "tool_calls", "choices[0].message"); - - if (content.GetType() == NJson::JSON_NULL && tollsCalls.GetType() == NJson::JSON_NULL) { + const auto& messageMap = message.GetMap(); + const auto* content = messageMap.FindPtr("content"); + const auto* tollsCalls = messageMap.FindPtr("tool_calls"); + if ((!content || content->GetType() == NJson::JSON_NULL) && (!tollsCalls || tollsCalls->GetType() == NJson::JSON_NULL)) { throw yexception() << "Response of model does not contain 'choices[0].message.content' or 'choices[0].message.tool_calls' fields, got response: " << result; } TResponse response; - if (content.GetType() != NJson::JSON_NULL) { - ValidateJsonType(content, NJson::JSON_STRING, "choices[0].message.content"); - response.Text = content.GetString(); + if (content && content->GetType() != NJson::JSON_NULL) { + ValidateJsonType(*content, NJson::JSON_STRING, "choices[0].message.content"); + response.Text = content->GetString(); } - if (tollsCalls.GetType() != NJson::JSON_NULL) { - ValidateJsonType(tollsCalls, NJson::JSON_ARRAY, "choices[0].message.tool_calls"); - ValidateJsonArraySize(tollsCalls, 1, "choices[0].message.tool_calls"); + if (tollsCalls && tollsCalls->GetType() != NJson::JSON_NULL) { + ValidateJsonType(*tollsCalls, NJson::JSON_ARRAY, "choices[0].message.tool_calls"); + ValidateJsonArraySize(*tollsCalls, 1, "choices[0].message.tool_calls"); - const auto& toolCall = tollsCalls.GetArray()[0]; + const auto& toolCall = tollsCalls->GetArray()[0]; ValidateJsonType(toolCall, NJson::JSON_MAP, "choices[0].message.tool_calls[0]"); const auto& function = ValidateJsonKey(toolCall, "function", "choices[0].message.tool_calls[0]"); @@ -239,20 +247,20 @@ class TModelOpenAi final : public IModel { private: void ValidateJsonType(const NJson::TJsonValue& value, NJson::EJsonValueType expectedType, const std::optional& fieldName = std::nullopt) const { if (const auto valueType = value.GetType(); valueType != expectedType) { - throw yexception() << "Response" << (fieldName ? " field '" + *fieldName + "'" : "") << " of model has unexpected type: " << valueType << ", expected type: " << expectedType; + throw yexception() << "Response" << (fieldName ? " field '" + *fieldName + "'" : "") << " of model has unexpected type: " << valueType << ", expected type: " << expectedType << ", got response: " << value; } } void ValidateJsonArraySize(const NJson::TJsonValue& value, size_t expectedSize, const std::optional& fieldName = std::nullopt) const { if (const auto valueSize = value.GetArray().size(); valueSize != expectedSize) { - throw yexception() << "Response" << (fieldName ? " field '" + *fieldName + "'" : "") << " of model has unexpected size: " << valueSize << ", expected size: " << expectedSize; + throw yexception() << "Response" << (fieldName ? " field '" + *fieldName + "'" : "") << " of model has unexpected size: " << valueSize << ", expected size: " << expectedSize << ", got response: " << value; } } const NJson::TJsonValue& ValidateJsonKey(const NJson::TJsonValue& value, const TString& key, const std::optional& fieldName = std::nullopt) const { const auto* output = value.GetMap().FindPtr(key); if (!output) { - throw yexception() << "Response of model does not contain '" << key << "' field" << (fieldName ? " in '" + *fieldName + "'" : ""); + throw yexception() << "Response of model does not contain '" << key << "' field" << (fieldName ? " in '" + *fieldName + "'" : "") << ", got response: " << value; } return *output; diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_openai.h b/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_openai.h index 9c09109e9b86..07570f69092b 100644 --- a/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_openai.h +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_openai.h @@ -8,7 +8,7 @@ namespace NYdb::NConsoleClient::NAi { struct TOpenAiModelSettings { TString BaseUrl; // AI-TODO KIKIMR-24211 add default value - TString ModelId; + std::optional ModelId; TString ApiKey; }; diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/exec_query_tool.cpp b/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/exec_query_tool.cpp index afaea2c44257..25b3f05750c0 100644 --- a/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/exec_query_tool.cpp +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/exec_query_tool.cpp @@ -16,8 +16,11 @@ class TExecQueryTool final : public ITool { { NJson::TJsonValue sqlParam; sqlParam["type"] = "string"; + sqlParam["description"] = "SQL query"; // AI-TODO: proper description - ParametersSchema["sql"] = sqlParam; + ParametersSchema["properties"]["sql"] = sqlParam; + ParametersSchema["type"] = "object"; + ParametersSchema["required"][0] = "sql"; } TString GetName() const final { @@ -46,6 +49,7 @@ class TExecQueryTool final : public ITool { // AI-TODO: proper error printing if (!result.IsSuccess()) { + Cerr << "\n!! Execute query error [" << result.GetStatus() << "]: " << result.GetIssues().ToString() << Endl; return TStringBuilder() << "Error executing SQL query, status: " << result.GetStatus() << ", issues: " << result.GetIssues().ToString(); } From c621a8269f363067778cd706a1807ada5dd1d57b Mon Sep 17 00:00:00 2001 From: Pisarenko Grigoriy Date: Sun, 23 Nov 2025 15:42:57 +0300 Subject: [PATCH 5/7] Supported directory listing --- ydb/public/lib/ydb_cli/commands/ydb_ai.cpp | 53 +++++---- .../commands/ydb_ai/models/model_interface.h | 9 +- .../commands/ydb_ai/models/model_openai.cpp | 54 +++++----- .../commands/ydb_ai/tools/exec_query_tool.cpp | 1 + .../ydb_ai/tools/list_directory_tool.cpp | 101 ++++++++++++++++++ .../ydb_ai/tools/list_directory_tool.h | 11 ++ .../lib/ydb_cli/commands/ydb_ai/tools/ya.make | 3 + 7 files changed, 184 insertions(+), 48 deletions(-) create mode 100644 ydb/public/lib/ydb_cli/commands/ydb_ai/tools/list_directory_tool.cpp create mode 100644 ydb/public/lib/ydb_cli/commands/ydb_ai/tools/list_directory_tool.h diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai.cpp b/ydb/public/lib/ydb_cli/commands/ydb_ai.cpp index 7170a851c628..b09f041ac42d 100644 --- a/ydb/public/lib/ydb_cli/commands/ydb_ai.cpp +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include @@ -48,12 +49,22 @@ int TCommandAi::Run(TConfig& config) { .ApiKey = GetEnv("MODEL_TOKEN"), // AI-TODO: KIKIMR-24214 -- configure it }); - const auto sqlTool = NAi::CreateExecQueryTool(config); // AI-TODO: more generic tools registration - model->RegisterTool(sqlTool->GetName(), sqlTool->GetParametersSchema(), sqlTool->GetDescription()); + std::unordered_map tools; + + const auto sqlTool = NAi::CreateExecQueryTool(config); + tools.emplace(sqlTool->GetName(), sqlTool); + + const auto lsTool = NAi::CreateListDirectoryTool(config); + tools.emplace(lsTool->GetName(), lsTool); + + for (const auto& [name, tool] : tools) { + model->RegisterTool(name, tool->GetParametersSchema(), tool->GetDescription()); + } // AI-TODO: there is strange highlighting of brackets + std::vector requests; while (const auto& maybeLine = lineReader.ReadLine()) { - TString input = *maybeLine; + const auto& input = *maybeLine; if (input.empty()) { continue; } @@ -64,31 +75,33 @@ int TCommandAi::Run(TConfig& config) { } // AI-TODO: limit interaction number - std::optional toolCallId; - while (input) { + requests.push_back({.Text = input}); + while (!requests.empty()) { // AI-TODO: progress visualization - auto output = model->HandleMessage(input); - Y_ENSURE(output.Text || output.ToolCall); + auto output = model->HandleMessages(requests); + requests.clear(); + Y_ENSURE(output.Text || !output.ToolCalls.empty()); if (output.Text) { // AI-TODO: proper answer format + // AI-TODO: how can I render markdown? Cout << "Model answer:\n" << *output.Text << Endl; } - if (!output.ToolCall) { - break; + for (const auto& toolCall : output.ToolCalls) { + const auto it = tools.find(toolCall.Name); + if (it == tools.end()) { + // AI-TODO: proper wrong tool handling + Cout << "Unsupported tool: " << toolCall.Name << Endl; + return EXIT_FAILURE; + } + + // AI-TODO: ask permission before call and show progress + requests.push_back({ + .Text = it->second->Execute(toolCall.Parameters), + .ToolCallId = toolCall.Id + }); } - - const auto& toolCall = *output.ToolCall; - if (toolCall.Name != sqlTool->GetName()) { - // AI-TODO: proper wrong tool handling - Cout << "Unsupported tool: " << toolCall.Name << Endl; - return EXIT_FAILURE; - } - - // AI-TODO: ask permission before call and show progress - toolCallId = toolCall.Id; - input = sqlTool->Execute(toolCall.Parameters); } } diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_interface.h b/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_interface.h index cf7846c6fa82..bc830ab9181b 100644 --- a/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_interface.h +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_interface.h @@ -14,6 +14,11 @@ class IModel { virtual ~IModel() = default; + struct TRequest { + TString Text; + std::optional ToolCallId; + }; + struct TResponse { struct TToolCall { TString Id; @@ -22,10 +27,10 @@ class IModel { }; std::optional Text; - std::optional ToolCall; + std::vector ToolCalls; }; - virtual TResponse HandleMessage(const TString& input, std::optional toolCallId = std::nullopt) = 0; + virtual TResponse HandleMessages(const std::vector& requests) = 0; virtual void RegisterTool(const TString& name, const NJson::TJsonValue& parametersSchema, const TString& description) = 0; }; diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_openai.cpp b/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_openai.cpp index b0b37fbb1013..53fd4292b736 100644 --- a/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_openai.cpp +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_openai.cpp @@ -110,11 +110,13 @@ class TModelOpenAi final : public IModel { Settings.BaseUrl = RemoveFinalSlash(sanitizedUrl); } - TResponse HandleMessage(const TString& input, std::optional toolCallId) final { - if (toolCallId) { - Conversation.emplace_back(input, *toolCallId); - } else { - Conversation.emplace_back(input, TConversationPart::ERole::User); + TResponse HandleMessages(const std::vector& requests) final { + for (const auto& request : requests) { + if (request.ToolCallId) { + Conversation.emplace_back(request.Text, *request.ToolCallId); + } else { + Conversation.emplace_back(request.Text, TConversationPart::ERole::User); + } } NJson::TJsonValue bodyJson; @@ -207,36 +209,36 @@ class TModelOpenAi final : public IModel { if (tollsCalls && tollsCalls->GetType() != NJson::JSON_NULL) { ValidateJsonType(*tollsCalls, NJson::JSON_ARRAY, "choices[0].message.tool_calls"); - ValidateJsonArraySize(*tollsCalls, 1, "choices[0].message.tool_calls"); - const auto& toolCall = tollsCalls->GetArray()[0]; - ValidateJsonType(toolCall, NJson::JSON_MAP, "choices[0].message.tool_calls[0]"); + for (const auto& toolCall : tollsCalls->GetArray()) { + ValidateJsonType(toolCall, NJson::JSON_MAP, "choices[0].message.tool_calls[0]"); - const auto& function = ValidateJsonKey(toolCall, "function", "choices[0].message.tool_calls[0]"); - ValidateJsonType(function, NJson::JSON_MAP, "choices[0].message.tool_calls[0].function"); + const auto& function = ValidateJsonKey(toolCall, "function", "choices[0].message.tool_calls[0]"); + ValidateJsonType(function, NJson::JSON_MAP, "choices[0].message.tool_calls[0].function"); - const auto& name = ValidateJsonKey(function, "name", "choices[0].message.tool_calls[0].function"); - ValidateJsonType(name, NJson::JSON_STRING, "choices[0].message.tool_calls[0].function.name"); + const auto& name = ValidateJsonKey(function, "name", "choices[0].message.tool_calls[0].function"); + ValidateJsonType(name, NJson::JSON_STRING, "choices[0].message.tool_calls[0].function.name"); - const auto& arguments = ValidateJsonKey(function, "arguments", "choices[0].message.tool_calls[0].function"); - ValidateJsonType(arguments, NJson::JSON_STRING, "choices[0].message.tool_calls[0].function.arguments"); + const auto& arguments = ValidateJsonKey(function, "arguments", "choices[0].message.tool_calls[0].function"); + ValidateJsonType(arguments, NJson::JSON_STRING, "choices[0].message.tool_calls[0].function.arguments"); - const auto& callId = ValidateJsonKey(toolCall, "id", "choices[0].message.tool_calls[0]"); - ValidateJsonType(callId, NJson::JSON_STRING, "choices[0].message.tool_calls[0].id"); + const auto& callId = ValidateJsonKey(toolCall, "id", "choices[0].message.tool_calls[0]"); + ValidateJsonType(callId, NJson::JSON_STRING, "choices[0].message.tool_calls[0].id"); - NJson::TJsonValue argumentsJson; - if (!NJson::ReadJsonTree(arguments.GetString(), &argumentsJson)) { - throw yexception() << "Tool call arguments is not valid JSON, got response: " << arguments.GetString(); - } + NJson::TJsonValue argumentsJson; + if (!NJson::ReadJsonTree(arguments.GetString(), &argumentsJson)) { + throw yexception() << "Tool call arguments is not valid JSON, got response: " << arguments.GetString(); + } - response.ToolCall = { - .Id = callId.GetString(), - .Name = name.GetString(), - .Parameters = std::move(argumentsJson) - }; + response.ToolCalls.push_back({ + .Id = callId.GetString(), + .Name = name.GetString(), + .Parameters = std::move(argumentsJson) + }); + } } - Y_ENSURE(response.Text || response.ToolCall); + Y_ENSURE(response.Text || !response.ToolCalls.empty()); return response; } diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/exec_query_tool.cpp b/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/exec_query_tool.cpp index 25b3f05750c0..d34daaeca641 100644 --- a/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/exec_query_tool.cpp +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/exec_query_tool.cpp @@ -74,6 +74,7 @@ class TExecQueryTool final : public ITool { } private: + // AI-TODO: reduce copypaste void ValidateJsonType(const NJson::TJsonValue& value, NJson::EJsonValueType expectedType, const std::optional& fieldName = std::nullopt) const { if (const auto valueType = value.GetType(); valueType != expectedType) { throw yexception() << "Tool request " << (fieldName ? " field '" + *fieldName + "'" : "") << " has unexpected type: " << valueType << ", expected type: " << expectedType; diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/list_directory_tool.cpp b/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/list_directory_tool.cpp new file mode 100644 index 000000000000..714a3eae0e5e --- /dev/null +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/list_directory_tool.cpp @@ -0,0 +1,101 @@ +#include "list_directory_tool.h" + +#include +#include + +namespace NYdb::NConsoleClient::NAi { + +namespace { + +class TListDirectoryTool final : public ITool { +public: + explicit TListDirectoryTool(TClientCommand::TConfig& config) + : Database(NKikimr::CanonizePath(config.Database)) + , Client(TDriver(config.CreateDriverConfig())) + { + NJson::TJsonValue dirParam; + dirParam["type"] = "string"; + dirParam["description"] = "Directory path to list (use empty to list root directory)"; // AI-TODO: proper description + + ParametersSchema["properties"]["directory"] = dirParam; + ParametersSchema["type"] = "object"; + ParametersSchema["required"][0] = "directory"; + } + + TString GetName() const final { + return "list_directory"; + } + + NJson::TJsonValue GetParametersSchema() const final { + return ParametersSchema; + } + + TString GetDescription() const final { + return "List directory"; // AI-TODO: proper description + } + + TString Execute(const NJson::TJsonValue& parameters) final { + ValidateJsonType(parameters, NJson::JSON_MAP); + + const auto& dir = ValidateJsonKey(parameters, "directory"); + ValidateJsonType(dir, NJson::JSON_STRING, "directory"); + + TString dirString = dir.GetString(); + Cerr << "\n!! List directory: " << dirString << Endl; // AI-TODO: proper list directory printing + + if (!dirString.StartsWith('/')) { + dirString = NKikimr::JoinPath({Database, dirString}); + } + + // AI-TODO: progress printing + auto result = Client.ListDirectory(dirString).ExtractValueSync(); + + // AI-TODO: proper error printing + if (!result.IsSuccess()) { + Cerr << "\n!! List directory error [" << result.GetStatus() << "]: " << result.GetIssues().ToString() << Endl; + return TStringBuilder() << "Error listing directory, status: " << result.GetStatus() << ", issues: " << result.GetIssues().ToString(); + } + + const auto& children = result.GetChildren(); + + // AI-TODO: proper result formating + TStringBuilder resultBuilder; + for (const auto& child : children) { + resultBuilder << child.Name << " (" << child.Type << ")" << "\n"; + } + + Cerr << "\n!! List directory result: " << resultBuilder << Endl; // AI-TODO: proper query result printing + + return resultBuilder; + } + +private: + // AI-TODO: reduce copypaste + void ValidateJsonType(const NJson::TJsonValue& value, NJson::EJsonValueType expectedType, const std::optional& fieldName = std::nullopt) const { + if (const auto valueType = value.GetType(); valueType != expectedType) { + throw yexception() << "Tool request " << (fieldName ? " field '" + *fieldName + "'" : "") << " has unexpected type: " << valueType << ", expected type: " << expectedType; + } + } + + const NJson::TJsonValue& ValidateJsonKey(const NJson::TJsonValue& value, const TString& key, const std::optional& fieldName = std::nullopt) const { + const auto* output = value.GetMap().FindPtr(key); + if (!output) { + throw yexception() << "Tool request does not contain '" << key << "' field" << (fieldName ? " in '" + *fieldName + "'" : ""); + } + + return *output; + } + +private: + const TString Database; + NScheme::TSchemeClient Client; + NJson::TJsonValue ParametersSchema; +}; + +} // anonymous namespace + +ITool::TPtr CreateListDirectoryTool(TClientCommand::TConfig& config) { + return std::make_shared(config); +} + +} // namespace NYdb::NConsoleClient::NAi diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/list_directory_tool.h b/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/list_directory_tool.h new file mode 100644 index 000000000000..175c5cdf0545 --- /dev/null +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/list_directory_tool.h @@ -0,0 +1,11 @@ +#pragma once + +#include "tool_interface.h" + +#include + +namespace NYdb::NConsoleClient::NAi { + +ITool::TPtr CreateListDirectoryTool(TClientCommand::TConfig& config); + +} // namespace NYdb::NConsoleClient::NAi diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/ya.make b/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/ya.make index 9e3583c729ab..22f7db175210 100644 --- a/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/ya.make +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/ya.make @@ -2,13 +2,16 @@ LIBRARY() SRCS( exec_query_tool.cpp + list_directory_tool.cpp ) PEERDIR( library/cpp/json/writer + ydb/core/base ydb/public/lib/json_value ydb/public/lib/ydb_cli/common ydb/public/sdk/cpp/src/client/query + ydb/public/sdk/cpp/src/client/scheme ) END() From 20584e7565771d61ee2862e24429e6db35a59745 Mon Sep 17 00:00:00 2001 From: Pisarenko Grigoriy Date: Sun, 23 Nov 2025 16:59:56 +0300 Subject: [PATCH 6/7] Supported Anthropic models connect --- ydb/public/lib/ydb_cli/commands/ydb_ai.cpp | 18 +- .../ydb_ai/models/model_anthropic.cpp | 272 ++++++++++++++++++ .../commands/ydb_ai/models/model_anthropic.h | 16 ++ .../commands/ydb_ai/models/model_openai.cpp | 3 +- .../ydb_cli/commands/ydb_ai/models/ya.make | 1 + 5 files changed, 304 insertions(+), 6 deletions(-) create mode 100644 ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_anthropic.cpp create mode 100644 ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_anthropic.h diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai.cpp b/ydb/public/lib/ydb_cli/commands/ydb_ai.cpp index b09f041ac42d..639820857e1a 100644 --- a/ydb/public/lib/ydb_cli/commands/ydb_ai.cpp +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai.cpp @@ -1,6 +1,7 @@ #include "ydb_ai.h" #include +#include #include #include #include @@ -35,20 +36,27 @@ int TCommandAi::Run(TConfig& config) { // AI-TODO: KIKIMR-24202 - robust file creation NAi::TLineReader lineReader("ydb-ai> ", (TFsPath(HomeDir) / ".ydb-ai/history").GetPath()); - + // DeepSeek // const auto model = NAi::CreateOpenAiModel({ - // .BaseUrl = "https://api.eliza.yandex.net/raw/internal/deepseek/v1", // AI-TODO: KIKIMR-24214 -- configure it + // .BaseUrl = "https://api.eliza.yandex.net/raw/internal/deepseek", // AI-TODO: KIKIMR-24214 -- configure it // .ModelId = "deepseek-0324", // AI-TODO: KIKIMR-24214 -- configure it // .ApiKey = GetEnv("MODEL_TOKEN"), // AI-TODO: KIKIMR-24214 -- configure it // }); - // YandexGPT Pro - const auto model = NAi::CreateOpenAiModel({ - .BaseUrl = "https://api.eliza.yandex.net/internal/zeliboba/32b_aligned_quantized_202506/generative/v1", // AI-TODO: KIKIMR-24214 -- configure it + const auto model = NAi::CreateAnthropicModel({ + .BaseUrl = "https://api.eliza.yandex.net/anthropic", // AI-TODO: KIKIMR-24214 -- configure it + .ModelId = "claude-3-5-haiku-20241022", .ApiKey = GetEnv("MODEL_TOKEN"), // AI-TODO: KIKIMR-24214 -- configure it + .MaxTokens = 2048, // AI-TODO configure it }); + // YandexGPT Pro + // const auto model = NAi::CreateOpenAiModel({ + // .BaseUrl = "https://api.eliza.yandex.net/internal/zeliboba/32b_aligned_quantized_202506/generative", // AI-TODO: KIKIMR-24214 -- configure it + // .ApiKey = GetEnv("MODEL_TOKEN"), // AI-TODO: KIKIMR-24214 -- configure it + // }); + std::unordered_map tools; const auto sqlTool = NAi::CreateExecQueryTool(config); diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_anthropic.cpp b/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_anthropic.cpp new file mode 100644 index 000000000000..b2d23d50f3a9 --- /dev/null +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_anthropic.cpp @@ -0,0 +1,272 @@ +#include "model_anthropic.h" + +#include + +#include +#include +#include +#include +#include + +#include + +namespace NYdb::NConsoleClient::NAi { + +namespace { + +class TModelAnthropic final : public IModel { + class TToolInfo { + public: + TToolInfo(const TString& name, const NJson::TJsonValue& parametersSchema, const TString& description) + : Name(name) + , ParametersSchema(parametersSchema) + , Description(description) + {} + + NJson::TJsonValue ToJson() const { + NJson::TJsonValue result; + result["name"] = Name; + result["description"] = Description; + result["input_schema"] = ParametersSchema; + + return result; + } + + private: + TString Name; + NJson::TJsonValue ParametersSchema; + TString Description; + }; + + class TConversationPart { + public: + enum class ERole { + User, + AI, + Tool, + }; + + TConversationPart(const TString& content, ERole role) + : Content(content) + , Role(role) + {} + + TConversationPart(const TString& content, const TString& toolCallId) + : Content(content) + , Role(ERole::Tool) + , ToolCallId(toolCallId) + {} + + NJson::TJsonValue ToJson() const { + NJson::TJsonValue result; + switch (Role) { + case ERole::User: { + result["content"] = Content; + result["role"] = "user"; + break; + } + case ERole::AI: { + Y_ENSURE(AssistantPart); + result["content"] = *AssistantPart; + result["role"] = "assistant"; + break; + } + case ERole::Tool: { + result["role"] = "user"; + + auto& content = result["content"][0]; + content["type"] = "tool_result"; + content["content"] = Content; + + Y_ENSURE(ToolCallId); + content["tool_use_id"] = *ToolCallId; + break; + } + } + + return result; + } + + public: + std::optional AssistantPart; + TString Content; + + private: + ERole Role = ERole::User; + std::optional ToolCallId; + }; + +public: + TModelAnthropic(NYql::IHTTPGateway::TPtr httpGateway, const TAnthropicModelSettings& settings) + : HttpGateway(httpGateway) + , Settings(settings) + { + TStringBuf sanitizedUrl; + TStringBuf query; + TStringBuf fragment; + SeparateUrlFromQueryAndFragment(Settings.BaseUrl, sanitizedUrl, query, fragment); + + if (query || fragment) { + throw yexception() << "BaseUrl must not contain query or fragment, got url: '" << Settings.BaseUrl << "' with query: '" << query << "' or fragment: '" << fragment << "'"; + } + + Settings.BaseUrl = RemoveFinalSlash(sanitizedUrl); + } + + TResponse HandleMessages(const std::vector& requests) final { + for (const auto& request : requests) { + if (request.ToolCallId) { + Conversation.emplace_back(request.Text, *request.ToolCallId); + } else { + Conversation.emplace_back(request.Text, TConversationPart::ERole::User); + } + } + + NJson::TJsonValue bodyJson; + bodyJson["model"] = Settings.ModelId; + bodyJson["max_tokens"] = Settings.MaxTokens; + + auto& messagesArray = bodyJson["messages"].SetType(NJson::JSON_ARRAY).GetArraySafe(); + for (const auto& part : Conversation) { + messagesArray.push_back(part.ToJson()); + } + + if (!Tools.empty()) { + auto& toolsArray = bodyJson["tools"].SetType(NJson::JSON_ARRAY).GetArraySafe(); + for (const auto& tool : Tools) { + toolsArray.push_back(tool.ToJson()); + } + } + + NJsonWriter::TBuf bodyWriter; + bodyWriter.WriteJsonValue(&bodyJson); + + NYql::THttpHeader headers = {.Fields = { + "Content-Type: application/json", + "anthropic-version: 2023-06-01" + }}; + + if (Settings.ApiKey) { + headers.Fields.emplace_back(TStringBuilder() << "x-api-key: " << Settings.ApiKey); + } + + auto answer = NThreading::NewPromise(); + HttpGateway->Upload( + TStringBuilder() << Settings.BaseUrl << "/v1/messages", + std::move(headers), + bodyWriter.Str(), + [&answer](NYql::IHTTPGateway::TResult result) { + if (result.CurlResponseCode != CURLE_OK) { + answer.SetException(TStringBuilder() << "Request model failed: " << result.Issues.ToOneLineString() << ", internal code: " << curl_easy_strerror(result.CurlResponseCode) << ", response: " << result.Content.Extract()); + return; + } + + auto& content = result.Content; + if (content.HttpResponseCode < 200 || content.HttpResponseCode >= 300) { + answer.SetException(TStringBuilder() << "Request model failed, internal code: " << content.HttpResponseCode << ", response: " << result.Content.Extract()); + return; + } + + answer.SetValue(content.Extract()); + } + ); + + const auto result = answer.GetFuture().ExtractValueSync(); + NJson::TJsonValue resultJson; + if (!NJson::ReadJsonTree(result, &resultJson)) { + throw yexception() << "Response of model is not JSON, got response: " << result; + } + + ValidateJsonType(resultJson, NJson::JSON_MAP); + + const auto& resultMap = resultJson.GetMap(); + if (const auto it = resultMap.find("response"); it != resultMap.end()) { + resultJson = it->second; + } + + TResponse response; + + // Extract content from Anthropic response + const auto& content = ValidateJsonKey(resultJson, "content"); + ValidateJsonType(content, NJson::JSON_ARRAY, "content"); + + auto& conversationPart = Conversation.emplace_back(TString(), TConversationPart::ERole::AI); + conversationPart.AssistantPart = content; + for (const auto& contentItem : content.GetArray()) { + ValidateJsonType(contentItem, NJson::JSON_MAP, "content[i]"); + + const auto& type = ValidateJsonKey(contentItem, "type", "content[i]"); + ValidateJsonType(type, NJson::JSON_STRING, "content[i].type"); + + if (type.GetString() == "text") { + if (response.Text) { + throw yexception() << "Response of model contains multiple text responses, got response: " << result; + } + + const auto& text = ValidateJsonKey(contentItem, "text", "content[i]"); + ValidateJsonType(text, NJson::JSON_STRING, "content[i].text"); + response.Text = text.GetString(); + } else if (type.GetString() == "tool_use") { + const auto& id = ValidateJsonKey(contentItem, "id", "content[i]"); + ValidateJsonType(id, NJson::JSON_STRING, "content[i].id"); + + const auto& name = ValidateJsonKey(contentItem, "name", "content[i]"); + ValidateJsonType(name, NJson::JSON_STRING, "content[i].name"); + + const auto& input = ValidateJsonKey(contentItem, "input", "content[i]"); + + response.ToolCalls.push_back({ + .Id = id.GetString(), + .Name = name.GetString(), + .Parameters = input + }); + } else { + throw yexception() << "Response of model contains unknown type: " << type.GetString() << ", got response: " << result; + } + } + + Y_ENSURE(response.Text || !response.ToolCalls.empty()); + return response; + } + + void RegisterTool(const TString& name, const NJson::TJsonValue& parametersSchema, const TString& description) final { + Tools.emplace_back(name, parametersSchema, description); + } + +private: + void ValidateJsonType(const NJson::TJsonValue& value, NJson::EJsonValueType expectedType, const std::optional& fieldName = std::nullopt) const { + if (const auto valueType = value.GetType(); valueType != expectedType) { + throw yexception() << "Response" << (fieldName ? " field '" + *fieldName + "'" : "") << " of model has unexpected type: " << valueType << ", expected type: " << expectedType << ", got response: " << value; + } + } + + void ValidateJsonArraySize(const NJson::TJsonValue& value, size_t expectedSize, const std::optional& fieldName = std::nullopt) const { + if (const auto valueSize = value.GetArray().size(); valueSize != expectedSize) { + throw yexception() << "Response" << (fieldName ? " field '" + *fieldName + "'" : "") << " of model has unexpected size: " << valueSize << ", expected size: " << expectedSize << ", got response: " << value; + } + } + + const NJson::TJsonValue& ValidateJsonKey(const NJson::TJsonValue& value, const TString& key, const std::optional& fieldName = std::nullopt) const { + const auto* output = value.GetMap().FindPtr(key); + if (!output) { + throw yexception() << "Response of model does not contain '" << key << "' field" << (fieldName ? " in '" + *fieldName + "'" : "") << ", got response: " << value; + } + + return *output; + } + +private: + const NYql::IHTTPGateway::TPtr HttpGateway; + TAnthropicModelSettings Settings; + + std::vector Tools; + std::vector Conversation; +}; + +} // anonymous namespace + +IModel::TPtr CreateAnthropicModel(const TAnthropicModelSettings& settings) { + return std::make_shared(NYql::IHTTPGateway::Make(), settings); +} + +} // namespace NYdb::NConsoleClient::NAi \ No newline at end of file diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_anthropic.h b/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_anthropic.h new file mode 100644 index 000000000000..152549d051d1 --- /dev/null +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_anthropic.h @@ -0,0 +1,16 @@ +#pragma once + +#include "model_interface.h" + +namespace NYdb::NConsoleClient::NAi { + +struct TAnthropicModelSettings { + TString BaseUrl; // AI-TODO KIKIMR-24211 add default value + TString ModelId; + TString ApiKey; + ui64 MaxTokens; +}; + +IModel::TPtr CreateAnthropicModel(const TAnthropicModelSettings& settings); + +} // namespace NYdb::NConsoleClient::NAi diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_openai.cpp b/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_openai.cpp index 53fd4292b736..41532aa37aab 100644 --- a/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_openai.cpp +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_openai.cpp @@ -148,7 +148,7 @@ class TModelOpenAi final : public IModel { auto answer = NThreading::NewPromise(); HttpGateway->Upload( - TStringBuilder() << Settings.BaseUrl << "/chat/completions", + TStringBuilder() << Settings.BaseUrl << "/v1/chat/completions", std::move(headers), bodyWriter.Str(), [&answer](NYql::IHTTPGateway::TResult result) { @@ -205,6 +205,7 @@ class TModelOpenAi final : public IModel { if (content && content->GetType() != NJson::JSON_NULL) { ValidateJsonType(*content, NJson::JSON_STRING, "choices[0].message.content"); response.Text = content->GetString(); + Conversation.emplace_back(*response.Text, TConversationPart::ERole::AI); } if (tollsCalls && tollsCalls->GetType() != NJson::JSON_NULL) { diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/models/ya.make b/ydb/public/lib/ydb_cli/commands/ydb_ai/models/ya.make index 1420896a56df..5bef635b4afe 100644 --- a/ydb/public/lib/ydb_cli/commands/ydb_ai/models/ya.make +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/models/ya.make @@ -1,6 +1,7 @@ LIBRARY() SRCS( + model_anthropic.cpp model_openai.cpp ) From a2c7388844430bdd9d47e4e82cc9152f6d99c43c Mon Sep 17 00:00:00 2001 From: Pisarenko Grigoriy Date: Mon, 24 Nov 2025 23:41:06 +0300 Subject: [PATCH 7/7] Refactored tools calling and models API --- ydb/core/base/validation.h | 7 + ydb/public/lib/ydb_cli/commands/ya.make | 1 + ydb/public/lib/ydb_cli/commands/ydb_ai.cpp | 92 +++-- .../commands/ydb_ai/common/json_utils.cpp | 223 ++++++++++++ .../commands/ydb_ai/common/json_utils.h | 87 +++++ .../ydb_cli/commands/ydb_ai/common/ya.make | 13 + .../ydb_ai/models/model_anthropic.cpp | 303 +++++----------- .../commands/ydb_ai/models/model_anthropic.h | 9 +- .../commands/ydb_ai/models/model_base.cpp | 145 ++++++++ .../commands/ydb_ai/models/model_base.h | 41 +++ .../commands/ydb_ai/models/model_interface.h | 15 +- .../commands/ydb_ai/models/model_openai.cpp | 323 ++++++------------ .../commands/ydb_ai/models/model_openai.h | 8 +- .../ydb_cli/commands/ydb_ai/models/ya.make | 3 + .../commands/ydb_ai/tools/exec_query_tool.cpp | 142 ++++---- .../ydb_ai/tools/list_directory_tool.cpp | 108 +++--- .../commands/ydb_ai/tools/tool_interface.cpp | 17 + .../commands/ydb_ai/tools/tool_interface.h | 17 +- .../lib/ydb_cli/commands/ydb_ai/tools/ya.make | 2 + 19 files changed, 942 insertions(+), 614 deletions(-) create mode 100644 ydb/core/base/validation.h create mode 100644 ydb/public/lib/ydb_cli/commands/ydb_ai/common/json_utils.cpp create mode 100644 ydb/public/lib/ydb_cli/commands/ydb_ai/common/json_utils.h create mode 100644 ydb/public/lib/ydb_cli/commands/ydb_ai/common/ya.make create mode 100644 ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_base.cpp create mode 100644 ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_base.h create mode 100644 ydb/public/lib/ydb_cli/commands/ydb_ai/tools/tool_interface.cpp diff --git a/ydb/core/base/validation.h b/ydb/core/base/validation.h new file mode 100644 index 000000000000..5c708ea0ae65 --- /dev/null +++ b/ydb/core/base/validation.h @@ -0,0 +1,7 @@ +#pragma once + +#ifndef NDEBUG + #define Y_DEBUG_VERIFY Y_ABORT_UNLESS +#else + #define Y_DEBUG_VERIFY Y_VERIFY +#endif diff --git a/ydb/public/lib/ydb_cli/commands/ya.make b/ydb/public/lib/ydb_cli/commands/ya.make index c060b79fe6fa..0b41caec8060 100644 --- a/ydb/public/lib/ydb_cli/commands/ya.make +++ b/ydb/public/lib/ydb_cli/commands/ya.make @@ -59,6 +59,7 @@ PEERDIR( ydb/public/lib/ydb_cli/commands/topic_workload ydb/public/lib/ydb_cli/commands/transfer_workload ydb/public/lib/ydb_cli/commands/ydb_ai + ydb/public/lib/ydb_cli/commands/ydb_ai/common ydb/public/lib/ydb_cli/commands/ydb_ai/models ydb/public/lib/ydb_cli/commands/ydb_ai/tools ydb/public/lib/ydb_cli/commands/ydb_discovery diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai.cpp b/ydb/public/lib/ydb_cli/commands/ydb_ai.cpp index 639820857e1a..0955b80af9a3 100644 --- a/ydb/public/lib/ydb_cli/commands/ydb_ai.cpp +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai.cpp @@ -1,5 +1,6 @@ #include "ydb_ai.h" +#include #include #include #include @@ -9,6 +10,22 @@ #include #include +/* + +FEATURES-TODO: + +- Streamable model response printing +- Streamable results printing +- Adjusting errors, progress and response printing +- Approving before tool use +- Integration into common interactive mode +- Think about helps +- Think about robust +- Provide system promt +- Somehow render markdown + +*/ + namespace NYdb::NConsoleClient { namespace { @@ -30,8 +47,6 @@ void TCommandAi::Config(TConfig& config) { } int TCommandAi::Run(TConfig& config) { - Y_UNUSED(config); - Cout << "AI-TODO: KIKIMR-24198 -- welcome message" << Endl; // AI-TODO: KIKIMR-24202 - robust file creation @@ -42,35 +57,31 @@ int TCommandAi::Run(TConfig& config) { // .BaseUrl = "https://api.eliza.yandex.net/raw/internal/deepseek", // AI-TODO: KIKIMR-24214 -- configure it // .ModelId = "deepseek-0324", // AI-TODO: KIKIMR-24214 -- configure it // .ApiKey = GetEnv("MODEL_TOKEN"), // AI-TODO: KIKIMR-24214 -- configure it - // }); + // }, config); - const auto model = NAi::CreateAnthropicModel({ - .BaseUrl = "https://api.eliza.yandex.net/anthropic", // AI-TODO: KIKIMR-24214 -- configure it - .ModelId = "claude-3-5-haiku-20241022", - .ApiKey = GetEnv("MODEL_TOKEN"), // AI-TODO: KIKIMR-24214 -- configure it - .MaxTokens = 2048, // AI-TODO configure it - }); - - // YandexGPT Pro - // const auto model = NAi::CreateOpenAiModel({ - // .BaseUrl = "https://api.eliza.yandex.net/internal/zeliboba/32b_aligned_quantized_202506/generative", // AI-TODO: KIKIMR-24214 -- configure it + // Claude 3.5 haiku + // const auto model = NAi::CreateAnthropicModel({ + // .BaseUrl = "https://api.eliza.yandex.net/anthropic", // AI-TODO: KIKIMR-24214 -- configure it + // .ModelId = "claude-3-5-haiku-20241022", // .ApiKey = GetEnv("MODEL_TOKEN"), // AI-TODO: KIKIMR-24214 -- configure it - // }); - - std::unordered_map tools; + // }, config); - const auto sqlTool = NAi::CreateExecQueryTool(config); - tools.emplace(sqlTool->GetName(), sqlTool); - - const auto lsTool = NAi::CreateListDirectoryTool(config); - tools.emplace(lsTool->GetName(), lsTool); + // YandexGPT Pro + const auto model = NAi::CreateOpenAiModel({ + .BaseUrl = "https://api.eliza.yandex.net/internal/zeliboba/32b_aligned_quantized_202506/generative", // AI-TODO: KIKIMR-24214 -- configure it + .ApiKey = GetEnv("MODEL_TOKEN"), // AI-TODO: KIKIMR-24214 -- configure it + }, config); + std::unordered_map tools = { + {"execute_query", NAi::CreateExecQueryTool(config)}, + {"list_directory", NAi::CreateListDirectoryTool(config)}, + }; for (const auto& [name, tool] : tools) { model->RegisterTool(name, tool->GetParametersSchema(), tool->GetDescription()); } // AI-TODO: there is strange highlighting of brackets - std::vector requests; + std::vector messages; while (const auto& maybeLine = lineReader.ReadLine()) { const auto& input = *maybeLine; if (input.empty()) { @@ -83,17 +94,21 @@ int TCommandAi::Run(TConfig& config) { } // AI-TODO: limit interaction number - requests.push_back({.Text = input}); - while (!requests.empty()) { + messages.emplace_back(NAi::IModel::TUserMessage{.Text = input}); + while (!messages.empty()) { // AI-TODO: progress visualization - auto output = model->HandleMessages(requests); - requests.clear(); - Y_ENSURE(output.Text || !output.ToolCalls.empty()); + const auto output = model->HandleMessages(messages); + messages.clear(); + + if (!output.Text && output.ToolCalls.empty()) { + // AI-TODO: proper answer format + Cout << "Model answer is empty(" << Endl; + break; + } if (output.Text) { // AI-TODO: proper answer format - // AI-TODO: how can I render markdown? - Cout << "Model answer:\n" << *output.Text << Endl; + Cout << "Model answer:\n" << output.Text << Endl; } for (const auto& toolCall : output.ToolCalls) { @@ -104,10 +119,21 @@ int TCommandAi::Run(TConfig& config) { return EXIT_FAILURE; } - // AI-TODO: ask permission before call and show progress - requests.push_back({ - .Text = it->second->Execute(toolCall.Parameters), - .ToolCallId = toolCall.Id + // AI-TODO: proper tool call printing + Cout << "Calling tool: " << toolCall.Name << " with params:\n" << NAi::FormatJsonValue(toolCall.Parameters) << Endl; + + // AI-TODO: add approving + const auto& result = it->second->Execute(toolCall.Parameters); + if (!result.IsSuccess) { + // AI-TODO: proper error handling + Cout << result.Text << Endl; + } + // AI-TODO: show progress + + messages.push_back(NAi::IModel::TToolResponse{ + .Text = result.Text, + .ToolCallId = toolCall.Id, + .IsSuccess = result.IsSuccess, }); } } diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/common/json_utils.cpp b/ydb/public/lib/ydb_cli/commands/ydb_ai/common/json_utils.cpp new file mode 100644 index 000000000000..90e1618415e1 --- /dev/null +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/common/json_utils.cpp @@ -0,0 +1,223 @@ +#include "json_utils.h" + +#include + +#include +#include + +#include + +namespace NYdb::NConsoleClient::NAi { + +//// TJsonParser + +TJsonParser::TJsonParser() + : TJsonParser(NJson::TJsonValue()) +{} + +TJsonParser::TJsonParser(const NJson::TJsonValue& value) + : JsonHolder(std::make_shared(value)) + , State(JsonHolder.get()) +{} + +TJsonParser::TJsonParser(std::shared_ptr jsonHolder, const NJson::TJsonValue* state, const TString& fieldName) + : JsonHolder(jsonHolder) + , State(state) + , FieldName(fieldName) +{ + Y_DEBUG_VERIFY(JsonHolder, "Internal error. JsonHolder should not be null"); + Y_DEBUG_VERIFY(State, "Internal error. State should not be null"); +} + +bool TJsonParser::Parse(const TString& value) { + Y_DEBUG_VERIFY(JsonHolder.get() == State, "Internal error. Parse should be called on root JSON"); + return NJson::ReadJsonTree(value, JsonHolder.get()); +} + +TString TJsonParser::ToString() const { + if (State->IsString()) { + return State->GetString(); + } + + NJsonWriter::TBuf valueWriter; + valueWriter.SetIndentSpaces(2); + valueWriter.WriteJsonValue(State); + return valueWriter.Str(); +} + +const NJson::TJsonValue& TJsonParser::GetValue() const { + return *State; +} + +TString TJsonParser::GetFieldName() const { + return FieldName.value_or("$"); +} + +TJsonParser TJsonParser::GetKey(const TString& key) const { + ValidateType(NJson::JSON_MAP); + + const auto* child = State->GetMapSafe().FindPtr(key); + if (!child) { + Fail(TStringBuilder() << "does not contain " << key << " field"); + } + + return TJsonParser(JsonHolder, child, AdvancePath(key)); +} + +std::optional TJsonParser::MaybeKey(const TString& key) const { + if (!State->IsMap()) { + return std::nullopt; + } + + if (const auto* child = State->GetMapSafe().FindPtr(key)) { + return TJsonParser(JsonHolder, child, AdvancePath(key)); + } + + return std::nullopt; +} + +TJsonParser TJsonParser::GetElement(ui64 index) const { + ValidateType(NJson::JSON_ARRAY); + + const auto& array = State->GetArraySafe(); + if (const auto size = array.size(); size <= index) { + Fail(TStringBuilder() << "does not contain element with index " << index << ", actual array size: " << size); + } + + return TJsonParser(JsonHolder, &array[index], AdvancePath(index)); +} + +void TJsonParser::Iterate(std::function handler) const { + ValidateType(NJson::JSON_ARRAY); + + const auto& array = State->GetArraySafe(); + for (ui64 i = 0; i < array.size(); ++i) { + handler(TJsonParser(JsonHolder, &array[i], AdvancePath(i))); + } +} + +TString TJsonParser::GetString() const { + ValidateType(NJson::JSON_STRING); + return State->GetString(); +} + +bool TJsonParser::IsNull() const { + return State->IsNull(); +} + +void TJsonParser::ValidateType(NJson::EJsonValueType expectedType) const { + if (const auto type = State->GetType(); type != expectedType) { + Fail(TStringBuilder() << "has unexpected type " << type << ", expected type: " << expectedType); + } +} + +TString TJsonParser::AdvancePath(const TString& key) const { + return TStringBuilder() << GetFieldName() << "." << key; +} + +TString TJsonParser::AdvancePath(const ui64& index) const { + return TStringBuilder() << GetFieldName() << "[" << index << "]"; +} + +void TJsonParser::Fail(const TString& message) const { + auto error = yexception() << "JSON"; + if (FieldName) { + error << " field " << *FieldName; + } + throw error << " " << message; +} + +//// TJsonSchemaBuilder + +TJsonSchemaBuilder::TJsonSchemaBuilder(TJsonSchemaBuilder* parent) + : Parent(parent) +{ + Y_DEBUG_VERIFY(Parent, "Internal error. Parent should not be null"); +} + +TJsonSchemaBuilder& TJsonSchemaBuilder::Type(EType type) { + Y_DEBUG_VERIFY(TypeValue == EType::Undefined, "Internal error. Type should be defined only once"); + Y_DEBUG_VERIFY(type != EType::Undefined, "Internal error. Type should not be undefined"); + TypeValue = type; + return *this; +} + +TJsonSchemaBuilder& TJsonSchemaBuilder::Description(const TString& description) { + Y_DEBUG_VERIFY(!DescriptionValue, "Internal error. Description should be defined only once"); + DescriptionValue = description; + return *this; +} + +TJsonSchemaBuilder& TJsonSchemaBuilder::Done() const { + Y_DEBUG_VERIFY(Parent, "Internal error. Done should be called only on child objects"); + return *Parent; +} + +NJson::TJsonValue TJsonSchemaBuilder::Build() const { + NJson::TJsonValue result; + + if (DescriptionValue) { + result["description"] = *DescriptionValue; + } + + auto& type = result["type"]; + switch (TypeValue) { + case EType::Undefined: { + Y_DEBUG_VERIFY(false, "Internal error. Type should not be defined before building"); + break; + } + case EType::String: { + type = "string"; + break; + } + case EType::Object: { + type = "object"; + + if (!Properties.empty()) { + auto& properties = result["properties"].SetType(NJson::JSON_MAP).GetMapSafe(); + properties.reserve(Properties.size()); + for (const auto& [name, builder] : Properties) { + properties.emplace(name, builder->Build()); + } + } + + if (!RequiredProperties.empty()) { + auto& required = result["required"].SetType(NJson::JSON_ARRAY).GetArraySafe(); + for (const auto& name : RequiredProperties) { + required.emplace_back(name); + } + } + + break; + } + } + + return result; +} + +TJsonSchemaBuilder& TJsonSchemaBuilder::Property(const TString& name, bool required) { + const auto [it, inserted] = Properties.emplace(name, std::make_shared(this)); + Y_DEBUG_VERIFY(inserted, "Internal error. Property should not be defined twice"); + + if (required) { + RequiredProperties.emplace_back(name); + } + + return *it->second; +} + +//// Utils + +TString FormatJsonValue(const NJson::TJsonValue& value) { + return TJsonParser(value).ToString(); +} + +TString FormatJsonValue(const TString& value) { + TJsonParser parser; + if (!parser.Parse(value)) { + return value; + } + return parser.ToString(); +} + +} // namespace NYdb::NConsoleClient::NAi diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/common/json_utils.h b/ydb/public/lib/ydb_cli/commands/ydb_ai/common/json_utils.h new file mode 100644 index 000000000000..0dc5fcbc08dd --- /dev/null +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/common/json_utils.h @@ -0,0 +1,87 @@ +#pragma once + +#include + +#include + +namespace NYdb::NConsoleClient::NAi { + +class TJsonParser { +public: + TJsonParser(); + + explicit TJsonParser(const NJson::TJsonValue& value); + + bool Parse(const TString& value); + + TString ToString() const; + + const NJson::TJsonValue& GetValue() const; + + TString GetFieldName() const; + + TJsonParser GetKey(const TString& key) const; + + std::optional MaybeKey(const TString& key) const; + + TJsonParser GetElement(ui64 index) const; + + void Iterate(std::function handler) const; + + TString GetString() const; + + bool IsNull() const; + + void ValidateType(NJson::EJsonValueType expectedType) const; + +private: + TJsonParser(std::shared_ptr jsonHolder, const NJson::TJsonValue* state, const TString& fieldName); + + TString AdvancePath(const TString& key) const; + + TString AdvancePath(const ui64& index) const; + + void Fail(const TString& message) const; + +private: + std::shared_ptr JsonHolder; + const NJson::TJsonValue* State = nullptr; + std::optional FieldName; +}; + +class TJsonSchemaBuilder { +public: + enum class EType { + Undefined, + String, + Object, + }; + + TJsonSchemaBuilder() = default; + + explicit TJsonSchemaBuilder(TJsonSchemaBuilder* parent); + + TJsonSchemaBuilder& Type(EType type); + + TJsonSchemaBuilder& Description(const TString& description); + + TJsonSchemaBuilder& Property(const TString& name, bool required = true); + + TJsonSchemaBuilder& Done() const; + + NJson::TJsonValue Build() const; + +private: + TJsonSchemaBuilder* Parent = nullptr; + + EType TypeValue = EType::Undefined; + std::unordered_map> Properties; + std::vector RequiredProperties; + std::optional DescriptionValue; +}; + +TString FormatJsonValue(const NJson::TJsonValue& value); + +TString FormatJsonValue(const TString& value); + +} // namespace NYdb::NConsoleClient::NAi diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/common/ya.make b/ydb/public/lib/ydb_cli/commands/ydb_ai/common/ya.make new file mode 100644 index 000000000000..ae7d960872a8 --- /dev/null +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/common/ya.make @@ -0,0 +1,13 @@ +LIBRARY() + +SRCS( + json_utils.cpp +) + +PEERDIR( + library/cpp/json + library/cpp/json/writer + ydb/core/base +) + +END() diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_anthropic.cpp b/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_anthropic.cpp index b2d23d50f3a9..589e1ecedc75 100644 --- a/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_anthropic.cpp +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_anthropic.cpp @@ -1,272 +1,133 @@ #include "model_anthropic.h" +#include "model_base.h" -#include +#include +#include #include -#include -#include -#include -#include #include +#include namespace NYdb::NConsoleClient::NAi { namespace { -class TModelAnthropic final : public IModel { - class TToolInfo { - public: - TToolInfo(const TString& name, const NJson::TJsonValue& parametersSchema, const TString& description) - : Name(name) - , ParametersSchema(parametersSchema) - , Description(description) - {} +class TModelAnthropic final : public TModelBase { + using TBlase = TModelBase; - NJson::TJsonValue ToJson() const { - NJson::TJsonValue result; - result["name"] = Name; - result["description"] = Description; - result["input_schema"] = ParametersSchema; - - return result; - } - - private: - TString Name; - NJson::TJsonValue ParametersSchema; - TString Description; - }; - - class TConversationPart { - public: - enum class ERole { - User, - AI, - Tool, - }; - - TConversationPart(const TString& content, ERole role) - : Content(content) - , Role(role) - {} - - TConversationPart(const TString& content, const TString& toolCallId) - : Content(content) - , Role(ERole::Tool) - , ToolCallId(toolCallId) - {} - - NJson::TJsonValue ToJson() const { - NJson::TJsonValue result; - switch (Role) { - case ERole::User: { - result["content"] = Content; - result["role"] = "user"; - break; - } - case ERole::AI: { - Y_ENSURE(AssistantPart); - result["content"] = *AssistantPart; - result["role"] = "assistant"; - break; - } - case ERole::Tool: { - result["role"] = "user"; - - auto& content = result["content"][0]; - content["type"] = "tool_result"; - content["content"] = Content; - - Y_ENSURE(ToolCallId); - content["tool_use_id"] = *ToolCallId; - break; - } - } - - return result; - } - - public: - std::optional AssistantPart; - TString Content; - - private: - ERole Role = ERole::User; - std::optional ToolCallId; - }; + static constexpr ui64 MAX_COMPLETION_TOKENS = 1024; public: - TModelAnthropic(NYql::IHTTPGateway::TPtr httpGateway, const TAnthropicModelSettings& settings) - : HttpGateway(httpGateway) - , Settings(settings) + TModelAnthropic(const TAnthropicModelSettings& settings, const TClientCommand::TConfig& config) + : TBlase(CreateApiUrl(settings.BaseUrl, "/v1/messages"), settings.ApiKey, config) + , Tools(ChatCompletionRequest["tools"].SetType(NJson::JSON_ARRAY).GetArraySafe()) + , Conversation(ChatCompletionRequest["messages"].SetType(NJson::JSON_ARRAY).GetArraySafe()) { - TStringBuf sanitizedUrl; - TStringBuf query; - TStringBuf fragment; - SeparateUrlFromQueryAndFragment(Settings.BaseUrl, sanitizedUrl, query, fragment); - - if (query || fragment) { - throw yexception() << "BaseUrl must not contain query or fragment, got url: '" << Settings.BaseUrl << "' with query: '" << query << "' or fragment: '" << fragment << "'"; - } - - Settings.BaseUrl = RemoveFinalSlash(sanitizedUrl); + ChatCompletionRequest["model"] = settings.ModelId; + ChatCompletionRequest["max_tokens"] = MAX_COMPLETION_TOKENS; } - TResponse HandleMessages(const std::vector& requests) final { - for (const auto& request : requests) { - if (request.ToolCallId) { - Conversation.emplace_back(request.Text, *request.ToolCallId); - } else { - Conversation.emplace_back(request.Text, TConversationPart::ERole::User); - } - } - - NJson::TJsonValue bodyJson; - bodyJson["model"] = Settings.ModelId; - bodyJson["max_tokens"] = Settings.MaxTokens; - - auto& messagesArray = bodyJson["messages"].SetType(NJson::JSON_ARRAY).GetArraySafe(); - for (const auto& part : Conversation) { - messagesArray.push_back(part.ToJson()); - } - - if (!Tools.empty()) { - auto& toolsArray = bodyJson["tools"].SetType(NJson::JSON_ARRAY).GetArraySafe(); - for (const auto& tool : Tools) { - toolsArray.push_back(tool.ToJson()); - } - } - - NJsonWriter::TBuf bodyWriter; - bodyWriter.WriteJsonValue(&bodyJson); - - NYql::THttpHeader headers = {.Fields = { - "Content-Type: application/json", - "anthropic-version: 2023-06-01" - }}; + void RegisterTool(const TString& name, const NJson::TJsonValue& parametersSchema, const TString& description) final { + Y_DEBUG_VERIFY(ValidateToolName(name), "Internal error. Invalid tool name: %s", name.c_str()); - if (Settings.ApiKey) { - headers.Fields.emplace_back(TStringBuilder() << "x-api-key: " << Settings.ApiKey); - } + auto& tool = Tools.emplace_back(); + tool["name"] = name; + tool["input_schema"] = parametersSchema; + tool["description"] = description; + } - auto answer = NThreading::NewPromise(); - HttpGateway->Upload( - TStringBuilder() << Settings.BaseUrl << "/v1/messages", - std::move(headers), - bodyWriter.Str(), - [&answer](NYql::IHTTPGateway::TResult result) { - if (result.CurlResponseCode != CURLE_OK) { - answer.SetException(TStringBuilder() << "Request model failed: " << result.Issues.ToOneLineString() << ", internal code: " << curl_easy_strerror(result.CurlResponseCode) << ", response: " << result.Content.Extract()); - return; - } +protected: + void AdvanceConversation(const std::vector& messages) final { + auto& conversationItem = Conversation.emplace_back(); + conversationItem["role"] = "user"; - auto& content = result.Content; - if (content.HttpResponseCode < 200 || content.HttpResponseCode >= 300) { - answer.SetException(TStringBuilder() << "Request model failed, internal code: " << content.HttpResponseCode << ", response: " << result.Content.Extract()); - return; - } + auto& content = conversationItem["content"].SetType(NJson::JSON_ARRAY).GetArraySafe(); + for (const auto& message : messages) { + auto& item = content.emplace_back(); + auto& type = item["type"]; - answer.SetValue(content.Extract()); + if (std::holds_alternative(message)) { + item["text"] = std::get(message).Text; + type = "text"; + } else { + const auto& toolResponse = std::get(message); + item["content"] = toolResponse.Text; + item["tool_use_id"] = toolResponse.ToolCallId; + item["is_error"] = !toolResponse.IsSuccess; + type = "tool_result"; } - ); - - const auto result = answer.GetFuture().ExtractValueSync(); - NJson::TJsonValue resultJson; - if (!NJson::ReadJsonTree(result, &resultJson)) { - throw yexception() << "Response of model is not JSON, got response: " << result; } + } - ValidateJsonType(resultJson, NJson::JSON_MAP); + TResponse HandleModelResponse(const NJson::TJsonValue& response) final { + TResponse result; - const auto& resultMap = resultJson.GetMap(); - if (const auto it = resultMap.find("response"); it != resultMap.end()) { - resultJson = it->second; + TJsonParser parser(response); + if (auto child = parser.MaybeKey("response")) { + parser = std::move(*child); } - TResponse response; - - // Extract content from Anthropic response - const auto& content = ValidateJsonKey(resultJson, "content"); - ValidateJsonType(content, NJson::JSON_ARRAY, "content"); + parser = parser.GetKey("content"); + auto& conversationItem = Conversation.emplace_back(); + conversationItem["role"] = "assistant"; + conversationItem["content"] = parser.GetValue(); - auto& conversationPart = Conversation.emplace_back(TString(), TConversationPart::ERole::AI); - conversationPart.AssistantPart = content; - for (const auto& contentItem : content.GetArray()) { - ValidateJsonType(contentItem, NJson::JSON_MAP, "content[i]"); - - const auto& type = ValidateJsonKey(contentItem, "type", "content[i]"); - ValidateJsonType(type, NJson::JSON_STRING, "content[i].type"); - - if (type.GetString() == "text") { - if (response.Text) { - throw yexception() << "Response of model contains multiple text responses, got response: " << result; + parser.Iterate([&](TJsonParser item) { + const auto& type = item.GetKey("type").GetString(); + if (type == "text") { + if (result.Text) { + throw yexception() << "Multiple conversation items contains text"; } - - const auto& text = ValidateJsonKey(contentItem, "text", "content[i]"); - ValidateJsonType(text, NJson::JSON_STRING, "content[i].text"); - response.Text = text.GetString(); - } else if (type.GetString() == "tool_use") { - const auto& id = ValidateJsonKey(contentItem, "id", "content[i]"); - ValidateJsonType(id, NJson::JSON_STRING, "content[i].id"); - - const auto& name = ValidateJsonKey(contentItem, "name", "content[i]"); - ValidateJsonType(name, NJson::JSON_STRING, "content[i].name"); - - const auto& input = ValidateJsonKey(contentItem, "input", "content[i]"); - - response.ToolCalls.push_back({ - .Id = id.GetString(), - .Name = name.GetString(), - .Parameters = input + result.Text = Strip(item.GetKey("text").GetString()); + } else if (type == "tool_use") { + result.ToolCalls.push_back({ + .Id = item.GetKey("id").GetString(), + .Name = item.GetKey("name").GetString(), + .Parameters = item.GetKey("input").GetValue(), }); } else { - throw yexception() << "Response of model contains unknown type: " << type.GetString() << ", got response: " << result; + throw yexception() << "Unknown conversation item type: " << type << ", expected text or tool_use"; } - } + }); - Y_ENSURE(response.Text || !response.ToolCalls.empty()); - return response; - } - - void RegisterTool(const TString& name, const NJson::TJsonValue& parametersSchema, const TString& description) final { - Tools.emplace_back(name, parametersSchema, description); + return result; } -private: - void ValidateJsonType(const NJson::TJsonValue& value, NJson::EJsonValueType expectedType, const std::optional& fieldName = std::nullopt) const { - if (const auto valueType = value.GetType(); valueType != expectedType) { - throw yexception() << "Response" << (fieldName ? " field '" + *fieldName + "'" : "") << " of model has unexpected type: " << valueType << ", expected type: " << expectedType << ", got response: " << value; + TString HandleErrorResponse(ui64 httpCode, const TString& response) final { + TJsonParser parser; + if (!parser.Parse(response)) { + return TBlase::HandleErrorResponse(httpCode, response); } - } - void ValidateJsonArraySize(const NJson::TJsonValue& value, size_t expectedSize, const std::optional& fieldName = std::nullopt) const { - if (const auto valueSize = value.GetArray().size(); valueSize != expectedSize) { - throw yexception() << "Response" << (fieldName ? " field '" + *fieldName + "'" : "") << " of model has unexpected size: " << valueSize << ", expected size: " << expectedSize << ", got response: " << value; + auto error = TStringBuilder() << "Request to model API failed:\n"; + if (const auto& info = parser.MaybeKey("error")) { + return error << info->ToString(); } - } - - const NJson::TJsonValue& ValidateJsonKey(const NJson::TJsonValue& value, const TString& key, const std::optional& fieldName = std::nullopt) const { - const auto* output = value.GetMap().FindPtr(key); - if (!output) { - throw yexception() << "Response of model does not contain '" << key << "' field" << (fieldName ? " in '" + *fieldName + "'" : "") << ", got response: " << value; + if (const auto& response = parser.MaybeKey("response")) { + if (const auto& info = parser.MaybeKey("error")) { + return error << info->ToString(); + } + return error << response->ToString(); } - return *output; + return TBlase::HandleErrorResponse(httpCode, response); } private: - const NYql::IHTTPGateway::TPtr HttpGateway; - TAnthropicModelSettings Settings; + static bool ValidateToolName(const TString& name) { + return 1 <= name.size() && name.size() <= 128; + } - std::vector Tools; - std::vector Conversation; +private: + NJson::TJsonValue::TArray& Tools; + NJson::TJsonValue::TArray& Conversation; }; } // anonymous namespace -IModel::TPtr CreateAnthropicModel(const TAnthropicModelSettings& settings) { - return std::make_shared(NYql::IHTTPGateway::Make(), settings); +IModel::TPtr CreateAnthropicModel(const TAnthropicModelSettings& settings, const TClientCommand::TConfig& config) { + return std::make_shared(settings, config); } } // namespace NYdb::NConsoleClient::NAi \ No newline at end of file diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_anthropic.h b/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_anthropic.h index 152549d051d1..c95737eadb0c 100644 --- a/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_anthropic.h +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_anthropic.h @@ -2,15 +2,16 @@ #include "model_interface.h" +#include + namespace NYdb::NConsoleClient::NAi { struct TAnthropicModelSettings { - TString BaseUrl; // AI-TODO KIKIMR-24211 add default value + TString BaseUrl; TString ModelId; - TString ApiKey; - ui64 MaxTokens; + std::optional ApiKey; }; -IModel::TPtr CreateAnthropicModel(const TAnthropicModelSettings& settings); +IModel::TPtr CreateAnthropicModel(const TAnthropicModelSettings& settings, const TClientCommand::TConfig& config); } // namespace NYdb::NConsoleClient::NAi diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_base.cpp b/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_base.cpp new file mode 100644 index 000000000000..7e6f063d9c44 --- /dev/null +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_base.cpp @@ -0,0 +1,145 @@ +#include "model_base.h" + +#include +#include + +#include +#include +#include +#include + +#include + +namespace NYdb::NConsoleClient::NAi { + +namespace { + +NYql::THttpHeader CreateApiHeaders(const std::optional& authToken) { + TSmallVec headers = {"Content-Type: application/json"}; + + if (authToken) { + headers.emplace_back(TStringBuilder() << "Authorization: Bearer " << *authToken); + } + + return {.Fields = std::move(headers)}; +} + +struct THttpResponse { + THttpResponse(TString&& content, ui64 httpCode) + : Content(std::move(content)) + , HttpCode(httpCode) + {} + + bool IsSuccess() const { + return HttpCode >= 200 && HttpCode < 300; + } + + TString Content; + ui64 HttpCode = 0; +}; + +} // anonymous namespace + +TModelBase::TModelBase(const TString& apiUrl, const std::optional& authToken, const TClientCommand::TConfig& config) + : Verbosity(config.VerbosityLevel) + , ApiUrl(apiUrl) + , ApiHeaders(CreateApiHeaders(authToken)) + , HttpGateway(NYql::IHTTPGateway::Make()) +{ + Y_DEBUG_VERIFY(apiUrl, "Internal error. Url should not be empty for model API"); + + if (Verbosity >= VERB_INFO) { + Cerr << "Using model API url: " << apiUrl << " with " + << (authToken ? TStringBuilder() << "auth token " << TString(authToken->size(), '*') : TStringBuilder() << "anonymous access") << Endl; + } +} + +TModelBase::TResponse TModelBase::HandleMessages(const std::vector& messages) { + Y_DEBUG_VERIFY(!messages.empty(), "Internal error. Messages should not be empty for advance conversation"); + + AdvanceConversation(messages); + + if (Verbosity >= VERB_TRACE) { + Cerr << "Request to model API:\n" << FormatJsonValue(ChatCompletionRequest) << Endl; + } + + NJsonWriter::TBuf requestJsonWriter; + requestJsonWriter.WriteJsonValue(&ChatCompletionRequest); + auto request = requestJsonWriter.Str(); + + auto responsePromise = NThreading::NewPromise(); + auto httpCallback = [&responsePromise](NYql::IHTTPGateway::TResult result) -> void { + const auto curlCode = result.CurlResponseCode; + if (curlCode == CURLE_OK) { + auto& content = result.Content; + responsePromise.SetValue(THttpResponse(content.Extract(), content.HttpResponseCode)); + return; + } + + auto error = TStringBuilder() << "Failed to connect to API server or process response, internal code: " << static_cast(curlCode); + if (result.Issues) { + error << ". Reason:\n" << result.Issues.ToString(); + } + responsePromise.SetException(error); + }; + + HttpGateway->Upload(ApiUrl, ApiHeaders, std::move(request), std::move(httpCallback)); + const auto response = responsePromise.GetFuture().ExtractValueSync(); + + if (Verbosity >= VERB_TRACE) { + Cerr << "Model API response http code: " << response.HttpCode; + if (response.Content) { + Cerr << ". Response data:\n" << FormatJsonValue(response.Content); + } + Cerr << Endl; + } + + if (!response.IsSuccess()) { + throw yexception() << HandleErrorResponse(response.HttpCode, response.Content); + } + + NJson::TJsonValue responseJson; + try { + NJson::ReadJsonTree(response.Content, &responseJson, /* throwOnError */ true); + } catch (const std::exception& e) { + throw yexception() << "Model API response is not valid JSON, reason: " << e.what(); + } + + try { + return HandleModelResponse(responseJson); + } catch (const std::exception& e) { + throw yexception() << "Processing model response error. " << e.what(); + } +} + +TString TModelBase::HandleErrorResponse(ui64 httpCode, const TString& response) { + auto error = TStringBuilder() << "Request to model API failed with code: " << httpCode; + if (response) { + error << ". Response:\n" << FormatJsonValue(response); + } + return error; +} + +TString TModelBase::CreateApiUrl(const TString& baseUrl, const TString& uri) { + Y_DEBUG_VERIFY(uri, "Internal error. Uri should not be empty for model API"); + + TStringBuf sanitizedUrl; + TStringBuf query; + TStringBuf fragment; + SeparateUrlFromQueryAndFragment(baseUrl, sanitizedUrl, query, fragment); + + if (query || fragment) { + auto error = yexception() << "Invalid model API base url: '" << baseUrl << "'"; + if (query) { + error << ". Query part should be empty, but got: '" << query << "'"; + } + if (fragment) { + error << ". Fragment part should be empty, but got: '" << fragment << "'"; + } + throw error; + } + + return TStringBuilder() << RemoveFinalSlash(sanitizedUrl) << uri; +} + +} // namespace NYdb::NConsoleClient::NAi diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_base.h b/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_base.h new file mode 100644 index 000000000000..8beef05c0a6d --- /dev/null +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_base.h @@ -0,0 +1,41 @@ +#pragma once + +#include "model_interface.h" + +#include +#include + +namespace NYdb::NConsoleClient::NAi { + +class TModelBase : public IModel { +protected: + enum EVerboseLevel { + VERB_INFO = 1, + VERB_TRACE = 2, + }; + +public: + TModelBase(const TString& apiUrl, const std::optional& authToken, const TClientCommand::TConfig& config); + + TResponse HandleMessages(const std::vector& messages) final; + +protected: + virtual void AdvanceConversation(const std::vector& messages) = 0; + + virtual TResponse HandleModelResponse(const NJson::TJsonValue& response) = 0; + + virtual TString HandleErrorResponse(ui64 httpCode, const TString& response); + + static TString CreateApiUrl(const TString& baseUrl, const TString& uri); + +protected: + NJson::TJsonValue ChatCompletionRequest; + +private: + const ui64 Verbosity = 0; + const TString ApiUrl; + const NYql::THttpHeader ApiHeaders; + const NYql::IHTTPGateway::TPtr HttpGateway; +}; + +} // namespace NYdb::NConsoleClient::NAi diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_interface.h b/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_interface.h index bc830ab9181b..8677d9525b3f 100644 --- a/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_interface.h +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_interface.h @@ -14,11 +14,18 @@ class IModel { virtual ~IModel() = default; - struct TRequest { + struct TUserMessage { TString Text; - std::optional ToolCallId; }; + struct TToolResponse { + TString Text; + TString ToolCallId; + bool IsSuccess = true; + }; + + using TMessage = std::variant; + struct TResponse { struct TToolCall { TString Id; @@ -26,11 +33,11 @@ class IModel { NJson::TJsonValue Parameters; }; - std::optional Text; + TString Text; std::vector ToolCalls; }; - virtual TResponse HandleMessages(const std::vector& requests) = 0; + virtual TResponse HandleMessages(const std::vector& messages) = 0; virtual void RegisterTool(const TString& name, const NJson::TJsonValue& parametersSchema, const TString& description) = 0; }; diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_openai.cpp b/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_openai.cpp index 41532aa37aab..27d78571e273 100644 --- a/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_openai.cpp +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_openai.cpp @@ -1,286 +1,157 @@ #include "model_openai.h" +#include "model_base.h" -#include +#include +#include #include -#include -#include -#include -#include #include +#include namespace NYdb::NConsoleClient::NAi { namespace { -class TModelOpenAi final : public IModel { - class TToolInfo { - public: - TToolInfo(const TString& name, const NJson::TJsonValue& parametersSchema, const TString& description) - : Name(name) - , ParametersSchema(parametersSchema) - , Description(description) - {} - - NJson::TJsonValue ToJson() const { - NJson::TJsonValue result; - result["type"] = "function"; - - auto& toolJson = result["function"]; - toolJson["strict"] = false; // AI-TODO: enable after fixes - toolJson["name"] = Name; - toolJson["parameters"] = ParametersSchema; - toolJson["description"] = Description; - - return result; - } - - private: - TString Name; - NJson::TJsonValue ParametersSchema; - TString Description; - }; - - class TConversationPart { - public: - enum class ERole { - User, - AI, - Tool, - }; - - TConversationPart(const TString& content, ERole role) - : Content(content) - , Role(role) - {} - - TConversationPart(const TString& content, const TString& toolCallId) - : Content(content) - , Role(ERole::Tool) - , ToolCallId(toolCallId) - {} - - NJson::TJsonValue ToJson() const { - NJson::TJsonValue result; - result["content"] = Content; - - auto& roleJson = result["role"]; - switch (Role) { - case ERole::User: { - roleJson = "user"; - break; - } - case ERole::AI: { - roleJson = "assistant"; - break; - } - case ERole::Tool: { - roleJson = "tool"; - - Y_ENSURE(ToolCallId); - result["tool_call_id"] = *ToolCallId; - break; - } - } +class TModelOpenAi final : public TModelBase { + using TBlase = TModelBase; - return result; - } - - private: - TString Content; - ERole Role = ERole::User; - std::optional ToolCallId; - }; + static constexpr ui64 MAX_COMPLETION_TOKENS = 1024; public: - TModelOpenAi(NYql::IHTTPGateway::TPtr httpGateway, const TOpenAiModelSettings& settings) - : HttpGateway(httpGateway) - , Settings(settings) + TModelOpenAi(const TOpenAiModelSettings& settings, const TClientCommand::TConfig& config) + : TBlase(CreateApiUrl(settings.BaseUrl, "/v1/chat/completions"), settings.ApiKey, config) + , Tools(ChatCompletionRequest["tools"].SetType(NJson::JSON_ARRAY).GetArraySafe()) + , Conversation(ChatCompletionRequest["messages"].SetType(NJson::JSON_ARRAY).GetArraySafe()) { - TStringBuf sanitizedUrl; - TStringBuf query; - TStringBuf fragment; - SeparateUrlFromQueryAndFragment(Settings.BaseUrl, sanitizedUrl, query, fragment); - - if (query || fragment) { - throw yexception() << "BaseUrl must not contain query or fragment, got url: '" << Settings.BaseUrl << "' with query: '" << query << "' or fragment: '" << fragment << "'"; + if (settings.ModelId) { + ChatCompletionRequest["model"] = *settings.ModelId; } - Settings.BaseUrl = RemoveFinalSlash(sanitizedUrl); + ChatCompletionRequest["max_completion_tokens"] = MAX_COMPLETION_TOKENS; } - TResponse HandleMessages(const std::vector& requests) final { - for (const auto& request : requests) { - if (request.ToolCallId) { - Conversation.emplace_back(request.Text, *request.ToolCallId); - } else { - Conversation.emplace_back(request.Text, TConversationPart::ERole::User); - } - } + void RegisterTool(const TString& name, const NJson::TJsonValue& parametersSchema, const TString& description) final { + Y_DEBUG_VERIFY(ValidateToolName(name), "Internal error. Invalid tool name: %s", name.c_str()); - NJson::TJsonValue bodyJson; + auto& tool = Tools.emplace_back(); + tool["type"] = "function"; - if (Settings.ModelId) { - bodyJson["model"] = *Settings.ModelId; - } + auto& toolInfo = tool["function"]; + toolInfo["name"] = name; + toolInfo["parameters"] = parametersSchema; + toolInfo["description"] = description; + } - auto& conversationJson = bodyJson["messages"].SetType(NJson::JSON_ARRAY).GetArraySafe(); - for (const auto& part : Conversation) { - conversationJson.push_back(part.ToJson()); - } +protected: + void AdvanceConversation(const std::vector& messages) final { + for (const auto& message : messages) { + auto& item = Conversation.emplace_back(); + auto& content = item["content"]; + auto& role = item["role"]; - auto& toolsArray = bodyJson["tools"].SetType(NJson::JSON_ARRAY).GetArraySafe(); - for (const auto& tool : Tools) { - toolsArray.push_back(tool.ToJson()); + if (std::holds_alternative(message)) { + content = std::get(message).Text; + role = "user"; + } else { + const auto& toolResponse = std::get(message); + item["tool_call_id"] = toolResponse.ToolCallId; + content = toolResponse.Text; + role = "tool"; + } } + } - NJsonWriter::TBuf bodyWriter; - bodyWriter.WriteJsonValue(&bodyJson); + TResponse HandleModelResponse(const NJson::TJsonValue& response) final { + TResponse result; - NYql::THttpHeader headers = {.Fields = {"Content-Type: application/json"}}; + TJsonParser parser(response); + if (auto child = parser.MaybeKey("response")) { + parser = std::move(*child); + } - // Cerr << "-------------------------- Request: " << bodyWriter.Str(); + parser = parser.GetKey("choices").GetElement(0).GetKey("message"); + Conversation.emplace_back(parser.GetValue()); - if (Settings.ApiKey) { - headers.Fields.emplace_back(TStringBuilder() << "Authorization: Bearer " << Settings.ApiKey); + const auto& content = parser.MaybeKey("content"); + const bool hasContent = content && !content->IsNull(); + if (hasContent) { + result.Text = Strip(content->GetString()); } - auto answer = NThreading::NewPromise(); - HttpGateway->Upload( - TStringBuilder() << Settings.BaseUrl << "/v1/chat/completions", - std::move(headers), - bodyWriter.Str(), - [&answer](NYql::IHTTPGateway::TResult result) { - if (result.CurlResponseCode != CURLE_OK) { - // AI-TODO: proper error handling - answer.SetException(TStringBuilder() << "Request model failed: " << result.Issues.ToOneLineString() << ", internal code: " << curl_easy_strerror(result.CurlResponseCode) << ", response: " << result.Content.Extract()); - return; - } + const auto& tollCalls = parser.MaybeKey("tool_calls"); + const bool hasToolsCalls = tollCalls && !tollCalls->IsNull(); + if (hasToolsCalls) { + tollCalls->Iterate([&](TJsonParser toolCall) { + auto function = toolCall.GetKey("function"); - auto& content = result.Content; - if (content.HttpResponseCode < 200 || content.HttpResponseCode >= 300) { - answer.SetException(TStringBuilder() << "Request model failed, internal code: " << content.HttpResponseCode << ", response: " << result.Content.Extract()); - return; + NJson::TJsonValue argumentsJson; + try { + NJson::ReadJsonTree(function.GetKey("arguments").GetString(), &argumentsJson, /* throwOnError */ true); + } catch (const std::exception& e) { + throw yexception() << "Tool call arguments is not valid JSON, reason: " << e.what(); } - answer.SetValue(content.Extract()); - } - ); - - const auto result = answer.GetFuture().ExtractValueSync(); - // Cerr << "-------------------------- Result: " << result; - NJson::TJsonValue resultJson; - if (!NJson::ReadJsonTree(result, &resultJson)) { - throw yexception() << "Response of model is not JSON, got response: " << result; + result.ToolCalls.push_back({ + .Id = toolCall.GetKey("id").GetString(), + .Name = function.GetKey("name").GetString(), + .Parameters = std::move(argumentsJson), + }); + }); } - ValidateJsonType(resultJson, NJson::JSON_MAP); - - const auto& resultMap = resultJson.GetMap(); - if (const auto it = resultMap.find("response"); it != resultMap.end()) { - resultJson = it->second; + if (!hasContent && !hasToolsCalls) { + throw yexception() << "Not found either content or tool_calls keys in field " << parser.GetFieldName(); } - const auto& choices = ValidateJsonKey(resultJson, "choices"); - ValidateJsonType(choices, NJson::JSON_ARRAY, "choices"); - ValidateJsonArraySize(choices, 1, "choices"); - - // AI-TODO: proper error description - const auto& choiceVal = choices.GetArray()[0]; - ValidateJsonType(choiceVal, NJson::JSON_MAP, "choices[0]"); - - const auto& message = ValidateJsonKey(choiceVal, "message", "choices[0]"); - ValidateJsonType(message, NJson::JSON_MAP, "choices[0].message"); + return result; + } - const auto& messageMap = message.GetMap(); - const auto* content = messageMap.FindPtr("content"); - const auto* tollsCalls = messageMap.FindPtr("tool_calls"); - if ((!content || content->GetType() == NJson::JSON_NULL) && (!tollsCalls || tollsCalls->GetType() == NJson::JSON_NULL)) { - throw yexception() << "Response of model does not contain 'choices[0].message.content' or 'choices[0].message.tool_calls' fields, got response: " << result; + TString HandleErrorResponse(ui64 httpCode, const TString& response) final { + TJsonParser parser; + if (!parser.Parse(response)) { + return TBlase::HandleErrorResponse(httpCode, response); } - TResponse response; - - if (content && content->GetType() != NJson::JSON_NULL) { - ValidateJsonType(*content, NJson::JSON_STRING, "choices[0].message.content"); - response.Text = content->GetString(); - Conversation.emplace_back(*response.Text, TConversationPart::ERole::AI); + auto error = TStringBuilder() << "Request to model API failed:\n"; + if (const auto& info = parser.MaybeKey("message")) { + return error << info->ToString(); } - - if (tollsCalls && tollsCalls->GetType() != NJson::JSON_NULL) { - ValidateJsonType(*tollsCalls, NJson::JSON_ARRAY, "choices[0].message.tool_calls"); - - for (const auto& toolCall : tollsCalls->GetArray()) { - ValidateJsonType(toolCall, NJson::JSON_MAP, "choices[0].message.tool_calls[0]"); - - const auto& function = ValidateJsonKey(toolCall, "function", "choices[0].message.tool_calls[0]"); - ValidateJsonType(function, NJson::JSON_MAP, "choices[0].message.tool_calls[0].function"); - - const auto& name = ValidateJsonKey(function, "name", "choices[0].message.tool_calls[0].function"); - ValidateJsonType(name, NJson::JSON_STRING, "choices[0].message.tool_calls[0].function.name"); - - const auto& arguments = ValidateJsonKey(function, "arguments", "choices[0].message.tool_calls[0].function"); - ValidateJsonType(arguments, NJson::JSON_STRING, "choices[0].message.tool_calls[0].function.arguments"); - - const auto& callId = ValidateJsonKey(toolCall, "id", "choices[0].message.tool_calls[0]"); - ValidateJsonType(callId, NJson::JSON_STRING, "choices[0].message.tool_calls[0].id"); - - NJson::TJsonValue argumentsJson; - if (!NJson::ReadJsonTree(arguments.GetString(), &argumentsJson)) { - throw yexception() << "Tool call arguments is not valid JSON, got response: " << arguments.GetString(); - } - - response.ToolCalls.push_back({ - .Id = callId.GetString(), - .Name = name.GetString(), - .Parameters = std::move(argumentsJson) - }); - } + if (const auto& info = parser.MaybeKey("raw_response")) { + return error << info->ToString(); + } + if (const auto& info = parser.MaybeKey("error")) { + return error << info->ToString(); } - Y_ENSURE(response.Text || !response.ToolCalls.empty()); - return response; - } - - void RegisterTool(const TString& name, const NJson::TJsonValue& parametersSchema, const TString& description) final { - Tools.emplace_back(name, parametersSchema, description); + return TBlase::HandleErrorResponse(httpCode, response); } private: - void ValidateJsonType(const NJson::TJsonValue& value, NJson::EJsonValueType expectedType, const std::optional& fieldName = std::nullopt) const { - if (const auto valueType = value.GetType(); valueType != expectedType) { - throw yexception() << "Response" << (fieldName ? " field '" + *fieldName + "'" : "") << " of model has unexpected type: " << valueType << ", expected type: " << expectedType << ", got response: " << value; + static bool ValidateToolName(const TString& name) { + if (name.size() > 64) { + return false; } - } - - void ValidateJsonArraySize(const NJson::TJsonValue& value, size_t expectedSize, const std::optional& fieldName = std::nullopt) const { - if (const auto valueSize = value.GetArray().size(); valueSize != expectedSize) { - throw yexception() << "Response" << (fieldName ? " field '" + *fieldName + "'" : "") << " of model has unexpected size: " << valueSize << ", expected size: " << expectedSize << ", got response: " << value; - } - } - const NJson::TJsonValue& ValidateJsonKey(const NJson::TJsonValue& value, const TString& key, const std::optional& fieldName = std::nullopt) const { - const auto* output = value.GetMap().FindPtr(key); - if (!output) { - throw yexception() << "Response of model does not contain '" << key << "' field" << (fieldName ? " in '" + *fieldName + "'" : "") << ", got response: " << value; + for (const auto c : name) { + if (('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || ('0' <= c && c <= '9') || IsIn({'_', '-'}, c)) { + continue; + } + return false; } - return *output; + return true; } private: - const NYql::IHTTPGateway::TPtr HttpGateway; - TOpenAiModelSettings Settings; - - std::vector Tools; - std::vector Conversation; + NJson::TJsonValue::TArray& Tools; + NJson::TJsonValue::TArray& Conversation; }; } // anonymous namespace -IModel::TPtr CreateOpenAiModel(const TOpenAiModelSettings& settings) { - return std::make_shared(NYql::IHTTPGateway::Make(), settings); +IModel::TPtr CreateOpenAiModel(const TOpenAiModelSettings& settings, const TClientCommand::TConfig& config) { + return std::make_shared(settings, config); } } // namespace NYdb::NConsoleClient::NAi diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_openai.h b/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_openai.h index 07570f69092b..150ed85d944a 100644 --- a/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_openai.h +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_openai.h @@ -2,16 +2,16 @@ #include "model_interface.h" -#include +#include namespace NYdb::NConsoleClient::NAi { struct TOpenAiModelSettings { - TString BaseUrl; // AI-TODO KIKIMR-24211 add default value + TString BaseUrl; std::optional ModelId; - TString ApiKey; + std::optional ApiKey; }; -IModel::TPtr CreateOpenAiModel(const TOpenAiModelSettings& settings); +IModel::TPtr CreateOpenAiModel(const TOpenAiModelSettings& settings, const TClientCommand::TConfig& config); } // namespace NYdb::NConsoleClient::NAi diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/models/ya.make b/ydb/public/lib/ydb_cli/commands/ydb_ai/models/ya.make index 5bef635b4afe..5e82e0375bfa 100644 --- a/ydb/public/lib/ydb_cli/commands/ydb_ai/models/ya.make +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/models/ya.make @@ -2,6 +2,7 @@ LIBRARY() SRCS( model_anthropic.cpp + model_base.cpp model_openai.cpp ) @@ -11,6 +12,8 @@ PEERDIR( library/cpp/string_utils/url library/cpp/threading/future ydb/library/yql/providers/common/http_gateway + ydb/public/lib/ydb_cli/commands/ydb_ai/common + ydb/public/lib/ydb_cli/common ) END() diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/exec_query_tool.cpp b/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/exec_query_tool.cpp index d34daaeca641..419f204370cd 100644 --- a/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/exec_query_tool.cpp +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/exec_query_tool.cpp @@ -1,98 +1,118 @@ #include "exec_query_tool.h" +#include #include +#include +#include #include -#include +#include namespace NYdb::NConsoleClient::NAi { namespace { class TExecQueryTool final : public ITool { -public: - explicit TExecQueryTool(TClientCommand::TConfig& config) - : Client(TDriver(config.CreateDriverConfig())) + static constexpr char DESCRIPTION[] = R"( +Execute query in Yandex Data Base (YDB) on YQL (SQL dialect). Returns list of result sets for query, each contains list of rows and column metadata. +For example if there exists table 'my_table' with string column 'Data' and we execute query: +``` +$filtered = SELECT * FROM my_table WHERE Data IS NOT NULL; +SELECT Data || "-first" FROM $filtered; +SELECT Data || "-second" FROM $filtered; +``` +Tool will return: +[ { - NJson::TJsonValue sqlParam; - sqlParam["type"] = "string"; - sqlParam["description"] = "SQL query"; // AI-TODO: proper description - - ParametersSchema["properties"]["sql"] = sqlParam; - ParametersSchema["type"] = "object"; - ParametersSchema["required"][0] = "sql"; + "rows": [ + {"Data": "A-first"}, + {"Data": "B-first"} + ], + "columns": [ + {"name": "Data", "type": "string"} + ] + }, + { + "rows": [ + {"Data": "A-second"}, + {"Data": "B-second"} + ], + "columns": [ + {"name": "Data", "type": "string"} + ] } +])"; - TString GetName() const final { - return "execute_sql_query"; - } + static constexpr char QUERY_PROPERTY[] = "query"; - NJson::TJsonValue GetParametersSchema() const final { +public: + explicit TExecQueryTool(TClientCommand::TConfig& config) + : ParametersSchema(TJsonSchemaBuilder() + .Type(TJsonSchemaBuilder::EType::Object) + .Property(QUERY_PROPERTY) + .Type(TJsonSchemaBuilder::EType::String) + .Description("Query to execute on YQL (SQL dialect), for example 'SELECT * FROM my_table'") + .Done() + .Build() + ) + , Description(DESCRIPTION) + , Client(TDriver(config.CreateDriverConfig())) + {} + + const NJson::TJsonValue& GetParametersSchema() const final { return ParametersSchema; } - TString GetDescription() const final { - return "Execute SQL query"; // AI-TODO: proper description + const TString& GetDescription() const final { + return Description; } - TString Execute(const NJson::TJsonValue& parameters) final { - ValidateJsonType(parameters, NJson::JSON_MAP); - - const auto& sql = ValidateJsonKey(parameters, "sql"); - ValidateJsonType(sql, NJson::JSON_STRING, "sql"); - - const auto& sqlString = sql.GetString(); - Cerr << "\n!! Execute SQL query: " << sqlString << Endl; // AI-TODO: proper query execution printing - - // AI-TODO: streaming execution - auto result = Client.ExecuteQuery(sqlString, NQuery::TTxControl::NoTx()).ExtractValueSync(); - - // AI-TODO: proper error printing - if (!result.IsSuccess()) { - Cerr << "\n!! Execute query error [" << result.GetStatus() << "]: " << result.GetIssues().ToString() << Endl; - return TStringBuilder() << "Error executing SQL query, status: " << result.GetStatus() << ", issues: " << result.GetIssues().ToString(); + TResponse Execute(const NJson::TJsonValue& parameters) final try { + const TString& query = ParseParameters(parameters); + const auto response = Client.ExecuteQuery(query, NQuery::TTxControl::NoTx()).ExtractValueSync(); + if (!response.IsSuccess()) { + return TResponse(TStringBuilder() << "Query execution failed with status " << response.GetStatus() << ", reason:\n" << response.GetIssues().ToString()); } - const auto& resultSets = result.GetResultSets(); - - // AI-TODO: proper result formating - TStringBuilder resultBuilder; - for (ui64 i = 0; i < resultSets.size(); ++i) { - resultBuilder << "Result set " << i << ":\n" << Endl; + NJson::TJsonValue result; + auto& resultArray = result.SetType(NJson::JSON_ARRAY).GetArraySafe(); + for (const auto& resultSet : response.GetResultSets()) { + auto& item = resultArray.emplace_back(); + + const auto& columnMeta = resultSet.GetColumnsMeta(); + auto& columns = item["columns"].SetType(NJson::JSON_ARRAY).GetArraySafe(); + for (const auto& column : columnMeta) { + auto& item = columns.emplace_back(); + item["name"] = column.Name; + item["type"] = column.Type.ToString(); + } - TResultSetParser parser(resultSets[i]); + auto& rows = item["rows"].SetType(NJson::JSON_ARRAY).GetArraySafe(); + TResultSetParser parser(resultSet); while (parser.TryNextRow()) { - NJsonWriter::TBuf writer(NJsonWriter::HEM_UNSAFE, &resultBuilder.Out); - FormatResultRowJson(parser, resultSets[i].GetColumnsMeta(), writer, EBinaryStringEncoding::Unicode); - resultBuilder << "\n"; + TJsonParser row; + Y_DEBUG_VERIFY(row.Parse(FormatResultRowJson(parser, resultSet.GetColumnsMeta(), EBinaryStringEncoding::Unicode)), "Internal error. Invalid serialized JSON row value."); + rows.emplace_back(row.GetValue()); } - } - Cerr << "\n!! Execute query result: " << resultBuilder << Endl; // AI-TODO: proper query result printing + TResultSetPrinter(EDataFormat::Pretty).Print(resultSet); + } - return resultBuilder; + return TResponse(std::move(result)); + } catch (const std::exception& e) { + return TResponse(TStringBuilder() << "Query execution failed. " << e.what()); } private: - // AI-TODO: reduce copypaste - void ValidateJsonType(const NJson::TJsonValue& value, NJson::EJsonValueType expectedType, const std::optional& fieldName = std::nullopt) const { - if (const auto valueType = value.GetType(); valueType != expectedType) { - throw yexception() << "Tool request " << (fieldName ? " field '" + *fieldName + "'" : "") << " has unexpected type: " << valueType << ", expected type: " << expectedType; - } - } - - const NJson::TJsonValue& ValidateJsonKey(const NJson::TJsonValue& value, const TString& key, const std::optional& fieldName = std::nullopt) const { - const auto* output = value.GetMap().FindPtr(key); - if (!output) { - throw yexception() << "Tool request does not contain '" << key << "' field" << (fieldName ? " in '" + *fieldName + "'" : ""); - } - - return *output; + TString ParseParameters(const NJson::TJsonValue& parameters) const { + TJsonParser parser(parameters); + return Strip(parser.GetKey(QUERY_PROPERTY).GetString()); } private: + const NJson::TJsonValue ParametersSchema; + const TString Description; NQuery::TQueryClient Client; - NJson::TJsonValue ParametersSchema; }; } // anonymous namespace diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/list_directory_tool.cpp b/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/list_directory_tool.cpp index 714a3eae0e5e..b971f399ab57 100644 --- a/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/list_directory_tool.cpp +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/list_directory_tool.cpp @@ -1,95 +1,93 @@ #include "list_directory_tool.h" #include +#include +#include +#include #include +#include + namespace NYdb::NConsoleClient::NAi { namespace { class TListDirectoryTool final : public ITool { + static constexpr char DESCRIPTION[] = R"( +List directory in Yandex Data Base (YDB) scheme tree. Returns list of item names inside directory and their types. +For example if called on directory 'data/', which contains two tables 'my_table1', 'my_table2' and one topic 'my_topic', then tool will return: +[ + {"name": "my_table1", "type": "table"}, + {"name": "my_table2", "type": "table"}, + {"name": "my_topic", "type": "topic"} +])"; + + static constexpr char DIRECTORY_PROPERTY[] = "directory"; + public: explicit TListDirectoryTool(TClientCommand::TConfig& config) - : Database(NKikimr::CanonizePath(config.Database)) + : ParametersSchema(TJsonSchemaBuilder() + .Type(TJsonSchemaBuilder::EType::Object) + .Property(DIRECTORY_PROPERTY) + .Type(TJsonSchemaBuilder::EType::String) + .Description("Path to directory which should be listed (use empty string to list database root), for example 'data/cold/'") + .Done() + .Build() + ) + , Description(DESCRIPTION) + , Database(NKikimr::CanonizePath(config.Database)) , Client(TDriver(config.CreateDriverConfig())) - { - NJson::TJsonValue dirParam; - dirParam["type"] = "string"; - dirParam["description"] = "Directory path to list (use empty to list root directory)"; // AI-TODO: proper description - - ParametersSchema["properties"]["directory"] = dirParam; - ParametersSchema["type"] = "object"; - ParametersSchema["required"][0] = "directory"; - } + {} - TString GetName() const final { - return "list_directory"; - } - - NJson::TJsonValue GetParametersSchema() const final { + const NJson::TJsonValue& GetParametersSchema() const final { return ParametersSchema; } - TString GetDescription() const final { - return "List directory"; // AI-TODO: proper description + const TString& GetDescription() const final { + return Description; } - TString Execute(const NJson::TJsonValue& parameters) final { - ValidateJsonType(parameters, NJson::JSON_MAP); - - const auto& dir = ValidateJsonKey(parameters, "directory"); - ValidateJsonType(dir, NJson::JSON_STRING, "directory"); - - TString dirString = dir.GetString(); - Cerr << "\n!! List directory: " << dirString << Endl; // AI-TODO: proper list directory printing - - if (!dirString.StartsWith('/')) { - dirString = NKikimr::JoinPath({Database, dirString}); + TResponse Execute(const NJson::TJsonValue& parameters) final try { + const auto& directory = ParseParameters(parameters); + const auto response = Client.ListDirectory(directory).ExtractValueSync(); + if (!response.IsSuccess()) { + return TResponse(TStringBuilder() << "Listing directory failed with status " << response.GetStatus() << ", reason:\n" << response.GetIssues().ToString()); } - // AI-TODO: progress printing - auto result = Client.ListDirectory(dirString).ExtractValueSync(); - - // AI-TODO: proper error printing - if (!result.IsSuccess()) { - Cerr << "\n!! List directory error [" << result.GetStatus() << "]: " << result.GetIssues().ToString() << Endl; - return TStringBuilder() << "Error listing directory, status: " << result.GetStatus() << ", issues: " << result.GetIssues().ToString(); - } + const auto& children = response.GetChildren(); - const auto& children = result.GetChildren(); - - // AI-TODO: proper result formating - TStringBuilder resultBuilder; + NJson::TJsonValue result; + auto& resultArray = result.SetType(NJson::JSON_ARRAY).GetArraySafe(); for (const auto& child : children) { - resultBuilder << child.Name << " (" << child.Type << ")" << "\n"; + auto& item = resultArray.emplace_back(); + item["name"] = child.Name; + item["type"] = EntryTypeToString(child.Type); } - Cerr << "\n!! List directory result: " << resultBuilder << Endl; // AI-TODO: proper query result printing + Cout << TAdaptiveTabbedTable(children); - return resultBuilder; + return TResponse(std::move(result)); + } catch (const std::exception& e) { + return TResponse(TStringBuilder() << "Listing directory failed. " << e.what()); } private: - // AI-TODO: reduce copypaste - void ValidateJsonType(const NJson::TJsonValue& value, NJson::EJsonValueType expectedType, const std::optional& fieldName = std::nullopt) const { - if (const auto valueType = value.GetType(); valueType != expectedType) { - throw yexception() << "Tool request " << (fieldName ? " field '" + *fieldName + "'" : "") << " has unexpected type: " << valueType << ", expected type: " << expectedType; - } - } + TString ParseParameters(const NJson::TJsonValue& parameters) const { + TJsonParser parser(parameters); - const NJson::TJsonValue& ValidateJsonKey(const NJson::TJsonValue& value, const TString& key, const std::optional& fieldName = std::nullopt) const { - const auto* output = value.GetMap().FindPtr(key); - if (!output) { - throw yexception() << "Tool request does not contain '" << key << "' field" << (fieldName ? " in '" + *fieldName + "'" : ""); + TString directory = Strip(parser.GetKey(DIRECTORY_PROPERTY).GetString()); + if (!directory.StartsWith('/')) { + directory = NKikimr::JoinPath({Database, directory}); } - return *output; + return NKikimr::CanonizePath(directory); } private: + const NJson::TJsonValue ParametersSchema; + const TString Description; const TString Database; NScheme::TSchemeClient Client; - NJson::TJsonValue ParametersSchema; }; } // anonymous namespace diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/tool_interface.cpp b/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/tool_interface.cpp new file mode 100644 index 000000000000..690579520d7c --- /dev/null +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/tool_interface.cpp @@ -0,0 +1,17 @@ +#include "tool_interface.h" + +#include + +namespace NYdb::NConsoleClient::NAi { + +ITool::TResponse::TResponse(const TString& error) + : Text(error) + , IsSuccess(false) +{} + +ITool::TResponse::TResponse(const NJson::TJsonValue& result) + : Text(FormatJsonValue(result)) + , IsSuccess(true) +{} + +} // namespace NYdb::NConsoleClient::NAi diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/tool_interface.h b/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/tool_interface.h index 128fbd4b15ee..c43beb40f9ad 100644 --- a/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/tool_interface.h +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/tool_interface.h @@ -2,8 +2,6 @@ #include -#include - #include namespace NYdb::NConsoleClient::NAi { @@ -14,13 +12,20 @@ class ITool { virtual ~ITool() = default; - virtual TString GetName() const = 0; + struct TResponse { + TString Text; + bool IsSuccess = true; + + explicit TResponse(const TString& error); + + explicit TResponse(const NJson::TJsonValue& result); + }; - virtual NJson::TJsonValue GetParametersSchema() const = 0; + virtual const NJson::TJsonValue& GetParametersSchema() const = 0; - virtual TString GetDescription() const = 0; + virtual const TString& GetDescription() const = 0; - virtual TString Execute(const NJson::TJsonValue& parameters) = 0; + virtual TResponse Execute(const NJson::TJsonValue& parameters) = 0; }; } // namespace NYdb::NConsoleClient::NAi diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/ya.make b/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/ya.make index 22f7db175210..6398e7700d38 100644 --- a/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/ya.make +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/ya.make @@ -3,12 +3,14 @@ LIBRARY() SRCS( exec_query_tool.cpp list_directory_tool.cpp + tool_interface.cpp ) PEERDIR( library/cpp/json/writer ydb/core/base ydb/public/lib/json_value + ydb/public/lib/ydb_cli/commands/ydb_ai/common ydb/public/lib/ydb_cli/common ydb/public/sdk/cpp/src/client/query ydb/public/sdk/cpp/src/client/scheme