diff --git a/src/cpp/include/openvino/genai/chat_history.hpp b/src/cpp/include/openvino/genai/chat_history.hpp new file mode 100644 index 0000000000..f2baf2ec03 --- /dev/null +++ b/src/cpp/include/openvino/genai/chat_history.hpp @@ -0,0 +1,74 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "openvino/genai/visibility.hpp" +#include "openvino/genai/json_container.hpp" + +namespace ov { +namespace genai { + +/** + * @brief ChatHistory stores conversation messages and optional metadata for chat templates. + * + * Manages: + * - Message history (array of message objects) + * - Optional tools definitions array (for function calling) + * - Optional extra context object (for custom template variables) + */ +class OPENVINO_GENAI_EXPORTS ChatHistory { +public: + ChatHistory(); + + explicit ChatHistory(const JsonContainer& messages); + + explicit ChatHistory(const std::vector& messages); + + /** + * @brief Construct from initializer list for convenient inline creation. + * + * Example: + * ChatHistory history({ + * {{"role", "system"}, {"content", "You are helpful assistant."}}, + * {{"role", "user"}, {"content", "Hello"}} + * }); + */ + ChatHistory(std::initializer_list>> messages); + + ~ChatHistory(); + + ChatHistory& push_back(const JsonContainer& message); + ChatHistory& push_back(const ov::AnyMap& message); + ChatHistory& push_back(std::initializer_list> message); + + void pop_back(); + + const JsonContainer& get_messages() const; + JsonContainer& get_messages(); + + JsonContainer operator[](size_t index) const; + JsonContainer operator[](int index) const; + + JsonContainer first() const; + JsonContainer last() const; + + void clear(); + + size_t size() const; + bool empty() const; + + ChatHistory& set_tools(const JsonContainer& tools); + const JsonContainer& get_tools() const; + + ChatHistory& set_extra_context(const JsonContainer& extra_context); + const JsonContainer& get_extra_context() const; + +private: + JsonContainer m_messages = JsonContainer::array(); + JsonContainer m_tools = JsonContainer::array(); + JsonContainer m_extra_context = JsonContainer::object(); +}; + +} // namespace genai +} // namespace ov diff --git a/src/cpp/include/openvino/genai/json_container.hpp b/src/cpp/include/openvino/genai/json_container.hpp index e669ecfdd8..d109241eef 100644 --- a/src/cpp/include/openvino/genai/json_container.hpp +++ b/src/cpp/include/openvino/genai/json_container.hpp @@ -8,8 +8,6 @@ #include #include -#include - #include "openvino/core/any.hpp" #include "openvino/genai/visibility.hpp" @@ -17,21 +15,6 @@ namespace ov { namespace genai { class OPENVINO_GENAI_EXPORTS JsonContainer { -private: - template - struct is_json_primitive { - using type = typename std::decay::type; - static constexpr bool value = - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value; - }; - public: /** * @brief Default constructor creates an empty JSON object. @@ -41,9 +24,14 @@ class OPENVINO_GENAI_EXPORTS JsonContainer { /** * @brief Construct from JSON primitive types (bool, int64_t, double, string, etc.). */ - template - JsonContainer(T&& value, typename std::enable_if::value, int>::type = 0) : - JsonContainer(nlohmann::ordered_json(std::forward(value))) {} + JsonContainer(bool value); + JsonContainer(int value); + JsonContainer(int64_t value); + JsonContainer(double value); + JsonContainer(float value); + JsonContainer(const std::string& value); + JsonContainer(const char* value); + JsonContainer(std::nullptr_t); /** * @brief Construct from initializer list of key-value pairs. @@ -105,13 +93,14 @@ class OPENVINO_GENAI_EXPORTS JsonContainer { /** * @brief Assignment operator for JSON primitive types (bool, int64_t, double, string, etc.). */ - template - typename std::enable_if::value, JsonContainer&>::type - operator=(T&& value) { - auto json_value_ptr = get_json_value_ptr(AccessMode::Write); - *json_value_ptr = nlohmann::ordered_json(std::forward(value)); - return *this; - } + JsonContainer& operator=(bool value); + JsonContainer& operator=(int value); + JsonContainer& operator=(int64_t value); + JsonContainer& operator=(double value); + JsonContainer& operator=(float value); + JsonContainer& operator=(const std::string& value); + JsonContainer& operator=(const char* value); + JsonContainer& operator=(std::nullptr_t); /** * @brief Copy assignment operator. @@ -156,7 +145,7 @@ class OPENVINO_GENAI_EXPORTS JsonContainer { * @param value JsonContainer to append * @return Reference to this container for chaining */ - JsonContainer& push_back(const JsonContainer& value); + JsonContainer& push_back(const JsonContainer& item); /** * @brief Add JSON primitive to end of array. @@ -164,16 +153,14 @@ class OPENVINO_GENAI_EXPORTS JsonContainer { * @param value JSON primitive to append (bool, int64_t, double, string, etc.) * @return Reference to this container for chaining */ - template - typename std::enable_if::value, JsonContainer&>::type - push_back(T&& value) { - auto json_value_ptr = get_json_value_ptr(AccessMode::Write); - if (!json_value_ptr->is_array()) { - *json_value_ptr = nlohmann::ordered_json::array(); - } - json_value_ptr->push_back(nlohmann::ordered_json(std::forward(value))); - return *this; - } + JsonContainer& push_back(bool value); + JsonContainer& push_back(int value); + JsonContainer& push_back(int64_t value); + JsonContainer& push_back(double value); + JsonContainer& push_back(float value); + JsonContainer& push_back(const std::string& value); + JsonContainer& push_back(const char* value); + JsonContainer& push_back(std::nullptr_t); /** * @brief Convert this container to an empty object. @@ -230,29 +217,27 @@ class OPENVINO_GENAI_EXPORTS JsonContainer { */ std::string to_json_string(int indent = -1) const; - /** - * @brief Convert to nlohmann::ordered_json for internal use. - * @return nlohmann::ordered_json representation - */ - nlohmann::ordered_json to_json() const; - /** * @brief Get string representation of the JSON type. * @return Type name: "null", "boolean", "number", "string", "array", "object" or "unknown" */ std::string type_name() const; + /** + * @internal + * @brief Internal use only - get pointer to underlying JSON for serialization. + * @return Opaque pointer to internal JSON representation + */ + void* _get_json_value_ptr() const; + private: - JsonContainer(std::shared_ptr json_ptr, const std::string& path); - JsonContainer(nlohmann::ordered_json json); + class JsonContainerImpl; - std::shared_ptr m_json; + JsonContainer(std::shared_ptr impl, const std::string& path = ""); - std::string m_path = ""; + std::shared_ptr m_impl; - enum class AccessMode { Read, Write }; - - nlohmann::ordered_json* get_json_value_ptr(AccessMode mode) const; + std::string m_path = ""; }; } // namespace genai diff --git a/src/cpp/include/openvino/genai/tokenizer.hpp b/src/cpp/include/openvino/genai/tokenizer.hpp index 12231406e1..f8cf0cf93b 100644 --- a/src/cpp/include/openvino/genai/tokenizer.hpp +++ b/src/cpp/include/openvino/genai/tokenizer.hpp @@ -13,11 +13,13 @@ #include "openvino/genai/visibility.hpp" #include +#include "openvino/genai/chat_history.hpp" + namespace ov { namespace genai { -using ChatHistory = std::vector; -using ToolDefinitions = std::vector; +using ov::genai::JsonContainer; +using ov::genai::ChatHistory; using Vocab = std::unordered_map; // similar to huggingface .get_vocab() output format @@ -248,20 +250,20 @@ class OPENVINO_GENAI_EXPORTS Tokenizer { * For example, for Qwen family models, the prompt "1+1=" would be transformed into * <|im_start|>user\n1+1=<|im_end|>\n<|im_start|>assistant\n. * - * @param history A vector of chat messages, where each message is represented as a map, e.g. [{"role": "user", "content": "prompt"}, ...]. + * @param history Chat history containing the conversation messages and optional tools/extra_context. Each message is a JSON-like object, e.g. [{"role": "user", "content": "prompt"}, ...]. * @param add_generation_prompt Whether to add an ending that indicates the start of generation. * @param chat_template An optional custom chat template string, if not specified will be taken from the tokenizer. - * @param tools An optional vector of tool definitions to be used in the chat template. - * @param extra_context An optional map of additional variables to be used in the chat template. + * @param tools An optional JSON-like array of tool definitions to be used in the chat template. If provided, overrides tools from chat history. + * @param extra_context An optional JSON-like object with additional variables to be used in the chat template. If provided, overrides extra_context from chat history. * @return A string with the formatted and concatenated prompts from the chat history. * @throws Exception if the chat template was unable to parse the input history. */ std::string apply_chat_template( - ChatHistory history, + const ChatHistory& history, bool add_generation_prompt, const std::string& chat_template = {}, - const ToolDefinitions& tools = {}, - const ov::AnyMap& extra_context = {} + const std::optional& tools = std::nullopt, + const std::optional& extra_context = std::nullopt ) const; /// @brief Override a chat_template read from tokenizer_config.json. diff --git a/src/cpp/src/chat_history.cpp b/src/cpp/src/chat_history.cpp new file mode 100644 index 0000000000..adc9d8feb2 --- /dev/null +++ b/src/cpp/src/chat_history.cpp @@ -0,0 +1,124 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "openvino/genai/chat_history.hpp" + +namespace ov { +namespace genai { + +ChatHistory::ChatHistory() = default; + +ChatHistory::ChatHistory(const JsonContainer& messages) : m_messages(messages) { + if (!m_messages.is_array()) { + OPENVINO_THROW("Chat history must be initialized with a JSON array."); + } +} +ChatHistory::ChatHistory(const std::vector& messages) : + m_messages(JsonContainer::array()) { + for (const auto& message : messages) { + m_messages.push_back(JsonContainer(message)); + } +} + +ChatHistory::ChatHistory(std::initializer_list>> messages) : + m_messages(JsonContainer::array()) { + for (const auto& message : messages) { + m_messages.push_back(JsonContainer(message)); + } +} + +ChatHistory::~ChatHistory() = default; + +ChatHistory& ChatHistory::push_back(const JsonContainer& message) { + m_messages.push_back(message); + return *this; +} + +ChatHistory& ChatHistory::push_back(const ov::AnyMap& message) { + m_messages.push_back(JsonContainer(message)); + return *this; +} + +ChatHistory& ChatHistory::push_back(std::initializer_list> message) { + m_messages.push_back(JsonContainer(message)); + return *this; +} + +void ChatHistory::pop_back() { + if (m_messages.empty()) { + OPENVINO_THROW("Cannot pop_back from an empty chat history."); + } + m_messages.erase(m_messages.size() - 1); +} + +const JsonContainer& ChatHistory::get_messages() const { + return m_messages; +} + +JsonContainer& ChatHistory::get_messages() { + return m_messages; +} + +JsonContainer ChatHistory::operator[](size_t index) const { + if (index >= m_messages.size()) { + OPENVINO_THROW("Index ", index, " is out of bounds for chat history of size ", m_messages.size()); + } + return m_messages[index]; +} + +JsonContainer ChatHistory::operator[](int index) const { + return operator[](size_t(index)); +} + +JsonContainer ChatHistory::first() const { + if (m_messages.empty()) { + OPENVINO_THROW("Cannot access first message of an empty chat history."); + } + return m_messages[0]; +} + +JsonContainer ChatHistory::last() const { + if (m_messages.empty()) { + OPENVINO_THROW("Cannot access last message of an empty chat history."); + } + return m_messages[m_messages.size() - 1]; +} + +void ChatHistory::clear() { + m_messages.clear(); +} + +size_t ChatHistory::size() const { + return m_messages.size(); +} + +bool ChatHistory::empty() const { + return m_messages.empty(); +} + +ChatHistory& ChatHistory::set_tools(const JsonContainer& tools) { + if (!tools.is_array()) { + OPENVINO_THROW("Tools must be an array-like JsonContainer."); + } + m_tools = tools; + return *this; +} + +const JsonContainer& ChatHistory::get_tools() const { + return m_tools; +} + +ChatHistory& ChatHistory::set_extra_context(const JsonContainer& extra_context) { + if (!extra_context.is_object()) { + OPENVINO_THROW("Extra context must be an object-like JsonContainer."); + } + m_extra_context = extra_context; + return *this; +} + +const JsonContainer& ChatHistory::get_extra_context() const { + return m_extra_context; +} + +} // namespace genai +} // namespace ov diff --git a/src/cpp/src/json_container.cpp b/src/cpp/src/json_container.cpp index cc5e194fc3..792c32e3ab 100644 --- a/src/cpp/src/json_container.cpp +++ b/src/cpp/src/json_container.cpp @@ -1,10 +1,10 @@ // Copyright (C) 2025 Intel Corporation // SPDX-License-Identifier: Apache-2.0 -#include - #include "openvino/core/except.hpp" +#include + #include "openvino/genai/json_container.hpp" #include "json_utils.hpp" @@ -12,37 +12,74 @@ namespace ov { namespace genai { +enum class AccessMode { Read, Write }; + +class JsonContainer::JsonContainerImpl { +public: + JsonContainerImpl() : m_json(nlohmann::ordered_json::object()) {} + explicit JsonContainerImpl(nlohmann::ordered_json json) : m_json(std::move(json)) {} + + nlohmann::ordered_json* get_json_value_ptr(const std::string& path, AccessMode mode) { + auto json_pointer = nlohmann::ordered_json::json_pointer(path); + if (mode == AccessMode::Read && !m_json.contains(json_pointer)) { + OPENVINO_THROW("Path '", path, "' does not exist in the JsonContainer."); + } + return &m_json[json_pointer]; + } + + const nlohmann::ordered_json* get_json_value_ptr(const std::string& path, AccessMode mode) const { + return const_cast(this)->get_json_value_ptr(path, mode); + } + +private: + nlohmann::ordered_json m_json; +}; + JsonContainer::JsonContainer() : - JsonContainer(nlohmann::ordered_json::object()) {} + m_impl(std::make_shared()) {} + +JsonContainer::JsonContainer(bool value) : + m_impl(std::make_shared(nlohmann::ordered_json(value))) {} +JsonContainer::JsonContainer(int value) : + m_impl(std::make_shared(nlohmann::ordered_json(value))) {} +JsonContainer::JsonContainer(int64_t value) : + m_impl(std::make_shared(nlohmann::ordered_json(value))) {} +JsonContainer::JsonContainer(double value) : + m_impl(std::make_shared(nlohmann::ordered_json(value))) {} +JsonContainer::JsonContainer(float value) : + m_impl(std::make_shared(nlohmann::ordered_json(value))) {} +JsonContainer::JsonContainer(const std::string& value) : + m_impl(std::make_shared(nlohmann::ordered_json(value))) {} +JsonContainer::JsonContainer(const char* value) : + m_impl(std::make_shared(nlohmann::ordered_json(value))) {} +JsonContainer::JsonContainer(std::nullptr_t) : + m_impl(std::make_shared(nlohmann::ordered_json(nullptr))) {} JsonContainer::JsonContainer(std::initializer_list> init) : - JsonContainer(ov::genai::utils::any_map_to_json(ov::AnyMap{init.begin(), init.end()})) {} + m_impl(std::make_shared(ov::genai::utils::any_map_to_json(ov::AnyMap{init.begin(), init.end()}))) {} JsonContainer::JsonContainer(const ov::AnyMap& data) : - JsonContainer(ov::genai::utils::any_map_to_json(data)) {} + m_impl(std::make_shared(ov::genai::utils::any_map_to_json(data))) {} JsonContainer::JsonContainer(ov::AnyMap&& data) : - JsonContainer(ov::genai::utils::any_map_to_json(std::move(data))) {} + m_impl(std::make_shared(ov::genai::utils::any_map_to_json(std::move(data)))) {} -JsonContainer::JsonContainer(std::shared_ptr json_ptr, const std::string& path) : - m_json(json_ptr), +JsonContainer::JsonContainer(std::shared_ptr impl, const std::string& path) : + m_impl(std::move(impl)), m_path(path) {} -JsonContainer::JsonContainer(nlohmann::ordered_json json) : - m_json(std::make_shared(std::move(json))) {} - JsonContainer::JsonContainer(const JsonContainer& other) : - m_json(std::make_shared(*other.m_json)), + m_impl(std::make_shared(*other.m_impl->get_json_value_ptr(other.m_path, AccessMode::Read))), m_path(other.m_path) {} JsonContainer::JsonContainer(JsonContainer&& other) noexcept : - m_json(std::move(other.m_json)), + m_impl(std::move(other.m_impl)), m_path(std::move(other.m_path)) {} JsonContainer::~JsonContainer() = default; JsonContainer JsonContainer::share() const { - return JsonContainer(m_json, m_path); + return JsonContainer(m_impl, m_path); } JsonContainer JsonContainer::copy() const { @@ -51,50 +88,68 @@ JsonContainer JsonContainer::copy() const { JsonContainer JsonContainer::from_json_string(const std::string& json_str) { try { - return JsonContainer(nlohmann::ordered_json::parse(json_str)); + return JsonContainer(std::make_shared(nlohmann::ordered_json::parse(json_str))); } catch (const std::exception& e) { OPENVINO_THROW("Failed to construct JsonContainer from JSON string: ", e.what()); } } JsonContainer JsonContainer::object() { - return JsonContainer(nlohmann::ordered_json::object()); + return JsonContainer(std::make_shared(nlohmann::ordered_json::object())); } JsonContainer JsonContainer::array() { - return JsonContainer(nlohmann::ordered_json::array()); + return JsonContainer(std::make_shared(nlohmann::ordered_json::array())); } -nlohmann::ordered_json* JsonContainer::get_json_value_ptr(AccessMode mode) const { - auto json_pointer = nlohmann::ordered_json::json_pointer(m_path); - if (mode == AccessMode::Read && !m_json->contains(json_pointer)) { - OPENVINO_THROW("Path '", m_path, "' does not exist in the JsonContainer."); +#define JSON_CONTAINER_PRIMITIVE_ASSIGNMENT(type) \ + JsonContainer& JsonContainer::operator=(type value) { \ + auto json_value_ptr = m_impl->get_json_value_ptr(m_path, AccessMode::Write); \ + *json_value_ptr = nlohmann::ordered_json(value); \ + return *this; \ } - return &(*m_json)[json_pointer]; + +JSON_CONTAINER_PRIMITIVE_ASSIGNMENT(bool) +JSON_CONTAINER_PRIMITIVE_ASSIGNMENT(int) +JSON_CONTAINER_PRIMITIVE_ASSIGNMENT(int64_t) +JSON_CONTAINER_PRIMITIVE_ASSIGNMENT(double) +JSON_CONTAINER_PRIMITIVE_ASSIGNMENT(float) +JSON_CONTAINER_PRIMITIVE_ASSIGNMENT(const std::string&) +JSON_CONTAINER_PRIMITIVE_ASSIGNMENT(const char*) + +#undef JSON_CONTAINER_PRIMITIVE_ASSIGNMENT + +JsonContainer& JsonContainer::operator=(std::nullptr_t) { + auto json_value_ptr = m_impl->get_json_value_ptr(m_path, AccessMode::Write); + *json_value_ptr = nlohmann::ordered_json(nullptr); + return *this; } JsonContainer& JsonContainer::operator=(const JsonContainer& other) { if (this != &other) { - auto json_value_ptr = get_json_value_ptr(AccessMode::Write); - *json_value_ptr = other.to_json(); + auto json_value_ptr = m_impl->get_json_value_ptr(m_path, AccessMode::Write); + *json_value_ptr = *other.m_impl->get_json_value_ptr(other.m_path, AccessMode::Read); } return *this; } JsonContainer& JsonContainer::operator=(JsonContainer&& other) noexcept { if (this != &other) { - auto json_value_ptr = get_json_value_ptr(AccessMode::Write); - if (m_json == other.m_json) { - *json_value_ptr = std::move(*other.get_json_value_ptr(AccessMode::Read)); + auto json_value_ptr = m_impl->get_json_value_ptr(m_path, AccessMode::Write); + auto other_json_value_ptr = other.m_impl->get_json_value_ptr(other.m_path, AccessMode::Read); + if (m_impl == other.m_impl) { + *json_value_ptr = std::move(*other_json_value_ptr); } else { - *json_value_ptr = other.to_json(); + *json_value_ptr = *other_json_value_ptr; } } return *this; } bool JsonContainer::operator==(const JsonContainer& other) const { - return to_json() == other.to_json(); + auto json_value_ptr = m_impl->get_json_value_ptr(m_path, AccessMode::Read); + auto other_json_value_ptr = other.m_impl->get_json_value_ptr(other.m_path, AccessMode::Read); + return *json_value_ptr == *other_json_value_ptr; } bool JsonContainer::operator!=(const JsonContainer& other) const { @@ -102,7 +157,7 @@ bool JsonContainer::operator!=(const JsonContainer& other) const { } JsonContainer JsonContainer::operator[](const std::string& key) const { - return JsonContainer(m_json, m_path + "/" + key); + return JsonContainer(m_impl, m_path + "/" + key); } JsonContainer JsonContainer::operator[](const char* key) const { @@ -110,75 +165,71 @@ JsonContainer JsonContainer::operator[](const char* key) const { } JsonContainer JsonContainer::operator[](size_t index) const { - return JsonContainer(m_json, m_path + "/" + std::to_string(index)); + return JsonContainer(m_impl, m_path + "/" + std::to_string(index)); } JsonContainer JsonContainer::operator[](int index) const { return operator[](size_t(index)); } -nlohmann::ordered_json JsonContainer::to_json() const { - return *get_json_value_ptr(AccessMode::Read); -} - std::string JsonContainer::to_json_string(int indent) const { - return to_json().dump(indent); + return m_impl->get_json_value_ptr(m_path, AccessMode::Read)->dump(indent); } bool JsonContainer::is_null() const { - return get_json_value_ptr(AccessMode::Read)->is_null(); + return m_impl->get_json_value_ptr(m_path, AccessMode::Read)->is_null(); } bool JsonContainer::is_boolean() const { - return get_json_value_ptr(AccessMode::Read)->is_boolean(); + return m_impl->get_json_value_ptr(m_path, AccessMode::Read)->is_boolean(); } bool JsonContainer::is_number() const { - return get_json_value_ptr(AccessMode::Read)->is_number(); + return m_impl->get_json_value_ptr(m_path, AccessMode::Read)->is_number(); } bool JsonContainer::is_number_integer() const { - return get_json_value_ptr(AccessMode::Read)->is_number_integer(); + return m_impl->get_json_value_ptr(m_path, AccessMode::Read)->is_number_integer(); } bool JsonContainer::is_number_float() const { - return get_json_value_ptr(AccessMode::Read)->is_number_float(); + return m_impl->get_json_value_ptr(m_path, AccessMode::Read)->is_number_float(); } bool JsonContainer::is_string() const { - return get_json_value_ptr(AccessMode::Read)->is_string(); + return m_impl->get_json_value_ptr(m_path, AccessMode::Read)->is_string(); } bool JsonContainer::is_array() const { - return get_json_value_ptr(AccessMode::Read)->is_array(); + return m_impl->get_json_value_ptr(m_path, AccessMode::Read)->is_array(); } bool JsonContainer::is_object() const { - return get_json_value_ptr(AccessMode::Read)->is_object(); + return m_impl->get_json_value_ptr(m_path, AccessMode::Read)->is_object(); } std::optional JsonContainer::as_bool() const { - auto json_value_ptr = get_json_value_ptr(AccessMode::Read); + auto json_value_ptr = m_impl->get_json_value_ptr(m_path, AccessMode::Read); return json_value_ptr->is_boolean() ? std::make_optional(json_value_ptr->get()) : std::nullopt; } std::optional JsonContainer::as_int() const { - auto json_value_ptr = get_json_value_ptr(AccessMode::Read); + auto json_value_ptr = m_impl->get_json_value_ptr(m_path, AccessMode::Read); return json_value_ptr->is_number_integer() ? std::make_optional(json_value_ptr->get()) : std::nullopt; } std::optional JsonContainer::as_double() const { - auto json_value_ptr = get_json_value_ptr(AccessMode::Read); + auto json_value_ptr = m_impl->get_json_value_ptr(m_path, AccessMode::Read); return json_value_ptr->is_number() ? std::make_optional(json_value_ptr->get()) : std::nullopt; } std::optional JsonContainer::as_string() const { - auto json_value_ptr = get_json_value_ptr(AccessMode::Read); + auto json_value_ptr = m_impl->get_json_value_ptr(m_path, AccessMode::Read); return json_value_ptr->is_string() ? std::make_optional(json_value_ptr->get()) : std::nullopt; } bool JsonContainer::get_bool() const { - auto json_value_ptr = get_json_value_ptr(AccessMode::Read); + auto json_value_ptr = m_impl->get_json_value_ptr(m_path, AccessMode::Read); if (!json_value_ptr->is_boolean()) { OPENVINO_THROW("JsonContainer expected boolean at path '", m_path, "' but found ", json_value_ptr->type_name(), " with value: ", json_value_ptr->dump()); @@ -187,7 +238,7 @@ bool JsonContainer::get_bool() const { } int64_t JsonContainer::get_int() const { - auto json_value_ptr = get_json_value_ptr(AccessMode::Read); + auto json_value_ptr = m_impl->get_json_value_ptr(m_path, AccessMode::Read); if (!json_value_ptr->is_number_integer()) { OPENVINO_THROW("JsonContainer expected integer number at path '", m_path, "' but found ", json_value_ptr->type_name(), " with value: ", json_value_ptr->dump()); @@ -196,7 +247,7 @@ int64_t JsonContainer::get_int() const { } double JsonContainer::get_double() const { - auto json_value_ptr = get_json_value_ptr(AccessMode::Read); + auto json_value_ptr = m_impl->get_json_value_ptr(m_path, AccessMode::Read); if (!json_value_ptr->is_number_float()) { OPENVINO_THROW("JsonContainer expected floating-point number at path '", m_path, "' but found ", json_value_ptr->type_name(), " with value: ", json_value_ptr->dump()); @@ -205,7 +256,7 @@ double JsonContainer::get_double() const { } std::string JsonContainer::get_string() const { - auto json_value_ptr = get_json_value_ptr(AccessMode::Read); + auto json_value_ptr = m_impl->get_json_value_ptr(m_path, AccessMode::Read); if (!json_value_ptr->is_string()) { OPENVINO_THROW("JsonContainer expected string at path '", m_path, "' but found ", json_value_ptr->type_name(), " with value: ", json_value_ptr->dump()); @@ -214,28 +265,57 @@ std::string JsonContainer::get_string() const { } JsonContainer& JsonContainer::to_empty_object() { - auto json_value_ptr = get_json_value_ptr(AccessMode::Write); + auto json_value_ptr = m_impl->get_json_value_ptr(m_path, AccessMode::Write); *json_value_ptr = nlohmann::ordered_json::object(); return *this; } JsonContainer& JsonContainer::to_empty_array() { - auto json_value_ptr = get_json_value_ptr(AccessMode::Write); + auto json_value_ptr = m_impl->get_json_value_ptr(m_path, AccessMode::Write); *json_value_ptr = nlohmann::ordered_json::array(); return *this; } JsonContainer& JsonContainer::push_back(const JsonContainer& item) { - auto json_value_ptr = get_json_value_ptr(AccessMode::Write); + auto json_value_ptr = m_impl->get_json_value_ptr(m_path, AccessMode::Write); + if (!json_value_ptr->is_array()) { + *json_value_ptr = nlohmann::ordered_json::array(); + } + json_value_ptr->push_back(*item.m_impl->get_json_value_ptr(item.m_path, AccessMode::Read)); + return *this; +} + +#define JSON_CONTAINER_PRIMITIVE_PUSH_BACK(type) \ + JsonContainer& JsonContainer::push_back(type value) { \ + auto json_value_ptr = m_impl->get_json_value_ptr(m_path, AccessMode::Write); \ + if (!json_value_ptr->is_array()) { \ + *json_value_ptr = nlohmann::ordered_json::array(); \ + } \ + json_value_ptr->push_back(nlohmann::ordered_json(value)); \ + return *this; \ + } + +JSON_CONTAINER_PRIMITIVE_PUSH_BACK(bool) +JSON_CONTAINER_PRIMITIVE_PUSH_BACK(int) +JSON_CONTAINER_PRIMITIVE_PUSH_BACK(int64_t) +JSON_CONTAINER_PRIMITIVE_PUSH_BACK(double) +JSON_CONTAINER_PRIMITIVE_PUSH_BACK(float) +JSON_CONTAINER_PRIMITIVE_PUSH_BACK(const std::string&) +JSON_CONTAINER_PRIMITIVE_PUSH_BACK(const char*) + +#undef JSON_CONTAINER_PRIMITIVE_PUSH_BACK + +JsonContainer& JsonContainer::push_back(std::nullptr_t) { + auto json_value_ptr = m_impl->get_json_value_ptr(m_path, AccessMode::Write); if (!json_value_ptr->is_array()) { *json_value_ptr = nlohmann::ordered_json::array(); } - json_value_ptr->push_back(item.to_json()); + json_value_ptr->push_back(nlohmann::ordered_json(nullptr)); return *this; } bool JsonContainer::contains(const std::string& key) const { - auto json_value_ptr = get_json_value_ptr(AccessMode::Read); + auto json_value_ptr = m_impl->get_json_value_ptr(m_path, AccessMode::Read); if (!json_value_ptr->is_object()) { return false; } @@ -243,17 +323,17 @@ bool JsonContainer::contains(const std::string& key) const { } size_t JsonContainer::size() const { - auto json_value_ptr = get_json_value_ptr(AccessMode::Read); + auto json_value_ptr = m_impl->get_json_value_ptr(m_path, AccessMode::Read); return json_value_ptr->size(); } bool JsonContainer::empty() const { - auto json_value_ptr = get_json_value_ptr(AccessMode::Read); + auto json_value_ptr = m_impl->get_json_value_ptr(m_path, AccessMode::Read); return json_value_ptr->empty(); } void JsonContainer::erase(const std::string& key) const { - auto json_value_ptr = get_json_value_ptr(AccessMode::Read); + auto json_value_ptr = m_impl->get_json_value_ptr(m_path, AccessMode::Read); if (!json_value_ptr->is_object()) { OPENVINO_THROW("JsonContainer erase by key is only supported for objects, but found ", json_value_ptr->type_name(), " at path '", m_path, "'"); @@ -265,7 +345,7 @@ void JsonContainer::erase(const std::string& key) const { } void JsonContainer::erase(size_t index) const { - auto json_value_ptr = get_json_value_ptr(AccessMode::Read); + auto json_value_ptr = m_impl->get_json_value_ptr(m_path, AccessMode::Read); if (!json_value_ptr->is_array()) { OPENVINO_THROW("JsonContainer erase by index is only supported for arrays, but found ", json_value_ptr->type_name(), " at path '", m_path, "'"); @@ -278,7 +358,7 @@ void JsonContainer::erase(size_t index) const { } void JsonContainer::clear() const { - auto json_value_ptr = get_json_value_ptr(AccessMode::Read); + auto json_value_ptr = m_impl->get_json_value_ptr(m_path, AccessMode::Read); if (!json_value_ptr->is_structured()) { OPENVINO_THROW("JsonContainer clear is only supported for objects and arrays, but found ", json_value_ptr->type_name(), " at path '", m_path, "'"); @@ -287,7 +367,7 @@ void JsonContainer::clear() const { } std::string JsonContainer::type_name() const { - auto json_value_ptr = get_json_value_ptr(AccessMode::Read); + auto json_value_ptr = m_impl->get_json_value_ptr(m_path, AccessMode::Read); if (json_value_ptr->is_null()) { return "null"; } else if (json_value_ptr->is_boolean()) { @@ -305,5 +385,9 @@ std::string JsonContainer::type_name() const { } } +void* JsonContainer::_get_json_value_ptr() const { + return m_impl->get_json_value_ptr(m_path, AccessMode::Read); +} + } // namespace genai } // namespace ov diff --git a/src/cpp/src/json_utils.hpp b/src/cpp/src/json_utils.hpp index ce5b7611db..e4bdd28b85 100644 --- a/src/cpp/src/json_utils.hpp +++ b/src/cpp/src/json_utils.hpp @@ -15,6 +15,22 @@ #include "openvino/genai/json_container.hpp" +namespace nlohmann { + +template<> +struct adl_serializer { + static void to_json(ordered_json& json, const ov::genai::JsonContainer& container) { + auto json_value_ptr = static_cast(container._get_json_value_ptr()); + json = *json_value_ptr; + } + + static ov::genai::JsonContainer from_json(const ordered_json& json) { + return ov::genai::JsonContainer::from_json_string(json.dump()); + } +}; + +} // namespace nlohmann + namespace ov { namespace genai { namespace utils { @@ -109,7 +125,7 @@ inline nlohmann::ordered_json any_to_json(const ov::Any& value) { } return array_json; } else if (value.is()) { - return value.as().to_json(); + return value.as(); } else { OPENVINO_THROW("Failed to convert Any to JSON, unsupported type: ", value.type_info().name()); } diff --git a/src/cpp/src/llm/pipeline_stateful.cpp b/src/cpp/src/llm/pipeline_stateful.cpp index f3bb6ff8bf..0ee8fd6e18 100644 --- a/src/cpp/src/llm/pipeline_stateful.cpp +++ b/src/cpp/src/llm/pipeline_stateful.cpp @@ -220,7 +220,7 @@ EncodedResults StatefulLLMPipeline::generate( if (is_chat_conversation) // if chat was run in StringInputs mode, but it was called EncodedInputs generate, last m_history entry will be with assistant role - OPENVINO_ASSERT(m_chat_input_type == ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS || m_history.back()["role"] == "user", + OPENVINO_ASSERT(m_chat_input_type == ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS || m_history.last()["role"] == "user", "Chat doesn't support switching between input types. Please, continue using StringInputs or restart the chat."); if (!is_chat_conversation) { diff --git a/src/cpp/src/tokenizer/tokenizer.cpp b/src/cpp/src/tokenizer/tokenizer.cpp index bbcd415d94..0c46d7a21b 100644 --- a/src/cpp/src/tokenizer/tokenizer.cpp +++ b/src/cpp/src/tokenizer/tokenizer.cpp @@ -131,11 +131,11 @@ std::string Tokenizer::get_eos_token() const { return m_pimpl->m_eos_token; } -std::string Tokenizer::apply_chat_template(ChatHistory history, +std::string Tokenizer::apply_chat_template(const ChatHistory& history, bool add_generation_prompt, const std::string& chat_template, - const ToolDefinitions& tools, - const ov::AnyMap& extra_context) const { + const std::optional& tools, + const std::optional& extra_context) const { return m_pimpl->apply_chat_template(history, add_generation_prompt, chat_template, tools, extra_context); } diff --git a/src/cpp/src/tokenizer/tokenizer_impl.cpp b/src/cpp/src/tokenizer/tokenizer_impl.cpp index 12fd550011..bbc69b71f6 100644 --- a/src/cpp/src/tokenizer/tokenizer_impl.cpp +++ b/src/cpp/src/tokenizer/tokenizer_impl.cpp @@ -727,46 +727,40 @@ std::vector Tokenizer::TokenizerImpl::decode(const std::vector& tools, + const std::optional& extra_context ) const { std::string chat_tpl = chat_template.empty() ? m_chat_template : remap_template(chat_template); OPENVINO_ASSERT(!chat_tpl.empty(), "Chat template wasn't found. This may indicate that the model wasn't trained for chat scenario." " Please add 'chat_template' to tokenizer_config.json to use the model in chat scenario." " For more information see the section Troubleshooting in README.md"); - - nlohmann::ordered_json messages_json = nlohmann::ordered_json::array(); - for (const auto& message : history) { - nlohmann::ordered_json message_json = ov::genai::utils::any_map_to_json(message); - messages_json.push_back(message_json); - } - nlohmann::ordered_json tools_json = nlohmann::ordered_json::array(); - for (const auto& tool : tools) { - nlohmann::ordered_json tool_json = ov::genai::utils::any_map_to_json(tool); - tools_json.push_back(tool_json); - } + auto resolved_tools = tools.value_or(history.get_tools()); + auto resolved_extra_context = extra_context.value_or(history.get_extra_context()); + + OPENVINO_ASSERT(resolved_tools.is_array(), + "Tools should be an array-like JsonContainer, got: ", resolved_tools.type_name()); + OPENVINO_ASSERT(resolved_extra_context.is_object(), + "Extra context should be an object-like JsonContainer, got: ", resolved_extra_context.type_name()); minja::chat_template minja_template(chat_tpl, m_bos_token, m_eos_token); minja::chat_template_inputs minja_inputs; - minja_inputs.messages = messages_json; - if (!tools_json.empty()) { - minja_inputs.tools = tools_json; + minja_inputs.messages = history.get_messages(); + if (!resolved_tools.empty()) { + minja_inputs.tools = resolved_tools; } minja_inputs.add_generation_prompt = add_generation_prompt; minja_inputs.extra_context = nlohmann::ordered_json::object(); minja_inputs.extra_context["bos_token"] = m_bos_token; minja_inputs.extra_context["eos_token"] = m_eos_token; minja_inputs.extra_context["pad_token"] = m_pad_token; - - if (!extra_context.empty()) { - auto extra_context_json = ov::genai::utils::any_map_to_json(extra_context); - minja_inputs.extra_context.update(extra_context_json); + if (!resolved_extra_context.empty()) { + minja_inputs.extra_context.update(resolved_extra_context); } std::string result; diff --git a/src/cpp/src/tokenizer/tokenizer_impl.hpp b/src/cpp/src/tokenizer/tokenizer_impl.hpp index c92771ec9f..773a5fd447 100644 --- a/src/cpp/src/tokenizer/tokenizer_impl.hpp +++ b/src/cpp/src/tokenizer/tokenizer_impl.hpp @@ -72,11 +72,11 @@ class Tokenizer::TokenizerImpl { std::vector decode(const ov::Tensor& tokens, const ov::AnyMap& detokenization_params = {}); std::vector decode(const std::vector>& lines, const ov::AnyMap& detokenization_params = {}); - std::string apply_chat_template(ChatHistory history, + std::string apply_chat_template(const ChatHistory& history, bool add_generation_prompt, const std::string& chat_template, - const ToolDefinitions& tools, - const ov::AnyMap& extra_context) const; + const std::optional& tools, + const std::optional& extra_context) const; void set_chat_template(const std::string& chat_template); std::string get_chat_template(); diff --git a/src/cpp/src/visual_language/pipeline.cpp b/src/cpp/src/visual_language/pipeline.cpp index 8c78351425..b7239fcaaf 100644 --- a/src/cpp/src/visual_language/pipeline.cpp +++ b/src/cpp/src/visual_language/pipeline.cpp @@ -341,7 +341,8 @@ class VLMPipeline::VLMPipelineImpl : public VLMPipelineBase{ if (system_message.empty()) { return; } - m_history = {{{"role", "system"}, {"content", system_message}}}; + m_history.clear(); + m_history.push_back({{"role", "system"}, {"content", system_message}}); } void finish_chat() override { diff --git a/src/js/include/helper.hpp b/src/js/include/helper.hpp index a29ee3311c..e8863dc4db 100644 --- a/src/js/include/helper.hpp +++ b/src/js/include/helper.hpp @@ -100,3 +100,5 @@ Napi::Object cpp_map_to_js_object(const Napi::Env& env, const std::map[], - // addGenerationPrompt: boolean, - // chatTemplate?: string, - // tools?: Record[], - // extraContext?: Record, - // ): string; + // TODO Consider adding bindings for ChatHistory and JsonContainer classes applyChatTemplate( - chatHistory: { role: string; content: string }[], + chatHistory: Record[], addGenerationPrompt: boolean, chatTemplate?: string, + tools?: Record[], + extraContext?: Record, ): string; getBosToken(): string; getBosTokenId(): number; diff --git a/src/js/src/helper.cpp b/src/js/src/helper.cpp index 7c6f325929..8e321f5786 100644 --- a/src/js/src/helper.cpp +++ b/src/js/src/helper.cpp @@ -131,41 +131,24 @@ ov::genai::StringInputs js_to_cpp(const Napi::Env& env, template <> ov::genai::ChatHistory js_to_cpp(const Napi::Env& env, const Napi::Value& value) { - // TODO Update for new ChatHistory type: Record[] - auto incorrect_argument_message = "Chat history must be { role: string, content: string }[]"; - if (value.IsArray()) { - auto array = value.As(); - size_t arrayLength = array.Length(); - - std::vector nativeArray; - for (uint32_t i = 0; i < arrayLength; ++i) { - Napi::Value arrayItem = array[i]; - if (!arrayItem.IsObject()) { - OPENVINO_THROW(incorrect_argument_message); - } - auto obj = arrayItem.As(); - if (obj.Get("role").IsUndefined() || obj.Get("content").IsUndefined()) { - OPENVINO_THROW(incorrect_argument_message); - } - ov::AnyMap result; - Napi::Array keys = obj.GetPropertyNames(); - - for (uint32_t i = 0; i < keys.Length(); ++i) { - Napi::Value key = keys[i]; - Napi::Value value = obj.Get(key); + auto incorrect_argument_message = "Chat history must be an array of JS objects"; + if (!value.IsArray()) { + OPENVINO_THROW(incorrect_argument_message); + } - std::string keyStr = key.ToString().Utf8Value(); - std::string valueStr = value.ToString().Utf8Value(); + auto array = value.As(); + size_t arrayLength = array.Length(); - result[keyStr] = valueStr; - } - nativeArray.push_back(result); + for (uint32_t i = 0; i < arrayLength; ++i) { + Napi::Value arrayItem = array[i]; + if (!arrayItem.IsObject()) { + OPENVINO_THROW(incorrect_argument_message); } - return nativeArray; - - } else { - OPENVINO_THROW(incorrect_argument_message); } + + // TODO Consider using direct native JsonContainer conversion instead of string serialization + auto messages = ov::genai::JsonContainer::from_json_string(json_stringify(env, value)); + return ov::genai::ChatHistory(messages); } template <> @@ -274,3 +257,14 @@ Napi::Value cpp_to_js, Napi::Value>(const Napi::Env& env, co bool is_napi_value_int(const Napi::Env& env, const Napi::Value& num) { return env.Global().Get("Number").ToObject().Get("isInteger").As().Call({num}).ToBoolean().Value(); } + +std::string json_stringify(const Napi::Env& env, const Napi::Value& value) { + return env.Global() + .Get("JSON") + .ToObject() + .Get("stringify") + .As() + .Call({ value }) + .ToString() + .Utf8Value(); +} diff --git a/src/js/src/tokenizer.cpp b/src/js/src/tokenizer.cpp index b888312cb9..aab586094d 100644 --- a/src/js/src/tokenizer.cpp +++ b/src/js/src/tokenizer.cpp @@ -36,10 +36,18 @@ Napi::Value TokenizerWrapper::apply_chat_template(const Napi::CallbackInfo& info OPENVINO_ASSERT(!info[1].IsUndefined() && info[1].IsBoolean(), "The argument 'addGenerationPrompt' must be a boolean"); bool add_generation_prompt = info[1].ToBoolean(); std::string chat_template = ""; - if (info.Length() == 3 && !info[2].IsUndefined()) { + if (!info[2].IsUndefined()) { chat_template = info[2].ToString().Utf8Value(); } - auto result = this->_tokenizer.apply_chat_template(history, add_generation_prompt, chat_template); + std::optional tools; + if (!info[3].IsUndefined()) { + tools = ov::genai::JsonContainer::from_json_string(json_stringify(info.Env(), info[3])); + } + std::optional extra_context; + if (!info[4].IsUndefined()) { + extra_context = ov::genai::JsonContainer::from_json_string(json_stringify(info.Env(), info[4])); + } + auto result = this->_tokenizer.apply_chat_template(history, add_generation_prompt, chat_template, tools, extra_context); return Napi::String::New(info.Env(), result); } catch (std::exception& err) { Napi::Error::New(info.Env(), err.what()).ThrowAsJavaScriptException(); diff --git a/src/js/tests/tokenizer.test.js b/src/js/tests/tokenizer.test.js index 3f3ead6db8..31be25b160 100644 --- a/src/js/tests/tokenizer.test.js +++ b/src/js/tests/tokenizer.test.js @@ -34,21 +34,6 @@ describe("tokenizer", async () => { assert.strictEqual(typeof template, "string"); }); - it("applyChatTemplate with unknown property", async () => { - const testValue = "1234567890"; - const template = tokenizer.applyChatTemplate( - [ - { - role: "user", - content: "continue: 1 2 3", - unknownProp: testValue, - }, - ], - false, - ); - assert.ok(!template.includes(testValue)); - }); - it("applyChatTemplate with true addGenerationPrompt", async () => { const template = tokenizer.applyChatTemplate( [ @@ -62,32 +47,6 @@ describe("tokenizer", async () => { assert.ok(template.includes("assistant")); }); - it("applyChatTemplate with missed role", async () => { - assert.throws(() => - tokenizer.applyChatTemplate( - [ - { - content: "continue: 1 2 3", - }, - ], - false, - ), - ); - }); - - it("applyChatTemplate with missed content", async () => { - assert.throws(() => - tokenizer.applyChatTemplate( - [ - { - role: "user", - }, - ], - false, - ), - ); - }); - it("applyChatTemplate with missed addGenerationPrompt", async () => { assert.throws(() => tokenizer.applyChatTemplate([ @@ -99,8 +58,12 @@ describe("tokenizer", async () => { ); }); - it("applyChatTemplate with incorrect type of history", async () => { + it("applyChatTemplate with incorrect type of history", async () => { assert.throws(() => tokenizer.applyChatTemplate("prompt", false)); + assert.throws(() => tokenizer.applyChatTemplate(["prompt"], false)); + assert.throws(() => + tokenizer.applyChatTemplate([{ role: "user", content: "prompt" }, "not an object"], false), + ); }); it("applyChatTemplate with unknown property", async () => { @@ -136,6 +99,49 @@ describe("tokenizer", async () => { assert.strictEqual(template, `${prompt}\n`); }); + it("applyChatTemplate use tools", async () => { + const prompt = "question"; + const chatHistory = [ + { + role: "user", + content: prompt, + }, + ]; + const chatTemplate = `{% for message in messages %} +{{ message['content'] }} +{% for tool in tools %}{{ tool | tojson }}{% endfor %} +{% endfor %}`; + const tools = [{ type: "function", function: { name: "test" } }]; + const templatedHistory = tokenizer.applyChatTemplate(chatHistory, false, chatTemplate, tools); + const expected = `${prompt}\n{"type": "function", "function": {"name": "test"}}`; + assert.strictEqual(templatedHistory, expected); + }); + + it("applyChatTemplate use extra_context", async () => { + const prompt = "question"; + const chatHistory = [ + { + role: "user", + content: prompt, + }, + ]; + const chatTemplate = `{% for message in messages %} +{{ message['content'] }} +{% if enable_thinking is defined and enable_thinking is false %}No thinking{% endif %} +{% endfor %}`; + const tools = []; + const extraContext = { enable_thinking: false }; // eslint-disable-line camelcase + const templatedHistory = tokenizer.applyChatTemplate( + chatHistory, + false, + chatTemplate, + tools, + extraContext, + ); + const expected = `${prompt}\nNo thinking`; + assert.strictEqual(templatedHistory, expected); + }); + it("getBosToken return string", async () => { const token = tokenizer.getBosToken(); assert.strictEqual(typeof token, "string"); diff --git a/src/python/openvino_genai/__init__.py b/src/python/openvino_genai/__init__.py index 218f782e29..4494c33621 100644 --- a/src/python/openvino_genai/__init__.py +++ b/src/python/openvino_genai/__init__.py @@ -50,6 +50,11 @@ StopCriteria ) +# Chat history +from .py_openvino_genai import ( + ChatHistory +) + # Tokenizers from .py_openvino_genai import ( TokenizedInputs, diff --git a/src/python/openvino_genai/__init__.pyi b/src/python/openvino_genai/__init__.pyi index 175df870eb..1517ff8524 100644 --- a/src/python/openvino_genai/__init__.pyi +++ b/src/python/openvino_genai/__init__.pyi @@ -10,6 +10,7 @@ from openvino_genai.py_openvino_genai import AutoencoderKL from openvino_genai.py_openvino_genai import CLIPTextModel from openvino_genai.py_openvino_genai import CLIPTextModelWithProjection from openvino_genai.py_openvino_genai import CacheEvictionConfig +from openvino_genai.py_openvino_genai import ChatHistory from openvino_genai.py_openvino_genai import ChunkStreamerBase from openvino_genai.py_openvino_genai import ContinuousBatchingPipeline from openvino_genai.py_openvino_genai import CppStdGenerator @@ -64,5 +65,5 @@ from openvino_genai.py_openvino_genai import draft_model from openvino_genai.py_openvino_genai import get_version import os as os from . import py_openvino_genai -__all__: list[str] = ['Adapter', 'AdapterConfig', 'AggregationMode', 'AutoencoderKL', 'CLIPTextModel', 'CLIPTextModelWithProjection', 'CacheEvictionConfig', 'ChunkStreamerBase', 'ContinuousBatchingPipeline', 'CppStdGenerator', 'DecodedResults', 'EncodedResults', 'FluxTransformer2DModel', 'GenerationConfig', 'GenerationFinishReason', 'GenerationResult', 'GenerationStatus', 'Generator', 'Image2ImagePipeline', 'ImageGenerationConfig', 'ImageGenerationPerfMetrics', 'InpaintingPipeline', 'KVCrushAnchorPointMode', 'KVCrushConfig', 'LLMPipeline', 'PerfMetrics', 'RawImageGenerationPerfMetrics', 'RawPerfMetrics', 'SD3Transformer2DModel', 'Scheduler', 'SchedulerConfig', 'SparseAttentionConfig', 'SparseAttentionMode', 'SpeechGenerationConfig', 'SpeechGenerationPerfMetrics', 'StopCriteria', 'StreamerBase', 'StreamingStatus', 'StructuralTagItem', 'StructuralTagsConfig', 'StructuredOutputConfig', 'T5EncoderModel', 'Text2ImagePipeline', 'Text2SpeechDecodedResults', 'Text2SpeechPipeline', 'TextEmbeddingPipeline', 'TextRerankPipeline', 'TextStreamer', 'TokenizedInputs', 'Tokenizer', 'TorchGenerator', 'UNet2DConditionModel', 'VLMPipeline', 'WhisperGenerationConfig', 'WhisperPerfMetrics', 'WhisperPipeline', 'WhisperRawPerfMetrics', 'draft_model', 'get_version', 'openvino', 'os', 'py_openvino_genai'] +__all__: list[str] = ['Adapter', 'AdapterConfig', 'AggregationMode', 'AutoencoderKL', 'CLIPTextModel', 'CLIPTextModelWithProjection', 'CacheEvictionConfig', 'ChatHistory', 'ChunkStreamerBase', 'ContinuousBatchingPipeline', 'CppStdGenerator', 'DecodedResults', 'EncodedResults', 'FluxTransformer2DModel', 'GenerationConfig', 'GenerationFinishReason', 'GenerationResult', 'GenerationStatus', 'Generator', 'Image2ImagePipeline', 'ImageGenerationConfig', 'ImageGenerationPerfMetrics', 'InpaintingPipeline', 'KVCrushAnchorPointMode', 'KVCrushConfig', 'LLMPipeline', 'PerfMetrics', 'RawImageGenerationPerfMetrics', 'RawPerfMetrics', 'SD3Transformer2DModel', 'Scheduler', 'SchedulerConfig', 'SparseAttentionConfig', 'SparseAttentionMode', 'SpeechGenerationConfig', 'SpeechGenerationPerfMetrics', 'StopCriteria', 'StreamerBase', 'StreamingStatus', 'StructuralTagItem', 'StructuralTagsConfig', 'StructuredOutputConfig', 'T5EncoderModel', 'Text2ImagePipeline', 'Text2SpeechDecodedResults', 'Text2SpeechPipeline', 'TextEmbeddingPipeline', 'TextRerankPipeline', 'TextStreamer', 'TokenizedInputs', 'Tokenizer', 'TorchGenerator', 'UNet2DConditionModel', 'VLMPipeline', 'WhisperGenerationConfig', 'WhisperPerfMetrics', 'WhisperPipeline', 'WhisperRawPerfMetrics', 'draft_model', 'get_version', 'openvino', 'os', 'py_openvino_genai'] __version__: str diff --git a/src/python/openvino_genai/py_openvino_genai.pyi b/src/python/openvino_genai/py_openvino_genai.pyi index b300311721..7ccf59e683 100644 --- a/src/python/openvino_genai/py_openvino_genai.pyi +++ b/src/python/openvino_genai/py_openvino_genai.pyi @@ -5,7 +5,7 @@ from __future__ import annotations import collections.abc import openvino._pyopenvino import typing -__all__: list[str] = ['Adapter', 'AdapterConfig', 'AggregationMode', 'AutoencoderKL', 'CLIPTextModel', 'CLIPTextModelWithProjection', 'CacheEvictionConfig', 'ChunkStreamerBase', 'ContinuousBatchingPipeline', 'CppStdGenerator', 'DecodedResults', 'EncodedGenerationResult', 'EncodedResults', 'ExtendedPerfMetrics', 'FluxTransformer2DModel', 'GenerationConfig', 'GenerationFinishReason', 'GenerationHandle', 'GenerationOutput', 'GenerationResult', 'GenerationStatus', 'Generator', 'Image2ImagePipeline', 'ImageGenerationConfig', 'ImageGenerationPerfMetrics', 'InpaintingPipeline', 'KVCrushAnchorPointMode', 'KVCrushConfig', 'LLMPipeline', 'MeanStdPair', 'PerfMetrics', 'PipelineMetrics', 'RawImageGenerationPerfMetrics', 'RawPerfMetrics', 'SD3Transformer2DModel', 'SDPerModelsPerfMetrics', 'SDPerfMetrics', 'Scheduler', 'SchedulerConfig', 'SparseAttentionConfig', 'SparseAttentionMode', 'SpeechGenerationConfig', 'SpeechGenerationPerfMetrics', 'StopCriteria', 'StreamerBase', 'StreamingStatus', 'StructuralTagItem', 'StructuralTagsConfig', 'StructuredOutputConfig', 'SummaryStats', 'T5EncoderModel', 'Text2ImagePipeline', 'Text2SpeechDecodedResults', 'Text2SpeechPipeline', 'TextEmbeddingPipeline', 'TextRerankPipeline', 'TextStreamer', 'TokenizedInputs', 'Tokenizer', 'TorchGenerator', 'UNet2DConditionModel', 'VLMDecodedResults', 'VLMPerfMetrics', 'VLMPipeline', 'VLMRawPerfMetrics', 'WhisperDecodedResultChunk', 'WhisperDecodedResults', 'WhisperGenerationConfig', 'WhisperPerfMetrics', 'WhisperPipeline', 'WhisperRawPerfMetrics', 'draft_model', 'get_version'] +__all__: list[str] = ['Adapter', 'AdapterConfig', 'AggregationMode', 'AutoencoderKL', 'CLIPTextModel', 'CLIPTextModelWithProjection', 'CacheEvictionConfig', 'ChatHistory', 'ChunkStreamerBase', 'ContinuousBatchingPipeline', 'CppStdGenerator', 'DecodedResults', 'EncodedGenerationResult', 'EncodedResults', 'ExtendedPerfMetrics', 'FluxTransformer2DModel', 'GenerationConfig', 'GenerationFinishReason', 'GenerationHandle', 'GenerationOutput', 'GenerationResult', 'GenerationStatus', 'Generator', 'Image2ImagePipeline', 'ImageGenerationConfig', 'ImageGenerationPerfMetrics', 'InpaintingPipeline', 'KVCrushAnchorPointMode', 'KVCrushConfig', 'LLMPipeline', 'MeanStdPair', 'PerfMetrics', 'PipelineMetrics', 'RawImageGenerationPerfMetrics', 'RawPerfMetrics', 'SD3Transformer2DModel', 'SDPerModelsPerfMetrics', 'SDPerfMetrics', 'Scheduler', 'SchedulerConfig', 'SparseAttentionConfig', 'SparseAttentionMode', 'SpeechGenerationConfig', 'SpeechGenerationPerfMetrics', 'StopCriteria', 'StreamerBase', 'StreamingStatus', 'StructuralTagItem', 'StructuralTagsConfig', 'StructuredOutputConfig', 'SummaryStats', 'T5EncoderModel', 'Text2ImagePipeline', 'Text2SpeechDecodedResults', 'Text2SpeechPipeline', 'TextEmbeddingPipeline', 'TextRerankPipeline', 'TextStreamer', 'TokenizedInputs', 'Tokenizer', 'TorchGenerator', 'UNet2DConditionModel', 'VLMDecodedResults', 'VLMPerfMetrics', 'VLMPipeline', 'VLMRawPerfMetrics', 'WhisperDecodedResultChunk', 'WhisperDecodedResults', 'WhisperGenerationConfig', 'WhisperPerfMetrics', 'WhisperPipeline', 'WhisperRawPerfMetrics', 'draft_model', 'get_version'] class Adapter: """ Immutable LoRA Adapter that carries the adaptation matrices and serves as unique adapter identifier. @@ -390,6 +390,79 @@ class CacheEvictionConfig: @snapkv_window_size.setter def snapkv_window_size(self, arg0: typing.SupportsInt) -> None: ... +class ChatHistory: + """ + + ChatHistory stores conversation messages and optional metadata for chat templates. + + Manages: + - Message history (array of message objects) + - Optional tools definitions array (for function calling) + - Optional extra context object (for custom template variables) + + Messages are stored as JSON-like structures but accessed as Python dicts. + Use get_messages() to retrieve the list of all messages, modify them, + and set_messages() to update the history. + + Example: + ```python + history = ChatHistory() + history.append({"role": "user", "content": "Hello"}) + + # Modify messages + messages = history.get_messages() + messages[0]["content"] = "Updated" + history.set_messages(messages) + ``` + """ + def __bool__(self) -> bool: + ... + @typing.overload + def __init__(self) -> None: + """ + Create an empty chat history. + """ + @typing.overload + def __init__(self, messages: list) -> None: + """ + Create chat history from a list of message dicts. + """ + def __len__(self) -> int: + ... + def append(self, message: dict) -> None: + """ + Add a message to the end of chat history. + """ + def clear(self) -> None: + ... + def get_extra_context(self) -> dict: + """ + Get the extra context object. + """ + def get_messages(self) -> list: + """ + Get all messages as a list of dicts (deep copy). + """ + def get_tools(self) -> list: + """ + Get the tools definitions array. + """ + def pop(self) -> dict: + """ + Remove and return the last message. + """ + def set_extra_context(self, extra_context: dict) -> None: + """ + Set the extra context object. + """ + def set_messages(self, messages: list) -> None: + """ + Replace all messages with a new list. + """ + def set_tools(self, tools: list) -> None: + """ + Set the tools definitions array. + """ class ChunkStreamerBase(StreamerBase): """ @@ -3226,7 +3299,7 @@ class Tokenizer: @typing.overload def __init__(self, tokenizer_model: str, tokenizer_weights: openvino._pyopenvino.Tensor, detokenizer_model: str, detokenizer_weights: openvino._pyopenvino.Tensor, **kwargs) -> None: ... - def apply_chat_template(self, history: collections.abc.Sequence[typing.Any], add_generation_prompt: bool, chat_template: str = '', tools: collections.abc.Sequence[typing.Any] = [], extra_context: typing.Any = {}) -> str: + def apply_chat_template(self, history: openvino_genai.py_openvino_genai.ChatHistory | collections.abc.Sequence[dict], add_generation_prompt: bool, chat_template: str = '', tools: collections.abc.Sequence[dict] | None = None, extra_context: dict | None = None) -> str: """ Applies a chat template to format chat history into a prompt string. """ diff --git a/src/python/py_chat_history.cpp b/src/python/py_chat_history.cpp new file mode 100644 index 0000000000..8e466fa8b2 --- /dev/null +++ b/src/python/py_chat_history.cpp @@ -0,0 +1,98 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include +#include +#include + +#include "py_utils.hpp" + +namespace { + +constexpr char class_docstring[] = R"( + ChatHistory stores conversation messages and optional metadata for chat templates. + + Manages: + - Message history (array of message objects) + - Optional tools definitions array (for function calling) + - Optional extra context object (for custom template variables) + + Messages are stored as JSON-like structures but accessed as Python dicts. + Use get_messages() to retrieve the list of all messages, modify them, + and set_messages() to update the history. + + Example: + ```python + history = ChatHistory() + history.append({"role": "user", "content": "Hello"}) + + # Modify messages + messages = history.get_messages() + messages[0]["content"] = "Updated" + history.set_messages(messages) + ``` +)"; + +} // namespace + +namespace py = pybind11; +namespace pyutils = ov::genai::pybind::utils; + +using ov::genai::ChatHistory; +using ov::genai::JsonContainer; + +void init_chat_history(py::module_& m) { + py::class_(m, "ChatHistory", class_docstring) + .def(py::init<>(), "Create an empty chat history.") + + .def(py::init([](const py::list& messages) { + JsonContainer history = pyutils::py_object_to_json_container(messages); + return ChatHistory(history); + }), py::arg("messages"), R"(Create chat history from a list of message dicts.)") + + .def("get_messages", [](const ChatHistory& self) -> py::list { + return pyutils::json_container_to_py_object(self.get_messages()); + }, R"(Get all messages as a list of dicts (deep copy).)") + + .def("set_messages", [](ChatHistory& self, const py::list& messages) { + self.get_messages() = pyutils::py_object_to_json_container(messages); + }, py::arg("messages"), R"(Replace all messages with a new list.)") + + .def("append", [](ChatHistory& self, const py::dict& message) { + JsonContainer message_jc = pyutils::py_object_to_json_container(message); + self.push_back(message_jc); + }, py::arg("message"), R"(Add a message to the end of chat history.)") + + .def("pop", [](ChatHistory& self) -> py::dict { + if (self.empty()) { + throw py::index_error("Cannot pop from an empty chat history"); + } + JsonContainer last = self.last().copy(); + self.pop_back(); + return pyutils::json_container_to_py_object(last); + }, R"(Remove and return the last message.)") + + .def("clear", &ChatHistory::clear) + + .def("__len__", &ChatHistory::size) + + .def("__bool__", [](const ChatHistory& self) { + return !self.empty(); + }) + + .def("set_tools", [](ChatHistory& self, const py::list& tools) { + self.set_tools(pyutils::py_object_to_json_container(tools)); + }, py::arg("tools"), R"(Set the tools definitions array.)") + + .def("get_tools", [](const ChatHistory& self) -> py::list { + return pyutils::json_container_to_py_object(self.get_tools()); + }, R"(Get the tools definitions array.)") + + .def("set_extra_context", [](ChatHistory& self, const py::dict& extra_context) { + self.set_extra_context(pyutils::py_object_to_json_container(extra_context)); + }, py::arg("extra_context"), R"(Set the extra context object.)") + + .def("get_extra_context", [](const ChatHistory& self) -> py::dict { + return pyutils::json_container_to_py_object(self.get_extra_context()); + }, R"(Get the extra context object.)"); +} diff --git a/src/python/py_openvino_genai.cpp b/src/python/py_openvino_genai.cpp index 8cec4de360..84ce3ec207 100644 --- a/src/python/py_openvino_genai.cpp +++ b/src/python/py_openvino_genai.cpp @@ -31,6 +31,7 @@ using ov::genai::get_version; void init_lora_adapter(py::module_& m); void init_perf_metrics(py::module_& m); +void init_chat_history(py::module_& m); void init_tokenizer(py::module_& m); void init_streamers(py::module_& m); void init_generation_config(py::module_& m); @@ -115,6 +116,7 @@ PYBIND11_MODULE(py_openvino_genai, m) { init_lora_adapter(m); init_generation_config(m); + init_chat_history(m); init_tokenizer(m); init_streamers(m); diff --git a/src/python/py_tokenizer.cpp b/src/python/py_tokenizer.cpp index 09ed834cf6..87383973d2 100644 --- a/src/python/py_tokenizer.cpp +++ b/src/python/py_tokenizer.cpp @@ -74,7 +74,7 @@ namespace py = pybind11; namespace pyutils = ov::genai::pybind::utils; using ov::genai::ChatHistory; -using ov::genai::ToolDefinitions; +using ov::genai::JsonContainer; using ov::genai::TokenizedInputs; using ov::genai::Tokenizer; @@ -251,29 +251,38 @@ void init_tokenizer(py::module_& m) { R"(Decode a batch of tokens into a list of string prompt.)") .def("apply_chat_template", [](Tokenizer& tok, - const std::vector& history, + const std::variant>& history, bool add_generation_prompt, const std::string& chat_template, - const std::vector& tools, - const py::object& extra_context) { - auto history_anymap = ChatHistory{}; - for (const auto& message : history) { - ov::AnyMap message_anymap = pyutils::py_object_to_any_map(message); - history_anymap.push_back(message_anymap); + const std::optional>& tools, + const std::optional& extra_context) { + ChatHistory chat_history; + std::visit(pyutils::overloaded { + [&](ChatHistory chat_history_obj) { + chat_history = chat_history_obj; + }, + [&](const std::vector& list_of_dicts) { + chat_history = ChatHistory(pyutils::py_object_to_json_container(py::cast(list_of_dicts))); + } + }, history); + + std::optional tools_jc; + if (tools.has_value()) { + tools_jc = pyutils::py_object_to_json_container(py::cast(tools.value())); } - auto tools_anymap = ToolDefinitions{}; - for (const auto& tool : tools) { - ov::AnyMap tool_anymap = pyutils::py_object_to_any_map(tool); - tools_anymap.push_back(tool_anymap); + + std::optional extra_context_jc; + if (extra_context.has_value()) { + extra_context_jc = pyutils::py_object_to_json_container(extra_context.value()); } - ov::AnyMap extra_context_anymap = pyutils::py_object_to_any_map(extra_context); - return tok.apply_chat_template(history_anymap, add_generation_prompt, chat_template, tools_anymap, extra_context_anymap); + + return tok.apply_chat_template(chat_history, add_generation_prompt, chat_template, tools_jc, extra_context_jc); }, py::arg("history"), py::arg("add_generation_prompt"), py::arg("chat_template") = "", - py::arg("tools") = std::vector(), - py::arg("extra_context") = ov::AnyMap({}), + py::arg("tools") = py::none(), + py::arg("extra_context") = py::none(), R"(Applies a chat template to format chat history into a prompt string.)") .def( diff --git a/src/python/py_utils.cpp b/src/python/py_utils.cpp index 531a3c41f4..f6e42bc7cc 100644 --- a/src/python/py_utils.cpp +++ b/src/python/py_utils.cpp @@ -449,4 +449,21 @@ ov::genai::GenerationConfig update_config_from_kwargs(ov::genai::GenerationConfi return config; } +ov::genai::JsonContainer py_object_to_json_container(const py::object& obj) { + if (obj.is_none()) { + return JsonContainer(); + } + // TODO Consider using direct native JsonContainer conversion instead of string serialization + py::module_ json_module = py::module_::import("json"); + std::string json_string = py::cast(json_module.attr("dumps")(obj)); + return JsonContainer::from_json_string(json_string); +} + +py::object json_container_to_py_object(const ov::genai::JsonContainer& container) { + // TODO Consider using direct native JsonContainer conversion instead of string serialization + std::string json_string = container.to_json_string(); + py::module_ json_module = py::module_::import("json"); + return json_module.attr("loads")(json_string); +} + } // namespace ov::genai::pybind::utils diff --git a/src/python/py_utils.hpp b/src/python/py_utils.hpp index 1ee8c8f2d4..8a059e989f 100644 --- a/src/python/py_utils.hpp +++ b/src/python/py_utils.hpp @@ -9,6 +9,7 @@ #include "openvino/genai/streamer_base.hpp" #include "openvino/genai/llm_pipeline.hpp" +#include "openvino/genai/json_container.hpp" namespace py = pybind11; using ov::genai::StreamerBase; @@ -47,4 +48,8 @@ ov::genai::StreamerVariant pystreamer_to_streamer(const PyBindStreamerVariant& p ov::AnyMap py_object_to_any_map(const py::object& py_obj); +ov::genai::JsonContainer py_object_to_json_container(const py::object& obj); + +py::object json_container_to_py_object(const ov::genai::JsonContainer& container); + } // namespace ov::genai::pybind::utils diff --git a/tests/cpp/test_json_container.cpp b/tests/cpp/test_json_container.cpp index d5741fa507..c989120fc1 100644 --- a/tests/cpp/test_json_container.cpp +++ b/tests/cpp/test_json_container.cpp @@ -4,6 +4,7 @@ #include #include "openvino/genai/json_container.hpp" +#include "json_utils.hpp" using namespace ov::genai; @@ -183,7 +184,6 @@ TEST(JsonContainerTest, primitive_values) { EXPECT_TRUE(bool_json.as_bool().has_value()); EXPECT_EQ(bool_json.as_bool().value(), BOOL_VALUE); EXPECT_EQ(bool_json.get_bool(), BOOL_VALUE); - EXPECT_EQ(bool_json.to_json(), BOOL_VALUE); EXPECT_TRUE(int_json.is_number()); EXPECT_TRUE(int_json.is_number_integer()); @@ -192,7 +192,6 @@ TEST(JsonContainerTest, primitive_values) { EXPECT_TRUE(int_json.as_int().has_value()); EXPECT_EQ(int_json.as_int().value(), INT_VALUE); EXPECT_EQ(int_json.get_int(), INT_VALUE); - EXPECT_EQ(int_json.to_json(), INT_VALUE); EXPECT_TRUE(int64_json.is_number()); EXPECT_TRUE(int64_json.is_number_integer()); @@ -201,7 +200,6 @@ TEST(JsonContainerTest, primitive_values) { EXPECT_TRUE(int64_json.as_int().has_value()); EXPECT_EQ(int64_json.as_int().value(), INT64_VALUE); EXPECT_EQ(int64_json.get_int(), INT64_VALUE); - EXPECT_EQ(int64_json.to_json(), INT64_VALUE); EXPECT_TRUE(double_json.is_number()); EXPECT_TRUE(double_json.is_number_float()); @@ -210,7 +208,6 @@ TEST(JsonContainerTest, primitive_values) { EXPECT_TRUE(double_json.as_double().has_value()); EXPECT_DOUBLE_EQ(double_json.as_double().value(), DOUBLE_VALUE); EXPECT_DOUBLE_EQ(double_json.get_double(), DOUBLE_VALUE); - EXPECT_DOUBLE_EQ(double_json.to_json(), DOUBLE_VALUE); EXPECT_TRUE(float_json.is_number()); EXPECT_TRUE(float_json.is_number_float()); @@ -219,27 +216,23 @@ TEST(JsonContainerTest, primitive_values) { EXPECT_TRUE(float_json.as_double().has_value()); EXPECT_FLOAT_EQ(static_cast(float_json.as_double().value()), FLOAT_VALUE); EXPECT_FLOAT_EQ(static_cast(float_json.get_double()), FLOAT_VALUE); - EXPECT_FLOAT_EQ(static_cast(float_json.to_json()), FLOAT_VALUE); EXPECT_TRUE(string_json.is_string()); EXPECT_EQ(string_json.type_name(), "string"); EXPECT_TRUE(string_json.as_string().has_value()); EXPECT_EQ(string_json.as_string().value(), TEST_STRING); EXPECT_EQ(string_json.get_string(), TEST_STRING); - EXPECT_EQ(string_json.to_json(), TEST_STRING); EXPECT_TRUE(c_string_json.is_string()); EXPECT_EQ(c_string_json.type_name(), "string"); EXPECT_TRUE(c_string_json.as_string().has_value()); EXPECT_EQ(c_string_json.as_string().value(), C_STRING_VALUE); EXPECT_EQ(c_string_json.get_string(), C_STRING_VALUE); - EXPECT_EQ(c_string_json.to_json(), C_STRING_VALUE); EXPECT_TRUE(null_json.is_null()); EXPECT_EQ(null_json.type_name(), "null"); EXPECT_EQ(null_json.size(), 0); EXPECT_EQ(null_json.empty(), true); - EXPECT_EQ(null_json.to_json(), nullptr); null_json = "not null"; EXPECT_EQ(null_json.get_string(), "not null"); null_json = nullptr; @@ -307,7 +300,7 @@ TEST(JsonContainerTest, array_operations) { EXPECT_THROW(jc[0].erase(0), ov::Exception); // test erase by index for primitives // Test out-of-bounds access - EXPECT_THROW(jc[100].to_json(), ov::Exception); + EXPECT_THROW(jc[100].as_string(), ov::Exception); // Test out-of-bounds assignment expands array with nulls jc.to_empty_array(); @@ -465,3 +458,36 @@ TEST(JsonContainerTest, json_string) { EXPECT_EQ(parsed["array"].size(), 2); EXPECT_EQ(parsed["object"]["nested"].get_string(), "value"); } + +TEST(JsonContainerTest, subcontainers_modifications) { + JsonContainer messages = JsonContainer::array(); + messages.push_back({{"role", "system"}, {"content", "assistant"}}); + messages.push_back({{"role", "user"}, {"content", "question"}}); + messages.push_back({{"role", "assistant"}, {"content", "answer"}}); + + auto middle = messages[1]; + middle["content"] = "modified_question"; + EXPECT_EQ(messages[1]["content"].get_string(), "modified_question"); + + messages.erase(1); + messages.erase(1); + EXPECT_THROW(middle["content"].get_string(), ov::Exception); +} + +TEST(JsonContainerTest, json_serialization) { + JsonContainer root = JsonContainer::object(); + root["user"]["name"] = "Alice"; + root["user"]["age"] = 30; + root["system"]["version"] = "1.0"; + + JsonContainer user = root["user"]; + + nlohmann::ordered_json json = user; + + EXPECT_TRUE(json.contains("name")); + EXPECT_TRUE(json.contains("age")); + EXPECT_FALSE(json.contains("system")); + + EXPECT_EQ(json["name"].get(), "Alice"); + EXPECT_EQ(json["age"].get(), 30); +} diff --git a/tests/python_tests/test_tokenizer.py b/tests/python_tests/test_tokenizer.py index 866139f8cf..cdc646bf13 100644 --- a/tests/python_tests/test_tokenizer.py +++ b/tests/python_tests/test_tokenizer.py @@ -10,7 +10,7 @@ import pytest from data.models import get_models_list from data.tokenizer_configs import get_tokenizer_configs -from openvino_genai import Tokenizer +from openvino_genai import Tokenizer, ChatHistory from openvino_tokenizers import convert_tokenizer from transformers import AutoTokenizer @@ -48,6 +48,10 @@ def get_chat_templates(): return [(k, v) for k, v in get_tokenizer_configs().items() if k not in skipped_models] +def assert_hf_equals_genai(hf_str: str, genai_str: str): + assert hf_str == genai_str, f"HF reference:\n{hf_str}\nGenAI output:\n{genai_str}" + + prompts = [ "table is made of", "你好! 你好嗎?", @@ -145,7 +149,7 @@ def test_apply_chat_template(model_tmp_path, chat_config: tuple[str, dict], ov_h ov_tokenizer.set_chat_template(tokenizer_config["chat_template"]) ov_full_history_str = ov_tokenizer.apply_chat_template(conversation, add_generation_prompt=False) - assert ov_full_history_str == hf_full_history_str, f"HF reference:\n{hf_full_history_str}\nGenAI output:\n{ov_full_history_str}" + assert_hf_equals_genai(hf_full_history_str, ov_full_history_str) @pytest.mark.precommit @@ -175,7 +179,17 @@ def test_apply_chat_template_nested_content(model_tmp_path, ov_hf_tokenizers, to messages, add_generation_prompt=add_generation_prompt ) - assert ov_full_history_str == hf_full_history_str, f"HF reference:\n{hf_full_history_str}\nGenAI output:\n{ov_full_history_str}" + assert_hf_equals_genai(hf_full_history_str, ov_full_history_str) + + chat_history = ChatHistory(messages) + + assert chat_history.get_messages() == messages + + genai_templated_chat_history = genai_tokenizer.apply_chat_template( + chat_history, add_generation_prompt=add_generation_prompt + ) + + assert genai_templated_chat_history == ov_full_history_str @pytest.mark.precommit @@ -198,8 +212,6 @@ def test_apply_chat_template_with_tools_and_extra_context(model_tmp_path, ov_hf_ } } }] - # In GenAI order of dict keys is not preserved (sorted alphabetically, due to conversion to AnyMap) - tools = [json.loads(json.dumps(tool, sort_keys=True)) for tool in tools] add_generation_prompt = True @@ -215,7 +227,24 @@ def test_apply_chat_template_with_tools_and_extra_context(model_tmp_path, ov_hf_ conversation, add_generation_prompt=add_generation_prompt, tools=tools, extra_context=extra_context ) - assert ov_full_history_str == hf_full_history_str, f"HF reference:\n{hf_full_history_str}\nGenAI output:\n{ov_full_history_str}" + assert_hf_equals_genai(hf_full_history_str, ov_full_history_str) + + # Test tools and extra context set via chat history state + chat_history = ChatHistory(conversation) + chat_history.set_tools(tools) + chat_history.set_extra_context(extra_context) + genai_templated_chat_history = genai_tokenizer.apply_chat_template( + chat_history, add_generation_prompt=add_generation_prompt + ) + assert_hf_equals_genai(genai_templated_chat_history, ov_full_history_str) + + # Test apply_chat_template tools and extra_context arguments prioritized over chat history state + chat_history.set_tools([]) + chat_history.set_extra_context({}) + genai_templated_chat_history = genai_tokenizer.apply_chat_template( + chat_history, add_generation_prompt=add_generation_prompt, tools=tools, extra_context=extra_context + ) + assert_hf_equals_genai(genai_templated_chat_history, ov_full_history_str) @pytest.mark.precommit @@ -233,7 +262,7 @@ def test_non_string_chat_template(hf_ov_genai_models): ov_full_history_str = genai_tokenzier.apply_chat_template(conversation, add_generation_prompt=False) - assert ov_full_history_str == hf_full_history_str, f"HF reference:\n{hf_full_history_str}\nGenAI output:\n{ov_full_history_str}" + assert_hf_equals_genai(hf_full_history_str, ov_full_history_str) @pytest.mark.precommit