diff --git a/hdr/sqlite_modern_cpp.h b/hdr/sqlite_modern_cpp.h index 0d328470..0d8065be 100644 --- a/hdr/sqlite_modern_cpp.h +++ b/hdr/sqlite_modern_cpp.h @@ -8,8 +8,6 @@ #include #include #include -#include -#include #define MODERN_SQLITE_VERSION 3002008 @@ -45,6 +43,7 @@ #include "sqlite_modern_cpp/errors.h" #include "sqlite_modern_cpp/utility/function_traits.h" #include "sqlite_modern_cpp/utility/uncaught_exceptions.h" +#include "sqlite_modern_cpp/utility/utf16_utf8.h" #ifdef MODERN_SQLITE_STD_VARIANT_SUPPORT #include "sqlite_modern_cpp/utility/variant.h" @@ -183,15 +182,9 @@ namespace sqlite { } } -#ifdef _MSC_VER sqlite3_stmt* _prepare(const std::u16string& sql) { - return _prepare(std::wstring_convert, wchar_t>().to_bytes(reinterpret_cast(sql.c_str()))); + return _prepare(utility::utf16_to_utf8(sql)); } -#else - sqlite3_stmt* _prepare(const std::u16string& sql) { - return _prepare(std::wstring_convert, char16_t>().to_bytes(sql)); - } -#endif sqlite3_stmt* _prepare(const std::string& sql) { int hresult; @@ -421,11 +414,7 @@ namespace sqlite { } database(const std::u16string &db_name, const sqlite_config &config = {}): _db(nullptr) { -#ifdef _MSC_VER - auto db_name_utf8 = std::wstring_convert, wchar_t>().to_bytes(reinterpret_cast(db_name.c_str())); -#else - auto db_name_utf8 = std::wstring_convert, char16_t>().to_bytes(db_name); -#endif + auto db_name_utf8 = utility::utf16_to_utf8(db_name); sqlite3* tmp = nullptr; auto ret = sqlite3_open_v2(db_name_utf8.data(), &tmp, static_cast(config.flags), config.zVfs); _db = std::shared_ptr(tmp, [=](sqlite3* ptr) { sqlite3_close_v2(ptr); }); // this will close the connection eventually when no longer needed. diff --git a/hdr/sqlite_modern_cpp/errors.h b/hdr/sqlite_modern_cpp/errors.h index 6c75b7ae..2b9ab75d 100644 --- a/hdr/sqlite_modern_cpp/errors.h +++ b/hdr/sqlite_modern_cpp/errors.h @@ -38,6 +38,7 @@ namespace sqlite { class more_rows: public sqlite_exception { using sqlite_exception::sqlite_exception; }; class no_rows: public sqlite_exception { using sqlite_exception::sqlite_exception; }; class more_statements: public sqlite_exception { using sqlite_exception::sqlite_exception; }; // Prepared statements can only contain one statement + class invalid_utf16: public sqlite_exception { using sqlite_exception::sqlite_exception; }; static void throw_sqlite_error(const int& error_code, const std::string &sql = "") { switch(error_code & 0xFF) { diff --git a/hdr/sqlite_modern_cpp/utility/utf16_utf8.h b/hdr/sqlite_modern_cpp/utility/utf16_utf8.h new file mode 100644 index 00000000..f2fa5ad5 --- /dev/null +++ b/hdr/sqlite_modern_cpp/utility/utf16_utf8.h @@ -0,0 +1,41 @@ +#pragma once + +#include +#include + +#include "../errors.h" + +namespace sqlite { + namespace utility { + inline std::string utf16_to_utf8(const std::u16string &input) { + struct : std::codecvt { + } codecvt; + std::mbstate_t state; + std::string result(std::max(input.size() * 3 / 2, std::size_t(4)), '\0'); + const char16_t *remaining_input = input.data(); + std::size_t produced_output = 0; + while(true) { + char *used_output; + switch(codecvt.out(state, remaining_input, &input[input.size()], + remaining_input, &result[produced_output], + &result[result.size() - 1] + 1, used_output)) { + case std::codecvt_base::ok: + result.resize(used_output - result.data()); + return result; + case std::codecvt_base::noconv: + // This should be unreachable + case std::codecvt_base::error: + throw errors::invalid_utf16("Invalid UTF-16 input", ""); + case std::codecvt_base::partial: + if(used_output == result.data() + produced_output) + throw errors::invalid_utf16("Unexpected end of input", ""); + produced_output = used_output - result.data(); + result.resize( + result.size() + + std::max((&input[input.size()] - remaining_input) * 3 / 2, + std::ptrdiff_t(4))); + } + } + } + } // namespace utility +} // namespace sqlite