diff --git a/ydb/core/base/validation.h b/ydb/core/base/validation.h new file mode 100644 index 000000000000..5c708ea0ae65 --- /dev/null +++ b/ydb/core/base/validation.h @@ -0,0 +1,7 @@ +#pragma once + +#ifndef NDEBUG + #define Y_DEBUG_VERIFY Y_ABORT_UNLESS +#else + #define Y_DEBUG_VERIFY Y_VERIFY +#endif diff --git a/ydb/library/yql/providers/common/http_gateway/yql_http_gateway.h b/ydb/library/yql/providers/common/http_gateway/yql_http_gateway.h index 2095419a6270..9ad44b19ef81 100644 --- a/ydb/library/yql/providers/common/http_gateway/yql_http_gateway.h +++ b/ydb/library/yql/providers/common/http_gateway/yql_http_gateway.h @@ -2,21 +2,21 @@ #include "yql_http_header.h" -#include - #include + +#include + #include #include #include -#include - #include -#include #include namespace NYql { +class THttpGatewayConfig; + class IHTTPGateway { public: using TPtr = std::shared_ptr; @@ -138,4 +138,4 @@ class IHTTPGateway { const TString& awsSigV4 = {}); }; -} +} // namespace NYql diff --git a/ydb/public/lib/ydb_cli/commands/ya.make b/ydb/public/lib/ydb_cli/commands/ya.make index 86e60e67e7f7..0b41caec8060 100644 --- a/ydb/public/lib/ydb_cli/commands/ya.make +++ b/ydb/public/lib/ydb_cli/commands/ya.make @@ -11,6 +11,7 @@ SRCS( topic_write_scenario.cpp topic_readwrite_scenario.cpp ydb_admin.cpp + ydb_ai.cpp ydb_benchmark.cpp ydb_bridge.cpp ydb_cluster.cpp @@ -57,6 +58,10 @@ PEERDIR( ydb/public/lib/ydb_cli/commands/sdk_core_access ydb/public/lib/ydb_cli/commands/topic_workload ydb/public/lib/ydb_cli/commands/transfer_workload + ydb/public/lib/ydb_cli/commands/ydb_ai + ydb/public/lib/ydb_cli/commands/ydb_ai/common + ydb/public/lib/ydb_cli/commands/ydb_ai/models + ydb/public/lib/ydb_cli/commands/ydb_ai/tools ydb/public/lib/ydb_cli/commands/ydb_discovery ydb/public/lib/ydb_cli/common ydb/public/lib/ydb_cli/dump @@ -92,5 +97,6 @@ RECURSE( sdk_core_access topic_workload transfer_workload + ydb_ai ydb_discovery ) diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai.cpp b/ydb/public/lib/ydb_cli/commands/ydb_ai.cpp new file mode 100644 index 000000000000..0955b80af9a3 --- /dev/null +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai.cpp @@ -0,0 +1,146 @@ +#include "ydb_ai.h" + +#include +#include +#include +#include +#include +#include + +#include +#include + +/* + +FEATURES-TODO: + +- Streamable model response printing +- Streamable results printing +- Adjusting errors, progress and response printing +- Approving before tool use +- Integration into common interactive mode +- Think about helps +- Think about robust +- Provide system promt +- Somehow render markdown + +*/ + +namespace NYdb::NConsoleClient { + +namespace { + +void PrintExitMessage() { + Cout << "\nBye" << Endl; +} + +} // anonymous namespace + +TCommandAi::TCommandAi() + : TBase("ai", {}, "AI-TODO: KIKIMR-24198 -- description") +{} + +void TCommandAi::Config(TConfig& config) { + TBase::Config(config); + config.Opts->SetTitle("AI-TODO: KIKIMR-24198 -- title"); + config.Opts->SetFreeArgsNum(0); +} + +int TCommandAi::Run(TConfig& config) { + Cout << "AI-TODO: KIKIMR-24198 -- welcome message" << Endl; + + // AI-TODO: KIKIMR-24202 - robust file creation + NAi::TLineReader lineReader("ydb-ai> ", (TFsPath(HomeDir) / ".ydb-ai/history").GetPath()); + + // DeepSeek + // const auto model = NAi::CreateOpenAiModel({ + // .BaseUrl = "https://api.eliza.yandex.net/raw/internal/deepseek", // AI-TODO: KIKIMR-24214 -- configure it + // .ModelId = "deepseek-0324", // AI-TODO: KIKIMR-24214 -- configure it + // .ApiKey = GetEnv("MODEL_TOKEN"), // AI-TODO: KIKIMR-24214 -- configure it + // }, config); + + // Claude 3.5 haiku + // const auto model = NAi::CreateAnthropicModel({ + // .BaseUrl = "https://api.eliza.yandex.net/anthropic", // AI-TODO: KIKIMR-24214 -- configure it + // .ModelId = "claude-3-5-haiku-20241022", + // .ApiKey = GetEnv("MODEL_TOKEN"), // AI-TODO: KIKIMR-24214 -- configure it + // }, config); + + // YandexGPT Pro + const auto model = NAi::CreateOpenAiModel({ + .BaseUrl = "https://api.eliza.yandex.net/internal/zeliboba/32b_aligned_quantized_202506/generative", // AI-TODO: KIKIMR-24214 -- configure it + .ApiKey = GetEnv("MODEL_TOKEN"), // AI-TODO: KIKIMR-24214 -- configure it + }, config); + + std::unordered_map tools = { + {"execute_query", NAi::CreateExecQueryTool(config)}, + {"list_directory", NAi::CreateListDirectoryTool(config)}, + }; + for (const auto& [name, tool] : tools) { + model->RegisterTool(name, tool->GetParametersSchema(), tool->GetDescription()); + } + + // AI-TODO: there is strange highlighting of brackets + std::vector messages; + while (const auto& maybeLine = lineReader.ReadLine()) { + const auto& input = *maybeLine; + if (input.empty()) { + continue; + } + + if (IsIn({"quit", "exit"}, to_lower(input))) { + PrintExitMessage(); + return EXIT_SUCCESS; + } + + // AI-TODO: limit interaction number + messages.emplace_back(NAi::IModel::TUserMessage{.Text = input}); + while (!messages.empty()) { + // AI-TODO: progress visualization + const auto output = model->HandleMessages(messages); + messages.clear(); + + if (!output.Text && output.ToolCalls.empty()) { + // AI-TODO: proper answer format + Cout << "Model answer is empty(" << Endl; + break; + } + + if (output.Text) { + // AI-TODO: proper answer format + Cout << "Model answer:\n" << output.Text << Endl; + } + + for (const auto& toolCall : output.ToolCalls) { + const auto it = tools.find(toolCall.Name); + if (it == tools.end()) { + // AI-TODO: proper wrong tool handling + Cout << "Unsupported tool: " << toolCall.Name << Endl; + return EXIT_FAILURE; + } + + // AI-TODO: proper tool call printing + Cout << "Calling tool: " << toolCall.Name << " with params:\n" << NAi::FormatJsonValue(toolCall.Parameters) << Endl; + + // AI-TODO: add approving + const auto& result = it->second->Execute(toolCall.Parameters); + if (!result.IsSuccess) { + // AI-TODO: proper error handling + Cout << result.Text << Endl; + } + // AI-TODO: show progress + + messages.push_back(NAi::IModel::TToolResponse{ + .Text = result.Text, + .ToolCallId = toolCall.Id, + .IsSuccess = result.IsSuccess, + }); + } + } + } + + PrintExitMessage(); + return EXIT_SUCCESS; +} + +} // namespace NYdb::NConsoleClient \ No newline at end of file diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai.h b/ydb/public/lib/ydb_cli/commands/ydb_ai.h new file mode 100644 index 000000000000..1385cea31908 --- /dev/null +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai.h @@ -0,0 +1,18 @@ +#pragma once + +#include "ydb_command.h" + +namespace NYdb::NConsoleClient { + +class TCommandAi final : public TYdbCommand { + using TBase = TYdbCommand; + +public: + TCommandAi(); + + void Config(TConfig& config) final; + + int Run(TConfig& config) final; +}; + +} // namespace NYdb::NConsoleClient diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/common/json_utils.cpp b/ydb/public/lib/ydb_cli/commands/ydb_ai/common/json_utils.cpp new file mode 100644 index 000000000000..90e1618415e1 --- /dev/null +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/common/json_utils.cpp @@ -0,0 +1,223 @@ +#include "json_utils.h" + +#include + +#include +#include + +#include + +namespace NYdb::NConsoleClient::NAi { + +//// TJsonParser + +TJsonParser::TJsonParser() + : TJsonParser(NJson::TJsonValue()) +{} + +TJsonParser::TJsonParser(const NJson::TJsonValue& value) + : JsonHolder(std::make_shared(value)) + , State(JsonHolder.get()) +{} + +TJsonParser::TJsonParser(std::shared_ptr jsonHolder, const NJson::TJsonValue* state, const TString& fieldName) + : JsonHolder(jsonHolder) + , State(state) + , FieldName(fieldName) +{ + Y_DEBUG_VERIFY(JsonHolder, "Internal error. JsonHolder should not be null"); + Y_DEBUG_VERIFY(State, "Internal error. State should not be null"); +} + +bool TJsonParser::Parse(const TString& value) { + Y_DEBUG_VERIFY(JsonHolder.get() == State, "Internal error. Parse should be called on root JSON"); + return NJson::ReadJsonTree(value, JsonHolder.get()); +} + +TString TJsonParser::ToString() const { + if (State->IsString()) { + return State->GetString(); + } + + NJsonWriter::TBuf valueWriter; + valueWriter.SetIndentSpaces(2); + valueWriter.WriteJsonValue(State); + return valueWriter.Str(); +} + +const NJson::TJsonValue& TJsonParser::GetValue() const { + return *State; +} + +TString TJsonParser::GetFieldName() const { + return FieldName.value_or("$"); +} + +TJsonParser TJsonParser::GetKey(const TString& key) const { + ValidateType(NJson::JSON_MAP); + + const auto* child = State->GetMapSafe().FindPtr(key); + if (!child) { + Fail(TStringBuilder() << "does not contain " << key << " field"); + } + + return TJsonParser(JsonHolder, child, AdvancePath(key)); +} + +std::optional TJsonParser::MaybeKey(const TString& key) const { + if (!State->IsMap()) { + return std::nullopt; + } + + if (const auto* child = State->GetMapSafe().FindPtr(key)) { + return TJsonParser(JsonHolder, child, AdvancePath(key)); + } + + return std::nullopt; +} + +TJsonParser TJsonParser::GetElement(ui64 index) const { + ValidateType(NJson::JSON_ARRAY); + + const auto& array = State->GetArraySafe(); + if (const auto size = array.size(); size <= index) { + Fail(TStringBuilder() << "does not contain element with index " << index << ", actual array size: " << size); + } + + return TJsonParser(JsonHolder, &array[index], AdvancePath(index)); +} + +void TJsonParser::Iterate(std::function handler) const { + ValidateType(NJson::JSON_ARRAY); + + const auto& array = State->GetArraySafe(); + for (ui64 i = 0; i < array.size(); ++i) { + handler(TJsonParser(JsonHolder, &array[i], AdvancePath(i))); + } +} + +TString TJsonParser::GetString() const { + ValidateType(NJson::JSON_STRING); + return State->GetString(); +} + +bool TJsonParser::IsNull() const { + return State->IsNull(); +} + +void TJsonParser::ValidateType(NJson::EJsonValueType expectedType) const { + if (const auto type = State->GetType(); type != expectedType) { + Fail(TStringBuilder() << "has unexpected type " << type << ", expected type: " << expectedType); + } +} + +TString TJsonParser::AdvancePath(const TString& key) const { + return TStringBuilder() << GetFieldName() << "." << key; +} + +TString TJsonParser::AdvancePath(const ui64& index) const { + return TStringBuilder() << GetFieldName() << "[" << index << "]"; +} + +void TJsonParser::Fail(const TString& message) const { + auto error = yexception() << "JSON"; + if (FieldName) { + error << " field " << *FieldName; + } + throw error << " " << message; +} + +//// TJsonSchemaBuilder + +TJsonSchemaBuilder::TJsonSchemaBuilder(TJsonSchemaBuilder* parent) + : Parent(parent) +{ + Y_DEBUG_VERIFY(Parent, "Internal error. Parent should not be null"); +} + +TJsonSchemaBuilder& TJsonSchemaBuilder::Type(EType type) { + Y_DEBUG_VERIFY(TypeValue == EType::Undefined, "Internal error. Type should be defined only once"); + Y_DEBUG_VERIFY(type != EType::Undefined, "Internal error. Type should not be undefined"); + TypeValue = type; + return *this; +} + +TJsonSchemaBuilder& TJsonSchemaBuilder::Description(const TString& description) { + Y_DEBUG_VERIFY(!DescriptionValue, "Internal error. Description should be defined only once"); + DescriptionValue = description; + return *this; +} + +TJsonSchemaBuilder& TJsonSchemaBuilder::Done() const { + Y_DEBUG_VERIFY(Parent, "Internal error. Done should be called only on child objects"); + return *Parent; +} + +NJson::TJsonValue TJsonSchemaBuilder::Build() const { + NJson::TJsonValue result; + + if (DescriptionValue) { + result["description"] = *DescriptionValue; + } + + auto& type = result["type"]; + switch (TypeValue) { + case EType::Undefined: { + Y_DEBUG_VERIFY(false, "Internal error. Type should not be defined before building"); + break; + } + case EType::String: { + type = "string"; + break; + } + case EType::Object: { + type = "object"; + + if (!Properties.empty()) { + auto& properties = result["properties"].SetType(NJson::JSON_MAP).GetMapSafe(); + properties.reserve(Properties.size()); + for (const auto& [name, builder] : Properties) { + properties.emplace(name, builder->Build()); + } + } + + if (!RequiredProperties.empty()) { + auto& required = result["required"].SetType(NJson::JSON_ARRAY).GetArraySafe(); + for (const auto& name : RequiredProperties) { + required.emplace_back(name); + } + } + + break; + } + } + + return result; +} + +TJsonSchemaBuilder& TJsonSchemaBuilder::Property(const TString& name, bool required) { + const auto [it, inserted] = Properties.emplace(name, std::make_shared(this)); + Y_DEBUG_VERIFY(inserted, "Internal error. Property should not be defined twice"); + + if (required) { + RequiredProperties.emplace_back(name); + } + + return *it->second; +} + +//// Utils + +TString FormatJsonValue(const NJson::TJsonValue& value) { + return TJsonParser(value).ToString(); +} + +TString FormatJsonValue(const TString& value) { + TJsonParser parser; + if (!parser.Parse(value)) { + return value; + } + return parser.ToString(); +} + +} // namespace NYdb::NConsoleClient::NAi diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/common/json_utils.h b/ydb/public/lib/ydb_cli/commands/ydb_ai/common/json_utils.h new file mode 100644 index 000000000000..0dc5fcbc08dd --- /dev/null +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/common/json_utils.h @@ -0,0 +1,87 @@ +#pragma once + +#include + +#include + +namespace NYdb::NConsoleClient::NAi { + +class TJsonParser { +public: + TJsonParser(); + + explicit TJsonParser(const NJson::TJsonValue& value); + + bool Parse(const TString& value); + + TString ToString() const; + + const NJson::TJsonValue& GetValue() const; + + TString GetFieldName() const; + + TJsonParser GetKey(const TString& key) const; + + std::optional MaybeKey(const TString& key) const; + + TJsonParser GetElement(ui64 index) const; + + void Iterate(std::function handler) const; + + TString GetString() const; + + bool IsNull() const; + + void ValidateType(NJson::EJsonValueType expectedType) const; + +private: + TJsonParser(std::shared_ptr jsonHolder, const NJson::TJsonValue* state, const TString& fieldName); + + TString AdvancePath(const TString& key) const; + + TString AdvancePath(const ui64& index) const; + + void Fail(const TString& message) const; + +private: + std::shared_ptr JsonHolder; + const NJson::TJsonValue* State = nullptr; + std::optional FieldName; +}; + +class TJsonSchemaBuilder { +public: + enum class EType { + Undefined, + String, + Object, + }; + + TJsonSchemaBuilder() = default; + + explicit TJsonSchemaBuilder(TJsonSchemaBuilder* parent); + + TJsonSchemaBuilder& Type(EType type); + + TJsonSchemaBuilder& Description(const TString& description); + + TJsonSchemaBuilder& Property(const TString& name, bool required = true); + + TJsonSchemaBuilder& Done() const; + + NJson::TJsonValue Build() const; + +private: + TJsonSchemaBuilder* Parent = nullptr; + + EType TypeValue = EType::Undefined; + std::unordered_map> Properties; + std::vector RequiredProperties; + std::optional DescriptionValue; +}; + +TString FormatJsonValue(const NJson::TJsonValue& value); + +TString FormatJsonValue(const TString& value); + +} // namespace NYdb::NConsoleClient::NAi diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/common/ya.make b/ydb/public/lib/ydb_cli/commands/ydb_ai/common/ya.make new file mode 100644 index 000000000000..ae7d960872a8 --- /dev/null +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/common/ya.make @@ -0,0 +1,13 @@ +LIBRARY() + +SRCS( + json_utils.cpp +) + +PEERDIR( + library/cpp/json + library/cpp/json/writer + ydb/core/base +) + +END() diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/line_reader.cpp b/ydb/public/lib/ydb_cli/commands/ydb_ai/line_reader.cpp new file mode 100644 index 000000000000..67d5c6b6109b --- /dev/null +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/line_reader.cpp @@ -0,0 +1,66 @@ +#include "line_reader.h" + +#include + +namespace NYdb::NConsoleClient::NAi { + +TLineReader::TLineReader(TString prompt, TString historyFilePath) + : Prompt(std::move(prompt)) + , HistoryFilePath(historyFilePath) + , HistoryFileLock(historyFilePath) +{ + Rx.install_window_change_handler(); + + Rx.bind_key(replxx::Replxx::KEY::control('N'), [this](char32_t code) { + return Rx.invoke(replxx::Replxx::ACTION::HISTORY_NEXT, code); + }); + Rx.bind_key(replxx::Replxx::KEY::control('P'), [this](char32_t code) { + return Rx.invoke(replxx::Replxx::ACTION::HISTORY_PREVIOUS, code); + }); + Rx.bind_key(replxx::Replxx::KEY::control('D'), [](char32_t) { + return replxx::Replxx::ACTION_RESULT::BAIL; + }); + Rx.bind_key(replxx::Replxx::KEY::control('J'), [this](char32_t code) { + return Rx.invoke(replxx::Replxx::ACTION::COMMIT_LINE, code); + }); + + Rx.enable_bracketed_paste(); + Rx.set_unique_history(true); + + if (const auto guard = TryLockHistory(); guard && !Rx.history_load(HistoryFilePath)) { + Rx.print("Loading history failed: %s\n", strerror(errno)); + } +} + +std::optional TLineReader::ReadLine() { + do { + if (const char* input = Rx.input(Prompt.c_str())) { + auto result = Strip(input); + AddToHistory(result); + return std::move(result); + } + } while (errno == EAGAIN); + + return std::nullopt; +} + +TTryGuard TLineReader::TryLockHistory() { + // AI-TODO: KIKIMR-24202 - robust file creation and handling + TTryGuard guard(HistoryFileLock); + + if (!guard) { + Rx.print("Lock of history file failed: %s\n", strerror(errno)); + } + + return guard; +} + +void TLineReader::AddToHistory(const TString& line) { + Rx.history_add(line); + + if (const auto guard = TryLockHistory(); guard && !Rx.history_save(HistoryFilePath)) { + Rx.print("Save history failed: %s\n", strerror(errno)); + } +} + +} // namespace NYdb::NConsoleClient::NAi diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/line_reader.h b/ydb/public/lib/ydb_cli/commands/ydb_ai/line_reader.h new file mode 100644 index 000000000000..f156a0f70a14 --- /dev/null +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/line_reader.h @@ -0,0 +1,30 @@ +#pragma once + +#include + +#include +#include + +#include + +namespace NYdb::NConsoleClient::NAi { + +class TLineReader { +public: + TLineReader(TString prompt, TString historyFilePath); + + std::optional ReadLine(); + +private: + TTryGuard TryLockHistory(); + + void AddToHistory(const TString& line); + +private: + TString Prompt; + TString HistoryFilePath; + TFileLock HistoryFileLock; + replxx::Replxx Rx; +}; + +} // namespace NYdb::NConsoleClient::NAi diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_anthropic.cpp b/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_anthropic.cpp new file mode 100644 index 000000000000..589e1ecedc75 --- /dev/null +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_anthropic.cpp @@ -0,0 +1,133 @@ +#include "model_anthropic.h" +#include "model_base.h" + +#include +#include + +#include + +#include +#include + +namespace NYdb::NConsoleClient::NAi { + +namespace { + +class TModelAnthropic final : public TModelBase { + using TBlase = TModelBase; + + static constexpr ui64 MAX_COMPLETION_TOKENS = 1024; + +public: + TModelAnthropic(const TAnthropicModelSettings& settings, const TClientCommand::TConfig& config) + : TBlase(CreateApiUrl(settings.BaseUrl, "/v1/messages"), settings.ApiKey, config) + , Tools(ChatCompletionRequest["tools"].SetType(NJson::JSON_ARRAY).GetArraySafe()) + , Conversation(ChatCompletionRequest["messages"].SetType(NJson::JSON_ARRAY).GetArraySafe()) + { + ChatCompletionRequest["model"] = settings.ModelId; + ChatCompletionRequest["max_tokens"] = MAX_COMPLETION_TOKENS; + } + + void RegisterTool(const TString& name, const NJson::TJsonValue& parametersSchema, const TString& description) final { + Y_DEBUG_VERIFY(ValidateToolName(name), "Internal error. Invalid tool name: %s", name.c_str()); + + auto& tool = Tools.emplace_back(); + tool["name"] = name; + tool["input_schema"] = parametersSchema; + tool["description"] = description; + } + +protected: + void AdvanceConversation(const std::vector& messages) final { + auto& conversationItem = Conversation.emplace_back(); + conversationItem["role"] = "user"; + + auto& content = conversationItem["content"].SetType(NJson::JSON_ARRAY).GetArraySafe(); + for (const auto& message : messages) { + auto& item = content.emplace_back(); + auto& type = item["type"]; + + if (std::holds_alternative(message)) { + item["text"] = std::get(message).Text; + type = "text"; + } else { + const auto& toolResponse = std::get(message); + item["content"] = toolResponse.Text; + item["tool_use_id"] = toolResponse.ToolCallId; + item["is_error"] = !toolResponse.IsSuccess; + type = "tool_result"; + } + } + } + + TResponse HandleModelResponse(const NJson::TJsonValue& response) final { + TResponse result; + + TJsonParser parser(response); + if (auto child = parser.MaybeKey("response")) { + parser = std::move(*child); + } + + parser = parser.GetKey("content"); + auto& conversationItem = Conversation.emplace_back(); + conversationItem["role"] = "assistant"; + conversationItem["content"] = parser.GetValue(); + + parser.Iterate([&](TJsonParser item) { + const auto& type = item.GetKey("type").GetString(); + if (type == "text") { + if (result.Text) { + throw yexception() << "Multiple conversation items contains text"; + } + result.Text = Strip(item.GetKey("text").GetString()); + } else if (type == "tool_use") { + result.ToolCalls.push_back({ + .Id = item.GetKey("id").GetString(), + .Name = item.GetKey("name").GetString(), + .Parameters = item.GetKey("input").GetValue(), + }); + } else { + throw yexception() << "Unknown conversation item type: " << type << ", expected text or tool_use"; + } + }); + + return result; + } + + TString HandleErrorResponse(ui64 httpCode, const TString& response) final { + TJsonParser parser; + if (!parser.Parse(response)) { + return TBlase::HandleErrorResponse(httpCode, response); + } + + auto error = TStringBuilder() << "Request to model API failed:\n"; + if (const auto& info = parser.MaybeKey("error")) { + return error << info->ToString(); + } + if (const auto& response = parser.MaybeKey("response")) { + if (const auto& info = parser.MaybeKey("error")) { + return error << info->ToString(); + } + return error << response->ToString(); + } + + return TBlase::HandleErrorResponse(httpCode, response); + } + +private: + static bool ValidateToolName(const TString& name) { + return 1 <= name.size() && name.size() <= 128; + } + +private: + NJson::TJsonValue::TArray& Tools; + NJson::TJsonValue::TArray& Conversation; +}; + +} // anonymous namespace + +IModel::TPtr CreateAnthropicModel(const TAnthropicModelSettings& settings, const TClientCommand::TConfig& config) { + return std::make_shared(settings, config); +} + +} // namespace NYdb::NConsoleClient::NAi \ No newline at end of file diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_anthropic.h b/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_anthropic.h new file mode 100644 index 000000000000..c95737eadb0c --- /dev/null +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_anthropic.h @@ -0,0 +1,17 @@ +#pragma once + +#include "model_interface.h" + +#include + +namespace NYdb::NConsoleClient::NAi { + +struct TAnthropicModelSettings { + TString BaseUrl; + TString ModelId; + std::optional ApiKey; +}; + +IModel::TPtr CreateAnthropicModel(const TAnthropicModelSettings& settings, const TClientCommand::TConfig& config); + +} // namespace NYdb::NConsoleClient::NAi diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_base.cpp b/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_base.cpp new file mode 100644 index 000000000000..7e6f063d9c44 --- /dev/null +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_base.cpp @@ -0,0 +1,145 @@ +#include "model_base.h" + +#include +#include + +#include +#include +#include +#include + +#include + +namespace NYdb::NConsoleClient::NAi { + +namespace { + +NYql::THttpHeader CreateApiHeaders(const std::optional& authToken) { + TSmallVec headers = {"Content-Type: application/json"}; + + if (authToken) { + headers.emplace_back(TStringBuilder() << "Authorization: Bearer " << *authToken); + } + + return {.Fields = std::move(headers)}; +} + +struct THttpResponse { + THttpResponse(TString&& content, ui64 httpCode) + : Content(std::move(content)) + , HttpCode(httpCode) + {} + + bool IsSuccess() const { + return HttpCode >= 200 && HttpCode < 300; + } + + TString Content; + ui64 HttpCode = 0; +}; + +} // anonymous namespace + +TModelBase::TModelBase(const TString& apiUrl, const std::optional& authToken, const TClientCommand::TConfig& config) + : Verbosity(config.VerbosityLevel) + , ApiUrl(apiUrl) + , ApiHeaders(CreateApiHeaders(authToken)) + , HttpGateway(NYql::IHTTPGateway::Make()) +{ + Y_DEBUG_VERIFY(apiUrl, "Internal error. Url should not be empty for model API"); + + if (Verbosity >= VERB_INFO) { + Cerr << "Using model API url: " << apiUrl << " with " + << (authToken ? TStringBuilder() << "auth token " << TString(authToken->size(), '*') : TStringBuilder() << "anonymous access") << Endl; + } +} + +TModelBase::TResponse TModelBase::HandleMessages(const std::vector& messages) { + Y_DEBUG_VERIFY(!messages.empty(), "Internal error. Messages should not be empty for advance conversation"); + + AdvanceConversation(messages); + + if (Verbosity >= VERB_TRACE) { + Cerr << "Request to model API:\n" << FormatJsonValue(ChatCompletionRequest) << Endl; + } + + NJsonWriter::TBuf requestJsonWriter; + requestJsonWriter.WriteJsonValue(&ChatCompletionRequest); + auto request = requestJsonWriter.Str(); + + auto responsePromise = NThreading::NewPromise(); + auto httpCallback = [&responsePromise](NYql::IHTTPGateway::TResult result) -> void { + const auto curlCode = result.CurlResponseCode; + if (curlCode == CURLE_OK) { + auto& content = result.Content; + responsePromise.SetValue(THttpResponse(content.Extract(), content.HttpResponseCode)); + return; + } + + auto error = TStringBuilder() << "Failed to connect to API server or process response, internal code: " << static_cast(curlCode); + if (result.Issues) { + error << ". Reason:\n" << result.Issues.ToString(); + } + responsePromise.SetException(error); + }; + + HttpGateway->Upload(ApiUrl, ApiHeaders, std::move(request), std::move(httpCallback)); + const auto response = responsePromise.GetFuture().ExtractValueSync(); + + if (Verbosity >= VERB_TRACE) { + Cerr << "Model API response http code: " << response.HttpCode; + if (response.Content) { + Cerr << ". Response data:\n" << FormatJsonValue(response.Content); + } + Cerr << Endl; + } + + if (!response.IsSuccess()) { + throw yexception() << HandleErrorResponse(response.HttpCode, response.Content); + } + + NJson::TJsonValue responseJson; + try { + NJson::ReadJsonTree(response.Content, &responseJson, /* throwOnError */ true); + } catch (const std::exception& e) { + throw yexception() << "Model API response is not valid JSON, reason: " << e.what(); + } + + try { + return HandleModelResponse(responseJson); + } catch (const std::exception& e) { + throw yexception() << "Processing model response error. " << e.what(); + } +} + +TString TModelBase::HandleErrorResponse(ui64 httpCode, const TString& response) { + auto error = TStringBuilder() << "Request to model API failed with code: " << httpCode; + if (response) { + error << ". Response:\n" << FormatJsonValue(response); + } + return error; +} + +TString TModelBase::CreateApiUrl(const TString& baseUrl, const TString& uri) { + Y_DEBUG_VERIFY(uri, "Internal error. Uri should not be empty for model API"); + + TStringBuf sanitizedUrl; + TStringBuf query; + TStringBuf fragment; + SeparateUrlFromQueryAndFragment(baseUrl, sanitizedUrl, query, fragment); + + if (query || fragment) { + auto error = yexception() << "Invalid model API base url: '" << baseUrl << "'"; + if (query) { + error << ". Query part should be empty, but got: '" << query << "'"; + } + if (fragment) { + error << ". Fragment part should be empty, but got: '" << fragment << "'"; + } + throw error; + } + + return TStringBuilder() << RemoveFinalSlash(sanitizedUrl) << uri; +} + +} // namespace NYdb::NConsoleClient::NAi diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_base.h b/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_base.h new file mode 100644 index 000000000000..8beef05c0a6d --- /dev/null +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_base.h @@ -0,0 +1,41 @@ +#pragma once + +#include "model_interface.h" + +#include +#include + +namespace NYdb::NConsoleClient::NAi { + +class TModelBase : public IModel { +protected: + enum EVerboseLevel { + VERB_INFO = 1, + VERB_TRACE = 2, + }; + +public: + TModelBase(const TString& apiUrl, const std::optional& authToken, const TClientCommand::TConfig& config); + + TResponse HandleMessages(const std::vector& messages) final; + +protected: + virtual void AdvanceConversation(const std::vector& messages) = 0; + + virtual TResponse HandleModelResponse(const NJson::TJsonValue& response) = 0; + + virtual TString HandleErrorResponse(ui64 httpCode, const TString& response); + + static TString CreateApiUrl(const TString& baseUrl, const TString& uri); + +protected: + NJson::TJsonValue ChatCompletionRequest; + +private: + const ui64 Verbosity = 0; + const TString ApiUrl; + const NYql::THttpHeader ApiHeaders; + const NYql::IHTTPGateway::TPtr HttpGateway; +}; + +} // namespace NYdb::NConsoleClient::NAi diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_interface.h b/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_interface.h new file mode 100644 index 000000000000..8677d9525b3f --- /dev/null +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_interface.h @@ -0,0 +1,45 @@ +#pragma once + +#include + +#include + +#include + +namespace NYdb::NConsoleClient::NAi { + +class IModel { +public: + using TPtr = std::shared_ptr; + + virtual ~IModel() = default; + + struct TUserMessage { + TString Text; + }; + + struct TToolResponse { + TString Text; + TString ToolCallId; + bool IsSuccess = true; + }; + + using TMessage = std::variant; + + struct TResponse { + struct TToolCall { + TString Id; + TString Name; + NJson::TJsonValue Parameters; + }; + + TString Text; + std::vector ToolCalls; + }; + + virtual TResponse HandleMessages(const std::vector& messages) = 0; + + virtual void RegisterTool(const TString& name, const NJson::TJsonValue& parametersSchema, const TString& description) = 0; +}; + +} // namespace NYdb::NConsoleClient::NAi diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_openai.cpp b/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_openai.cpp new file mode 100644 index 000000000000..27d78571e273 --- /dev/null +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_openai.cpp @@ -0,0 +1,157 @@ +#include "model_openai.h" +#include "model_base.h" + +#include +#include + +#include + +#include +#include + +namespace NYdb::NConsoleClient::NAi { + +namespace { + +class TModelOpenAi final : public TModelBase { + using TBlase = TModelBase; + + static constexpr ui64 MAX_COMPLETION_TOKENS = 1024; + +public: + TModelOpenAi(const TOpenAiModelSettings& settings, const TClientCommand::TConfig& config) + : TBlase(CreateApiUrl(settings.BaseUrl, "/v1/chat/completions"), settings.ApiKey, config) + , Tools(ChatCompletionRequest["tools"].SetType(NJson::JSON_ARRAY).GetArraySafe()) + , Conversation(ChatCompletionRequest["messages"].SetType(NJson::JSON_ARRAY).GetArraySafe()) + { + if (settings.ModelId) { + ChatCompletionRequest["model"] = *settings.ModelId; + } + + ChatCompletionRequest["max_completion_tokens"] = MAX_COMPLETION_TOKENS; + } + + void RegisterTool(const TString& name, const NJson::TJsonValue& parametersSchema, const TString& description) final { + Y_DEBUG_VERIFY(ValidateToolName(name), "Internal error. Invalid tool name: %s", name.c_str()); + + auto& tool = Tools.emplace_back(); + tool["type"] = "function"; + + auto& toolInfo = tool["function"]; + toolInfo["name"] = name; + toolInfo["parameters"] = parametersSchema; + toolInfo["description"] = description; + } + +protected: + void AdvanceConversation(const std::vector& messages) final { + for (const auto& message : messages) { + auto& item = Conversation.emplace_back(); + auto& content = item["content"]; + auto& role = item["role"]; + + if (std::holds_alternative(message)) { + content = std::get(message).Text; + role = "user"; + } else { + const auto& toolResponse = std::get(message); + item["tool_call_id"] = toolResponse.ToolCallId; + content = toolResponse.Text; + role = "tool"; + } + } + } + + TResponse HandleModelResponse(const NJson::TJsonValue& response) final { + TResponse result; + + TJsonParser parser(response); + if (auto child = parser.MaybeKey("response")) { + parser = std::move(*child); + } + + parser = parser.GetKey("choices").GetElement(0).GetKey("message"); + Conversation.emplace_back(parser.GetValue()); + + const auto& content = parser.MaybeKey("content"); + const bool hasContent = content && !content->IsNull(); + if (hasContent) { + result.Text = Strip(content->GetString()); + } + + const auto& tollCalls = parser.MaybeKey("tool_calls"); + const bool hasToolsCalls = tollCalls && !tollCalls->IsNull(); + if (hasToolsCalls) { + tollCalls->Iterate([&](TJsonParser toolCall) { + auto function = toolCall.GetKey("function"); + + NJson::TJsonValue argumentsJson; + try { + NJson::ReadJsonTree(function.GetKey("arguments").GetString(), &argumentsJson, /* throwOnError */ true); + } catch (const std::exception& e) { + throw yexception() << "Tool call arguments is not valid JSON, reason: " << e.what(); + } + + result.ToolCalls.push_back({ + .Id = toolCall.GetKey("id").GetString(), + .Name = function.GetKey("name").GetString(), + .Parameters = std::move(argumentsJson), + }); + }); + } + + if (!hasContent && !hasToolsCalls) { + throw yexception() << "Not found either content or tool_calls keys in field " << parser.GetFieldName(); + } + + return result; + } + + TString HandleErrorResponse(ui64 httpCode, const TString& response) final { + TJsonParser parser; + if (!parser.Parse(response)) { + return TBlase::HandleErrorResponse(httpCode, response); + } + + auto error = TStringBuilder() << "Request to model API failed:\n"; + if (const auto& info = parser.MaybeKey("message")) { + return error << info->ToString(); + } + if (const auto& info = parser.MaybeKey("raw_response")) { + return error << info->ToString(); + } + if (const auto& info = parser.MaybeKey("error")) { + return error << info->ToString(); + } + + return TBlase::HandleErrorResponse(httpCode, response); + } + +private: + static bool ValidateToolName(const TString& name) { + if (name.size() > 64) { + return false; + } + + for (const auto c : name) { + if (('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || ('0' <= c && c <= '9') || IsIn({'_', '-'}, c)) { + continue; + } + return false; + } + + return true; + } + +private: + NJson::TJsonValue::TArray& Tools; + NJson::TJsonValue::TArray& Conversation; +}; + +} // anonymous namespace + +IModel::TPtr CreateOpenAiModel(const TOpenAiModelSettings& settings, const TClientCommand::TConfig& config) { + return std::make_shared(settings, config); +} + +} // namespace NYdb::NConsoleClient::NAi diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_openai.h b/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_openai.h new file mode 100644 index 000000000000..150ed85d944a --- /dev/null +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/models/model_openai.h @@ -0,0 +1,17 @@ +#pragma once + +#include "model_interface.h" + +#include + +namespace NYdb::NConsoleClient::NAi { + +struct TOpenAiModelSettings { + TString BaseUrl; + std::optional ModelId; + std::optional ApiKey; +}; + +IModel::TPtr CreateOpenAiModel(const TOpenAiModelSettings& settings, const TClientCommand::TConfig& config); + +} // namespace NYdb::NConsoleClient::NAi diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/models/ya.make b/ydb/public/lib/ydb_cli/commands/ydb_ai/models/ya.make new file mode 100644 index 000000000000..5e82e0375bfa --- /dev/null +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/models/ya.make @@ -0,0 +1,19 @@ +LIBRARY() + +SRCS( + model_anthropic.cpp + model_base.cpp + model_openai.cpp +) + +PEERDIR( + library/cpp/json + library/cpp/json/writer + library/cpp/string_utils/url + library/cpp/threading/future + ydb/library/yql/providers/common/http_gateway + ydb/public/lib/ydb_cli/commands/ydb_ai/common + ydb/public/lib/ydb_cli/common +) + +END() diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/exec_query_tool.cpp b/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/exec_query_tool.cpp new file mode 100644 index 000000000000..419f204370cd --- /dev/null +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/exec_query_tool.cpp @@ -0,0 +1,124 @@ +#include "exec_query_tool.h" + +#include +#include +#include +#include +#include + +#include + +namespace NYdb::NConsoleClient::NAi { + +namespace { + +class TExecQueryTool final : public ITool { + static constexpr char DESCRIPTION[] = R"( +Execute query in Yandex Data Base (YDB) on YQL (SQL dialect). Returns list of result sets for query, each contains list of rows and column metadata. +For example if there exists table 'my_table' with string column 'Data' and we execute query: +``` +$filtered = SELECT * FROM my_table WHERE Data IS NOT NULL; +SELECT Data || "-first" FROM $filtered; +SELECT Data || "-second" FROM $filtered; +``` +Tool will return: +[ + { + "rows": [ + {"Data": "A-first"}, + {"Data": "B-first"} + ], + "columns": [ + {"name": "Data", "type": "string"} + ] + }, + { + "rows": [ + {"Data": "A-second"}, + {"Data": "B-second"} + ], + "columns": [ + {"name": "Data", "type": "string"} + ] + } +])"; + + static constexpr char QUERY_PROPERTY[] = "query"; + +public: + explicit TExecQueryTool(TClientCommand::TConfig& config) + : ParametersSchema(TJsonSchemaBuilder() + .Type(TJsonSchemaBuilder::EType::Object) + .Property(QUERY_PROPERTY) + .Type(TJsonSchemaBuilder::EType::String) + .Description("Query to execute on YQL (SQL dialect), for example 'SELECT * FROM my_table'") + .Done() + .Build() + ) + , Description(DESCRIPTION) + , Client(TDriver(config.CreateDriverConfig())) + {} + + const NJson::TJsonValue& GetParametersSchema() const final { + return ParametersSchema; + } + + const TString& GetDescription() const final { + return Description; + } + + TResponse Execute(const NJson::TJsonValue& parameters) final try { + const TString& query = ParseParameters(parameters); + const auto response = Client.ExecuteQuery(query, NQuery::TTxControl::NoTx()).ExtractValueSync(); + if (!response.IsSuccess()) { + return TResponse(TStringBuilder() << "Query execution failed with status " << response.GetStatus() << ", reason:\n" << response.GetIssues().ToString()); + } + + NJson::TJsonValue result; + auto& resultArray = result.SetType(NJson::JSON_ARRAY).GetArraySafe(); + for (const auto& resultSet : response.GetResultSets()) { + auto& item = resultArray.emplace_back(); + + const auto& columnMeta = resultSet.GetColumnsMeta(); + auto& columns = item["columns"].SetType(NJson::JSON_ARRAY).GetArraySafe(); + for (const auto& column : columnMeta) { + auto& item = columns.emplace_back(); + item["name"] = column.Name; + item["type"] = column.Type.ToString(); + } + + auto& rows = item["rows"].SetType(NJson::JSON_ARRAY).GetArraySafe(); + TResultSetParser parser(resultSet); + while (parser.TryNextRow()) { + TJsonParser row; + Y_DEBUG_VERIFY(row.Parse(FormatResultRowJson(parser, resultSet.GetColumnsMeta(), EBinaryStringEncoding::Unicode)), "Internal error. Invalid serialized JSON row value."); + rows.emplace_back(row.GetValue()); + } + + TResultSetPrinter(EDataFormat::Pretty).Print(resultSet); + } + + return TResponse(std::move(result)); + } catch (const std::exception& e) { + return TResponse(TStringBuilder() << "Query execution failed. " << e.what()); + } + +private: + TString ParseParameters(const NJson::TJsonValue& parameters) const { + TJsonParser parser(parameters); + return Strip(parser.GetKey(QUERY_PROPERTY).GetString()); + } + +private: + const NJson::TJsonValue ParametersSchema; + const TString Description; + NQuery::TQueryClient Client; +}; + +} // anonymous namespace + +ITool::TPtr CreateExecQueryTool(TClientCommand::TConfig& config) { + return std::make_shared(config); +} + +} // namespace NYdb::NConsoleClient::NAi diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/exec_query_tool.h b/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/exec_query_tool.h new file mode 100644 index 000000000000..a7e8407f5305 --- /dev/null +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/exec_query_tool.h @@ -0,0 +1,11 @@ +#pragma once + +#include "tool_interface.h" + +#include + +namespace NYdb::NConsoleClient::NAi { + +ITool::TPtr CreateExecQueryTool(TClientCommand::TConfig& config); + +} // namespace NYdb::NConsoleClient::NAi diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/list_directory_tool.cpp b/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/list_directory_tool.cpp new file mode 100644 index 000000000000..b971f399ab57 --- /dev/null +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/list_directory_tool.cpp @@ -0,0 +1,99 @@ +#include "list_directory_tool.h" + +#include +#include +#include +#include +#include + +#include + +namespace NYdb::NConsoleClient::NAi { + +namespace { + +class TListDirectoryTool final : public ITool { + static constexpr char DESCRIPTION[] = R"( +List directory in Yandex Data Base (YDB) scheme tree. Returns list of item names inside directory and their types. +For example if called on directory 'data/', which contains two tables 'my_table1', 'my_table2' and one topic 'my_topic', then tool will return: +[ + {"name": "my_table1", "type": "table"}, + {"name": "my_table2", "type": "table"}, + {"name": "my_topic", "type": "topic"} +])"; + + static constexpr char DIRECTORY_PROPERTY[] = "directory"; + +public: + explicit TListDirectoryTool(TClientCommand::TConfig& config) + : ParametersSchema(TJsonSchemaBuilder() + .Type(TJsonSchemaBuilder::EType::Object) + .Property(DIRECTORY_PROPERTY) + .Type(TJsonSchemaBuilder::EType::String) + .Description("Path to directory which should be listed (use empty string to list database root), for example 'data/cold/'") + .Done() + .Build() + ) + , Description(DESCRIPTION) + , Database(NKikimr::CanonizePath(config.Database)) + , Client(TDriver(config.CreateDriverConfig())) + {} + + const NJson::TJsonValue& GetParametersSchema() const final { + return ParametersSchema; + } + + const TString& GetDescription() const final { + return Description; + } + + TResponse Execute(const NJson::TJsonValue& parameters) final try { + const auto& directory = ParseParameters(parameters); + const auto response = Client.ListDirectory(directory).ExtractValueSync(); + if (!response.IsSuccess()) { + return TResponse(TStringBuilder() << "Listing directory failed with status " << response.GetStatus() << ", reason:\n" << response.GetIssues().ToString()); + } + + const auto& children = response.GetChildren(); + + NJson::TJsonValue result; + auto& resultArray = result.SetType(NJson::JSON_ARRAY).GetArraySafe(); + for (const auto& child : children) { + auto& item = resultArray.emplace_back(); + item["name"] = child.Name; + item["type"] = EntryTypeToString(child.Type); + } + + Cout << TAdaptiveTabbedTable(children); + + return TResponse(std::move(result)); + } catch (const std::exception& e) { + return TResponse(TStringBuilder() << "Listing directory failed. " << e.what()); + } + +private: + TString ParseParameters(const NJson::TJsonValue& parameters) const { + TJsonParser parser(parameters); + + TString directory = Strip(parser.GetKey(DIRECTORY_PROPERTY).GetString()); + if (!directory.StartsWith('/')) { + directory = NKikimr::JoinPath({Database, directory}); + } + + return NKikimr::CanonizePath(directory); + } + +private: + const NJson::TJsonValue ParametersSchema; + const TString Description; + const TString Database; + NScheme::TSchemeClient Client; +}; + +} // anonymous namespace + +ITool::TPtr CreateListDirectoryTool(TClientCommand::TConfig& config) { + return std::make_shared(config); +} + +} // namespace NYdb::NConsoleClient::NAi diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/list_directory_tool.h b/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/list_directory_tool.h new file mode 100644 index 000000000000..175c5cdf0545 --- /dev/null +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/list_directory_tool.h @@ -0,0 +1,11 @@ +#pragma once + +#include "tool_interface.h" + +#include + +namespace NYdb::NConsoleClient::NAi { + +ITool::TPtr CreateListDirectoryTool(TClientCommand::TConfig& config); + +} // namespace NYdb::NConsoleClient::NAi diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/tool_interface.cpp b/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/tool_interface.cpp new file mode 100644 index 000000000000..690579520d7c --- /dev/null +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/tool_interface.cpp @@ -0,0 +1,17 @@ +#include "tool_interface.h" + +#include + +namespace NYdb::NConsoleClient::NAi { + +ITool::TResponse::TResponse(const TString& error) + : Text(error) + , IsSuccess(false) +{} + +ITool::TResponse::TResponse(const NJson::TJsonValue& result) + : Text(FormatJsonValue(result)) + , IsSuccess(true) +{} + +} // namespace NYdb::NConsoleClient::NAi diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/tool_interface.h b/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/tool_interface.h new file mode 100644 index 000000000000..c43beb40f9ad --- /dev/null +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/tool_interface.h @@ -0,0 +1,31 @@ +#pragma once + +#include + +#include + +namespace NYdb::NConsoleClient::NAi { + +class ITool { +public: + using TPtr = std::shared_ptr; + + virtual ~ITool() = default; + + struct TResponse { + TString Text; + bool IsSuccess = true; + + explicit TResponse(const TString& error); + + explicit TResponse(const NJson::TJsonValue& result); + }; + + virtual const NJson::TJsonValue& GetParametersSchema() const = 0; + + virtual const TString& GetDescription() const = 0; + + virtual TResponse Execute(const NJson::TJsonValue& parameters) = 0; +}; + +} // namespace NYdb::NConsoleClient::NAi diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/ya.make b/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/ya.make new file mode 100644 index 000000000000..6398e7700d38 --- /dev/null +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/tools/ya.make @@ -0,0 +1,19 @@ +LIBRARY() + +SRCS( + exec_query_tool.cpp + list_directory_tool.cpp + tool_interface.cpp +) + +PEERDIR( + library/cpp/json/writer + ydb/core/base + ydb/public/lib/json_value + ydb/public/lib/ydb_cli/commands/ydb_ai/common + ydb/public/lib/ydb_cli/common + ydb/public/sdk/cpp/src/client/query + ydb/public/sdk/cpp/src/client/scheme +) + +END() diff --git a/ydb/public/lib/ydb_cli/commands/ydb_ai/ya.make b/ydb/public/lib/ydb_cli/commands/ydb_ai/ya.make new file mode 100644 index 000000000000..78586dda5f78 --- /dev/null +++ b/ydb/public/lib/ydb_cli/commands/ydb_ai/ya.make @@ -0,0 +1,17 @@ +LIBRARY() + +SRCS( + line_reader.cpp +) + +PEERDIR( + contrib/restricted/patched/replxx + util +) + +END() + +RECURSE( + models + tools +) diff --git a/ydb/public/lib/ydb_cli/commands/ydb_root_common.cpp b/ydb/public/lib/ydb_cli/commands/ydb_root_common.cpp index fdf4bf86a8f1..e48e854404cf 100644 --- a/ydb/public/lib/ydb_cli/commands/ydb_root_common.cpp +++ b/ydb/public/lib/ydb_cli/commands/ydb_root_common.cpp @@ -1,6 +1,7 @@ #include "ydb_root_common.h" #include "ydb_profile.h" #include "ydb_admin.h" +#include "ydb_ai.h" #include "ydb_debug.h" #include "ydb_service_auth.h" #include "ydb_service_discovery.h" @@ -38,6 +39,7 @@ TClientCommandRootCommon::TClientCommandRootCommon(const TString& name, const TC { ValidateSettings(); AddDangerousCommand(std::make_unique()); + AddCommand(std::make_unique()); AddCommand(std::make_unique()); AddCommand(std::make_unique()); AddCommand(std::make_unique());