diff --git a/.github/workflows/cpp.yml b/.github/workflows/cpp.yml index 0961e66ac8a..56996c98ef5 100644 --- a/.github/workflows/cpp.yml +++ b/.github/workflows/cpp.yml @@ -398,6 +398,7 @@ jobs: ARROW_FLIGHT: ON ARROW_FLIGHT_SQL: ON ARROW_FLIGHT_SQL_ODBC: ON + ARROW_FLIGHT_SQL_ODBC_INSTALLER: ON ARROW_GANDIVA: ON ARROW_GCS: ON ARROW_HDFS: OFF @@ -476,6 +477,10 @@ jobs: PIPX_BASE_PYTHON: ${{ steps.python-install.outputs.python-path }} run: | ci/scripts/install_gcs_testbench.sh default + - name: Register Flight SQL ODBC Driver + shell: cmd + run: | + call "cpp\src\arrow\flight\sql\odbc\install\install_amd64.cmd" ${{github.workspace}}\build\cpp\%ARROW_BUILD_TYPE%\libarrow_flight_sql_odbc.dll - name: Test shell: msys2 {0} run: | diff --git a/.gitignore b/.gitignore index 8354aa8f816..64c713a74ed 100644 --- a/.gitignore +++ b/.gitignore @@ -32,6 +32,7 @@ dependency-reduced-pom.xml MANIFEST compile_commands.json build.ninja +build*/ # Generated Visual Studio files *.vcxproj @@ -107,3 +108,6 @@ java/.mvn/.develocity/ # rat filtered_rat.txt rat.txt + +# rc +*.rc diff --git a/ci/scripts/cpp_build.sh b/ci/scripts/cpp_build.sh index 8df5ec2b2d0..ad4f5ac5904 100755 --- a/ci/scripts/cpp_build.sh +++ b/ci/scripts/cpp_build.sh @@ -64,6 +64,7 @@ if [ "${ARROW_ENABLE_THREADING:-ON}" = "OFF" ]; then ARROW_AZURE=OFF ARROW_FLIGHT=OFF ARROW_FLIGHT_SQL=OFF + ARROW_FLIGHT_SQL_ODBC=OFF ARROW_GCS=OFF ARROW_JEMALLOC=OFF ARROW_MIMALLOC=OFF @@ -206,6 +207,8 @@ else -DARROW_FILESYSTEM=${ARROW_FILESYSTEM:-ON} \ -DARROW_FLIGHT=${ARROW_FLIGHT:-OFF} \ -DARROW_FLIGHT_SQL=${ARROW_FLIGHT_SQL:-OFF} \ + -DARROW_FLIGHT_SQL_ODBC=${ARROW_FLIGHT_SQL_ODBC:-OFF} \ + -DARROW_FLIGHT_SQL_ODBC_INSTALLER=${ARROW_FLIGHT_SQL_ODBC_INSTALLER:-OFF} \ -DARROW_FUZZING=${ARROW_FUZZING:-OFF} \ -DARROW_GANDIVA_PC_CXX_FLAGS=${ARROW_GANDIVA_PC_CXX_FLAGS:-} \ -DARROW_GANDIVA=${ARROW_GANDIVA:-OFF} \ diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 18841ac874b..8d9e1f7a1ef 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -713,9 +713,13 @@ endif() install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/../LICENSE.txt ${CMAKE_CURRENT_SOURCE_DIR}/../NOTICE.txt - ${CMAKE_CURRENT_SOURCE_DIR}/README.md DESTINATION "${ARROW_DOC_DIR}") + ${CMAKE_CURRENT_SOURCE_DIR}/README.md + DESTINATION "${ARROW_DOC_DIR}" + COMPONENT arrow_doc) -install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/gdb_arrow.py DESTINATION "${ARROW_GDB_DIR}") +install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/gdb_arrow.py + DESTINATION "${ARROW_GDB_DIR}" + COMPONENT arrow_gdb) # # Validate and print out Arrow configuration options diff --git a/cpp/CMakePresets.json b/cpp/CMakePresets.json index c9e2444389f..db0a4ddd061 100644 --- a/cpp/CMakePresets.json +++ b/cpp/CMakePresets.json @@ -179,6 +179,7 @@ "ARROW_BUILD_EXAMPLES": "ON", "ARROW_BUILD_UTILITIES": "ON", "ARROW_FLIGHT_SQL_ODBC": "ON", + "ARROW_FLIGHT_SQL_ODBC_INSTALLER": "ON", "ARROW_TENSORFLOW": "ON", "PARQUET_BUILD_EXAMPLES": "ON", "PARQUET_BUILD_EXECUTABLES": "ON" diff --git a/cpp/cmake_modules/BuildUtils.cmake b/cpp/cmake_modules/BuildUtils.cmake index db760400f7c..305546572c4 100644 --- a/cpp/cmake_modules/BuildUtils.cmake +++ b/cpp/cmake_modules/BuildUtils.cmake @@ -178,10 +178,12 @@ function(arrow_install_cmake_package PACKAGE_NAME EXPORT_NAME) write_basic_package_version_file("${BUILT_CONFIG_VERSION_CMAKE}" COMPATIBILITY SameMajorVersion) install(FILES "${BUILT_CONFIG_CMAKE}" "${BUILT_CONFIG_VERSION_CMAKE}" - DESTINATION "${ARROW_CMAKE_DIR}/${PACKAGE_NAME}") + DESTINATION "${ARROW_CMAKE_DIR}/${PACKAGE_NAME}" + COMPONENT config_cmake_file) set(TARGETS_CMAKE "${PACKAGE_NAME}Targets.cmake") install(EXPORT ${EXPORT_NAME} DESTINATION "${ARROW_CMAKE_DIR}/${PACKAGE_NAME}" + COMPONENT config_cmake_export NAMESPACE "${PACKAGE_NAME}::" FILE "${TARGETS_CMAKE}") endfunction() @@ -403,8 +405,11 @@ function(ADD_ARROW_LIB LIB_NAME) install(TARGETS ${LIB_NAME}_shared ${INSTALL_IS_OPTIONAL} EXPORT ${LIB_NAME}_targets ARCHIVE DESTINATION ${INSTALL_ARCHIVE_DIR} + COMPONENT ${LIB_NAME}_shared_archive LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR} + COMPONENT ${LIB_NAME}_shared_library RUNTIME DESTINATION ${INSTALL_RUNTIME_DIR} + COMPONENT ${LIB_NAME}_shared_runtime INCLUDES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}) endif() @@ -471,8 +476,11 @@ function(ADD_ARROW_LIB LIB_NAME) install(TARGETS ${LIB_NAME}_static ${INSTALL_IS_OPTIONAL} EXPORT ${LIB_NAME}_targets ARCHIVE DESTINATION ${INSTALL_ARCHIVE_DIR} + COMPONENT ${LIB_NAME}_static_library LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR} + COMPONENT ${LIB_NAME}_static_library RUNTIME DESTINATION ${INSTALL_RUNTIME_DIR} + COMPONENT ${LIB_NAME}_static_library INCLUDES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}) endif() @@ -934,7 +942,9 @@ function(ARROW_INSTALL_ALL_HEADERS PATH) endif() list(APPEND PUBLIC_HEADERS ${HEADER}) endforeach() - install(FILES ${PUBLIC_HEADERS} DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/${PATH}") + install(FILES ${PUBLIC_HEADERS} + DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/${PATH}" + COMPONENT ${HEADER}_header) endfunction() function(ARROW_ADD_PKG_CONFIG MODULE) @@ -944,7 +954,8 @@ function(ARROW_ADD_PKG_CONFIG MODULE) OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/$/${MODULE}.pc" INPUT "${CMAKE_CURRENT_BINARY_DIR}/${MODULE}.pc.generate.in") install(FILES "${CMAKE_CURRENT_BINARY_DIR}/$/${MODULE}.pc" - DESTINATION "${CMAKE_INSTALL_LIBDIR}/pkgconfig/") + DESTINATION "${CMAKE_INSTALL_LIBDIR}/pkgconfig/" + COMPONENT ${MODULE}_pkg_config) endfunction() # Implementations of lisp "car" and "cdr" functions diff --git a/cpp/cmake_modules/DefineOptions.cmake b/cpp/cmake_modules/DefineOptions.cmake index faac95c4004..1b28f70f4ce 100644 --- a/cpp/cmake_modules/DefineOptions.cmake +++ b/cpp/cmake_modules/DefineOptions.cmake @@ -108,7 +108,7 @@ endmacro() macro(resolve_option_dependencies) # Arrow Flight SQL ODBC is available only for Windows for now. - if(NOT MSVC_TOOLCHAIN) + if(NOT WIN32) set(ARROW_FLIGHT_SQL_ODBC OFF) endif() if(MSVC_TOOLCHAIN) diff --git a/cpp/cmake_modules/ThirdpartyToolchain.cmake b/cpp/cmake_modules/ThirdpartyToolchain.cmake index 6e7544a707d..be9744828e8 100644 --- a/cpp/cmake_modules/ThirdpartyToolchain.cmake +++ b/cpp/cmake_modules/ThirdpartyToolchain.cmake @@ -236,7 +236,8 @@ function(provide_cmake_module MODULE_NAME ARROW_CMAKE_PACKAGE_NAME) message(STATUS "Providing CMake module for ${MODULE_NAME} as part of ${ARROW_CMAKE_PACKAGE_NAME} CMake package" ) install(FILES "${module}" - DESTINATION "${ARROW_CMAKE_DIR}/${ARROW_CMAKE_PACKAGE_NAME}") + DESTINATION "${ARROW_CMAKE_DIR}/${ARROW_CMAKE_PACKAGE_NAME}" + COMPONENT ${MODULE_NAME}_module) endif() endfunction() @@ -1269,7 +1270,7 @@ if(ARROW_USE_BOOST) endif() if(ARROW_BOOST_REQUIRE_LIBRARY) set(ARROW_BOOST_COMPONENTS filesystem system) - if(ARROW_FLIGHT_SQL_ODBC AND MSVC) + if(ARROW_FLIGHT_SQL_ODBC) list(APPEND ARROW_BOOST_COMPONENTS locale) endif() if(ARROW_ENABLE_THREADING) @@ -2375,20 +2376,22 @@ function(build_gtest) endforeach() install(DIRECTORY "${googletest_SOURCE_DIR}/googlemock/include/" "${googletest_SOURCE_DIR}/googletest/include/" - DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}") + DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}" + COMPONENT gtest_dir) add_library(arrow::GTest::gtest_headers INTERFACE IMPORTED) target_include_directories(arrow::GTest::gtest_headers INTERFACE "${googletest_SOURCE_DIR}/googlemock/include/" "${googletest_SOURCE_DIR}/googletest/include/") install(TARGETS gmock gmock_main gtest gtest_main EXPORT arrow_testing_targets - RUNTIME DESTINATION "${CMAKE_INSTALL_BINDIR}" - ARCHIVE DESTINATION "${CMAKE_INSTALL_LIBDIR}" - LIBRARY DESTINATION "${CMAKE_INSTALL_LIBDIR}") + RUNTIME DESTINATION "${CMAKE_INSTALL_BINDIR}" COMPONENT gtest_runtime + ARCHIVE DESTINATION "${CMAKE_INSTALL_LIBDIR}" COMPONENT gtest_archive + LIBRARY DESTINATION "${CMAKE_INSTALL_LIBDIR}" COMPONENT gtest_library) if(MSVC) install(FILES $ $ $ $ DESTINATION "${CMAKE_INSTALL_BINDIR}" + COMPONENT gtest_pdb OPTIONAL) endif() add_library(arrow::GTest::gmock ALIAS gmock) diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 42b0bcc151c..9b2a6567508 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -345,7 +345,8 @@ endmacro() configure_file("util/config.h.cmake" "util/config.h" ESCAPE_QUOTES) configure_file("util/config_internal.h.cmake" "util/config_internal.h" ESCAPE_QUOTES) install(FILES "${CMAKE_CURRENT_BINARY_DIR}/util/config.h" - DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/arrow/util") + DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/arrow/util" + COMPONENT arrow_config) set(ARROW_SRCS builder.cc @@ -1039,7 +1040,8 @@ if(ARROW_BUILD_BUNDLED_DEPENDENCIES) get_target_property(arrow_bundled_dependencies_path arrow_bundled_dependencies IMPORTED_LOCATION) install(FILES ${arrow_bundled_dependencies_path} ${INSTALL_IS_OPTIONAL} - DESTINATION ${CMAKE_INSTALL_LIBDIR}) + DESTINATION ${CMAKE_INSTALL_LIBDIR} + COMPONENT arrow_bundled_dependencies) string(PREPEND ARROW_PC_LIBS_PRIVATE " -larrow_bundled_dependencies") list(INSERT ARROW_STATIC_INSTALL_INTERFACE_LIBS 0 "Arrow::arrow_bundled_dependencies") endif() @@ -1156,6 +1158,7 @@ if(ARROW_BUILD_SHARED AND NOT WIN32) if(ARROW_GDB_AUTO_LOAD_LIBARROW_GDB_INSTALL) install(FILES "${CMAKE_CURRENT_BINARY_DIR}/libarrow_gdb.py" DESTINATION "${ARROW_GDB_AUTO_LOAD_LIBARROW_GDB_DIR}" + COMPONENT arrow_gdb RENAME "$-gdb.py") endif() endif() @@ -1219,11 +1222,13 @@ arrow_install_all_headers("arrow") config_summary_cmake_setters("${CMAKE_CURRENT_BINARY_DIR}/ArrowOptions.cmake") install(FILES ${CMAKE_CURRENT_BINARY_DIR}/ArrowOptions.cmake - DESTINATION "${ARROW_CMAKE_DIR}/Arrow") + DESTINATION "${ARROW_CMAKE_DIR}/Arrow" + COMPONENT arrow_options_cmake) # For backward compatibility for find_package(arrow) install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/arrow-config.cmake - DESTINATION "${ARROW_CMAKE_DIR}/Arrow") + DESTINATION "${ARROW_CMAKE_DIR}/Arrow" + COMPONENT arrow_config_cmake) # # Unit tests diff --git a/cpp/src/arrow/flight/CMakeLists.txt b/cpp/src/arrow/flight/CMakeLists.txt index a827a7307f8..bf7d9c7f1f9 100644 --- a/cpp/src/arrow/flight/CMakeLists.txt +++ b/cpp/src/arrow/flight/CMakeLists.txt @@ -295,7 +295,9 @@ if(ARROW_TESTING) STATIC_INSTALL_INTERFACE_LIBS ${ARROW_FLIGHT_TESTING_STATIC_INSTALL_INTERFACE_LIBS} PRIVATE_INCLUDES - "${Protobuf_INCLUDE_DIRS}") + "${Protobuf_INCLUDE_DIRS}" + SHARED_PRIVATE_LINK_LIBS + GTest::gmock) foreach(LIB_TARGET ${ARROW_FLIGHT_TESTING_LIBRARIES}) target_compile_definitions(${LIB_TARGET} PRIVATE ARROW_FLIGHT_EXPORTING) diff --git a/cpp/src/arrow/flight/sql/column_metadata.cc b/cpp/src/arrow/flight/sql/column_metadata.cc index 30f557084b2..8d2d2b4ddca 100644 --- a/cpp/src/arrow/flight/sql/column_metadata.cc +++ b/cpp/src/arrow/flight/sql/column_metadata.cc @@ -58,8 +58,15 @@ const char* ColumnMetadata::kIsSearchable = "ARROW:FLIGHT:SQL:IS_SEARCHABLE"; const char* ColumnMetadata::kRemarks = "ARROW:FLIGHT:SQL:REMARKS"; ColumnMetadata::ColumnMetadata( - std::shared_ptr metadata_map) - : metadata_map_(std::move(metadata_map)) {} + std::shared_ptr metadata_map) { + if (metadata_map) { + metadata_map_ = std::move(metadata_map); + } else { + std::shared_ptr empty_metadata_map( + new arrow::KeyValueMetadata); + metadata_map_ = std::move(empty_metadata_map); + } +} arrow::Result ColumnMetadata::GetCatalogName() const { return metadata_map_->Get(kCatalogName); diff --git a/cpp/src/arrow/flight/sql/odbc/ArrowFlightSqlOdbcConfig.cmake.in b/cpp/src/arrow/flight/sql/odbc/ArrowFlightSqlOdbcConfig.cmake.in new file mode 100644 index 00000000000..da6d44ebc82 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/ArrowFlightSqlOdbcConfig.cmake.in @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# This config sets the following variables in your project:: +# +# ArrowFlightSqlOdbc_FOUND - true if Arrow Flight SQL ODBC found on the system +# +# This config sets the following targets in your project:: +# +# ArrowFlightSqlOdbc::arrow_flight_sql_odbc_shared - for linked as shared library if shared library is built +# ArrowFlightSqlOdbc::arrow_flight_sql_odbc_static - for linked as static library if static library is built + +@PACKAGE_INIT@ + +include(CMakeFindDependencyMacro) +find_dependency(ArrowFlightSql) + +include("${CMAKE_CURRENT_LIST_DIR}/ArrowFlightSqlOdbcTargets.cmake") + +arrow_keep_backward_compatibility(ArrowFlightSqlOdbc arrow_flight_sql_odbc) + +check_required_components(ArrowFlightSqlOdbc) + +arrow_show_details(ArrowFlightSqlOdbc ARROW_FLIGHT_SQL_ODBC) diff --git a/cpp/src/arrow/flight/sql/odbc/CMakeLists.txt b/cpp/src/arrow/flight/sql/odbc/CMakeLists.txt index 80be0dee99f..d641873514b 100644 --- a/cpp/src/arrow/flight/sql/odbc/CMakeLists.txt +++ b/cpp/src/arrow/flight/sql/odbc/CMakeLists.txt @@ -15,7 +15,169 @@ # specific language governing permissions and limitations # under the License. +# Use C++ 20 for ODBC and its subdirectory +# GH-44792: Arrow will switch to C++ 20 +set(CMAKE_CXX_STANDARD 20) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + add_custom_target(arrow_flight_sql_odbc) +# Ensure fmt is loaded as header only +add_compile_definitions(FMT_HEADER_ONLY) + +if(WIN32) + if(MSVC_VERSION GREATER_EQUAL 1900) + set(ODBCINST legacy_stdio_definitions odbccp32 shlwapi) + elseif(MINGW) + set(ODBCINST odbccp32 shlwapi) + endif() +elseif(APPLE) + set(ODBCINST iodbcinst) +else() + set(ODBCINST odbcinst) +endif() + +add_definitions(-DUNICODE=1) + +include(FetchContent) +fetchcontent_declare(spdlog + URL https://github.com/gabime/spdlog/archive/refs/tags/v1.15.3.zip + CONFIGURE_COMMAND + "" + BUILD_COMMAND + "") +fetchcontent_makeavailable(spdlog) + add_subdirectory(flight_sql) add_subdirectory(odbcabstraction) +add_subdirectory(tests) + +arrow_install_all_headers("arrow/flight/sql/odbc") + +# ODBC Release information +set(ODBC_PACKAGE_VERSION_MAJOR "1") +set(ODBC_PACKAGE_VERSION_MINOR "0") +set(ODBC_PACKAGE_VERSION_PATCH "0") +set(ODBC_PACKAGE_NAME "Apache Arrow Flight SQL ODBC") +set(ODBC_PACKAGE_VENDOR "Apache Arrow") + +set(ARROW_FLIGHT_SQL_ODBC_SRCS entry_points.cc odbc_api.cc) + +if(WIN32) + set(VER_FILEVERSION + "${ODBC_PACKAGE_VERSION_MAJOR},${ODBC_PACKAGE_VERSION_MINOR},${ODBC_PACKAGE_VERSION_PATCH},0" + ) + set(VER_FILEVERSION_STR + ${ODBC_PACKAGE_VERSION_MAJOR}.${ODBC_PACKAGE_VERSION_MINOR}.${ODBC_PACKAGE_VERSION_PATCH} + ) + set(VER_COMPANYNAME_STR ${ODBC_PACKAGE_VENDOR}) + set(VER_PRODUCTNAME_STR ${ODBC_PACKAGE_NAME}) + + configure_file("install/versioninfo.rc.in" "install/versioninfo.rc" @ONLY) + + list(APPEND ARROW_FLIGHT_SQL_ODBC_SRCS odbc.def install/versioninfo.rc) +endif() + +add_arrow_lib(arrow_flight_sql_odbc + CMAKE_PACKAGE_NAME + ArrowFlightSqlOdbc + PKG_CONFIG_NAME + arrow-flight-sql-odbc + OUTPUTS + ARROW_FLIGHT_SQL_ODBC_LIBRARIES + SOURCES + ${ARROW_FLIGHT_SQL_ODBC_SRCS} + DEPENDENCIES + arrow_flight_sql + SHARED_LINK_FLAGS + ${ARROW_VERSION_SCRIPT_FLAGS} # Defined in cpp/arrow/CMakeLists.txt + SHARED_LINK_LIBS + arrow_flight_sql_shared + SHARED_INSTALL_INTERFACE_LIBS + ArrowFlight::arrow_flight_sql_shared + STATIC_LINK_LIBS + arrow_flight_sql_static + STATIC_INSTALL_INTERFACE_LIBS + ArrowFlight::arrow_flight_sql_static + SHARED_PRIVATE_LINK_LIBS + ${ODBC_LIBRARIES} + ${ODBCINST} + odbcabstraction + arrow_odbc_spi_impl + spdlog::spdlog) + +foreach(LIB_TARGET ${ARROW_FLIGHT_SQL_ODBC_LIBRARIES}) + target_compile_definitions(${LIB_TARGET} PRIVATE ARROW_FLIGHT_SQL_ODBC_EXPORTING) +endforeach() + +# Construct ODBC Windows installer. Only Release installer is supported +if(ARROW_FLIGHT_SQL_ODBC_INSTALLER) + + include(InstallRequiredSystemLibraries) + + set(CPACK_RESOURCE_FILE_LICENSE + "${CMAKE_CURRENT_SOURCE_DIR}/../../../../../../LICENSE.txt") + # Tentative version 1.0.0 + set(CPACK_PACKAGE_VERSION_MAJOR ${ODBC_PACKAGE_VERSION_MAJOR}) + set(CPACK_PACKAGE_VERSION_MINOR ${ODBC_PACKAGE_VERSION_MINOR}) + set(CPACK_PACKAGE_VERSION_PATCH ${ODBC_PACKAGE_VERSION_PATCH}) + + set(CPACK_PACKAGE_NAME ${ODBC_PACKAGE_NAME}) + set(CPACK_PACKAGE_VENDOR ${ODBC_PACKAGE_VENDOR}) + set(CPACK_PACKAGE_DESCRIPTION_SUMMARY "Apache Arrow Flight SQL ODBC Driver") + set(CPACK_PACKAGE_CONTACT "#TODO arrow maintainers") + + # TODO: set up `flight_sql_odbc_lib` component for macOS Installer + # TODO: set up `flight_sql_odbc_lib` component for Linux Installer + if(WIN32) + install(DIRECTORY "${BUILD_OUTPUT_ROOT_DIRECTORY}${CMAKE_BUILD_TYPE}/" + DESTINATION bin + COMPONENT flight_sql_odbc_lib + FILES_MATCHING + # Use regex for dll name patterns with versions + PATTERN "abseil_dll.dll" + PATTERN "arrow.dll" + PATTERN "arrow_compute.dll" + PATTERN "arrow_flight.dll" + PATTERN "arrow_flight_sql.dll" + PATTERN "arrow_flight_sql_odbc.dll" + PATTERN "boost_locale*.dll" + PATTERN "cares.dll" + PATTERN "libcrypto*.dll" + PATTERN "libprotobuf.dll" + PATTERN "libssl*.dll" + PATTERN "re2.dll" + PATTERN "utf8proc.dll" + PATTERN "zlib1.dll") + + set(CPACK_WIX_EXTRA_SOURCES + "${CMAKE_CURRENT_SOURCE_DIR}/install/arrow-flight-sql-odbc.wxs") + set(CPACK_WIX_PATCH_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/install/arrow-flight-sql-odbc-patch.xml") + + set(CPACK_WIX_UI_BANNER "${CMAKE_CURRENT_SOURCE_DIR}/install/arrow-wix-banner.bmp") + endif() + + get_cmake_property(CPACK_COMPONENTS_ALL COMPONENTS) + set(CPACK_COMPONENTS_ALL Unspecified) + list(APPEND CPACK_COMPONENTS_ALL "flight_sql_odbc_lib") + + if(WIN32) + # WiX msi installer on Windows + # CPack is compatible with WiX V.5 and V.6 + set(CPACK_GENERATOR "WIX") + set(CPACK_WIX_VERSION 4) + + # Upgrade GUID is required to be unchanged for ODBC installer to upgrade + set(CPACK_WIX_UPGRADE_GUID "DBF27A18-F8BF-423F-9E3A-957414D52C4B") + endif() + # TODO: create macOS Installer using cpack + # TODO: create Linux Installer using cpack + + # Load CPack after all CPACK* variables are set + include(CPack) + cpack_add_component(flight_sql_odbc_lib + DISPLAY_NAME "ODBC library" + DESCRIPTION "ODBC library bin, required to install" + REQUIRED) +endif() diff --git a/cpp/src/arrow/flight/sql/odbc/README b/cpp/src/arrow/flight/sql/odbc/README new file mode 100644 index 00000000000..da9857b7ecc --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/README @@ -0,0 +1,26 @@ +Steps to Register the 64-bit Apache Arrow ODBC driver on Windows + +After the build succeeds, the ODBC DLL will be located in +`build\debug\Debug` for a debug build and `build\release\Release` for a release build. + +1. Open Power Shell as administrator. + +2. Register your ODBC DLL: + Need to replace with actual path to repository in the commands. + + i. `cd to repo.` + ii. `cd ` + iii. Run script to register your ODBC DLL as Apache Arrow Flight SQL ODBC Driver + `.\cpp\src\arrow\flight\sql\odbc\install\install_amd64.cmd \cpp\build\< release | debug >\< Release | Debug>\arrow_flight_sql_odbc.dll` + Example command for reference: + `.\cpp\src\arrow\flight\sql\odbc\install\install_amd64.cmd C:\path\to\arrow\cpp\build\release\Release\arrow_flight_sql_odbc.dll` + +If the registration is successful, then Apache Arrow Flight SQL ODBC Driver +should show as an available ODBC driver in the x64 ODBC Driver Manager. + +Steps to Generate Windows Installer +1. Build with `ARROW_FLIGHT_SQL_ODBC=ON` and `ARROW_FLIGHT_SQL_ODBC_INSTALLER=ON`. +2. `cd` to `build` folder. +3. Run `cpack`. + +If the generation is successful, you will find `Apache Arrow Flight SQL ODBC--win64.msi` generated under the `build` folder. diff --git a/cpp/src/arrow/flight/sql/odbc/arrow-flight-sql-odbc.pc.in b/cpp/src/arrow/flight/sql/odbc/arrow-flight-sql-odbc.pc.in new file mode 100644 index 00000000000..78959034954 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/arrow-flight-sql-odbc.pc.in @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +prefix=@CMAKE_INSTALL_PREFIX@ +includedir=@ARROW_PKG_CONFIG_INCLUDEDIR@ +libdir=@ARROW_PKG_CONFIG_LIBDIR@ + +Name: Apache Arrow Flight SQL ODBC +Description: Apache Arrow Flight SQL ODBC extension +Version: @ARROW_VERSION@ +Requires: arrow-flight-sql +Libs: -L${libdir} -larrow_flight_sql_odbc +Cflags.private: -DARROW_FLIGHT_SQL_ODBC_STATIC diff --git a/cpp/src/arrow/flight/sql/odbc/entry_points.cc b/cpp/src/arrow/flight/sql/odbc/entry_points.cc new file mode 100644 index 00000000000..38b4a1fc8ed --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/entry_points.cc @@ -0,0 +1,299 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// platform.h includes windows.h, so it needs to be included first +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/platform.h" + +#ifdef _WIN32 +# include +#endif + +#include +#include +#include +#include + +#include "arrow/flight/sql/odbc/odbc_api.h" +#include "arrow/flight/sql/odbc/visibility.h" + +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/odbc_connection.h" +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/odbc_descriptor.h" +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/odbc_environment.h" +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/odbc_statement.h" + +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/logger.h" + +SQLRETURN SQL_API SQLAllocHandle(SQLSMALLINT type, SQLHANDLE parent, SQLHANDLE* result) { + return arrow::SQLAllocHandle(type, parent, result); +} + +SQLRETURN SQL_API SQLAllocEnv(SQLHENV* env) { + return arrow::SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, env); +} + +SQLRETURN SQL_API SQLAllocConnect(SQLHENV env, SQLHDBC* conn) { + return arrow::SQLAllocHandle(SQL_HANDLE_DBC, env, conn); +} + +SQLRETURN SQL_API SQLAllocStmt(SQLHDBC conn, SQLHSTMT* stmt) { + return arrow::SQLAllocHandle(SQL_HANDLE_STMT, conn, stmt); +} + +SQLRETURN SQL_API SQLFreeHandle(SQLSMALLINT type, SQLHANDLE handle) { + return arrow::SQLFreeHandle(type, handle); +} + +SQLRETURN SQL_API SQLFreeEnv(SQLHENV env) { + return arrow::SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +SQLRETURN SQL_API SQLFreeConnect(SQLHDBC conn) { + return arrow::SQLFreeHandle(SQL_HANDLE_DBC, conn); +} + +SQLRETURN SQL_API SQLFreeStmt(SQLHSTMT stmt, SQLUSMALLINT option) { + return arrow::SQLFreeStmt(stmt, option); +} + +SQLRETURN SQL_API SQLGetDiagField(SQLSMALLINT handleType, SQLHANDLE handle, + SQLSMALLINT recNumber, SQLSMALLINT diagIdentifier, + SQLPOINTER diagInfoPtr, SQLSMALLINT bufferLength, + SQLSMALLINT* stringLengthPtr) { + return arrow::SQLGetDiagField(handleType, handle, recNumber, diagIdentifier, + diagInfoPtr, bufferLength, stringLengthPtr); +} + +SQLRETURN SQL_API SQLGetDiagRec(SQLSMALLINT handleType, SQLHANDLE handle, + SQLSMALLINT recNumber, SQLWCHAR* sqlState, + SQLINTEGER* nativeErrorPtr, SQLWCHAR* messageText, + SQLSMALLINT bufferLength, SQLSMALLINT* textLengthPtr) { + return arrow::SQLGetDiagRec(handleType, handle, recNumber, sqlState, nativeErrorPtr, + messageText, bufferLength, textLengthPtr); +} + +SQLRETURN SQL_API SQLGetEnvAttr(SQLHENV env, SQLINTEGER attr, SQLPOINTER valuePtr, + SQLINTEGER bufferLen, SQLINTEGER* strLenPtr) { + return arrow::SQLGetEnvAttr(env, attr, valuePtr, bufferLen, strLenPtr); +} + +SQLRETURN SQL_API SQLSetEnvAttr(SQLHENV env, SQLINTEGER attr, SQLPOINTER valuePtr, + SQLINTEGER strLen) { + return arrow::SQLSetEnvAttr(env, attr, valuePtr, strLen); +} + +SQLRETURN SQL_API SQLGetConnectAttr(SQLHDBC conn, SQLINTEGER attribute, + SQLPOINTER valuePtr, SQLINTEGER bufferLength, + SQLINTEGER* stringLengthPtr) { + return arrow::SQLGetConnectAttr(conn, attribute, valuePtr, bufferLength, + stringLengthPtr); +} + +SQLRETURN SQL_API SQLSetConnectAttr(SQLHDBC conn, SQLINTEGER attr, SQLPOINTER value, + SQLINTEGER valueLen) { + return arrow::SQLSetConnectAttr(conn, attr, value, valueLen); +} + +SQLRETURN SQL_API SQLGetInfo(SQLHDBC conn, SQLUSMALLINT infoType, SQLPOINTER infoValuePtr, + SQLSMALLINT bufLen, SQLSMALLINT* length) { + return arrow::SQLGetInfo(conn, infoType, infoValuePtr, bufLen, length); +} + +SQLRETURN SQL_API SQLDriverConnect(SQLHDBC conn, SQLHWND windowHandle, + SQLWCHAR* inConnectionString, + SQLSMALLINT inConnectionStringLen, + SQLWCHAR* outConnectionString, + SQLSMALLINT outConnectionStringBufferLen, + SQLSMALLINT* outConnectionStringLen, + SQLUSMALLINT driverCompletion) { + return arrow::SQLDriverConnect( + conn, windowHandle, inConnectionString, inConnectionStringLen, outConnectionString, + outConnectionStringBufferLen, outConnectionStringLen, driverCompletion); +} + +SQLRETURN SQL_API SQLConnect(SQLHDBC conn, SQLWCHAR* dsnName, SQLSMALLINT dsnNameLen, + SQLWCHAR* userName, SQLSMALLINT userNameLen, + SQLWCHAR* password, SQLSMALLINT passwordLen) { + return arrow::SQLConnect(conn, dsnName, dsnNameLen, userName, userNameLen, password, + passwordLen); +} + +SQLRETURN SQL_API SQLDisconnect(SQLHDBC conn) { return arrow::SQLDisconnect(conn); } + +SQLRETURN SQL_API SQLGetStmtAttr(SQLHSTMT stmt, SQLINTEGER attribute, SQLPOINTER valuePtr, + SQLINTEGER bufferLength, SQLINTEGER* stringLengthPtr) { + return arrow::SQLGetStmtAttr(stmt, attribute, valuePtr, bufferLength, stringLengthPtr); +} + +SQLRETURN SQL_API SQLExecDirect(SQLHSTMT stmt, SQLWCHAR* queryText, + SQLINTEGER textLength) { + return arrow::SQLExecDirect(stmt, queryText, textLength); +} + +SQLRETURN SQL_API SQLFetch(SQLHSTMT stmt) { return arrow::SQLFetch(stmt); } + +SQLRETURN SQL_API SQLExtendedFetch(SQLHSTMT stmt, SQLUSMALLINT fetchOrientation, + SQLLEN fetchOffset, SQLULEN* rowCountPtr, + SQLUSMALLINT* rowStatusArray) { + return arrow::SQLExtendedFetch(stmt, fetchOrientation, fetchOffset, rowCountPtr, + rowStatusArray); +} + +SQLRETURN SQL_API SQLFetchScroll(SQLHSTMT stmt, SQLSMALLINT fetchOrientation, + SQLLEN fetchOffset) { + return arrow::SQLFetchScroll(stmt, fetchOrientation, fetchOffset); +} + +SQLRETURN SQL_API SQLGetData(SQLHSTMT stmt, SQLUSMALLINT recordNumber, SQLSMALLINT cType, + SQLPOINTER dataPtr, SQLLEN bufferLength, + SQLLEN* indicatorPtr) { + return arrow::SQLGetData(stmt, recordNumber, cType, dataPtr, bufferLength, + indicatorPtr); +} + +SQLRETURN SQL_API SQLPrepare(SQLHSTMT stmt, SQLWCHAR* queryText, SQLINTEGER textLength) { + return arrow::SQLPrepare(stmt, queryText, textLength); +} + +SQLRETURN SQL_API SQLExecute(SQLHSTMT stmt) { return arrow::SQLExecute(stmt); } + +SQLRETURN SQL_API SQLBindCol(SQLHSTMT stmt, SQLUSMALLINT recordNumber, SQLSMALLINT cType, + SQLPOINTER dataPtr, SQLLEN bufferLength, + SQLLEN* indicatorPtr) { + return arrow::SQLBindCol(stmt, recordNumber, cType, dataPtr, bufferLength, + indicatorPtr); +} + +SQLRETURN SQL_API SQLCancel(SQLHSTMT stmt) { + LOG_DEBUG("SQLCancel called with stmt: {}", stmt); + return ODBC::ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() { + throw driver::odbcabstraction::DriverException("SQLCancel is not implemented", + "IM001"); + return SQL_ERROR; + }); +} + +SQLRETURN SQL_API SQLCloseCursor(SQLHSTMT stmt) { return arrow::SQLCloseCursor(stmt); } + +SQLRETURN SQL_API SQLColAttribute(SQLHSTMT stmt, SQLUSMALLINT recordNumber, + SQLUSMALLINT fieldIdentifier, + SQLPOINTER characterAttributePtr, + SQLSMALLINT bufferLength, SQLSMALLINT* outputLength, + SQLLEN* numericAttributePtr) { + return arrow::SQLColAttribute(stmt, recordNumber, fieldIdentifier, + characterAttributePtr, bufferLength, outputLength, + numericAttributePtr); +} + +SQLRETURN SQL_API SQLTables(SQLHSTMT stmt, SQLWCHAR* catalogName, + SQLSMALLINT catalogNameLength, SQLWCHAR* schemaName, + SQLSMALLINT schemaNameLength, SQLWCHAR* tableName, + SQLSMALLINT tableNameLength, SQLWCHAR* tableType, + SQLSMALLINT tableTypeLength) { + return arrow::SQLTables(stmt, catalogName, catalogNameLength, schemaName, + schemaNameLength, tableName, tableNameLength, tableType, + tableTypeLength); +} + +SQLRETURN SQL_API SQLColumns(SQLHSTMT stmt, SQLWCHAR* catalogName, + SQLSMALLINT catalogNameLength, SQLWCHAR* schemaName, + SQLSMALLINT schemaNameLength, SQLWCHAR* tableName, + SQLSMALLINT tableNameLength, SQLWCHAR* columnName, + SQLSMALLINT columnNameLength) { + return arrow::SQLColumns(stmt, catalogName, catalogNameLength, schemaName, + schemaNameLength, tableName, tableNameLength, columnName, + columnNameLength); +} + +SQLRETURN SQL_API SQLForeignKeys(SQLHSTMT stmt, SQLWCHAR* pKCatalogName, + SQLSMALLINT pKCatalogNameLength, SQLWCHAR* pKSchemaName, + SQLSMALLINT pKSchemaNameLength, SQLWCHAR* pKTableName, + SQLSMALLINT pKTableNameLength, SQLWCHAR* fKCatalogName, + SQLSMALLINT fKCatalogNameLength, SQLWCHAR* fKSchemaName, + SQLSMALLINT fKSchemaNameLength, SQLWCHAR* fKTableName, + SQLSMALLINT fKTableNameLength) { + LOG_DEBUG( + "SQLForeignKeysW called with stmt: {}, pKCatalogName: {}, " + "pKCatalogNameLength: " + "{}, pKSchemaName: {}, pKSchemaNameLength: {}, pKTableName: {}, pKTableNameLength: " + "{}, " + "fKCatalogName: {}, fKCatalogNameLength: {}, fKSchemaName: {}, fKSchemaNameLength: " + "{}, " + "fKTableName: {}, fKTableNameLength : {}", + stmt, fmt::ptr(pKCatalogName), pKCatalogNameLength, fmt::ptr(pKSchemaName), + pKSchemaNameLength, fmt::ptr(pKTableName), pKTableNameLength, + fmt::ptr(fKCatalogName), fKCatalogNameLength, fmt::ptr(fKSchemaName), + fKSchemaNameLength, fmt::ptr(fKTableName), fKTableNameLength); + return ODBC::ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() { + throw driver::odbcabstraction::DriverException("SQLForeignKeysW is not implemented", + "IM001"); + return SQL_ERROR; + }); +} + +SQLRETURN SQL_API SQLGetTypeInfo(SQLHSTMT stmt, SQLSMALLINT dataType) { + return arrow::SQLGetTypeInfo(stmt, dataType); +} + +SQLRETURN SQL_API SQLMoreResults(SQLHSTMT stmt) { return arrow::SQLMoreResults(stmt); } + +SQLRETURN SQL_API SQLNativeSql(SQLHDBC connectionHandle, SQLWCHAR* inStatementText, + SQLINTEGER inStatementTextLength, + SQLWCHAR* outStatementText, SQLINTEGER bufferLength, + SQLINTEGER* outStatementTextLength) { + return arrow::SQLNativeSql(connectionHandle, inStatementText, inStatementTextLength, + outStatementText, bufferLength, outStatementTextLength); +} + +SQLRETURN SQL_API SQLNumResultCols(SQLHSTMT stmt, SQLSMALLINT* columnCountPtr) { + return arrow::SQLNumResultCols(stmt, columnCountPtr); +} + +SQLRETURN SQL_API SQLRowCount(SQLHSTMT stmt, SQLLEN* rowCountPtr) { + return arrow::SQLRowCount(stmt, rowCountPtr); +} + +SQLRETURN SQL_API SQLPrimaryKeys(SQLHSTMT stmt, SQLWCHAR* catalogName, + SQLSMALLINT catalogNameLength, SQLWCHAR* schemaName, + SQLSMALLINT schemaNameLength, SQLWCHAR* tableName, + SQLSMALLINT tableNameLength) { + LOG_DEBUG( + "SQLPrimaryKeysW called with stmt: {}, catalogName: {}, " + "catalogNameLength: " + "{}, schemaName: {}, schemaNameLength: {}, tableName: {}, tableNameLength: {}", + stmt, fmt::ptr(catalogName), catalogNameLength, fmt::ptr(schemaName), + schemaNameLength, fmt::ptr(tableName), tableNameLength); + return ODBC::ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() { + throw driver::odbcabstraction::DriverException("SQLPrimaryKeysW is not implemented", + "IM001"); + return SQL_ERROR; + }); +} + +SQLRETURN SQL_API SQLSetStmtAttr(SQLHSTMT stmt, SQLINTEGER attribute, SQLPOINTER valuePtr, + SQLINTEGER stringLength) { + return arrow::SQLSetStmtAttr(stmt, attribute, valuePtr, stringLength); +} + +SQLRETURN SQL_API SQLDescribeCol(SQLHSTMT statementHandle, SQLUSMALLINT columnNumber, + SQLWCHAR* columnName, SQLSMALLINT bufferLength, + SQLSMALLINT* nameLengthPtr, SQLSMALLINT* dataTypePtr, + SQLULEN* columnSizePtr, SQLSMALLINT* decimalDigitsPtr, + SQLSMALLINT* nullablePtr) { + return arrow::SQLDescribeCol(statementHandle, columnNumber, columnName, bufferLength, + nameLengthPtr, dataTypePtr, columnSizePtr, + decimalDigitsPtr, nullablePtr); +} diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/CMakeLists.txt b/cpp/src/arrow/flight/sql/odbc/flight_sql/CMakeLists.txt index 56aabb54dbf..bd876804279 100644 --- a/cpp/src/arrow/flight/sql/odbc/flight_sql/CMakeLists.txt +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/CMakeLists.txt @@ -76,6 +76,8 @@ add_library(arrow_odbc_spi_impl scalar_function_reporter.h system_trust_store.cc system_trust_store.h + system_dsn.cc + system_dsn.h utils.cc) target_include_directories(arrow_odbc_spi_impl PUBLIC include include/flight_sql @@ -96,13 +98,15 @@ if(WIN32) ui/window.cc ui/dsn_configuration_window.cc ui/add_property_window.cc - system_dsn.cc) + win_system_dsn.cc) endif() -target_link_libraries(arrow_odbc_spi_impl PUBLIC odbcabstraction arrow_flight_sql_shared) +target_link_libraries(arrow_odbc_spi_impl PUBLIC odbcabstraction arrow_flight_sql_shared + arrow_compute_shared Boost::locale) -if(MSVC) - target_link_libraries(arrow_odbc_spi_impl PUBLIC Boost::locale) +# Link libraries on MINGW64 only +if(MINGW AND CMAKE_CXX_COMPILER_ID STREQUAL "GNU") + target_link_libraries(arrow_odbc_spi_impl PUBLIC ${ODBCINST}) endif() set_target_properties(arrow_odbc_spi_impl @@ -132,9 +136,11 @@ add_arrow_test(arrow_odbc_spi_impl_test accessors/time_array_accessor_test.cc accessors/timestamp_array_accessor_test.cc flight_sql_connection_test.cc + flight_sql_stream_chunk_buffer_test.cc parse_table_types_test.cc json_converter_test.cc record_batch_transformer_test.cc utils_test.cc EXTRA_LINK_LIBS - arrow_odbc_spi_impl) + arrow_odbc_spi_impl + arrow_flight_testing_shared) diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/primitive_array_accessor_test.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/primitive_array_accessor_test.cc index 820c0a7bd84..abf18fa9ce8 100644 --- a/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/primitive_array_accessor_test.cc +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/primitive_array_accessor_test.cc @@ -16,7 +16,7 @@ // under the License. #include "arrow/flight/sql/odbc/flight_sql/accessors/primitive_array_accessor.h" -#include +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/diagnostics.h" #include "arrow/testing/builder.h" #include "gtest/gtest.h" diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/string_array_accessor_test.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/string_array_accessor_test.cc index 8b568bbffcf..587e7d5eb1c 100644 --- a/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/string_array_accessor_test.cc +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/string_array_accessor_test.cc @@ -134,7 +134,6 @@ TEST(StringArrayAccessor, Test_CDataType_WCHAR_Truncation) { ColumnBinding binding(odbcabstraction::CDataType_WCHAR, 0, 0, buffer.data(), max_strlen, strlen_buffer.data()); - std::basic_stringstream ss; int64_t value_offset = 0; // Construct the whole string by concatenating smaller chunks from diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/timestamp_array_accessor.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/timestamp_array_accessor.cc index 68e9f64fffb..c0225edea01 100644 --- a/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/timestamp_array_accessor.cc +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/accessors/timestamp_array_accessor.cc @@ -18,10 +18,13 @@ #include "arrow/flight/sql/odbc/flight_sql/accessors/timestamp_array_accessor.h" #include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/calendar_utils.h" +#include +#include + using arrow::TimeUnit; namespace { -int64_t GetConversionToSecondsDivisor(TimeUnit::type unit) { +inline int64_t GetConversionToSecondsDivisor(TimeUnit::type unit) { int64_t divisor = 1; switch (unit) { case TimeUnit::SECOND: @@ -85,6 +88,10 @@ RowStatus TimestampArrayFlightSqlAccessor::MoveSingleCell_imp ColumnBinding* binding, int64_t arrow_row, int64_t cell_counter, int64_t& value_offset, bool update_value_offset, odbcabstraction::Diagnostics& diagnostics) { + // Times less than the minimum integer number of seconds that can be represented + // for each time unit will not convert correctly. This is mostly interesting for + // nanoseconds as timestamps in other units are outside of the accepted range of + // Gregorian dates. auto* buffer = static_cast(binding->buffer); int64_t value = this->GetArray()->Value(arrow_row); diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/address_info.h b/cpp/src/arrow/flight/sql/odbc/flight_sql/address_info.h index 312d5689a98..91f5a7175d7 100644 --- a/cpp/src/arrow/flight/sql/odbc/flight_sql/address_info.h +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/address_info.h @@ -19,8 +19,10 @@ #include -#include +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/platform.h" + #include +#include #if !_WIN32 # include #endif diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/config/configuration.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/config/configuration.cc index be92be057da..db18239c47a 100644 --- a/cpp/src/arrow/flight/sql/odbc/flight_sql/config/configuration.cc +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/config/configuration.cc @@ -17,6 +17,8 @@ #include "arrow/flight/sql/odbc/flight_sql/include/flight_sql/config/configuration.h" #include "arrow/flight/sql/odbc/flight_sql/flight_sql_connection.h" +#include "arrow/result.h" +#include "arrow/util/utf8.h" #include #include @@ -27,7 +29,6 @@ namespace driver { namespace flight_sql { namespace config { - static const char DEFAULT_DSN[] = "Apache Arrow Flight SQL"; static const char DEFAULT_ENABLE_ENCRYPTION[] = TRUE_STR; static const char DEFAULT_USE_CERT_STORE[] = TRUE_STR; @@ -36,23 +37,27 @@ static const char DEFAULT_DISABLE_CERT_VERIFICATION[] = FALSE_STR; namespace { std::string ReadDsnString(const std::string& dsn, const std::string_view& key, const std::string& dflt = "") { -#define BUFFER_SIZE (1024) - std::vector buf(BUFFER_SIZE); + std::wstring wDsn = arrow::util::UTF8ToWideString(dsn).ValueOr(L""); + std::wstring wKey = arrow::util::UTF8ToWideString(key).ValueOr(L""); + std::wstring wDflt = arrow::util::UTF8ToWideString(dflt).ValueOr(L""); - std::string key_str = std::string(key); +#define BUFFER_SIZE (1024) + std::vector buf(BUFFER_SIZE); int ret = - SQLGetPrivateProfileString(dsn.c_str(), key_str.c_str(), dflt.c_str(), buf.data(), - static_cast(buf.size()), "ODBC.INI"); + SQLGetPrivateProfileString(wDsn.c_str(), wKey.c_str(), wDflt.c_str(), buf.data(), + static_cast(buf.size()), L"ODBC.INI"); if (ret > BUFFER_SIZE) { // If there wasn't enough space, try again with the right size buffer. buf.resize(ret + 1); ret = - SQLGetPrivateProfileString(dsn.c_str(), key_str.c_str(), dflt.c_str(), buf.data(), - static_cast(buf.size()), "ODBC.INI"); + SQLGetPrivateProfileString(wDsn.c_str(), wKey.c_str(), wDflt.c_str(), buf.data(), + static_cast(buf.size()), L"ODBC.INI"); } - return std::string(buf.data(), ret); + std::wstring wResult = std::wstring(buf.data(), ret); + std::string result = arrow::util::WideStringToUTF8(wResult).ValueOr(""); + return result; } void RemoveAllKnownKeys(std::vector& keys) { @@ -69,28 +74,32 @@ void RemoveAllKnownKeys(std::vector& keys) { } std::vector ReadAllKeys(const std::string& dsn) { - std::vector buf(BUFFER_SIZE); + std::wstring wDsn = arrow::util::UTF8ToWideString(dsn).ValueOr(L""); + + std::vector buf(BUFFER_SIZE); - int ret = SQLGetPrivateProfileString(dsn.c_str(), NULL, "", buf.data(), - static_cast(buf.size()), "ODBC.INI"); + int ret = SQLGetPrivateProfileString(wDsn.c_str(), NULL, L"", buf.data(), + static_cast(buf.size()), L"ODBC.INI"); if (ret > BUFFER_SIZE) { // If there wasn't enough space, try again with the right size buffer. buf.resize(ret + 1); - ret = SQLGetPrivateProfileString(dsn.c_str(), NULL, "", buf.data(), - static_cast(buf.size()), "ODBC.INI"); + ret = SQLGetPrivateProfileString(wDsn.c_str(), NULL, L"", buf.data(), + static_cast(buf.size()), L"ODBC.INI"); } // When you pass NULL to SQLGetPrivateProfileString it gives back a \0 delimited list of // all the keys. The below loop simply tokenizes all the keys and places them into a // vector. std::vector keys; - char* begin = buf.data(); + wchar_t* begin = buf.data(); while (begin && *begin != '\0') { - char* cur; + wchar_t* cur; for (cur = begin; *cur != '\0'; ++cur) { } - keys.emplace_back(begin, cur); + + std::string key = arrow::util::WideStringToUTF8(std::wstring(begin, cur)).ValueOr(""); + keys.emplace_back(key); begin = ++cur; } return keys; @@ -142,11 +151,11 @@ void Configuration::LoadDsn(const std::string& dsn) { void Configuration::Clear() { this->properties.clear(); } bool Configuration::IsSet(const std::string_view& key) const { - return 0 != this->properties.count(key); + return 0 != this->properties.count(std::string(key)); } const std::string& Configuration::Get(const std::string_view& key) const { - const auto itr = this->properties.find(key); + const auto itr = this->properties.find(std::string(key)); if (itr == this->properties.cend()) { static const std::string empty(""); return empty; @@ -154,10 +163,23 @@ const std::string& Configuration::Get(const std::string_view& key) const { return itr->second; } +void Configuration::Set(const std::string_view& key, const std::wstring& wValue) { + std::string value = arrow::util::WideStringToUTF8(wValue).ValueOr(""); + Set(key, value); +} + void Configuration::Set(const std::string_view& key, const std::string& value) { const std::string copy = boost::trim_copy(value); if (!copy.empty()) { - this->properties[key] = value; + this->properties[std::string(key)] = value; + } +} + +void Configuration::Emplace(const std::string_view& key, std::string&& value) { + const std::string copy = boost::trim_copy(value); + if (!copy.empty()) { + this->properties.emplace( + std::make_pair(std::move(std::string(key)), std::move(value))); } } @@ -166,16 +188,15 @@ const driver::odbcabstraction::Connection::ConnPropertyMap& Configuration::GetPr return this->properties; } -std::vector Configuration::GetCustomKeys() const { +std::vector Configuration::GetCustomKeys() const { driver::odbcabstraction::Connection::ConnPropertyMap copyProps(properties); for (auto& key : FlightSqlConnection::ALL_KEYS) { - copyProps.erase(key); + copyProps.erase(std::string(key)); } - std::vector keys; + std::vector keys; boost::copy(copyProps | boost::adaptors::map_keys, std::back_inserter(keys)); return keys; } - } // namespace config } // namespace flight_sql } // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_auth_method.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_auth_method.cc index fcf951270e6..1b36f4916fe 100644 --- a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_auth_method.cc +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_auth_method.cc @@ -45,6 +45,10 @@ class NoOpAuthMethod : public FlightSqlAuthMethod { void Authenticate(FlightSqlConnection& connection, FlightCallOptions& call_options) override { // Do nothing + + // TODO: implement NoOpAuthMethod to validate server address. + // Can use NoOpClientAuthHandler. + // https://github.com/apache/arrow/issues/46733 } }; @@ -54,8 +58,8 @@ class NoOpClientAuthHandler : public arrow::flight::ClientAuthHandler { arrow::Status Authenticate(arrow::flight::ClientAuthSender* outgoing, arrow::flight::ClientAuthReader* incoming) override { - // Write a blank string. The server should ignore this and just accept any Handshake - // request. + // The server should ignore this and just accept any Handshake + // request. Some servers do not allow authentication with no handshakes. return outgoing->Write(std::string()); } @@ -103,7 +107,9 @@ class UserPasswordAuthMethod : public FlightSqlAuthMethod { throw odbcabstraction::DriverException(bearer_result.status().message()); } - call_options.headers.push_back(bearer_result.ValueOrDie()); + // call_options may have already been populated with data from the connection string + // or DSN. Ensure auth-generated headers are placed at the front of the header list. + call_options.headers.insert(call_options.headers.begin(), bearer_result.ValueOrDie()); } std::string GetUser() override { return user_; } @@ -125,10 +131,11 @@ class TokenAuthMethod : public FlightSqlAuthMethod { void Authenticate(FlightSqlConnection& connection, FlightCallOptions& call_options) override { - // add the token to the headers + // add the token to the front of the headers. For consistency auth headers should be + // at the front. const std::pair token_header("authorization", "Bearer " + token_); - call_options.headers.push_back(token_header); + call_options.headers.insert(call_options.headers.begin(), token_header); const arrow::Status status = client_.Authenticate( call_options, @@ -153,22 +160,22 @@ std::unique_ptr FlightSqlAuthMethod::FromProperties( const std::unique_ptr& client, const Connection::ConnPropertyMap& properties) { // Check if should use user-password authentication - auto it_user = properties.find(FlightSqlConnection::USER); + auto it_user = properties.find(std::string(FlightSqlConnection::USER)); if (it_user == properties.end()) { // The Microsoft OLE DB to ODBC bridge provider (MSDASQL) will write // "User ID" and "Password" properties instead of mapping // to ODBC compliant UID/PWD keys. - it_user = properties.find(FlightSqlConnection::USER_ID); + it_user = properties.find(std::string(FlightSqlConnection::USER_ID)); } - auto it_password = properties.find(FlightSqlConnection::PASSWORD); - auto it_token = properties.find(FlightSqlConnection::TOKEN); + auto it_password = properties.find(std::string(FlightSqlConnection::PASSWORD)); + auto it_token = properties.find(std::string(FlightSqlConnection::TOKEN)); if (it_user == properties.end() || it_password == properties.end()) { // Accept UID/PWD as aliases for User/Password. These are suggested as // standard properties in the documentation for SQLDriverConnect. - it_user = properties.find(FlightSqlConnection::UID); - it_password = properties.find(FlightSqlConnection::PWD); + it_user = properties.find(std::string(FlightSqlConnection::UID)); + it_password = properties.find(std::string(FlightSqlConnection::PWD)); } if (it_user != properties.end() || it_password != properties.end()) { const std::string& user = it_user != properties.end() ? it_user->second : ""; diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_connection.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_connection.cc index 09764e5c18b..c87c394fc31 100644 --- a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_connection.cc +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_connection.cc @@ -83,7 +83,7 @@ namespace { #if _WIN32 || _WIN64 constexpr auto SYSTEM_TRUST_STORE_DEFAULT = true; -constexpr auto STORES = {"CA", "MY", "ROOT", "SPC"}; +constexpr auto STORES = {L"CA", L"MY", L"ROOT", L"SPC"}; inline std::string GetCerts() { std::string certs; @@ -111,26 +111,28 @@ inline std::string GetCerts() { return ""; } #endif -const std::set - BUILT_IN_PROPERTIES = {FlightSqlConnection::HOST, - FlightSqlConnection::PORT, - FlightSqlConnection::USER, - FlightSqlConnection::USER_ID, - FlightSqlConnection::UID, - FlightSqlConnection::PASSWORD, - FlightSqlConnection::PWD, - FlightSqlConnection::TOKEN, - FlightSqlConnection::USE_ENCRYPTION, - FlightSqlConnection::DISABLE_CERTIFICATE_VERIFICATION, - FlightSqlConnection::TRUSTED_CERTS, - FlightSqlConnection::USE_SYSTEM_TRUST_STORE, - FlightSqlConnection::STRING_COLUMN_LENGTH, - FlightSqlConnection::USE_WIDE_CHAR}; +const std::set BUILT_IN_PROPERTIES = { + FlightSqlConnection::DRIVER, + FlightSqlConnection::DSN, + FlightSqlConnection::HOST, + FlightSqlConnection::PORT, + FlightSqlConnection::USER, + FlightSqlConnection::USER_ID, + FlightSqlConnection::UID, + FlightSqlConnection::PASSWORD, + FlightSqlConnection::PWD, + FlightSqlConnection::TOKEN, + FlightSqlConnection::USE_ENCRYPTION, + FlightSqlConnection::DISABLE_CERTIFICATE_VERIFICATION, + FlightSqlConnection::TRUSTED_CERTS, + FlightSqlConnection::USE_SYSTEM_TRUST_STORE, + FlightSqlConnection::STRING_COLUMN_LENGTH, + FlightSqlConnection::USE_WIDE_CHAR}; Connection::ConnPropertyMap::const_iterator TrackMissingRequiredProperty( const std::string_view& property, const Connection::ConnPropertyMap& properties, std::vector& missing_attr) { - auto prop_iter = properties.find(property); + auto prop_iter = properties.find(std::string(property)); if (properties.end() == prop_iter) { missing_attr.push_back(property); } @@ -149,7 +151,8 @@ std::shared_ptr LoadFlightSslConfigs( AsBool(connPropertyMap, FlightSqlConnection::USE_SYSTEM_TRUST_STORE) .value_or(SYSTEM_TRUST_STORE_DEFAULT); - auto trusted_certs_iterator = connPropertyMap.find(FlightSqlConnection::TRUSTED_CERTS); + auto trusted_certs_iterator = + connPropertyMap.find(std::string(FlightSqlConnection::TRUSTED_CERTS)); auto trusted_certs = trusted_certs_iterator != connPropertyMap.end() ? trusted_certs_iterator->second : ""; @@ -164,15 +167,18 @@ void FlightSqlConnection::Connect(const ConnPropertyMap& properties, auto flight_ssl_configs = LoadFlightSslConfigs(properties); Location location = BuildLocation(properties, missing_attr, flight_ssl_configs); - FlightClientOptions client_options = + client_options_ = BuildFlightClientOptions(properties, missing_attr, flight_ssl_configs); const std::shared_ptr& cookie_factory = arrow::flight::GetCookieFactory(); - client_options.middleware.push_back(cookie_factory); + client_options_.middleware.push_back(cookie_factory); std::unique_ptr flight_client; - ThrowIfNotOK(FlightClient::Connect(location, client_options).Value(&flight_client)); + ThrowIfNotOK(FlightClient::Connect(location, client_options_).Value(&flight_client)); + + PopulateMetadataSettings(properties); + PopulateCallOptions(properties); std::unique_ptr auth_method = FlightSqlAuthMethod::FromProperties(flight_client, properties); @@ -187,9 +193,6 @@ void FlightSqlConnection::Connect(const ConnPropertyMap& properties, info_.SetProperty(SQL_USER_NAME, auth_method->GetUser()); attribute_[CONNECTION_DEAD] = static_cast(SQL_FALSE); - - PopulateMetadataSettings(properties); - PopulateCallOptions(properties); } catch (...) { attribute_[CONNECTION_DEAD] = static_cast(SQL_TRUE); sql_client_.reset(); @@ -376,7 +379,7 @@ void FlightSqlConnection::Close() { std::shared_ptr FlightSqlConnection::CreateStatement() { return std::shared_ptr(new FlightSqlStatement( - diagnostics_, *sql_client_, call_options_, metadata_settings_)); + diagnostics_, *sql_client_, client_options_, call_options_, metadata_settings_)); } bool FlightSqlConnection::SetAttribute(Connection::AttributeId attribute, @@ -422,7 +425,7 @@ FlightSqlConnection::FlightSqlConnection(OdbcVersion odbc_version, const std::string& driver_version) : diagnostics_("Apache Arrow", "Flight SQL", odbc_version), odbc_version_(odbc_version), - info_(call_options_, sql_client_, driver_version), + info_(client_options_, call_options_, sql_client_, driver_version), closed_(true) { attribute_[CONNECTION_DEAD] = static_cast(SQL_TRUE); attribute_[LOGIN_TIMEOUT] = static_cast(0); diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_connection.h b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_connection.h index 0ee6d5d5391..0a4b213229f 100644 --- a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_connection.h +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_connection.h @@ -29,6 +29,13 @@ namespace driver { namespace flight_sql { +/// \brief Case insensitive comparator that takes string_view +struct CaseInsensitiveComparatorStrView { + bool operator()(const std::string_view& s1, const std::string_view& s2) const { + return boost::lexicographical_compare(s1, s2, boost::is_iless()); + } +}; + class FlightSqlSslConfig; /// \brief Create an instance of the FlightSqlSslConfig class, from the properties passed diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_connection_test.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_connection_test.cc index 6a519138b63..a7a0fc10c29 100644 --- a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_connection_test.cc +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_connection_test.cc @@ -69,10 +69,12 @@ TEST(MetadataSettingsTest, StringColumnLengthTest) { const int32_t expected_string_column_length = 100000; const Connection::ConnPropertyMap properties = { - {FlightSqlConnection::HOST, std::string("localhost")}, // expect not used - {FlightSqlConnection::PORT, std::string("32010")}, // expect not used - {FlightSqlConnection::USE_ENCRYPTION, std::string("false")}, // expect not used - {FlightSqlConnection::STRING_COLUMN_LENGTH, + {std::string(FlightSqlConnection::HOST), + std::string("localhost")}, // expect not used + {std::string(FlightSqlConnection::PORT), std::string("32010")}, // expect not used + {std::string(FlightSqlConnection::USE_ENCRYPTION), + std::string("false")}, // expect not used + {std::string(FlightSqlConnection::STRING_COLUMN_LENGTH), std::to_string(expected_string_column_length)}, }; @@ -90,10 +92,10 @@ TEST(MetadataSettingsTest, UseWideCharTest) { connection.SetClosed(false); const Connection::ConnPropertyMap properties1 = { - {FlightSqlConnection::USE_WIDE_CHAR, std::string("true")}, + {std::string(FlightSqlConnection::USE_WIDE_CHAR), std::string("true")}, }; const Connection::ConnPropertyMap properties2 = { - {FlightSqlConnection::USE_WIDE_CHAR, std::string("false")}, + {std::string(FlightSqlConnection::USE_WIDE_CHAR), std::string("false")}, }; EXPECT_EQ(true, connection.GetUseWideChar(properties1)); @@ -105,9 +107,9 @@ TEST(MetadataSettingsTest, UseWideCharTest) { TEST(BuildLocationTests, ForTcp) { std::vector missing_attr; Connection::ConnPropertyMap properties = { - {FlightSqlConnection::HOST, std::string("localhost")}, - {FlightSqlConnection::PORT, std::string("32010")}, - {FlightSqlConnection::USE_ENCRYPTION, std::string("false")}, + {std::string(FlightSqlConnection::HOST), std::string("localhost")}, + {std::string(FlightSqlConnection::PORT), std::string("32010")}, + {std::string(FlightSqlConnection::USE_ENCRYPTION), std::string("false")}, }; const std::shared_ptr& ssl_config = @@ -117,8 +119,8 @@ TEST(BuildLocationTests, ForTcp) { FlightSqlConnection::BuildLocation(properties, missing_attr, ssl_config); const Location& actual_location2 = FlightSqlConnection::BuildLocation( { - {FlightSqlConnection::HOST, std::string("localhost")}, - {FlightSqlConnection::PORT, std::string("32011")}, + {std::string(FlightSqlConnection::HOST), std::string("localhost")}, + {std::string(FlightSqlConnection::PORT), std::string("32011")}, }, missing_attr, ssl_config); @@ -131,9 +133,9 @@ TEST(BuildLocationTests, ForTcp) { TEST(BuildLocationTests, ForTls) { std::vector missing_attr; Connection::ConnPropertyMap properties = { - {FlightSqlConnection::HOST, std::string("localhost")}, - {FlightSqlConnection::PORT, std::string("32010")}, - {FlightSqlConnection::USE_ENCRYPTION, std::string("1")}, + {std::string(FlightSqlConnection::HOST), std::string("localhost")}, + {std::string(FlightSqlConnection::PORT), std::string("32010")}, + {std::string(FlightSqlConnection::USE_ENCRYPTION), std::string("1")}, }; const std::shared_ptr& ssl_config = @@ -143,9 +145,9 @@ TEST(BuildLocationTests, ForTls) { FlightSqlConnection::BuildLocation(properties, missing_attr, ssl_config); Connection::ConnPropertyMap second_properties = { - {FlightSqlConnection::HOST, std::string("localhost")}, - {FlightSqlConnection::PORT, std::string("32011")}, - {FlightSqlConnection::USE_ENCRYPTION, std::string("1")}, + {std::string(FlightSqlConnection::HOST), std::string("localhost")}, + {std::string(FlightSqlConnection::PORT), std::string("32011")}, + {std::string(FlightSqlConnection::USE_ENCRYPTION), std::string("1")}, }; const std::shared_ptr& second_ssl_config = diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_driver.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_driver.cc index 1949d2f15ad..0736dac8486 100644 --- a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_driver.cc +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_driver.cc @@ -16,13 +16,17 @@ // under the License. #include "arrow/flight/sql/odbc/flight_sql/include/flight_sql/flight_sql_driver.h" +#include "arrow/compute/api.h" #include "arrow/flight/sql/odbc/flight_sql/flight_sql_connection.h" +#include "arrow/flight/sql/odbc/flight_sql/utils.h" #include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/platform.h" #include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/spd_logger.h" #include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/utils.h" +#include "arrow/util/io_util.h" #define DEFAULT_MAXIMUM_FILE_SIZE 16777216 #define CONFIG_FILE_NAME "arrow-odbc.ini" +#define CONFIG_FILE_PATH "CONFIG_FILE_PATH" namespace driver { namespace flight_sql { @@ -52,7 +56,11 @@ LogLevel ToLogLevel(int64_t level) { } // namespace FlightSqlDriver::FlightSqlDriver() - : diagnostics_("Apache Arrow", "Flight SQL", OdbcVersion::V_3), version_("0.9.0.0") {} + : diagnostics_("Apache Arrow", "Flight SQL", OdbcVersion::V_3), version_("0.9.0.0") { + RegisterLog(); + // Register Kernel functions to library + ThrowIfNotOK(arrow::compute::Initialize()); +} std::shared_ptr FlightSqlDriver::CreateConnection(OdbcVersion odbc_version) { return std::make_shared(odbc_version, version_); @@ -63,24 +71,29 @@ odbcabstraction::Diagnostics& FlightSqlDriver::GetDiagnostics() { return diagnos void FlightSqlDriver::SetVersion(std::string version) { version_ = std::move(version); } void FlightSqlDriver::RegisterLog() { + std::string config_path = arrow::internal::GetEnvVar(CONFIG_FILE_PATH).ValueOr(""); + if (config_path.empty()) { + return; + } + odbcabstraction::PropertyMap propertyMap; - driver::odbcabstraction::ReadConfigFile(propertyMap, CONFIG_FILE_NAME); + driver::odbcabstraction::ReadConfigFile(propertyMap, config_path, CONFIG_FILE_NAME); - auto log_enable_iterator = propertyMap.find(SPDLogger::LOG_ENABLED); + auto log_enable_iterator = propertyMap.find(std::string(SPDLogger::LOG_ENABLED)); auto log_enabled = log_enable_iterator != propertyMap.end() ? odbcabstraction::AsBool(log_enable_iterator->second) : false; - if (!log_enabled) { + if (!log_enabled.get()) { return; } - auto log_path_iterator = propertyMap.find(SPDLogger::LOG_PATH); + auto log_path_iterator = propertyMap.find(std::string(SPDLogger::LOG_PATH)); auto log_path = log_path_iterator != propertyMap.end() ? log_path_iterator->second : ""; if (log_path.empty()) { return; } - auto log_level_iterator = propertyMap.find(SPDLogger::LOG_LEVEL); + auto log_level_iterator = propertyMap.find(std::string(SPDLogger::LOG_LEVEL)); auto log_level = ToLogLevel(log_level_iterator != propertyMap.end() ? std::stoi(log_level_iterator->second) : 1); @@ -88,12 +101,14 @@ void FlightSqlDriver::RegisterLog() { return; } - auto maximum_file_size_iterator = propertyMap.find(SPDLogger::MAXIMUM_FILE_SIZE); + auto maximum_file_size_iterator = + propertyMap.find(std::string(SPDLogger::MAXIMUM_FILE_SIZE)); auto maximum_file_size = maximum_file_size_iterator != propertyMap.end() ? std::stoi(maximum_file_size_iterator->second) : DEFAULT_MAXIMUM_FILE_SIZE; - auto maximum_file_quantity_iterator = propertyMap.find(SPDLogger::FILE_QUANTITY); + auto maximum_file_quantity_iterator = + propertyMap.find(std::string(SPDLogger::FILE_QUANTITY)); auto maximum_file_quantity = maximum_file_quantity_iterator != propertyMap.end() ? std::stoi(maximum_file_quantity_iterator->second) : 1; diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_get_tables_reader.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_get_tables_reader.cc index ccd6058f8cd..b048d1984c5 100644 --- a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_get_tables_reader.cc +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_get_tables_reader.cc @@ -80,9 +80,10 @@ std::shared_ptr GetTablesReader::GetSchema() { const arrow::Result>& result = arrow::ipc::ReadSchema(&dataset_schema_reader, &in_memo); if (!result.ok()) { - // TODO: Ignoring this error until we fix the problem on Dremio server - // The problem is that complex types columns are being returned without the children - // types. + // TODO: Test and build the driver against a server that returns + // complex types columns with the children + // types and handle the failure properly + // https://github.com/apache/arrow/issues/46561 return nullptr; } diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_result_set.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_result_set.cc index 824260a6868..258c810996a 100644 --- a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_result_set.cc +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_result_set.cc @@ -18,6 +18,8 @@ #include "arrow/flight/sql/odbc/flight_sql/flight_sql_result_set.h" #include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/platform.h" +#include + #include #include "arrow/flight/types.h" #include "arrow/scalar.h" @@ -42,13 +44,14 @@ using odbcabstraction::DriverException; FlightSqlResultSet::FlightSqlResultSet( FlightSqlClient& flight_sql_client, + const arrow::flight::FlightClientOptions& client_options, const arrow::flight::FlightCallOptions& call_options, const std::shared_ptr& flight_info, const std::shared_ptr& transformer, odbcabstraction::Diagnostics& diagnostics, const odbcabstraction::MetadataSettings& metadata_settings) : metadata_settings_(metadata_settings), - chunk_buffer_(flight_sql_client, call_options, flight_info, + chunk_buffer_(flight_sql_client, client_options, call_options, flight_info, metadata_settings_.chunk_buffer_capacity_), transformer_(transformer), metadata_(transformer @@ -226,14 +229,14 @@ void FlightSqlResultSet::Cancel() { current_chunk_.data = nullptr; } -bool FlightSqlResultSet::GetData(int column_n, int16_t target_type, int precision, - int scale, void* buffer, size_t buffer_length, - ssize_t* strlen_buffer) { +SQLRETURN FlightSqlResultSet::GetData(int column_n, int16_t target_type, int precision, + int scale, void* buffer, size_t buffer_length, + ssize_t* strlen_buffer) { reset_get_data_ = true; // Check if the offset is already at the end. int64_t& value_offset = get_data_offsets_[column_n - 1]; if (value_offset == -1) { - return false; + return SQL_NO_DATA; } ColumnBinding binding(ConvertCDataTypeFromV2ToV3(target_type), precision, scale, buffer, @@ -249,7 +252,11 @@ bool FlightSqlResultSet::GetData(int column_n, int16_t target_type, int precisio diagnostics_, nullptr); // If there was truncation, the converter would have reported it to the diagnostics. - return diagnostics_.HasWarning(); + if (diagnostics_.HasWarning()) { + return SQL_SUCCESS_WITH_INFO; + } else { + return SQL_SUCCESS; + } } std::shared_ptr FlightSqlResultSet::GetMetadata() { return metadata_; } diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_result_set.h b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_result_set.h index d1f20979a24..5a03b16f066 100644 --- a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_result_set.h +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_result_set.h @@ -63,6 +63,7 @@ class FlightSqlResultSet : public ResultSet { ~FlightSqlResultSet() override; FlightSqlResultSet(FlightSqlClient& flight_sql_client, + const arrow::flight::FlightClientOptions& client_options, const arrow::flight::FlightCallOptions& call_options, const std::shared_ptr& flight_info, const std::shared_ptr& transformer, @@ -73,8 +74,8 @@ class FlightSqlResultSet : public ResultSet { void Cancel() override; - bool GetData(int column_n, int16_t target_type, int precision, int scale, void* buffer, - size_t buffer_length, ssize_t* strlen_buffer) override; + SQLRETURN GetData(int column_n, int16_t target_type, int precision, int scale, + void* buffer, size_t buffer_length, ssize_t* strlen_buffer) override; size_t Move(size_t rows, size_t bind_offset, size_t bind_type, uint16_t* row_status_array) override; diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_result_set_accessors.h b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_result_set_accessors.h index 3f7d6856083..1d5014140ef 100644 --- a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_result_set_accessors.h +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_result_set_accessors.h @@ -17,9 +17,9 @@ #pragma once -#include -#include #include +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/types.h" +#include "arrow/type_fwd.h" namespace driver { namespace flight_sql { diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_result_set_metadata.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_result_set_metadata.cc index f863d4bc489..0fa6b03c4a7 100644 --- a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_result_set_metadata.cc +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_result_set_metadata.cc @@ -16,10 +16,10 @@ // under the License. #include "arrow/flight/sql/odbc/flight_sql/flight_sql_result_set_metadata.h" -#include -#include +#include "arrow/flight/sql/column_metadata.h" #include "arrow/flight/sql/odbc/flight_sql/utils.h" #include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/platform.h" +#include "arrow/util/key_value_metadata.h" #include #include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/exceptions.h" @@ -42,15 +42,9 @@ constexpr int32_t DefaultDecimalPrecision = 38; constexpr int32_t DefaultLengthForVariableLengthColumns = 1024; namespace { -std::shared_ptr empty_metadata_map( - new arrow::KeyValueMetadata); - inline arrow::flight::sql::ColumnMetadata GetMetadata( const std::shared_ptr& field) { - const auto& metadata_map = field->metadata(); - - arrow::flight::sql::ColumnMetadata metadata(metadata_map ? metadata_map - : empty_metadata_map); + arrow::flight::sql::ColumnMetadata metadata(field->metadata()); return metadata; } @@ -260,18 +254,29 @@ bool FlightSqlResultSetMetadata::IsUnsigned(int column_position) { const std::shared_ptr& field = schema_->field(column_position - 1); switch (field->type()->id()) { + case arrow::Type::INT8: + case arrow::Type::INT16: + case arrow::Type::INT32: + case arrow::Type::INT64: + case arrow::Type::DOUBLE: + case arrow::Type::FLOAT: + case arrow::Type::HALF_FLOAT: + case arrow::Type::DECIMAL32: + case arrow::Type::DECIMAL64: + case arrow::Type::DECIMAL128: + case arrow::Type::DECIMAL256: + return false; case arrow::Type::UINT8: case arrow::Type::UINT16: case arrow::Type::UINT32: case arrow::Type::UINT64: - return true; default: - return false; + return true; } } bool FlightSqlResultSetMetadata::IsFixedPrecScale(int column_position) { - // TODO: Flight SQL column metadata does not have this, should we add to the spec? + // Precision for Arrow data types are modifiable by the user return false; } diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_result_set_metadata.h b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_result_set_metadata.h index f8e78eb2d6d..29901652c52 100644 --- a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_result_set_metadata.h +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_result_set_metadata.h @@ -89,6 +89,7 @@ class FlightSqlResultSetMetadata : public odbcabstraction::ResultSetMetadata { odbcabstraction::Searchability IsSearchable(int column_position) override; + /// \brief Returns true if the column is unsigned (not numeric) bool IsUnsigned(int column_position) override; bool IsFixedPrecScale(int column_position) override; diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_ssl_config.h b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_ssl_config.h index 76a54f13ce1..2369f0aab4d 100644 --- a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_ssl_config.h +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_ssl_config.h @@ -17,9 +17,9 @@ #pragma once -#include -#include #include +#include "arrow/flight/types.h" +#include "arrow/status.h" namespace driver { namespace flight_sql { diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_statement.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_statement.cc index 1e8498ad7e3..efe333d836a 100644 --- a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_statement.cc +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_statement.cc @@ -52,9 +52,10 @@ using driver::odbcabstraction::Statement; namespace { void ClosePreparedStatementIfAny( - std::shared_ptr& prepared_statement) { + std::shared_ptr& prepared_statement, + const FlightCallOptions& options) { if (prepared_statement != nullptr) { - ThrowIfNotOK(prepared_statement->Close()); + ThrowIfNotOK(prepared_statement->Close(options)); prepared_statement.reset(); } } @@ -63,11 +64,12 @@ void ClosePreparedStatementIfAny( FlightSqlStatement::FlightSqlStatement( const odbcabstraction::Diagnostics& diagnostics, FlightSqlClient& sql_client, - FlightCallOptions call_options, + arrow::flight::FlightClientOptions client_options, FlightCallOptions call_options, const odbcabstraction::MetadataSettings& metadata_settings) : diagnostics_("Apache Arrow", diagnostics.GetDataSourceComponent(), diagnostics.GetOdbcVersion()), sql_client_(sql_client), + client_options_(std::move(client_options)), call_options_(std::move(call_options)), metadata_settings_(metadata_settings) { attribute_[METADATA_ID] = static_cast(SQL_FALSE); @@ -77,6 +79,10 @@ FlightSqlStatement::FlightSqlStatement( call_options_.timeout = TimeoutDuration{-1}; } +FlightSqlStatement::~FlightSqlStatement() { + ClosePreparedStatementIfAny(prepared_statement_, call_options_); +} + bool FlightSqlStatement::SetAttribute(StatementAttributeId attribute, const Attribute& value) { switch (attribute) { @@ -108,7 +114,7 @@ boost::optional FlightSqlStatement::GetAttribute( boost::optional> FlightSqlStatement::Prepare( const std::string& query) { - ClosePreparedStatementIfAny(prepared_statement_); + ClosePreparedStatementIfAny(prepared_statement_, call_options_); Result> result = sql_client_.Prepare(call_options_, query); @@ -124,25 +130,27 @@ boost::optional> FlightSqlStatement::Prepare( bool FlightSqlStatement::ExecutePrepared() { assert(prepared_statement_.get() != nullptr); - Result> result = prepared_statement_->Execute(); + Result> result = + prepared_statement_->Execute(call_options_); + ThrowIfNotOK(result.status()); current_result_set_ = std::make_shared( - sql_client_, call_options_, result.ValueOrDie(), nullptr, diagnostics_, - metadata_settings_); + sql_client_, client_options_, call_options_, result.ValueOrDie(), nullptr, + diagnostics_, metadata_settings_); return true; } bool FlightSqlStatement::Execute(const std::string& query) { - ClosePreparedStatementIfAny(prepared_statement_); + ClosePreparedStatementIfAny(prepared_statement_, call_options_); Result> result = sql_client_.Execute(call_options_, query); ThrowIfNotOK(result.status()); current_result_set_ = std::make_shared( - sql_client_, call_options_, result.ValueOrDie(), nullptr, diagnostics_, - metadata_settings_); + sql_client_, client_options_, call_options_, result.ValueOrDie(), nullptr, + diagnostics_, metadata_settings_); return true; } @@ -157,33 +165,35 @@ std::shared_ptr FlightSqlStatement::GetTables( const std::string* catalog_name, const std::string* schema_name, const std::string* table_name, const std::string* table_type, const ColumnNames& column_names) { - ClosePreparedStatementIfAny(prepared_statement_); + ClosePreparedStatementIfAny(prepared_statement_, call_options_); std::vector table_types; if ((catalog_name && *catalog_name == "%") && (schema_name && schema_name->empty()) && (table_name && table_name->empty())) { - current_result_set_ = GetTablesForSQLAllCatalogs( - column_names, call_options_, sql_client_, diagnostics_, metadata_settings_); + current_result_set_ = + GetTablesForSQLAllCatalogs(column_names, client_options_, call_options_, + sql_client_, diagnostics_, metadata_settings_); } else if ((catalog_name && catalog_name->empty()) && (schema_name && *schema_name == "%") && (table_name && table_name->empty())) { - current_result_set_ = - GetTablesForSQLAllDbSchemas(column_names, call_options_, sql_client_, schema_name, - diagnostics_, metadata_settings_); + current_result_set_ = GetTablesForSQLAllDbSchemas( + column_names, client_options_, call_options_, sql_client_, schema_name, + diagnostics_, metadata_settings_); } else if ((catalog_name && catalog_name->empty()) && (schema_name && schema_name->empty()) && (table_name && table_name->empty()) && (table_type && *table_type == "%")) { - current_result_set_ = GetTablesForSQLAllTableTypes( - column_names, call_options_, sql_client_, diagnostics_, metadata_settings_); + current_result_set_ = + GetTablesForSQLAllTableTypes(column_names, client_options_, call_options_, + sql_client_, diagnostics_, metadata_settings_); } else { if (table_type) { ParseTableTypes(*table_type, table_types); } current_result_set_ = GetTablesForGenericUse( - column_names, call_options_, sql_client_, catalog_name, schema_name, table_name, - table_types, diagnostics_, metadata_settings_); + column_names, client_options_, call_options_, sql_client_, catalog_name, + schema_name, table_name, table_types, diagnostics_, metadata_settings_); } return current_result_set_; @@ -210,7 +220,7 @@ std::shared_ptr FlightSqlStatement::GetTables_V3( std::shared_ptr FlightSqlStatement::GetColumns_V2( const std::string* catalog_name, const std::string* schema_name, const std::string* table_name, const std::string* column_name) { - ClosePreparedStatementIfAny(prepared_statement_); + ClosePreparedStatementIfAny(prepared_statement_, call_options_); Result> result = sql_client_.GetTables( call_options_, catalog_name, schema_name, table_name, true, nullptr); @@ -221,9 +231,9 @@ std::shared_ptr FlightSqlStatement::GetColumns_V2( auto transformer = std::make_shared( metadata_settings_, odbcabstraction::V_2, column_name); - current_result_set_ = - std::make_shared(sql_client_, call_options_, flight_info, - transformer, diagnostics_, metadata_settings_); + current_result_set_ = std::make_shared( + sql_client_, client_options_, call_options_, flight_info, transformer, diagnostics_, + metadata_settings_); return current_result_set_; } @@ -231,7 +241,7 @@ std::shared_ptr FlightSqlStatement::GetColumns_V2( std::shared_ptr FlightSqlStatement::GetColumns_V3( const std::string* catalog_name, const std::string* schema_name, const std::string* table_name, const std::string* column_name) { - ClosePreparedStatementIfAny(prepared_statement_); + ClosePreparedStatementIfAny(prepared_statement_, call_options_); Result> result = sql_client_.GetTables( call_options_, catalog_name, schema_name, table_name, true, nullptr); @@ -242,15 +252,15 @@ std::shared_ptr FlightSqlStatement::GetColumns_V3( auto transformer = std::make_shared( metadata_settings_, odbcabstraction::V_3, column_name); - current_result_set_ = - std::make_shared(sql_client_, call_options_, flight_info, - transformer, diagnostics_, metadata_settings_); + current_result_set_ = std::make_shared( + sql_client_, client_options_, call_options_, flight_info, transformer, diagnostics_, + metadata_settings_); return current_result_set_; } std::shared_ptr FlightSqlStatement::GetTypeInfo_V2(int16_t data_type) { - ClosePreparedStatementIfAny(prepared_statement_); + ClosePreparedStatementIfAny(prepared_statement_, call_options_); Result> result = sql_client_.GetXdbcTypeInfo(call_options_); ThrowIfNotOK(result.status()); @@ -260,15 +270,15 @@ std::shared_ptr FlightSqlStatement::GetTypeInfo_V2(int16_t data_type) auto transformer = std::make_shared( metadata_settings_, odbcabstraction::V_2, data_type); - current_result_set_ = - std::make_shared(sql_client_, call_options_, flight_info, - transformer, diagnostics_, metadata_settings_); + current_result_set_ = std::make_shared( + sql_client_, client_options_, call_options_, flight_info, transformer, diagnostics_, + metadata_settings_); return current_result_set_; } std::shared_ptr FlightSqlStatement::GetTypeInfo_V3(int16_t data_type) { - ClosePreparedStatementIfAny(prepared_statement_); + ClosePreparedStatementIfAny(prepared_statement_, call_options_); Result> result = sql_client_.GetXdbcTypeInfo(call_options_); ThrowIfNotOK(result.status()); @@ -278,9 +288,9 @@ std::shared_ptr FlightSqlStatement::GetTypeInfo_V3(int16_t data_type) auto transformer = std::make_shared( metadata_settings_, odbcabstraction::V_3, data_type); - current_result_set_ = - std::make_shared(sql_client_, call_options_, flight_info, - transformer, diagnostics_, metadata_settings_); + current_result_set_ = std::make_shared( + sql_client_, client_options_, call_options_, flight_info, transformer, diagnostics_, + metadata_settings_); return current_result_set_; } diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_statement.h b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_statement.h index 7ffb02ba40b..00fe9137f51 100644 --- a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_statement.h +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_statement.h @@ -33,6 +33,7 @@ class FlightSqlStatement : public odbcabstraction::Statement { private: odbcabstraction::Diagnostics diagnostics_; std::map attribute_; + arrow::flight::FlightClientOptions client_options_; arrow::flight::FlightCallOptions call_options_; arrow::flight::sql::FlightSqlClient& sql_client_; std::shared_ptr current_result_set_; @@ -48,8 +49,10 @@ class FlightSqlStatement : public odbcabstraction::Statement { public: FlightSqlStatement(const odbcabstraction::Diagnostics& diagnostics, arrow::flight::sql::FlightSqlClient& sql_client, + arrow::flight::FlightClientOptions client_options, arrow::flight::FlightCallOptions call_options, const odbcabstraction::MetadataSettings& metadata_settings); + ~FlightSqlStatement(); bool SetAttribute(StatementAttributeId attribute, const Attribute& value) override; diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_statement_get_columns.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_statement_get_columns.cc index 0e250d1af9b..f61c198cd23 100644 --- a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_statement_get_columns.cc +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_statement_get_columns.cc @@ -98,10 +98,10 @@ Result> Transform_inner( const auto& table_name = reader.GetTableName(); const std::shared_ptr& schema = reader.GetSchema(); if (schema == nullptr) { - // TODO: Remove this if after fixing TODO on GetTablesReader::GetSchema() - // This is because of a problem on Dremio server, where complex types columns - // are being returned without the children types, so we are simply ignoring - // it by now. + // TODO: Test and build the driver against a server that returns + // complex types columns with the children + // types and handle the failure properly. + // https://github.com/apache/arrow/issues/46561 continue; } for (int i = 0; i < schema->num_fields(); ++i) { @@ -125,8 +125,8 @@ Result> Transform_inner( ? data_type_v3 : ConvertSqlDataTypeFromV3ToV2(data_type_v3); - // TODO: Use `metadata.GetTypeName()` when ARROW-16064 is merged. - const auto& type_name_result = field->metadata()->Get("ARROW:FLIGHT:SQL:TYPE_NAME"); + const auto& type_name_result = metadata.GetTypeName(); + data.type_name = type_name_result.ok() ? type_name_result.ValueOrDie() : GetTypeNameFromSqlDataType(data_type_v3); diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_statement_get_tables.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_statement_get_tables.cc index a3cdf9768d2..7dfedac6392 100644 --- a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_statement_get_tables.cc +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_statement_get_tables.cc @@ -22,6 +22,7 @@ #include "arrow/flight/sql/odbc/flight_sql/utils.h" #include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/platform.h" #include "arrow/flight/types.h" +#include "arrow/util/string.h" namespace driver { namespace flight_sql { @@ -31,6 +32,15 @@ using arrow::flight::FlightClientOptions; using arrow::flight::FlightInfo; using arrow::flight::sql::FlightSqlClient; +static void AddTableType(std::string& table_type, std::vector& table_types) { + std::string trimmed_type = arrow::internal::TrimString(table_type); + + // Only put the string if the trimmed result is non-empty + if (trimmed_type.length() > 0) { + table_types.emplace_back(trimmed_type); + } +} + void ParseTableTypes(const std::string& table_type, std::vector& table_types) { bool encountered = false; // for checking if there is a single quote @@ -39,41 +49,29 @@ void ParseTableTypes(const std::string& table_type, for (char temp : table_type) { // while still in the string switch (temp) { // switch depending on the character case '\'': // if the character is a single quote - if (encountered) { - encountered = false; // if we already found a single quote, reset encountered - } else { - encountered = - true; // if we haven't found a single quote, set encountered to true - } + // track when we've encountered a single opening quote + // and are still looking for the closing quote + encountered = !encountered; break; - case ',': // if it is a comma - if (!encountered) { // if we have not found a single quote - table_types.push_back(curr_parse); // put our current string into our vector - curr_parse = ""; // reset the current string + case ',': // if it is a comma + if (!encountered) { // if we have not found a single quote + AddTableType(curr_parse, table_types); // put current string into vector + curr_parse = ""; // reset the current string break; } - default: // if it is a normal character - if (encountered && isspace(temp)) { - curr_parse.push_back(temp); // if we have found a single quote put the - // whitespace, we don't care - } else if (temp == '\'' || temp == ' ') { - break; // if the current character is a single quote, trash it and go to - // the next character. - } else { - curr_parse.push_back(temp); // if all of the above failed, put the - // character into the current string - } - break; // go to the next character + [[fallthrough]]; + default: // if it is a normal character + curr_parse.push_back(temp); // put the character into the current string + break; // go to the next character } } - table_types.emplace_back( - curr_parse); // if we have found a single quote put the whitespace, - // we don't care + AddTableType(curr_parse, table_types); } std::shared_ptr GetTablesForSQLAllCatalogs( - const ColumnNames& names, FlightCallOptions& call_options, - FlightSqlClient& sql_client, odbcabstraction::Diagnostics& diagnostics, + const ColumnNames& names, FlightClientOptions& client_options, + FlightCallOptions& call_options, FlightSqlClient& sql_client, + odbcabstraction::Diagnostics& diagnostics, const odbcabstraction::MetadataSettings& metadata_settings) { Result> result = sql_client.GetCatalogs(call_options); @@ -92,14 +90,15 @@ std::shared_ptr GetTablesForSQLAllCatalogs( .AddFieldOfNulls(names.remarks_column, arrow::utf8()) .Build(); - return std::make_shared( - sql_client, call_options, flight_info, transformer, diagnostics, metadata_settings); + return std::make_shared(sql_client, client_options, call_options, + flight_info, transformer, diagnostics, + metadata_settings); } std::shared_ptr GetTablesForSQLAllDbSchemas( - const ColumnNames& names, FlightCallOptions& call_options, - FlightSqlClient& sql_client, const std::string* schema_name, - odbcabstraction::Diagnostics& diagnostics, + const ColumnNames& names, FlightClientOptions& client_options, + FlightCallOptions& call_options, FlightSqlClient& sql_client, + const std::string* schema_name, odbcabstraction::Diagnostics& diagnostics, const odbcabstraction::MetadataSettings& metadata_settings) { Result> result = sql_client.GetDbSchemas(call_options, nullptr, schema_name); @@ -119,13 +118,15 @@ std::shared_ptr GetTablesForSQLAllDbSchemas( .AddFieldOfNulls(names.remarks_column, arrow::utf8()) .Build(); - return std::make_shared( - sql_client, call_options, flight_info, transformer, diagnostics, metadata_settings); + return std::make_shared(sql_client, client_options, call_options, + flight_info, transformer, diagnostics, + metadata_settings); } std::shared_ptr GetTablesForSQLAllTableTypes( - const ColumnNames& names, FlightCallOptions& call_options, - FlightSqlClient& sql_client, odbcabstraction::Diagnostics& diagnostics, + const ColumnNames& names, FlightClientOptions& client_options, + FlightCallOptions& call_options, FlightSqlClient& sql_client, + odbcabstraction::Diagnostics& diagnostics, const odbcabstraction::MetadataSettings& metadata_settings) { Result> result = sql_client.GetTableTypes(call_options); @@ -144,15 +145,16 @@ std::shared_ptr GetTablesForSQLAllTableTypes( .AddFieldOfNulls(names.remarks_column, arrow::utf8()) .Build(); - return std::make_shared( - sql_client, call_options, flight_info, transformer, diagnostics, metadata_settings); + return std::make_shared(sql_client, client_options, call_options, + flight_info, transformer, diagnostics, + metadata_settings); } std::shared_ptr GetTablesForGenericUse( - const ColumnNames& names, FlightCallOptions& call_options, - FlightSqlClient& sql_client, const std::string* catalog_name, - const std::string* schema_name, const std::string* table_name, - const std::vector& table_types, + const ColumnNames& names, FlightClientOptions& client_options, + FlightCallOptions& call_options, FlightSqlClient& sql_client, + const std::string* catalog_name, const std::string* schema_name, + const std::string* table_name, const std::vector& table_types, odbcabstraction::Diagnostics& diagnostics, const odbcabstraction::MetadataSettings& metadata_settings) { Result> result = sql_client.GetTables( @@ -173,8 +175,9 @@ std::shared_ptr GetTablesForGenericUse( .AddFieldOfNulls(names.remarks_column, arrow::utf8()) .Build(); - return std::make_shared( - sql_client, call_options, flight_info, transformer, diagnostics, metadata_settings); + return std::make_shared(sql_client, client_options, call_options, + flight_info, transformer, diagnostics, + metadata_settings); } } // namespace flight_sql diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_statement_get_tables.h b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_statement_get_tables.h index 8f0dc5fef6d..0f5ac461f3f 100644 --- a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_statement_get_tables.h +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_statement_get_tables.h @@ -30,6 +30,7 @@ namespace driver { namespace flight_sql { using arrow::flight::FlightCallOptions; +using arrow::flight::FlightClientOptions; using arrow::flight::sql::FlightSqlClient; using odbcabstraction::MetadataSettings; using odbcabstraction::ResultSet; @@ -46,26 +47,28 @@ void ParseTableTypes(const std::string& table_type, std::vector& table_types); std::shared_ptr GetTablesForSQLAllCatalogs( - const ColumnNames& column_names, FlightCallOptions& call_options, - FlightSqlClient& sql_client, odbcabstraction::Diagnostics& diagnostics, + const ColumnNames& column_names, FlightClientOptions& client_options, + FlightCallOptions& call_options, FlightSqlClient& sql_client, + odbcabstraction::Diagnostics& diagnostics, const odbcabstraction::MetadataSettings& metadata_settings); std::shared_ptr GetTablesForSQLAllDbSchemas( - const ColumnNames& column_names, FlightCallOptions& call_options, - FlightSqlClient& sql_client, const std::string* schema_name, - odbcabstraction::Diagnostics& diagnostics, + const ColumnNames& column_names, FlightClientOptions& client_options, + FlightCallOptions& call_options, FlightSqlClient& sql_client, + const std::string* schema_name, odbcabstraction::Diagnostics& diagnostics, const odbcabstraction::MetadataSettings& metadata_settings); std::shared_ptr GetTablesForSQLAllTableTypes( - const ColumnNames& column_names, FlightCallOptions& call_options, - FlightSqlClient& sql_client, odbcabstraction::Diagnostics& diagnostics, + const ColumnNames& column_names, FlightClientOptions& client_options, + FlightCallOptions& call_options, FlightSqlClient& sql_client, + odbcabstraction::Diagnostics& diagnostics, const odbcabstraction::MetadataSettings& metadata_settings); std::shared_ptr GetTablesForGenericUse( - const ColumnNames& column_names, FlightCallOptions& call_options, - FlightSqlClient& sql_client, const std::string* catalog_name, - const std::string* schema_name, const std::string* table_name, - const std::vector& table_types, + const ColumnNames& column_names, FlightClientOptions& client_options, + FlightCallOptions& call_options, FlightSqlClient& sql_client, + const std::string* catalog_name, const std::string* schema_name, + const std::string* table_name, const std::vector& table_types, odbcabstraction::Diagnostics& diagnostics, const odbcabstraction::MetadataSettings& metadata_settings); } // namespace flight_sql diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_statement_get_type_info.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_statement_get_type_info.cc index 3faf607c5f2..3f59f4abcc1 100644 --- a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_statement_get_type_info.cc +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_statement_get_type_info.cc @@ -105,7 +105,7 @@ Result> Transform_inner( data.literal_suffix = reader.GetLiteralSuffix(); const auto& create_params = reader.GetCreateParams(); - if (create_params) { + if (create_params && !create_params->empty()) { data.create_params = boost::algorithm::join(*create_params, ","); } else { data.create_params = nullopt; @@ -114,6 +114,8 @@ Result> Transform_inner( data.nullable = reader.GetNullable() ? odbcabstraction::NULLABILITY_NULLABLE : odbcabstraction::NULLABILITY_NO_NULLS; data.case_sensitive = reader.GetCaseSensitive(); + // GH-47237 return SEARCHABILITY_LIKE_ONLY and SEARCHABILITY_ALL_EXPECT_LIKE for + // appropriate data types data.searchable = reader.GetSearchable() ? odbcabstraction::SEARCHABILITY_ALL : odbcabstraction::SEARCHABILITY_NONE; data.unsigned_attribute = reader.GetUnsignedAttribute(); @@ -122,9 +124,9 @@ Result> Transform_inner( data.local_type_name = reader.GetLocalTypeName(); data.minimum_scale = reader.GetMinimumScale(); data.maximum_scale = reader.GetMaximumScale(); - data.sql_data_type = EnsureRightSqlCharType( + data.sql_data_type = GetNonConciseDataType(EnsureRightSqlCharType( static_cast(reader.GetSqlDataType()), - metadata_settings_.use_wide_char_); + metadata_settings_.use_wide_char_)); data.sql_datetime_sub = GetSqlDateTimeSubCode(static_cast(data.data_type)); data.num_prec_radix = reader.GetNumPrecRadix(); diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_stream_chunk_buffer.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_stream_chunk_buffer.cc index 093a46dfe83..7da2d6ca89d 100644 --- a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_stream_chunk_buffer.cc +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_stream_chunk_buffer.cc @@ -21,38 +21,71 @@ namespace driver { namespace flight_sql { +using arrow::flight::FlightClient; using arrow::flight::FlightEndpoint; FlightStreamChunkBuffer::FlightStreamChunkBuffer( FlightSqlClient& flight_sql_client, + const arrow::flight::FlightClientOptions& client_options, const arrow::flight::FlightCallOptions& call_options, const std::shared_ptr& flight_info, size_t queue_capacity) : queue_(queue_capacity) { - // FIXME: Endpoint iteration should consider endpoints may be at different hosts for (const auto& endpoint : flight_info->endpoints()) { const arrow::flight::Ticket& ticket = endpoint.ticket; - auto result = flight_sql_client.DoGet(call_options, ticket); + arrow::Result> result; + std::shared_ptr temp_flight_sql_client; + auto endpoint_locations = endpoint.locations; + if (endpoint_locations.empty()) { + // list of Locations needs to be empty to proceed + result = flight_sql_client.DoGet(call_options, ticket); + } else { + // If it is non-empty, the driver should create a FlightSqlClient to connect to one + // of the specified Locations directly. + + // GH-47117: Currently a new FlightClient will be made for each partition that + // returns a non-empty Location, which is then disposed of. It may be better to + // cache clients because a server may report the same Locations. It would also be + // good to identify when the reported Location is the same as the original + // connection's Location and skip creating a FlightClient in that scenario. + + std::unique_ptr temp_flight_client; + ThrowIfNotOK(FlightClient::Connect(endpoint_locations[0], client_options) + .Value(&temp_flight_client)); + temp_flight_sql_client.reset(new FlightSqlClient(std::move(temp_flight_client))); + + result = temp_flight_sql_client->DoGet(call_options, ticket); + } + ThrowIfNotOK(result.status()); std::shared_ptr stream_reader_ptr(std::move(result.ValueOrDie())); - BlockingQueue>::Supplier supplier = [=] { + BlockingQueue, + std::shared_ptr>>::Supplier supplier = [=] { auto result = stream_reader_ptr->Next(); bool isNotOk = !result.ok(); bool isNotEmpty = result.ok() && (result.ValueOrDie().data != nullptr); - return boost::make_optional(isNotOk || isNotEmpty, std::move(result)); + // If result is valid, save the temp Flight SQL Client for future stream reader + // call. temp_flight_sql_client is intentionally null if the list of endpoint + // Locations is empty. + // After all data is fetched from reader, the temp client is closed. + return boost::make_optional( + isNotOk || isNotEmpty, + std::make_pair(std::move(result), temp_flight_sql_client)); }; queue_.AddProducer(std::move(supplier)); } } bool FlightStreamChunkBuffer::GetNext(FlightStreamChunk* chunk) { - Result result; - if (!queue_.Pop(&result)) { + std::pair, std::shared_ptr> + closeableEndpointStreamPair; + if (!queue_.Pop(&closeableEndpointStreamPair)) { return false; } + Result result = closeableEndpointStreamPair.first; if (!result.status().ok()) { Close(); throw odbcabstraction::DriverException(result.status().message()); diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_stream_chunk_buffer.h b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_stream_chunk_buffer.h index 4a84bcbede0..5d5616a4f02 100644 --- a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_stream_chunk_buffer.h +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_stream_chunk_buffer.h @@ -17,9 +17,9 @@ #pragma once -#include -#include -#include +#include "arrow/flight/client.h" +#include "arrow/flight/sql/client.h" +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/blocking_queue.h" namespace driver { namespace flight_sql { @@ -32,10 +32,12 @@ using arrow::flight::sql::FlightSqlClient; using driver::odbcabstraction::BlockingQueue; class FlightStreamChunkBuffer { - BlockingQueue> queue_; + BlockingQueue, std::shared_ptr>> + queue_; public: FlightStreamChunkBuffer(FlightSqlClient& flight_sql_client, + const arrow::flight::FlightClientOptions& client_options, const arrow::flight::FlightCallOptions& call_options, const std::shared_ptr& flight_info, size_t queue_capacity = 5); diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_stream_chunk_buffer_test.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_stream_chunk_buffer_test.cc new file mode 100644 index 00000000000..6857b53f5c2 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_stream_chunk_buffer_test.cc @@ -0,0 +1,137 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/array.h" + +#include "arrow/testing/gtest_util.h" + +#include "arrow/flight/sql/odbc/flight_sql/flight_sql_stream_chunk_buffer.h" +#include "arrow/flight/sql/odbc/flight_sql/json_converter.h" +#include "arrow/flight/test_flight_server.h" +#include "arrow/flight/test_util.h" +#include "gtest/gtest.h" + +namespace driver { +namespace flight_sql { + +using arrow::Array; +using arrow::flight::FlightCallOptions; +using arrow::flight::FlightClientOptions; +using arrow::flight::FlightDescriptor; +using arrow::flight::FlightEndpoint; +using arrow::flight::Location; +using arrow::flight::Ticket; +using arrow::flight::sql::FlightSqlClient; + +class FlightStreamChunkBufferTest : public ::testing::Test { + // Sets up two mock servers for each test case. + // This is for testing endpoint iteration only. + + protected: + void SetUp() override { + // Set up server 1 + server1 = std::make_shared(); + ASSERT_OK_AND_ASSIGN(auto location1, Location::ForGrpcTcp("0.0.0.0", 0)); + arrow::flight::FlightServerOptions options1(location1); + ASSERT_OK(server1->Init(options1)); + ASSERT_OK_AND_ASSIGN(server_location1, + Location::ForGrpcTcp("localhost", server1->port())); + + // Set up server 2 + server2 = std::make_shared(); + ASSERT_OK_AND_ASSIGN(auto location2, Location::ForGrpcTcp("0.0.0.0", 0)); + arrow::flight::FlightServerOptions options2(location2); + ASSERT_OK(server2->Init(options2)); + ASSERT_OK_AND_ASSIGN(server_location2, + Location::ForGrpcTcp("localhost", server2->port())); + + // Make SQL Client that is connected to server 1 + ASSERT_OK_AND_ASSIGN(auto client, arrow::flight::FlightClient::Connect(location1)); + sql_client.reset(new FlightSqlClient(std::move(client))); + } + + void TearDown() override { + ASSERT_OK(server1->Shutdown()); + ASSERT_OK(server2->Shutdown()); + } + + public: + arrow::flight::Location server_location1; + std::shared_ptr server1; + arrow::flight::Location server_location2; + std::shared_ptr server2; + std::shared_ptr sql_client; +}; + +FlightInfo MultipleEndpointsFlightInfo(Location location1, Location location2) { + // Sever will generate random data for `ticket-ints-1` + FlightEndpoint endpoint1({Ticket{"ticket-ints-1"}, {location1}, std::nullopt, {}}); + FlightEndpoint endpoint2({Ticket{"ticket-ints-1"}, {location2}, std::nullopt, {}}); + + FlightDescriptor descr1{FlightDescriptor::PATH, "", {"examples", "ints"}}; + + auto schema1 = arrow::flight::ExampleIntSchema(); + + return arrow::flight::MakeFlightInfo(*schema1, descr1, {endpoint1, endpoint2}, 1000, + 100000, false, ""); +} + +void verifyArraysContainIntsOnly(std::shared_ptr intArray) { + for (int64_t i = 0; i < intArray->length(); ++i) { + // null values are accepted + if (!intArray->IsNull(i)) { + auto scalar_data = intArray->GetScalar(i).ValueOrDie(); + std::string scalar_str = ConvertToJson(*scalar_data); + ASSERT_TRUE(std::all_of(scalar_str.begin(), scalar_str.end(), ::isdigit)); + } + } +} + +TEST_F(FlightStreamChunkBufferTest, TestMultipleEndpointsInt) { + FlightClientOptions client_options = FlightClientOptions::Defaults(); + FlightCallOptions options; + FlightInfo info = MultipleEndpointsFlightInfo(server_location1, server_location2); + std::shared_ptr info_ptr = std::make_shared(info); + + FlightStreamChunkBuffer chunk_buffer(*sql_client, client_options, options, info_ptr); + + FlightStreamChunk current_chunk; + + // Server returns 5 batch of results from each endpoints. + // Each batch contains 8 columns + int num_chunks = 0; + while (chunk_buffer.GetNext(¤t_chunk)) { + num_chunks++; + + int num_cols = current_chunk.data->num_columns(); + EXPECT_EQ(num_cols, 8); + + for (int i = 0; i < num_cols; i++) { + auto array = current_chunk.data->column(i); + // Each array has random length + EXPECT_GT(array->length(), 0); + + verifyArraysContainIntsOnly(array); + } + } + + // Verify 5 batches of data is returned by each of the two endpoints. + // In total 10 batches should be returned. + EXPECT_EQ(num_chunks, 10); +} +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/get_info_cache.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/get_info_cache.cc index d18322badbe..1e4cdbeb65d 100644 --- a/cpp/src/arrow/flight/sql/odbc/flight_sql/get_info_cache.cc +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/get_info_cache.cc @@ -204,15 +204,20 @@ inline void SetDefaultIfMissing( namespace driver { namespace flight_sql { using arrow::flight::FlightCallOptions; +using arrow::flight::FlightClientOptions; using arrow::flight::sql::FlightSqlClient; using arrow::flight::sql::SqlInfoOptions; using driver::odbcabstraction::Connection; using driver::odbcabstraction::DriverException; -GetInfoCache::GetInfoCache(FlightCallOptions& call_options, +GetInfoCache::GetInfoCache(FlightClientOptions& client_options, + FlightCallOptions& call_options, std::unique_ptr& client, const std::string& driver_version) - : call_options_(call_options), sql_client_(client), has_server_info_(false) { + : client_options_(client_options), + call_options_(call_options), + sql_client_(client), + has_server_info_(false) { info_[SQL_DRIVER_NAME] = "Arrow Flight ODBC Driver"; info_[SQL_DRIVER_VER] = ConvertToDBMSVer(driver_version); @@ -294,7 +299,8 @@ bool GetInfoCache::LoadInfoFromServer() { arrow::Result> result = sql_client_->GetSqlInfo(call_options_, {}); ThrowIfNotOK(result.status()); - FlightStreamChunkBuffer chunk_iter(*sql_client_, call_options_, result.ValueOrDie()); + FlightStreamChunkBuffer chunk_iter(*sql_client_, client_options_, call_options_, + result.ValueOrDie()); FlightStreamChunk chunk; bool supports_correlation_name = false; @@ -1173,6 +1179,7 @@ void GetInfoCache::LoadDefaultsForMissingEntries() { SetDefaultIfMissing(info_, SQL_CONVERT_DECIMAL, static_cast(0)); SetDefaultIfMissing(info_, SQL_CONVERT_DOUBLE, static_cast(0)); SetDefaultIfMissing(info_, SQL_CONVERT_FLOAT, static_cast(0)); + SetDefaultIfMissing(info_, SQL_CONVERT_FUNCTIONS, static_cast(0)); SetDefaultIfMissing(info_, SQL_CONVERT_GUID, static_cast(0)); SetDefaultIfMissing(info_, SQL_CONVERT_INTEGER, static_cast(0)); SetDefaultIfMissing(info_, SQL_CONVERT_INTERVAL_YEAR_MONTH, static_cast(0)); @@ -1251,6 +1258,7 @@ void GetInfoCache::LoadDefaultsForMissingEntries() { SetDefaultIfMissing(info_, SQL_MAX_COLUMNS_IN_ORDER_BY, static_cast(0)); SetDefaultIfMissing(info_, SQL_MAX_COLUMNS_IN_SELECT, static_cast(0)); SetDefaultIfMissing(info_, SQL_MAX_COLUMNS_IN_TABLE, static_cast(0)); + SetDefaultIfMissing(info_, SQL_MAX_CONCURRENT_ACTIVITIES, static_cast(0)); SetDefaultIfMissing(info_, SQL_MAX_CURSOR_NAME_LEN, static_cast(0)); SetDefaultIfMissing(info_, SQL_MAX_DRIVER_CONNECTIONS, static_cast(0)); SetDefaultIfMissing(info_, SQL_MAX_IDENTIFIER_LEN, static_cast(65535)); @@ -1270,6 +1278,7 @@ void GetInfoCache::LoadDefaultsForMissingEntries() { SetDefaultIfMissing(info_, SQL_OJ_CAPABILITIES, static_cast(SQL_OJ_LEFT | SQL_OJ_RIGHT | SQL_OJ_FULL)); SetDefaultIfMissing(info_, SQL_ORDER_BY_COLUMNS_IN_SELECT, "Y"); + SetDefaultIfMissing(info_, SQL_OUTER_JOINS, "N"); SetDefaultIfMissing(info_, SQL_PROCEDURE_TERM, ""); SetDefaultIfMissing(info_, SQL_PROCEDURES, "N"); SetDefaultIfMissing(info_, SQL_QUOTED_IDENTIFIER_CASE, @@ -1278,6 +1287,7 @@ void GetInfoCache::LoadDefaultsForMissingEntries() { SetDefaultIfMissing(info_, SQL_SCHEMA_USAGE, static_cast(SQL_SU_DML_STATEMENTS)); SetDefaultIfMissing(info_, SQL_SEARCH_PATTERN_ESCAPE, "\\"); + SetDefaultIfMissing(info_, SQL_SPECIAL_CHARACTERS, ""); SetDefaultIfMissing( info_, SQL_SERVER_NAME, "Arrow Flight SQL Server"); // This might actually need to be the hostname. @@ -1332,6 +1342,16 @@ void GetInfoCache::LoadDefaultsForMissingEntries() { SQL_FN_TSI_FRAC_SECOND | SQL_FN_TSI_SECOND | SQL_FN_TSI_MINUTE | SQL_FN_TSI_HOUR | SQL_FN_TSI_DAY | SQL_FN_TSI_WEEK | SQL_FN_TSI_MONTH | SQL_FN_TSI_QUARTER | SQL_FN_TSI_YEAR)); + SetDefaultIfMissing( + info_, SQL_TIMEDATE_FUNCTIONS, + static_cast( + SQL_FN_TD_CURRENT_DATE | SQL_FN_TD_CURRENT_TIME | SQL_FN_TD_CURRENT_TIMESTAMP | + SQL_FN_TD_CURDATE | SQL_FN_TD_CURTIME | SQL_FN_TD_DAYNAME | + SQL_FN_TD_DAYOFMONTH | SQL_FN_TD_DAYOFWEEK | SQL_FN_TD_DAYOFYEAR | + SQL_FN_TD_EXTRACT | SQL_FN_TD_HOUR | SQL_FN_TD_MINUTE | SQL_FN_TD_MONTH | + SQL_FN_TD_MONTHNAME | SQL_FN_TD_NOW | SQL_FN_TD_QUARTER | SQL_FN_TD_SECOND | + SQL_FN_TD_TIMESTAMPADD | SQL_FN_TD_TIMESTAMPDIFF | SQL_FN_TD_WEEK | + SQL_FN_TD_YEAR)); SetDefaultIfMissing(info_, SQL_UNION, static_cast(SQL_U_UNION | SQL_U_UNION_ALL)); SetDefaultIfMissing(info_, SQL_XOPEN_CLI_YEAR, "1995"); diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/get_info_cache.h b/cpp/src/arrow/flight/sql/odbc/flight_sql/get_info_cache.h index a54dda2e13b..547fb1cdf28 100644 --- a/cpp/src/arrow/flight/sql/odbc/flight_sql/get_info_cache.h +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/get_info_cache.h @@ -17,12 +17,12 @@ #pragma once -#include -#include #include #include #include #include +#include "arrow/flight/sql/client.h" +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/spi/connection.h" namespace driver { namespace flight_sql { @@ -30,13 +30,15 @@ namespace flight_sql { class GetInfoCache { private: std::unordered_map info_; + arrow::flight::FlightClientOptions& client_options_; arrow::flight::FlightCallOptions& call_options_; std::unique_ptr& sql_client_; std::mutex mutex_; std::atomic has_server_info_; public: - GetInfoCache(arrow::flight::FlightCallOptions& call_options, + GetInfoCache(arrow::flight::FlightClientOptions& client_options, + arrow::flight::FlightCallOptions& call_options, std::unique_ptr& client, const std::string& driver_version); void SetProperty(uint16_t property, driver::odbcabstraction::Connection::Info value); diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/include/flight_sql/config/configuration.h b/cpp/src/arrow/flight/sql/odbc/flight_sql/include/flight_sql/config/configuration.h index 69fa8a8696c..c94cc5b7832 100644 --- a/cpp/src/arrow/flight/sql/odbc/flight_sql/include/flight_sql/config/configuration.h +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/include/flight_sql/config/configuration.h @@ -46,27 +46,21 @@ class Configuration { */ ~Configuration(); - /** - * Convert configure to connect string. - * - * @return Connect string. - */ - std::string ToConnectString() const; - void LoadDefaults(); void LoadDsn(const std::string& dsn); void Clear(); bool IsSet(const std::string_view& key) const; const std::string& Get(const std::string_view& key) const; + void Set(const std::string_view& key, const std::wstring& wValue); void Set(const std::string_view& key, const std::string& value); - + void Emplace(const std::string_view& key, std::string&& value); /** * Get properties map. */ const driver::odbcabstraction::Connection::ConnPropertyMap& GetProperties() const; - std::vector GetCustomKeys() const; + std::vector GetCustomKeys() const; private: driver::odbcabstraction::Connection::ConnPropertyMap properties; diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/include/flight_sql/flight_sql_driver.h b/cpp/src/arrow/flight/sql/odbc/flight_sql/include/flight_sql/flight_sql_driver.h index 88460cdf5b2..48f2a16416a 100644 --- a/cpp/src/arrow/flight/sql/odbc/flight_sql/include/flight_sql/flight_sql_driver.h +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/include/flight_sql/flight_sql_driver.h @@ -17,8 +17,8 @@ #pragma once -#include -#include +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/diagnostics.h" +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/spi/driver.h" namespace driver { namespace flight_sql { diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/include/flight_sql/ui/add_property_window.h b/cpp/src/arrow/flight/sql/odbc/flight_sql/include/flight_sql/ui/add_property_window.h index 01d93829a46..b7a8016447c 100644 --- a/cpp/src/arrow/flight/sql/odbc/flight_sql/include/flight_sql/ui/add_property_window.h +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/include/flight_sql/ui/add_property_window.h @@ -70,7 +70,7 @@ class AddPropertyWindow : public CustomWindow { * * @return true if the dialog was OK'd, false otherwise. */ - bool GetProperty(std::string& key, std::string& value); + bool GetProperty(std::wstring& key, std::wstring& value); private: /** @@ -97,9 +97,9 @@ class AddPropertyWindow : public CustomWindow { std::unique_ptr valueEdit; - std::string key; + std::wstring key; - std::string value; + std::wstring value; /** Window width. */ int width; diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/include/flight_sql/ui/custom_window.h b/cpp/src/arrow/flight/sql/odbc/flight_sql/include/flight_sql/ui/custom_window.h index 0fc3737ed8b..649f0ef6547 100644 --- a/cpp/src/arrow/flight/sql/odbc/flight_sql/include/flight_sql/ui/custom_window.h +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/include/flight_sql/ui/custom_window.h @@ -65,7 +65,7 @@ class CustomWindow : public Window { * @param className Window class name. * @param title Window title. */ - CustomWindow(Window* parent, const char* className, const char* title); + CustomWindow(Window* parent, const wchar_t* className, const wchar_t* title); /** * Destructor. diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/include/flight_sql/ui/window.h b/cpp/src/arrow/flight/sql/odbc/flight_sql/include/flight_sql/ui/window.h index e56ad88dec6..596ff47c577 100644 --- a/cpp/src/arrow/flight/sql/odbc/flight_sql/include/flight_sql/ui/window.h +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/include/flight_sql/ui/window.h @@ -44,7 +44,7 @@ class Window { * @param className Window class name. * @param title Window title. */ - Window(Window* parent, const char* className, const char* title); + Window(Window* parent, const wchar_t* className, const wchar_t* title); /** * Constructor for the existing window. @@ -102,7 +102,7 @@ class Window { * @return Auto pointer containing new window. */ std::unique_ptr CreateGroupBox(int posX, int posY, int sizeX, int sizeY, - const char* title, int id); + const wchar_t* title, int id); /** * Create child label window. @@ -116,7 +116,7 @@ class Window { * @return Auto pointer containing new window. */ std::unique_ptr CreateLabel(int posX, int posY, int sizeX, int sizeY, - const char* title, int id); + const wchar_t* title, int id); /** * Create child Edit window. @@ -131,7 +131,7 @@ class Window { * @return Auto pointer containing new window. */ std::unique_ptr CreateEdit(int posX, int posY, int sizeX, int sizeY, - const char* title, int id, int style = 0); + const wchar_t* title, int id, int style = 0); /** * Create child button window. @@ -146,7 +146,7 @@ class Window { * @return Auto pointer containing new window. */ std::unique_ptr CreateButton(int posX, int posY, int sizeX, int sizeY, - const char* title, int id, int style = 0); + const wchar_t* title, int id, int style = 0); /** * Create child CheckBox window. @@ -161,7 +161,7 @@ class Window { * @return Auto pointer containing new window. */ std::unique_ptr CreateCheckBox(int posX, int posY, int sizeX, int sizeY, - const char* title, int id, bool state); + const wchar_t* title, int id, bool state); /** * Create child ComboBox window. @@ -175,7 +175,7 @@ class Window { * @return Auto pointer containing new window. */ std::unique_ptr CreateComboBox(int posX, int posY, int sizeX, int sizeY, - const char* title, int id); + const wchar_t* title, int id); /** * Show window. @@ -201,15 +201,15 @@ class Window { void SetVisible(bool isVisible); - void ListAddColumn(const std::string& name, int index, int width); + void ListAddColumn(const std::wstring& name, int index, int width); - void ListAddItem(const std::vector& items); + void ListAddItem(const std::vector& items); void ListDeleteSelectedItem(); - std::vector > ListGetAll(); + std::vector > ListGetAll(); - void AddTab(const std::string& name, int index); + void AddTab(const std::wstring& name, int index); bool IsTextEmpty() const; @@ -218,14 +218,14 @@ class Window { * * @param text Text. */ - void GetText(std::string& text) const; + void GetText(std::wstring& text) const; /** * Set window text. * * @param text Text. */ - void SetText(const std::string& text) const; + void SetText(const std::wstring& text) const; /** * Get CheckBox state. @@ -246,7 +246,7 @@ class Window { * * @param str String. */ - void AddString(const std::string& str); + void AddString(const std::wstring& str); /** * Set current ComboBox selection. @@ -285,10 +285,10 @@ class Window { void SetHandle(HWND value) { handle = value; } /** Window class name. */ - std::string className; + std::wstring className; /** Window title. */ - std::string title; + std::wstring title; /** Window handle. */ HWND handle; diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/json_converter.h b/cpp/src/arrow/flight/sql/odbc/flight_sql/json_converter.h index de466af4f77..83809265df4 100644 --- a/cpp/src/arrow/flight/sql/odbc/flight_sql/json_converter.h +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/json_converter.h @@ -17,8 +17,8 @@ #pragma once -#include #include +#include "arrow/type_fwd.h" namespace driver { namespace flight_sql { diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/main.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/main.cc index e112fdf67c0..aaf267cc268 100644 --- a/cpp/src/arrow/flight/sql/odbc/flight_sql/main.cc +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/main.cc @@ -43,7 +43,7 @@ using driver::odbcabstraction::Statement; void TestBindColumn(const std::shared_ptr& connection) { const std::shared_ptr& statement = connection->CreateStatement(); - statement->Execute("SELECT IncidntNum, Category FROM \"@dremio\".Test LIMIT 10"); + statement->Execute("SELECT IncidntNum, Category FROM \"@apache\".Test LIMIT 10"); const std::shared_ptr& result_set = statement->GetResultSet(); @@ -105,7 +105,7 @@ void TestBindColumnBigInt(const std::shared_ptr& connection) { " SELECT CONVERT_TO_INTEGER(IncidntNum, 1, 1, 0) AS IncidntNum, " "Category\n" " FROM (\n" - " SELECT IncidntNum, Category FROM \"@dremio\".Test LIMIT 10\n" + " SELECT IncidntNum, Category FROM \"@apache\".Test LIMIT 10\n" " ) nested_0\n" ") nested_0"); @@ -202,11 +202,11 @@ int main() { driver.CreateConnection(driver::odbcabstraction::V_3); Connection::ConnPropertyMap properties = { - {FlightSqlConnection::HOST, std::string("automaster.drem.io")}, - {FlightSqlConnection::PORT, std::string("32010")}, - {FlightSqlConnection::USER, std::string("dremio")}, - {FlightSqlConnection::PASSWORD, std::string("dremio123")}, - {FlightSqlConnection::USE_ENCRYPTION, std::string("false")}, + {std::string(FlightSqlConnection::HOST), std::string("automaster.apache")}, + {std::string(FlightSqlConnection::PORT), std::string("32010")}, + {std::string(FlightSqlConnection::USER), std::string("apache")}, + {std::string(FlightSqlConnection::PASSWORD), std::string("apache123")}, + {std::string(FlightSqlConnection::USE_ENCRYPTION), std::string("false")}, }; std::vector missing_attr; connection->Connect(properties, missing_attr); diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/parse_table_types_test.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/parse_table_types_test.cc index 9bfcabb0bbd..9e6e93ed21c 100644 --- a/cpp/src/arrow/flight/sql/odbc/flight_sql/parse_table_types_test.cc +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/parse_table_types_test.cc @@ -49,5 +49,18 @@ TEST(TableTypeParser, ParsingWithSingleQuotesWithoutLeadingWhiteSpace) { TEST(TableTypeParser, ParsingWithCommaInsideSingleQuotes) { AssertParseTest("'TABLE, TEST', 'VIEW, TEMPORARY'", {"TABLE, TEST", "VIEW, TEMPORARY"}); } + +TEST(TableTypeParser, ParsingWithManyLeadingAndTrailingWhiteSpaces) { + AssertParseTest(" TABLE , VIEW ", {"TABLE", "VIEW"}); +} + +TEST(TableTypeParser, ParsingWithOnlyWhiteSpaceBetweenCommas) { + AssertParseTest("TABLE, ,VIEW", {"TABLE", "VIEW"}); +} + +TEST(TableTypeParser, ParsingWithWhiteSpaceInsideValue) { + AssertParseTest("BASE TABLE", {"BASE TABLE"}); +} + } // namespace flight_sql } // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/record_batch_transformer.h b/cpp/src/arrow/flight/sql/odbc/flight_sql/record_batch_transformer.h index 261b8c1d7c0..15c482cc631 100644 --- a/cpp/src/arrow/flight/sql/odbc/flight_sql/record_batch_transformer.h +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/record_batch_transformer.h @@ -17,9 +17,9 @@ #pragma once -#include -#include #include +#include "arrow/flight/client.h" +#include "arrow/type.h" namespace driver { namespace flight_sql { diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/scalar_function_reporter.h b/cpp/src/arrow/flight/sql/odbc/flight_sql/scalar_function_reporter.h index 5c2ae06cdba..fd6abf6420e 100644 --- a/cpp/src/arrow/flight/sql/odbc/flight_sql/scalar_function_reporter.h +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/scalar_function_reporter.h @@ -17,7 +17,7 @@ #pragma once -#include +#include "arrow/type.h" namespace driver { namespace flight_sql { diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/system_dsn.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/system_dsn.cc index 504b62a81eb..f0006b36c9a 100644 --- a/cpp/src/arrow/flight/sql/odbc/flight_sql/system_dsn.cc +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/system_dsn.cc @@ -15,75 +15,30 @@ // specific language governing permissions and limitations // under the License. -// platform.h includes windows.h, so it needs to be included -// before winuser.h -#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/platform.h" +#include "arrow/flight/sql/odbc/flight_sql/system_dsn.h" -#include -#include #include "arrow/flight/sql/odbc/flight_sql/flight_sql_connection.h" #include "arrow/flight/sql/odbc/flight_sql/include/flight_sql/config/configuration.h" -#include "arrow/flight/sql/odbc/flight_sql/include/flight_sql/config/connection_string_parser.h" -#include "arrow/flight/sql/odbc/flight_sql/include/flight_sql/ui/dsn_configuration_window.h" -#include "arrow/flight/sql/odbc/flight_sql/include/flight_sql/ui/window.h" -#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/exceptions.h" +#include "arrow/result.h" +#include "arrow/util/utf8.h" #include -#include -#include #include using driver::flight_sql::FlightSqlConnection; using driver::flight_sql::config::Configuration; -using driver::flight_sql::config::ConnectionStringParser; -using driver::flight_sql::config::DsnConfigurationWindow; -using driver::flight_sql::config::Result; -using driver::flight_sql::config::Window; - -BOOL CALLBACK ConfigDriver(HWND hwndParent, WORD fRequest, LPCSTR lpszDriver, - LPCSTR lpszArgs, LPSTR lpszMsg, WORD cbMsgMax, - WORD* pcbMsgOut) { - return false; -} - -bool DisplayConnectionWindow(void* windowParent, Configuration& config) { - HWND hwndParent = (HWND)windowParent; - - if (!hwndParent) return true; - - try { - Window parent(hwndParent); - DsnConfigurationWindow window(&parent, config); - - window.Create(); - - window.Show(); - window.Update(); - - return ProcessMessages(window) == Result::OK; - } catch (driver::odbcabstraction::DriverException& err) { - std::stringstream buf; - buf << "Message: " << err.GetMessageText() << ", Code: " << err.GetNativeError(); - std::string message = buf.str(); - MessageBox(NULL, message.c_str(), "Error!", MB_ICONEXCLAMATION | MB_OK); - - SQLPostInstallerError(err.GetNativeError(), err.GetMessageText().c_str()); - } - - return false; -} void PostLastInstallerError() { #define BUFFER_SIZE (1024) DWORD code; - char msg[BUFFER_SIZE]; + wchar_t msg[BUFFER_SIZE]; SQLInstallerError(1, &code, msg, BUFFER_SIZE, NULL); - std::stringstream buf; - buf << "Message: \"" << msg << "\", Code: " << code; - std::string errorMsg = buf.str(); + std::wstringstream buf; + buf << L"Message: \"" << msg << L"\", Code: " << code; + std::wstring errorMsg = buf.str(); - MessageBox(NULL, errorMsg.c_str(), "Error!", MB_ICONEXCLAMATION | MB_OK); + MessageBox(NULL, errorMsg.c_str(), L"Error!", MB_ICONEXCLAMATION | MB_OK); SQLPostInstallerError(code, errorMsg.c_str()); } @@ -93,7 +48,7 @@ void PostLastInstallerError() { * @param dsn DSN name. * @return True on success and false on fail. */ -bool UnregisterDsn(const std::string& dsn) { +bool UnregisterDsn(const std::wstring& dsn) { if (SQLRemoveDSNFromIni(dsn.c_str())) { return true; } @@ -109,10 +64,11 @@ bool UnregisterDsn(const std::string& dsn) { * @param driver Driver. * @return True on success and false on fail. */ -bool RegisterDsn(const Configuration& config, LPCSTR driver) { +bool RegisterDsn(const Configuration& config, LPCWSTR driver) { const std::string& dsn = config.Get(FlightSqlConnection::DSN); + std::wstring wDsn = arrow::util::UTF8ToWideString(dsn).ValueOr(L""); - if (!SQLWriteDSNToIni(dsn.c_str(), driver)) { + if (!SQLWriteDSNToIni(wDsn.c_str(), driver)) { PostLastInstallerError(); return false; } @@ -125,9 +81,10 @@ bool RegisterDsn(const Configuration& config, LPCSTR driver) { continue; } - std::string key_str = std::string(key); - if (!SQLWritePrivateProfileString(dsn.c_str(), key_str.c_str(), it->second.c_str(), - "ODBC.INI")) { + std::wstring wKey = arrow::util::UTF8ToWideString(key).ValueOr(L""); + std::wstring wValue = arrow::util::UTF8ToWideString(it->second).ValueOr(L""); + if (!SQLWritePrivateProfileString(wDsn.c_str(), wKey.c_str(), wValue.c_str(), + L"ODBC.INI")) { PostLastInstallerError(); return false; } @@ -135,45 +92,3 @@ bool RegisterDsn(const Configuration& config, LPCSTR driver) { return true; } - -BOOL INSTAPI ConfigDSN(HWND hwndParent, WORD req, LPCSTR driver, LPCSTR attributes) { - Configuration config; - ConnectionStringParser parser(config); - parser.ParseConfigAttributes(attributes); - - switch (req) { - case ODBC_ADD_DSN: { - config.LoadDefaults(); - if (!DisplayConnectionWindow(hwndParent, config) || !RegisterDsn(config, driver)) - return FALSE; - - break; - } - - case ODBC_CONFIG_DSN: { - const std::string& dsn = config.Get(FlightSqlConnection::DSN); - if (!SQLValidDSN(dsn.c_str())) return FALSE; - - Configuration loaded(config); - loaded.LoadDsn(dsn); - - if (!DisplayConnectionWindow(hwndParent, loaded) || !UnregisterDsn(dsn.c_str()) || - !RegisterDsn(loaded, driver)) - return FALSE; - - break; - } - - case ODBC_REMOVE_DSN: { - const std::string& dsn = config.Get(FlightSqlConnection::DSN); - if (!SQLValidDSN(dsn.c_str()) || !UnregisterDsn(dsn)) return FALSE; - - break; - } - - default: - return FALSE; - } - - return TRUE; -} diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/system_dsn.h b/cpp/src/arrow/flight/sql/odbc/flight_sql/system_dsn.h new file mode 100644 index 00000000000..1ac9b4d9b80 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/system_dsn.h @@ -0,0 +1,65 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// platform.h includes windows.h, so it needs to be included first +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/platform.h" + +#include "arrow/flight/sql/odbc/flight_sql/include/flight_sql/config/configuration.h" + +using driver::flight_sql::config::Configuration; +using driver::odbcabstraction::Connection; + +#if defined _WIN32 || defined _WIN64 +/** + * Display connection window for user to configure connection parameters. + * + * @param windowParent Parent window handle. + * @param config Output configuration. + * @return True on success and false on fail. + */ +bool DisplayConnectionWindow(void* windowParent, Configuration& config); + +/** + * For SQLDriverConnect. + * Display connection window for user to configure connection parameters. + * + * @param windowParent Parent window handle. + * @param config Output configuration, presumed to be empty, it will be using values from + * properties. + * @param properties Output properties. + * @return True on success and false on fail. + */ +bool DisplayConnectionWindow(void* windowParent, Configuration& config, + Connection::ConnPropertyMap& properties); +#endif + +/** + * Register DSN with specified configuration. + * + * @param config Configuration. + * @param driver Driver. + * @return True on success and false on fail. + */ +bool RegisterDsn(const Configuration& config, LPCWSTR driver); + +/** + * Unregister specified DSN. + * + * @param dsn DSN name. + * @return True on success and false on fail. + */ +bool UnregisterDsn(const std::wstring& dsn); diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/system_trust_store.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/system_trust_store.cc index 67db1fc35be..ebc8fd90adf 100644 --- a/cpp/src/arrow/flight/sql/odbc/flight_sql/system_trust_store.cc +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/system_trust_store.cc @@ -15,6 +15,9 @@ // specific language governing permissions and limitations // under the License. +#include "arrow/result.h" +#include "arrow/util/utf8.h" + #include "arrow/flight/sql/odbc/flight_sql/system_trust_store.h" #if defined _WIN32 || defined _WIN64 @@ -32,18 +35,20 @@ std::string SystemTrustStore::GetNext() const { CryptBinaryToString(p_context_->pbCertEncoded, p_context_->cbCertEncoded, CRYPT_STRING_BASE64HEADER, nullptr, &size); - std::string cert; - cert.resize(size); + std::wstring wCert; + wCert.resize(size); CryptBinaryToString(p_context_->pbCertEncoded, p_context_->cbCertEncoded, - CRYPT_STRING_BASE64HEADER, &cert[0], &size); - cert.resize(size); + CRYPT_STRING_BASE64HEADER, &wCert[0], &size); + wCert.resize(size); + + std::string cert = arrow::util::WideStringToUTF8(wCert).ValueOr(""); return cert; } bool SystemTrustStore::SystemHasStore() { return h_store_ != nullptr; } -SystemTrustStore::SystemTrustStore(const char* store) +SystemTrustStore::SystemTrustStore(const wchar_t* store) : stores_(store), h_store_(CertOpenSystemStore(NULL, store)), p_context_(nullptr) {} SystemTrustStore::~SystemTrustStore() { diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/system_trust_store.h b/cpp/src/arrow/flight/sql/odbc/flight_sql/system_trust_store.h index 71175b09709..0ff3adc2f48 100644 --- a/cpp/src/arrow/flight/sql/odbc/flight_sql/system_trust_store.h +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/system_trust_store.h @@ -24,6 +24,9 @@ # include # include + +# include + # include # include @@ -38,12 +41,12 @@ namespace flight_sql { /// https://github.com/apache/drill/blob/master/contrib/native/client/src/clientlib/wincert.ipp. class SystemTrustStore { private: - const char* stores_; + const wchar_t* stores_; HCERTSTORE h_store_; PCCERT_CONTEXT p_context_; public: - explicit SystemTrustStore(const char* store); + explicit SystemTrustStore(const wchar_t* store); ~SystemTrustStore(); diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/ui/add_property_window.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/ui/add_property_window.cc index 64cc1797f7e..15799c1f9a2 100644 --- a/cpp/src/arrow/flight/sql/odbc/flight_sql/ui/add_property_window.cc +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/ui/add_property_window.cc @@ -24,7 +24,7 @@ #include -#include +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/exceptions.h" #include "ui/custom_window.h" #include "ui/window.h" @@ -33,7 +33,7 @@ namespace flight_sql { namespace config { AddPropertyWindow::AddPropertyWindow(Window* parent) - : CustomWindow(parent, "AddProperty", "Add Property"), + : CustomWindow(parent, L"AddProperty", L"Add Property"), width(300), height(120), accepted(false), @@ -69,7 +69,7 @@ void AddPropertyWindow::Create() { } } -bool AddPropertyWindow::GetProperty(std::string& key, std::string& value) { +bool AddPropertyWindow::GetProperty(std::wstring& key, std::wstring& value) { if (accepted) { key = this->key; value = this->value; @@ -87,10 +87,10 @@ void AddPropertyWindow::OnCreate() { int cancelPosX = width - MARGIN - BUTTON_WIDTH; int okPosX = cancelPosX - INTERVAL - BUTTON_WIDTH; - okButton = CreateButton(okPosX, groupPosY, BUTTON_WIDTH, BUTTON_HEIGHT, "Ok", + okButton = CreateButton(okPosX, groupPosY, BUTTON_WIDTH, BUTTON_HEIGHT, L"Ok", ChildId::OK_BUTTON, BS_DEFPUSHBUTTON); cancelButton = CreateButton(cancelPosX, groupPosY, BUTTON_WIDTH, BUTTON_HEIGHT, - "Cancel", ChildId::CANCEL_BUTTON); + L"Cancel", ChildId::CANCEL_BUTTON); isInitialized = true; CheckEnableOk(); } @@ -104,15 +104,15 @@ int AddPropertyWindow::CreateEdits(int posX, int posY, int sizeX) { int rowPos = posY; labels.push_back( - CreateLabel(posX, rowPos, LABEL_WIDTH, ROW_HEIGHT, "Key:", ChildId::KEY_LABEL)); - keyEdit = CreateEdit(editPosX, rowPos, editSizeX, ROW_HEIGHT, "", ChildId::KEY_EDIT); + CreateLabel(posX, rowPos, LABEL_WIDTH, ROW_HEIGHT, L"Key:", ChildId::KEY_LABEL)); + keyEdit = CreateEdit(editPosX, rowPos, editSizeX, ROW_HEIGHT, L"", ChildId::KEY_EDIT); rowPos += INTERVAL + ROW_HEIGHT; - labels.push_back( - CreateLabel(posX, rowPos, LABEL_WIDTH, ROW_HEIGHT, "Value:", ChildId::VALUE_LABEL)); + labels.push_back(CreateLabel(posX, rowPos, LABEL_WIDTH, ROW_HEIGHT, L"Value:", + ChildId::VALUE_LABEL)); valueEdit = - CreateEdit(editPosX, rowPos, editSizeX, ROW_HEIGHT, "", ChildId::VALUE_EDIT); + CreateEdit(editPosX, rowPos, editSizeX, ROW_HEIGHT, L"", ChildId::VALUE_EDIT); rowPos += INTERVAL + ROW_HEIGHT; diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/ui/custom_window.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/ui/custom_window.cc index 5443ea0ec8d..bde7967c7e9 100644 --- a/cpp/src/arrow/flight/sql/odbc/flight_sql/ui/custom_window.cc +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/ui/custom_window.cc @@ -17,15 +17,16 @@ // platform.h includes windows.h, so it needs to be included // before Windowsx.h and commctrl.h -#include +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/platform.h" #include #include #include +#include #include -#include +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/exceptions.h" #include "ui/custom_window.h" namespace driver { @@ -53,7 +54,7 @@ LRESULT CALLBACK CustomWindow::WndProc(HWND hwnd, UINT msg, WPARAM wParam, switch (msg) { case WM_NCCREATE: { - _ASSERT(lParam != NULL); + assert(lParam != NULL); CREATESTRUCT* createStruct = reinterpret_cast(lParam); @@ -65,7 +66,7 @@ LRESULT CALLBACK CustomWindow::WndProc(HWND hwnd, UINT msg, WPARAM wParam, } case WM_CREATE: { - _ASSERT(window != NULL); + assert(window != NULL); window->SetHandle(hwnd); @@ -83,7 +84,7 @@ LRESULT CALLBACK CustomWindow::WndProc(HWND hwnd, UINT msg, WPARAM wParam, return DefWindowProc(hwnd, msg, wParam, lParam); } -CustomWindow::CustomWindow(Window* parent, const char* className, const char* title) +CustomWindow::CustomWindow(Window* parent, const wchar_t* className, const wchar_t* title) : Window(parent, className, title) { WNDCLASS wcx; diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/ui/dsn_configuration_window.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/ui/dsn_configuration_window.cc index 42741c5a3e5..a3a6c30ff51 100644 --- a/cpp/src/arrow/flight/sql/odbc/flight_sql/ui/dsn_configuration_window.cc +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/ui/dsn_configuration_window.cc @@ -15,16 +15,19 @@ // specific language governing permissions and limitations // under the License. -#include "arrow/flight/sql/odbc/flight_sql/include/flight_sql/ui/dsn_configuration_window.h" +#include "arrow/result.h" +#include "arrow/util/utf8.h" + #include "arrow/flight/sql/odbc/flight_sql/flight_sql_connection.h" +#include "arrow/flight/sql/odbc/flight_sql/include/flight_sql/ui/dsn_configuration_window.h" #include #include -#include #include #include #include #include +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/utils.h" #include "arrow/flight/sql/odbc/flight_sql/include/flight_sql/ui/add_property_window.h" @@ -55,7 +58,7 @@ namespace config { DsnConfigurationWindow::DsnConfigurationWindow(Window* parent, config::Configuration& config) - : CustomWindow(parent, "FlightConfigureDSN", "Configure Apache Arrow Flight SQL"), + : CustomWindow(parent, L"FlightConfigureDSN", L"Configure Apache Arrow Flight SQL"), width(480), height(375), config(config), @@ -94,8 +97,8 @@ void DsnConfigurationWindow::Create() { void DsnConfigurationWindow::OnCreate() { tabControl = CreateTabControl(ChildId::TAB_CONTROL); - tabControl->AddTab("Common", COMMON_TAB); - tabControl->AddTab("Advanced", ADVANCED_TAB); + tabControl->AddTab(L"Common", COMMON_TAB); + tabControl->AddTab(L"Advanced", ADVANCED_TAB); int groupPosY = 3 * MARGIN; int groupSizeY = width - 2 * MARGIN; @@ -118,11 +121,11 @@ void DsnConfigurationWindow::OnCreate() { int buttonPosY = std::max(commonGroupPosY, advancedGroupPosY); testButton = CreateButton(testPosX, buttonPosY, BUTTON_WIDTH + 20, BUTTON_HEIGHT, - "Test Connection", ChildId::TEST_CONNECTION_BUTTON); - okButton = CreateButton(okPosX, buttonPosY, BUTTON_WIDTH, BUTTON_HEIGHT, "Ok", + L"Test Connection", ChildId::TEST_CONNECTION_BUTTON); + okButton = CreateButton(okPosX, buttonPosY, BUTTON_WIDTH, BUTTON_HEIGHT, L"Ok", ChildId::OK_BUTTON); cancelButton = CreateButton(cancelPosX, buttonPosY, BUTTON_WIDTH, BUTTON_HEIGHT, - "Cancel", ChildId::CANCEL_BUTTON); + L"Cancel", ChildId::CANCEL_BUTTON); isInitialized = true; CheckEnableOk(); SelectTab(COMMON_TAB); @@ -138,31 +141,35 @@ int DsnConfigurationWindow::CreateConnectionSettingsGroup(int posX, int posY, in int rowPos = posY + 2 * INTERVAL; - const char* val = config.Get(FlightSqlConnection::DSN).c_str(); + std::string val = config.Get(FlightSqlConnection::DSN); + std::wstring wVal = arrow::util::UTF8ToWideString(val).ValueOr(L""); labels.push_back(CreateLabel(labelPosX, rowPos, LABEL_WIDTH, ROW_HEIGHT, - "Data Source Name:", ChildId::NAME_LABEL)); - nameEdit = CreateEdit(editPosX, rowPos, editSizeX, ROW_HEIGHT, val, ChildId::NAME_EDIT); + L"Data Source Name:", ChildId::NAME_LABEL)); + nameEdit = CreateEdit(editPosX, rowPos, editSizeX, ROW_HEIGHT, wVal.c_str(), + ChildId::NAME_EDIT); rowPos += INTERVAL + ROW_HEIGHT; - val = config.Get(FlightSqlConnection::HOST).c_str(); - labels.push_back(CreateLabel(labelPosX, rowPos, LABEL_WIDTH, ROW_HEIGHT, - "Host Name:", ChildId::SERVER_LABEL)); - serverEdit = - CreateEdit(editPosX, rowPos, editSizeX, ROW_HEIGHT, val, ChildId::SERVER_EDIT); + val = config.Get(FlightSqlConnection::HOST); + wVal = arrow::util::UTF8ToWideString(val).ValueOr(L""); + labels.push_back(CreateLabel(labelPosX, rowPos, LABEL_WIDTH, ROW_HEIGHT, L"Host Name:", + ChildId::SERVER_LABEL)); + serverEdit = CreateEdit(editPosX, rowPos, editSizeX, ROW_HEIGHT, wVal.c_str(), + ChildId::SERVER_EDIT); rowPos += INTERVAL + ROW_HEIGHT; - val = config.Get(FlightSqlConnection::PORT).c_str(); - labels.push_back(CreateLabel(labelPosX, rowPos, LABEL_WIDTH, ROW_HEIGHT, - "Port:", ChildId::PORT_LABEL)); - portEdit = CreateEdit(editPosX, rowPos, editSizeX, ROW_HEIGHT, val, ChildId::PORT_EDIT, - ES_NUMBER); + val = config.Get(FlightSqlConnection::PORT); + wVal = arrow::util::UTF8ToWideString(val).ValueOr(L""); + labels.push_back(CreateLabel(labelPosX, rowPos, LABEL_WIDTH, ROW_HEIGHT, L"Port:", + ChildId::PORT_LABEL)); + portEdit = CreateEdit(editPosX, rowPos, editSizeX, ROW_HEIGHT, wVal.c_str(), + ChildId::PORT_EDIT, ES_NUMBER); rowPos += INTERVAL + ROW_HEIGHT; connectionSettingsGroupBox = - CreateGroupBox(posX, posY, sizeX, rowPos - posY, "Connection settings", + CreateGroupBox(posX, posY, sizeX, rowPos - posY, L"Connection settings", ChildId::CONNECTION_SETTINGS_GROUP_BOX); return rowPos - posY; @@ -179,36 +186,39 @@ int DsnConfigurationWindow::CreateAuthSettingsGroup(int posX, int posY, int size int rowPos = posY + 2 * INTERVAL; labels.push_back(CreateLabel(labelPosX, rowPos, LABEL_WIDTH, ROW_HEIGHT, - "Authentication Type:", ChildId::AUTH_TYPE_LABEL)); + L"Authentication Type:", ChildId::AUTH_TYPE_LABEL)); authTypeComboBox = CreateComboBox(editPosX, rowPos, editSizeX, ROW_HEIGHT, - "Authentication Type:", ChildId::AUTH_TYPE_COMBOBOX); - authTypeComboBox->AddString("Basic Authentication"); - authTypeComboBox->AddString("Token Authentication"); + L"Authentication Type:", ChildId::AUTH_TYPE_COMBOBOX); + authTypeComboBox->AddString(L"Basic Authentication"); + authTypeComboBox->AddString(L"Token Authentication"); rowPos += INTERVAL + ROW_HEIGHT; - const char* val = config.Get(FlightSqlConnection::UID).c_str(); + std::string val = config.Get(FlightSqlConnection::UID); + std::wstring wVal = arrow::util::UTF8ToWideString(val).ValueOr(L""); - labels.push_back(CreateLabel(labelPosX, rowPos, LABEL_WIDTH, ROW_HEIGHT, - "User:", ChildId::USER_LABEL)); - userEdit = CreateEdit(editPosX, rowPos, editSizeX, ROW_HEIGHT, val, ChildId::USER_EDIT); + labels.push_back(CreateLabel(labelPosX, rowPos, LABEL_WIDTH, ROW_HEIGHT, L"User:", + ChildId::USER_LABEL)); + userEdit = CreateEdit(editPosX, rowPos, editSizeX, ROW_HEIGHT, wVal.c_str(), + ChildId::USER_EDIT); rowPos += INTERVAL + ROW_HEIGHT; - val = config.Get(FlightSqlConnection::PWD).c_str(); - labels.push_back(CreateLabel(labelPosX, rowPos, LABEL_WIDTH, ROW_HEIGHT, - "Password:", ChildId::PASSWORD_LABEL)); - passwordEdit = CreateEdit(editPosX, rowPos, editSizeX, ROW_HEIGHT, val, + val = config.Get(FlightSqlConnection::PWD); + wVal = arrow::util::UTF8ToWideString(val).ValueOr(L""); + labels.push_back(CreateLabel(labelPosX, rowPos, LABEL_WIDTH, ROW_HEIGHT, L"Password:", + ChildId::PASSWORD_LABEL)); + passwordEdit = CreateEdit(editPosX, rowPos, editSizeX, ROW_HEIGHT, wVal.c_str(), ChildId::USER_EDIT, ES_PASSWORD); rowPos += INTERVAL + ROW_HEIGHT; const auto& token = config.Get(FlightSqlConnection::TOKEN); - val = token.c_str(); + wVal = arrow::util::UTF8ToWideString(token).ValueOr(L""); labels.push_back(CreateLabel(labelPosX, rowPos, LABEL_WIDTH, ROW_HEIGHT, - "Authentication Token:", ChildId::AUTH_TOKEN_LABEL)); - authTokenEdit = - CreateEdit(editPosX, rowPos, editSizeX, ROW_HEIGHT, val, ChildId::AUTH_TOKEN_EDIT); + L"Authentication Token:", ChildId::AUTH_TOKEN_LABEL)); + authTokenEdit = CreateEdit(editPosX, rowPos, editSizeX, ROW_HEIGHT, wVal.c_str(), + ChildId::AUTH_TOKEN_EDIT); authTokenEdit->SetEnabled(false); // Ensure the right elements are selected. @@ -218,7 +228,7 @@ int DsnConfigurationWindow::CreateAuthSettingsGroup(int posX, int posY, int size rowPos += INTERVAL + ROW_HEIGHT; authSettingsGroupBox = - CreateGroupBox(posX, posY, sizeX, rowPos - posY, "Authentication settings", + CreateGroupBox(posX, posY, sizeX, rowPos - posY, L"Authentication settings", ChildId::AUTH_SETTINGS_GROUP_BOX); return rowPos - posY; @@ -234,37 +244,40 @@ int DsnConfigurationWindow::CreateEncryptionSettingsGroup(int posX, int posY, in int rowPos = posY + 2 * INTERVAL; - const char* val = config.Get(FlightSqlConnection::USE_ENCRYPTION).c_str(); + std::string val = config.Get(FlightSqlConnection::USE_ENCRYPTION); + // Enable encryption default value is true const bool enableEncryption = driver::odbcabstraction::AsBool(val).value_or(true); labels.push_back(CreateLabel(labelPosX, rowPos, LABEL_WIDTH, ROW_HEIGHT, - "Use Encryption:", ChildId::ENABLE_ENCRYPTION_LABEL)); + L"Use Encryption:", ChildId::ENABLE_ENCRYPTION_LABEL)); enableEncryptionCheckBox = - CreateCheckBox(editPosX, rowPos - 2, editSizeX, ROW_HEIGHT, "", + CreateCheckBox(editPosX, rowPos - 2, editSizeX, ROW_HEIGHT, L"", ChildId::ENABLE_ENCRYPTION_CHECKBOX, enableEncryption); rowPos += INTERVAL + ROW_HEIGHT; - val = config.Get(FlightSqlConnection::TRUSTED_CERTS).c_str(); + val = config.Get(FlightSqlConnection::TRUSTED_CERTS); + std::wstring wVal = arrow::util::UTF8ToWideString(val).ValueOr(L""); labels.push_back(CreateLabel(labelPosX, rowPos, LABEL_WIDTH, ROW_HEIGHT, - "Certificate:", ChildId::CERTIFICATE_LABEL)); + L"Certificate:", ChildId::CERTIFICATE_LABEL)); certificateEdit = CreateEdit(editPosX, rowPos, editSizeX - MARGIN - BUTTON_WIDTH, - ROW_HEIGHT, val, ChildId::CERTIFICATE_EDIT); + ROW_HEIGHT, wVal.c_str(), ChildId::CERTIFICATE_EDIT); certificateBrowseButton = CreateButton(editPosX + editSizeX - BUTTON_WIDTH, rowPos - 2, BUTTON_WIDTH, - BUTTON_HEIGHT, "Browse", ChildId::CERTIFICATE_BROWSE_BUTTON); + BUTTON_HEIGHT, L"Browse", ChildId::CERTIFICATE_BROWSE_BUTTON); rowPos += INTERVAL + ROW_HEIGHT; val = config.Get(FlightSqlConnection::USE_SYSTEM_TRUST_STORE).c_str(); + // System trust store default value is true const bool useSystemCertStore = driver::odbcabstraction::AsBool(val).value_or(true); - labels.push_back( - CreateLabel(labelPosX, rowPos, LABEL_WIDTH, 2 * ROW_HEIGHT, - "Use System Certificate Store:", ChildId::USE_SYSTEM_CERT_STORE_LABEL)); + labels.push_back(CreateLabel(labelPosX, rowPos, LABEL_WIDTH, 2 * ROW_HEIGHT, + L"Use System Certificate Store:", + ChildId::USE_SYSTEM_CERT_STORE_LABEL)); useSystemCertStoreCheckBox = - CreateCheckBox(editPosX, rowPos - 2, 20, 2 * ROW_HEIGHT, "", + CreateCheckBox(editPosX, rowPos - 2, 20, 2 * ROW_HEIGHT, L"", ChildId::USE_SYSTEM_CERT_STORE_CHECKBOX, useSystemCertStore); val = config.Get(FlightSqlConnection::DISABLE_CERTIFICATE_VERIFICATION).c_str(); @@ -273,19 +286,24 @@ int DsnConfigurationWindow::CreateEncryptionSettingsGroup(int posX, int posY, in const int rightCheckPosX = rightPosX + (editPosX - labelPosX); const bool disableCertVerification = driver::odbcabstraction::AsBool(val).value_or(false); - labels.push_back(CreateLabel( - rightPosX, rowPos, LABEL_WIDTH, 2 * ROW_HEIGHT, - "Disable Certificate Verification:", ChildId::DISABLE_CERT_VERIFICATION_LABEL)); + labels.push_back(CreateLabel(rightPosX, rowPos, LABEL_WIDTH, 2 * ROW_HEIGHT, + L"Disable Certificate Verification:", + ChildId::DISABLE_CERT_VERIFICATION_LABEL)); disableCertVerificationCheckBox = CreateCheckBox( - rightCheckPosX, rowPos - 2, 20, 2 * ROW_HEIGHT, "", + rightCheckPosX, rowPos - 2, 20, 2 * ROW_HEIGHT, L"", ChildId::DISABLE_CERT_VERIFICATION_CHECKBOX, disableCertVerification); - rowPos += INTERVAL + static_cast(1.5 * ROW_HEIGHT); + rowPos += INTERVAL + static_cast(1.5 * static_cast(ROW_HEIGHT)); encryptionSettingsGroupBox = - CreateGroupBox(posX, posY, sizeX, rowPos - posY, "Encryption settings", + CreateGroupBox(posX, posY, sizeX, rowPos - posY, L"Encryption settings", ChildId::AUTH_SETTINGS_GROUP_BOX); + certificateEdit->SetEnabled(enableEncryption); + certificateBrowseButton->SetEnabled(enableEncryption); + useSystemCertStoreCheckBox->SetEnabled(enableEncryption); + disableCertVerificationCheckBox->SetEnabled(enableEncryption); + return rowPos - posY; } @@ -301,12 +319,15 @@ int DsnConfigurationWindow::CreatePropertiesGroup(int posX, int posY, int sizeX) propertyList = CreateList(labelPosX, rowPos, listSize, listHeight, ChildId::PROPERTY_LIST); - propertyList->ListAddColumn("Key", 0, columnSize); - propertyList->ListAddColumn("Value", 1, columnSize); + propertyList->ListAddColumn(L"Key", 0, columnSize); + propertyList->ListAddColumn(L"Value", 1, columnSize); const auto keys = config.GetCustomKeys(); for (const auto& key : keys) { - propertyList->ListAddItem({std::string(key), config.Get(key)}); + std::wstring wKey = arrow::util::UTF8ToWideString(key).ValueOr(L""); + std::wstring wVal = arrow::util::UTF8ToWideString(config.Get(key)).ValueOr(L""); + + propertyList->ListAddItem({wKey, wVal}); } SendMessage(propertyList->GetHandle(), LVM_SETEXTENDEDLISTVIEWSTYLE, @@ -316,15 +337,15 @@ int DsnConfigurationWindow::CreatePropertiesGroup(int posX, int posY, int sizeX) int deletePosX = width - INTERVAL - MARGIN - BUTTON_WIDTH; int addPosX = deletePosX - INTERVAL - BUTTON_WIDTH; - addButton = CreateButton(addPosX, rowPos, BUTTON_WIDTH, BUTTON_HEIGHT, "Add", + addButton = CreateButton(addPosX, rowPos, BUTTON_WIDTH, BUTTON_HEIGHT, L"Add", ChildId::ADD_BUTTON); - deleteButton = CreateButton(deletePosX, rowPos, BUTTON_WIDTH, BUTTON_HEIGHT, "Delete", + deleteButton = CreateButton(deletePosX, rowPos, BUTTON_WIDTH, BUTTON_HEIGHT, L"Delete", ChildId::DELETE_BUTTON); rowPos += INTERVAL + BUTTON_HEIGHT; propertyGroupBox = CreateGroupBox(posX, posY, sizeX, rowPos - posY, - "Advanced properties", ChildId::PROPERTY_GROUP_BOX); + L"Advanced properties", ChildId::PROPERTY_GROUP_BOX); return rowPos - posY; } @@ -384,7 +405,7 @@ void DsnConfigurationWindow::CheckEnableOk() { void DsnConfigurationWindow::SaveParameters(Configuration& targetConfig) { targetConfig.Clear(); - std::string text; + std::wstring text; nameEdit->GetText(text); targetConfig.Set(FlightSqlConnection::DSN, text); serverEdit->GetText(text); @@ -421,13 +442,17 @@ void DsnConfigurationWindow::SaveParameters(Configuration& targetConfig) { targetConfig.Set(FlightSqlConnection::DISABLE_CERTIFICATE_VERIFICATION, disableCertVerificationCheckBox->IsChecked() ? TRUE_STR : FALSE_STR); } else { + // System trust store verification requires encryption targetConfig.Set(FlightSqlConnection::USE_ENCRYPTION, FALSE_STR); + targetConfig.Set(FlightSqlConnection::USE_SYSTEM_TRUST_STORE, FALSE_STR); } // Get all the list properties. const auto properties = propertyList->ListGetAll(); for (const auto& property : properties) { - targetConfig.Set(property[0], property[1]); + std::string propertyKey = arrow::util::WideStringToUTF8(property[0]).ValueOr(""); + std::string propertyValue = arrow::util::WideStringToUTF8(property[1]).ValueOr(""); + targetConfig.Set(propertyKey, propertyValue); } } @@ -463,10 +488,13 @@ bool DsnConfigurationWindow::OnMessage(UINT msg, WPARAM wParam, LPARAM lParam) { SaveParameters(testConfig); std::string testMessage = TestConnection(testConfig); - MessageBox(NULL, testMessage.c_str(), "Test Connection Success", MB_OK); + std::wstring wTestMessage = + arrow::util::UTF8ToWideString(testMessage).ValueOr(L""); + MessageBox(NULL, wTestMessage.c_str(), L"Test Connection Success", MB_OK); } catch (odbcabstraction::DriverException& err) { - MessageBox(NULL, err.GetMessageText().c_str(), "Error!", - MB_ICONEXCLAMATION | MB_OK); + std::wstring wMessageText = + arrow::util::UTF8ToWideString(err.GetMessageText()).ValueOr(L""); + MessageBox(NULL, wMessageText.c_str(), L"Error!", MB_ICONEXCLAMATION | MB_OK); } break; @@ -477,8 +505,9 @@ bool DsnConfigurationWindow::OnMessage(UINT msg, WPARAM wParam, LPARAM lParam) { accepted = true; PostMessage(GetHandle(), WM_CLOSE, 0, 0); } catch (odbcabstraction::DriverException& err) { - MessageBox(NULL, err.GetMessageText().c_str(), "Error!", - MB_ICONEXCLAMATION | MB_OK); + std::wstring wMessageText = + arrow::util::UTF8ToWideString(err.GetMessageText()).ValueOr(L""); + MessageBox(NULL, wMessageText.c_str(), L"Error!", MB_ICONEXCLAMATION | MB_OK); } break; @@ -520,7 +549,7 @@ bool DsnConfigurationWindow::OnMessage(UINT msg, WPARAM wParam, LPARAM lParam) { case ChildId::CERTIFICATE_BROWSE_BUTTON: { OPENFILENAME openFileName; - char fileName[FILENAME_MAX]; + wchar_t fileName[FILENAME_MAX]; ZeroMemory(&openFileName, sizeof(openFileName)); openFileName.lStructSize = sizeof(openFileName); @@ -529,7 +558,7 @@ bool DsnConfigurationWindow::OnMessage(UINT msg, WPARAM wParam, LPARAM lParam) { openFileName.lpstrFile[0] = '\0'; openFileName.nMaxFile = FILENAME_MAX; // TODO: What type should this be? - openFileName.lpstrFilter = "All\0*.*"; + openFileName.lpstrFilter = L"All\0*.*"; openFileName.nFilterIndex = 1; openFileName.lpstrFileTitle = NULL; openFileName.nMaxFileTitle = 0; @@ -566,8 +595,8 @@ bool DsnConfigurationWindow::OnMessage(UINT msg, WPARAM wParam, LPARAM lParam) { addWindow.Update(); if (ProcessMessages(addWindow) == Result::OK) { - std::string key; - std::string value; + std::wstring key; + std::wstring value; addWindow.GetProperty(key, value); propertyList->ListAddItem({key, value}); } diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/ui/window.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/ui/window.cc index f88cd8a3f88..2940c95578a 100644 --- a/cpp/src/arrow/flight/sql/odbc/flight_sql/ui/window.cc +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/ui/window.cc @@ -49,7 +49,7 @@ HINSTANCE GetHInstance() { return hInstance; } -Window::Window(Window* parent, const char* className, const char* title) +Window::Window(Window* parent, const wchar_t* className, const wchar_t* title) : className(className), title(title), handle(NULL), parent(parent), created(false) { // No-op. } @@ -88,7 +88,7 @@ void Window::Create(DWORD style, int posX, int posY, int width, int height, int } std::unique_ptr Window::CreateTabControl(int id) { - std::unique_ptr child(new Window(this, WC_TABCONTROL, "")); + std::unique_ptr child(new Window(this, WC_TABCONTROL, L"")); // Get the dimensions of the parent window's client area, and // create a tab control child window of that size. @@ -103,7 +103,7 @@ std::unique_ptr Window::CreateTabControl(int id) { std::unique_ptr Window::CreateList(int posX, int posY, int sizeX, int sizeY, int id) { - std::unique_ptr child(new Window(this, WC_LISTVIEW, "")); + std::unique_ptr child(new Window(this, WC_LISTVIEW, L"")); child->Create( WS_CHILD | WS_VISIBLE | WS_BORDER | LVS_REPORT | LVS_EDITLABELS | WS_TABSTOP, posX, @@ -113,8 +113,8 @@ std::unique_ptr Window::CreateList(int posX, int posY, int sizeX, int si } std::unique_ptr Window::CreateGroupBox(int posX, int posY, int sizeX, int sizeY, - const char* title, int id) { - std::unique_ptr child(new Window(this, "Button", title)); + const wchar_t* title, int id) { + std::unique_ptr child(new Window(this, L"Button", title)); child->Create(WS_CHILD | WS_VISIBLE | BS_GROUPBOX, posX, posY, sizeX, sizeY, id); @@ -122,8 +122,8 @@ std::unique_ptr Window::CreateGroupBox(int posX, int posY, int sizeX, in } std::unique_ptr Window::CreateLabel(int posX, int posY, int sizeX, int sizeY, - const char* title, int id) { - std::unique_ptr child(new Window(this, "Static", title)); + const wchar_t* title, int id) { + std::unique_ptr child(new Window(this, L"Static", title)); child->Create(WS_CHILD | WS_VISIBLE, posX, posY, sizeX, sizeY, id); @@ -131,8 +131,8 @@ std::unique_ptr Window::CreateLabel(int posX, int posY, int sizeX, int s } std::unique_ptr Window::CreateEdit(int posX, int posY, int sizeX, int sizeY, - const char* title, int id, int style) { - std::unique_ptr child(new Window(this, "Edit", title)); + const wchar_t* title, int id, int style) { + std::unique_ptr child(new Window(this, L"Edit", title)); child->Create(WS_CHILD | WS_VISIBLE | WS_BORDER | ES_AUTOHSCROLL | WS_TABSTOP | style, posX, posY, sizeX, sizeY, id); @@ -141,8 +141,8 @@ std::unique_ptr Window::CreateEdit(int posX, int posY, int sizeX, int si } std::unique_ptr Window::CreateButton(int posX, int posY, int sizeX, int sizeY, - const char* title, int id, int style) { - std::unique_ptr child(new Window(this, "Button", title)); + const wchar_t* title, int id, int style) { + std::unique_ptr child(new Window(this, L"Button", title)); child->Create(WS_CHILD | WS_VISIBLE | WS_TABSTOP | style, posX, posY, sizeX, sizeY, id); @@ -150,8 +150,8 @@ std::unique_ptr Window::CreateButton(int posX, int posY, int sizeX, int } std::unique_ptr Window::CreateCheckBox(int posX, int posY, int sizeX, int sizeY, - const char* title, int id, bool state) { - std::unique_ptr child(new Window(this, "Button", title)); + const wchar_t* title, int id, bool state) { + std::unique_ptr child(new Window(this, L"Button", title)); child->Create(WS_CHILD | WS_VISIBLE | BS_CHECKBOX | WS_TABSTOP, posX, posY, sizeX, sizeY, id); @@ -162,8 +162,8 @@ std::unique_ptr Window::CreateCheckBox(int posX, int posY, int sizeX, in } std::unique_ptr Window::CreateComboBox(int posX, int posY, int sizeX, int sizeY, - const char* title, int id) { - std::unique_ptr child(new Window(this, "Combobox", title)); + const wchar_t* title, int id) { + std::unique_ptr child(new Window(this, L"Combobox", title)); child->Create(WS_CHILD | WS_VISIBLE | CBS_DROPDOWNLIST | WS_TABSTOP, posX, posY, sizeX, sizeY, id); @@ -194,12 +194,12 @@ bool Window::IsTextEmpty() const { return (len <= 0); } -void Window::ListAddColumn(const std::string& name, int index, int width) { +void Window::ListAddColumn(const std::wstring& name, int index, int width) { LVCOLUMN lvc; lvc.mask = LVCF_FMT | LVCF_WIDTH | LVCF_TEXT | LVCF_SUBITEM; lvc.fmt = LVCFMT_LEFT; lvc.cx = width; - lvc.pszText = const_cast(name.c_str()); + lvc.pszText = const_cast(name.c_str()); lvc.iSubItem = index; if (ListView_InsertColumn(handle, index, &lvc) == -1) { @@ -209,10 +209,10 @@ void Window::ListAddColumn(const std::string& name, int index, int width) { } } -void Window::ListAddItem(const std::vector& items) { +void Window::ListAddItem(const std::vector& items) { LVITEM lvi = {0}; lvi.mask = LVIF_TEXT; - lvi.pszText = const_cast(items[0].c_str()); + lvi.pszText = const_cast(items[0].c_str()); int ret = ListView_InsertItem(handle, &lvi); if (ret < 0) { @@ -223,7 +223,7 @@ void Window::ListAddItem(const std::vector& items) { for (size_t i = 1; i < items.size(); ++i) { ListView_SetItemText(handle, ret, static_cast(i), - const_cast(items[i].c_str())); + const_cast(items[i].c_str())); } } @@ -238,15 +238,15 @@ void Window::ListDeleteSelectedItem() { } } -std::vector > Window::ListGetAll() { +std::vector > Window::ListGetAll() { #define BUF_LEN 1024 - char buf[BUF_LEN]; + wchar_t buf[BUF_LEN]; - std::vector > values; + std::vector > values; const int numColumns = Header_GetItemCount(ListView_GetHeader(handle)); const int numItems = ListView_GetItemCount(handle); for (int i = 0; i < numItems; ++i) { - std::vector row; + std::vector row; for (int j = 0; j < numColumns; ++j) { ListView_GetItemText(handle, i, j, buf, BUF_LEN); row.emplace_back(buf); @@ -257,11 +257,11 @@ std::vector > Window::ListGetAll() { return values; } -void Window::AddTab(const std::string& name, int index) { +void Window::AddTab(const std::wstring& name, int index) { TCITEM tabControlItem; tabControlItem.mask = TCIF_TEXT | TCIF_IMAGE; tabControlItem.iImage = -1; - tabControlItem.pszText = const_cast(name.c_str()); + tabControlItem.pszText = const_cast(name.c_str()); if (TabCtrl_InsertItem(handle, index, &tabControlItem) == -1) { std::stringstream buf; buf << "Can not add tab, error code: " << GetLastError(); @@ -269,7 +269,7 @@ void Window::AddTab(const std::string& name, int index) { } } -void Window::GetText(std::string& text) const { +void Window::GetText(std::wstring& text) const { if (!IsEnabled()) { text.clear(); @@ -292,7 +292,7 @@ void Window::GetText(std::string& text) const { boost::algorithm::trim(text); } -void Window::SetText(const std::string& text) const { +void Window::SetText(const std::wstring& text) const { SNDMSG(handle, WM_SETTEXT, 0, reinterpret_cast(text.c_str())); } @@ -304,7 +304,7 @@ void Window::SetChecked(bool state) { Button_SetCheck(handle, state ? BST_CHECKED : BST_UNCHECKED); } -void Window::AddString(const std::string& str) { +void Window::AddString(const std::wstring& str) { SNDMSG(handle, CB_ADDSTRING, 0, reinterpret_cast(str.c_str())); } diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/utils.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/utils.cc index 33a11aaabed..945d3c9f0da 100644 --- a/cpp/src/arrow/flight/sql/odbc/flight_sql/utils.cc +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/utils.cc @@ -59,6 +59,10 @@ odbcabstraction::SqlDataType GetDefaultSqlVarcharType(bool useWideChar) { return useWideChar ? odbcabstraction::SqlDataType_WVARCHAR : odbcabstraction::SqlDataType_VARCHAR; } +odbcabstraction::SqlDataType GetDefaultSqlLongVarcharType(bool useWideChar) { + return useWideChar ? odbcabstraction::SqlDataType_WLONGVARCHAR + : odbcabstraction::SqlDataType_LONGVARCHAR; +} odbcabstraction::CDataType GetDefaultCCharType(bool useWideChar) { return useWideChar ? odbcabstraction::CDataType_WCHAR : odbcabstraction::CDataType_CHAR; } @@ -155,6 +159,9 @@ SqlDataType EnsureRightSqlCharType(SqlDataType data_type, bool useWideChar) { case odbcabstraction::SqlDataType_VARCHAR: case odbcabstraction::SqlDataType_WVARCHAR: return GetDefaultSqlVarcharType(useWideChar); + case odbcabstraction::SqlDataType_LONGVARCHAR: + case odbcabstraction::SqlDataType_WLONGVARCHAR: + return GetDefaultSqlLongVarcharType(useWideChar); default: return data_type; } diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/utils.h b/cpp/src/arrow/flight/sql/odbc/flight_sql/utils.h index 586cfb22a30..8b3e14599a7 100644 --- a/cpp/src/arrow/flight/sql/odbc/flight_sql/utils.h +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/utils.h @@ -17,13 +17,13 @@ #pragma once -#include -#include -#include #include #include #include #include +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/exceptions.h" +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/types.h" +#include "arrow/flight/types.h" namespace driver { namespace flight_sql { diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/utils_test.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/utils_test.cc index 1575bf09fab..f5d61da50bf 100644 --- a/cpp/src/arrow/flight/sql/odbc/flight_sql/utils_test.cc +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/utils_test.cc @@ -19,6 +19,7 @@ #include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/calendar_utils.h" +#include "arrow/compute/initialize.h" #include "arrow/testing/builder.h" #include "arrow/testing/gtest_util.h" #include "arrow/testing/util.h" @@ -27,6 +28,13 @@ namespace driver { namespace flight_sql { +class UtilTestsWithCompute : public ::testing::Test { + public: + // This must be done before using the compute kernels in order to + // register them to the FunctionRegistry. + void SetUp() override { ASSERT_OK(arrow::compute::Initialize()); } +}; + void AssertConvertedArray(const std::shared_ptr& expected_array, const std::shared_ptr& converted_array, uint64_t size, arrow::Type::type arrow_type) { @@ -80,7 +88,7 @@ void TestTime64ArrayConversion(const std::vector& input, AssertConvertedArray(expected_array, converted_array, input.size(), arrow_type); } -TEST(Utils, Time32ToTimeStampArray) { +TEST_F(UtilTestsWithCompute, Time32ToTimeStampArray) { std::vector input_data = {14896, 17820}; const auto seconds_from_epoch = odbcabstraction::GetTodayTimeFromEpoch(); @@ -100,7 +108,7 @@ TEST(Utils, Time32ToTimeStampArray) { arrow::Type::TIMESTAMP); } -TEST(Utils, Time64ToTimeStampArray) { +TEST_F(UtilTestsWithCompute, Time64ToTimeStampArray) { std::vector input_data = {1579489200000, 1646881200000}; const auto seconds_from_epoch = odbcabstraction::GetTodayTimeFromEpoch(); @@ -120,7 +128,7 @@ TEST(Utils, Time64ToTimeStampArray) { arrow::Type::TIMESTAMP); } -TEST(Utils, StringToDateArray) { +TEST_F(UtilTestsWithCompute, StringToDateArray) { std::shared_ptr expected; arrow::ArrayFromVector({1579489200000, 1646881200000}, &expected); @@ -129,7 +137,7 @@ TEST(Utils, StringToDateArray) { odbcabstraction::CDataType_DATE, arrow::Type::DATE64); } -TEST(Utils, StringToTimeArray) { +TEST_F(UtilTestsWithCompute, StringToTimeArray) { std::shared_ptr expected; arrow::ArrayFromVector( time64(arrow::TimeUnit::MICRO), {36000000000, 43200000000}, &expected); diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/win_system_dsn.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/win_system_dsn.cc new file mode 100644 index 00000000000..2017936dd90 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/win_system_dsn.cc @@ -0,0 +1,139 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// platform.h includes windows.h, so it needs to be included +// before winuser.h +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/platform.h" + +#include +#include + +#include "arrow/result.h" +#include "arrow/util/utf8.h" + +#include "arrow/flight/sql/odbc/flight_sql/flight_sql_connection.h" +#include "arrow/flight/sql/odbc/flight_sql/include/flight_sql/config/configuration.h" +#include "arrow/flight/sql/odbc/flight_sql/include/flight_sql/config/connection_string_parser.h" +#include "arrow/flight/sql/odbc/flight_sql/include/flight_sql/ui/dsn_configuration_window.h" +#include "arrow/flight/sql/odbc/flight_sql/include/flight_sql/ui/window.h" +#include "arrow/flight/sql/odbc/flight_sql/system_dsn.h" +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/exceptions.h" +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/logger.h" + +#include +#include +#include +#include + +using driver::flight_sql::FlightSqlConnection; +using driver::flight_sql::config::Configuration; +using driver::flight_sql::config::ConnectionStringParser; +using driver::flight_sql::config::DsnConfigurationWindow; +using driver::flight_sql::config::Result; +using driver::flight_sql::config::Window; +using driver::odbcabstraction::DriverException; + +bool DisplayConnectionWindow(void* windowParent, Configuration& config) { + HWND hwndParent = (HWND)windowParent; + + if (!hwndParent) return true; + + try { + Window parent(hwndParent); + DsnConfigurationWindow window(&parent, config); + + window.Create(); + + window.Show(); + window.Update(); + + return ProcessMessages(window) == Result::OK; + } catch (const DriverException& err) { + std::stringstream buf; + buf << "SQL State: " << err.GetSqlState() << ", Message: " << err.GetMessageText() + << ", Code: " << err.GetNativeError(); + std::wstring wMessage = arrow::util::UTF8ToWideString(buf.str()).ValueOr(L""); + MessageBox(NULL, wMessage.c_str(), L"Error!", MB_ICONEXCLAMATION | MB_OK); + + std::wstring wMessageText = + arrow::util::UTF8ToWideString(err.GetMessageText()).ValueOr(L""); + SQLPostInstallerError(err.GetNativeError(), wMessageText.c_str()); + } + + return false; +} + +bool DisplayConnectionWindow(void* windowParent, Configuration& config, + Connection::ConnPropertyMap& properties) { + for (const auto& [key, value] : properties) { + config.Set(key, value); + } + + if (DisplayConnectionWindow(windowParent, config)) { + properties = config.GetProperties(); + return true; + } else { + LOG_INFO("Dialog is cancelled by user"); + return false; + } +} + +BOOL INSTAPI ConfigDSNW(HWND hwndParent, WORD req, LPCWSTR wDriver, LPCWSTR wAttributes) { + Configuration config; + ConnectionStringParser parser(config); + std::string attributes = + arrow::util::WideStringToUTF8(std::wstring(wAttributes)).ValueOr(""); + parser.ParseConfigAttributes(attributes.c_str()); + + switch (req) { + case ODBC_ADD_DSN: { + config.LoadDefaults(); + if (!DisplayConnectionWindow(hwndParent, config) || !RegisterDsn(config, wDriver)) + return FALSE; + + break; + } + + case ODBC_CONFIG_DSN: { + const std::string& dsn = config.Get(FlightSqlConnection::DSN); + std::wstring wDsn = arrow::util::UTF8ToWideString(dsn).ValueOr(L""); + if (!SQLValidDSN(wDsn.c_str())) return FALSE; + + Configuration loaded(config); + loaded.LoadDsn(dsn); + + if (!DisplayConnectionWindow(hwndParent, loaded) || !UnregisterDsn(wDsn.c_str()) || + !RegisterDsn(loaded, wDriver)) + return FALSE; + + break; + } + + case ODBC_REMOVE_DSN: { + const std::string& dsn = config.Get(FlightSqlConnection::DSN); + std::wstring wDsn = arrow::util::UTF8ToWideString(dsn).ValueOr(L""); + if (!SQLValidDSN(wDsn.c_str()) || !UnregisterDsn(wDsn)) return FALSE; + + break; + } + + default: + return FALSE; + } + + return TRUE; +} diff --git a/cpp/src/arrow/flight/sql/odbc/install/arrow-flight-sql-odbc-patch.xml b/cpp/src/arrow/flight/sql/odbc/install/arrow-flight-sql-odbc-patch.xml new file mode 100644 index 00000000000..f1a63ce5d3b --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/install/arrow-flight-sql-odbc-patch.xml @@ -0,0 +1,22 @@ + + + + + + + diff --git a/cpp/src/arrow/flight/sql/odbc/install/arrow-flight-sql-odbc.wxs b/cpp/src/arrow/flight/sql/odbc/install/arrow-flight-sql-odbc.wxs new file mode 100644 index 00000000000..bd0216aa766 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/install/arrow-flight-sql-odbc.wxs @@ -0,0 +1,37 @@ + + + + + + + + + + + + + + + + + + + + + diff --git a/cpp/src/arrow/flight/sql/odbc/install/arrow-wix-banner.bmp b/cpp/src/arrow/flight/sql/odbc/install/arrow-wix-banner.bmp new file mode 100644 index 00000000000..0c82036f4ec Binary files /dev/null and b/cpp/src/arrow/flight/sql/odbc/install/arrow-wix-banner.bmp differ diff --git a/cpp/src/arrow/flight/sql/odbc/install/install_amd64.cmd b/cpp/src/arrow/flight/sql/odbc/install/install_amd64.cmd new file mode 100644 index 00000000000..b1fd85d578e --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/install/install_amd64.cmd @@ -0,0 +1,53 @@ +@REM Licensed to the Apache Software Foundation (ASF) under one +@REM or more contributor license agreements. See the NOTICE file +@REM distributed with this work for additional information +@REM regarding copyright ownership. The ASF licenses this file +@REM to you under the Apache License, Version 2.0 (the +@REM "License"); you may not use this file except in compliance +@REM with the License. You may obtain a copy of the License at +@REM +@REM http://www.apache.org/licenses/LICENSE-2.0 +@REM +@REM Unless required by applicable law or agreed to in writing, +@REM software distributed under the License is distributed on an +@REM "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +@REM KIND, either express or implied. See the License for the +@REM specific language governing permissions and limitations +@REM under the License. + +@echo off + +set ODBC_AMD64=%1 + +@REM enable delayed variable expansion to make environment variables enclosed with "!" to be evaluated +@REM when the command is executed instead of when the command is parsed +setlocal enableextensions enabledelayedexpansion + +if [%ODBC_AMD64%] == [] ( + echo error: 64-bit driver is not specified. Call format: install_amd64 abs_path_to_64_bit_driver + pause + exit /b 1 +) + +if exist %ODBC_AMD64% ( + for %%i IN (%ODBC_AMD64%) DO IF EXIST %%~si\NUL ( + echo warning: The path you have specified seems to be a directory. Note that you have to specify path to driver file itself instead. + ) + echo Installing 64-bit driver: %ODBC_AMD64% + reg add "HKEY_LOCAL_MACHINE\SOFTWARE\ODBC\ODBCINST.INI\Apache Arrow Flight SQL ODBC Driver" /v DriverODBCVer /t REG_SZ /d "03.80" /f + reg add "HKEY_LOCAL_MACHINE\SOFTWARE\ODBC\ODBCINST.INI\Apache Arrow Flight SQL ODBC Driver" /v UsageCount /t REG_DWORD /d 00000001 /f + reg add "HKEY_LOCAL_MACHINE\SOFTWARE\ODBC\ODBCINST.INI\Apache Arrow Flight SQL ODBC Driver" /v Driver /t REG_SZ /d %ODBC_AMD64% /f + reg add "HKEY_LOCAL_MACHINE\SOFTWARE\ODBC\ODBCINST.INI\Apache Arrow Flight SQL ODBC Driver" /v Setup /t REG_SZ /d %ODBC_AMD64% /f + reg add "HKEY_LOCAL_MACHINE\SOFTWARE\ODBC\ODBCINST.INI\ODBC Drivers" /v "Apache Arrow Flight SQL ODBC Driver" /t REG_SZ /d "Installed" /f + + IF !ERRORLEVEL! NEQ 0 ( + echo Error occurred while registering 64-bit driver. Exiting. + echo ERRORLEVEL: !ERRORLEVEL! + exit !ERRORLEVEL! + ) +) else ( + echo 64-bit driver can not be found: %ODBC_AMD64% + echo Call format: install_amd64 abs_path_to_64_bit_driver + pause + exit /b 1 +) diff --git a/cpp/src/arrow/flight/sql/odbc/install/versioninfo.rc.in b/cpp/src/arrow/flight/sql/odbc/install/versioninfo.rc.in new file mode 100644 index 00000000000..13024a7a50b --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/install/versioninfo.rc.in @@ -0,0 +1,54 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#define VER_FILEVERSION @VER_FILEVERSION@ +#define VER_FILEVERSION_STR "@VER_FILEVERSION_STR@\0" + +#define VER_PRODUCTVERSION @VER_FILEVERSION@ +#define VER_PRODUCTVERSION_STR "@VER_FILEVERSION_STR@\0" + +#define VER_COMPANYNAME_STR "@VER_COMPANYNAME_STR@\0" +#define VER_PRODUCTNAME_STR "@VER_PRODUCTNAME_STR@\0" + +1 VERSIONINFO +FILEVERSION VER_FILEVERSION +PRODUCTVERSION VER_PRODUCTVERSION +BEGIN + BLOCK "StringFileInfo" + BEGIN + BLOCK "040904E4" + BEGIN + VALUE "CompanyName", VER_COMPANYNAME_STR + VALUE "FileVersion", VER_FILEVERSION_STR + VALUE "ProductName", VER_PRODUCTNAME_STR + VALUE "ProductVersion", VER_PRODUCTVERSION_STR + END + END + + BLOCK "VarFileInfo" + BEGIN + /* The following line should only be modified for localized versions. */ + /* It consists of any number of WORD,WORD pairs, with each pair */ + /* describing a language,codepage combination supported by the file. */ + /* */ + /* For example, a file might have values "0x409,1252" indicating that it */ + /* supports English language (0x409) in the Windows ANSI codepage (1252). */ + + VALUE "Translation", 0x409, 1252 + + END +END diff --git a/cpp/src/arrow/flight/sql/odbc/odbc.def b/cpp/src/arrow/flight/sql/odbc/odbc.def new file mode 100644 index 00000000000..8ba5b3fff78 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/odbc.def @@ -0,0 +1,61 @@ +; Licensed to the Apache Software Foundation (ASF) under one +; or more contributor license agreements. See the NOTICE file +; distributed with this work for additional information +; regarding copyright ownership. The ASF licenses this file +; to you under the Apache License, Version 2.0 (the +; "License"); you may not use this file except in compliance +; with the License. You may obtain a copy of the License at +; +; http://www.apache.org/licenses/LICENSE-2.0 +; +; Unless required by applicable law or agreed to in writing, +; software distributed under the License is distributed on an +; "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +; KIND, either express or implied. See the License for the +; specific language governing permissions and limitations +; under the License. + +LIBRARY arrow_flight_sql_odbc +EXPORTS + ConfigDSNW + SQLAllocConnect + SQLAllocEnv + SQLAllocHandle + SQLAllocStmt + SQLBindCol + SQLCancel + SQLCloseCursor + SQLColAttributeW + SQLColumnsW + SQLConnectW + SQLDescribeColW + SQLDisconnect + SQLDriverConnectW + SQLExecDirectW + SQLExecute + SQLExtendedFetch + SQLFetch + SQLFetchScroll + SQLForeignKeysW + SQLFreeEnv + SQLFreeConnect + SQLFreeHandle + SQLFreeStmt + SQLGetConnectAttrW + SQLGetData + SQLGetDiagFieldW + SQLGetDiagRecW + SQLGetEnvAttr + SQLGetInfoW + SQLGetStmtAttrW + SQLGetTypeInfoW + SQLRowCount + SQLMoreResults + SQLNativeSqlW + SQLNumResultCols + SQLPrepareW + SQLPrimaryKeysW + SQLSetConnectAttrW + SQLSetEnvAttr + SQLSetStmtAttrW + SQLTablesW diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_api.cc b/cpp/src/arrow/flight/sql/odbc/odbc_api.cc new file mode 100644 index 00000000000..82a167b3c16 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/odbc_api.cc @@ -0,0 +1,1521 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// flight_sql_connection.h needs to be included first due to conflicts with windows.h +#include "arrow/flight/sql/odbc/flight_sql/flight_sql_connection.h" + +#include "arrow/flight/sql/odbc/flight_sql/include/flight_sql/config/configuration.h" +#include "arrow/flight/sql/odbc/flight_sql/include/flight_sql/flight_sql_driver.h" +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/diagnostics.h" +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/logger.h" +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/attribute_utils.h" +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/encoding_utils.h" +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/odbc_connection.h" +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/odbc_descriptor.h" +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/odbc_environment.h" +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/odbc_statement.h" +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/spi/connection.h" + +#if defined _WIN32 || defined _WIN64 +// For displaying DSN Window +# include "arrow/flight/sql/odbc/flight_sql/system_dsn.h" +#endif + +// odbc_api includes windows.h, which needs to be put behind winsock2.h. +// odbc_environment.h includes winsock2.h +#include "arrow/flight/sql/odbc/odbc_api.h" + +namespace arrow { +SQLRETURN SQLAllocHandle(SQLSMALLINT type, SQLHANDLE parent, SQLHANDLE* result) { + LOG_DEBUG("SQLAllocHandle called with type: {}, parent: {}, result: {}", type, parent, + fmt::ptr(result)); + + *result = nullptr; + + switch (type) { + case SQL_HANDLE_ENV: { + using driver::flight_sql::FlightSqlDriver; + using ODBC::ODBCEnvironment; + + *result = SQL_NULL_HENV; + + try { + static std::shared_ptr odbc_driver = + std::make_shared(); + *result = reinterpret_cast(new ODBCEnvironment(odbc_driver)); + + return SQL_SUCCESS; + } catch (const std::bad_alloc&) { + // allocating environment failed so cannot log diagnostic error here + return SQL_ERROR; + } + } + + case SQL_HANDLE_DBC: { + using ODBC::ODBCConnection; + using ODBC::ODBCEnvironment; + + *result = SQL_NULL_HDBC; + + ODBCEnvironment* environment = reinterpret_cast(parent); + + return ODBCEnvironment::ExecuteWithDiagnostics(environment, SQL_ERROR, [=]() { + std::shared_ptr conn = environment->CreateConnection(); + + if (conn) { + *result = reinterpret_cast(conn.get()); + + return SQL_SUCCESS; + } + + return SQL_ERROR; + }); + } + + case SQL_HANDLE_STMT: { + using ODBC::ODBCConnection; + using ODBC::ODBCStatement; + + *result = SQL_NULL_HSTMT; + + ODBCConnection* connection = reinterpret_cast(parent); + + return ODBCConnection::ExecuteWithDiagnostics(connection, SQL_ERROR, [=]() { + std::shared_ptr statement = connection->createStatement(); + + if (statement) { + *result = reinterpret_cast(statement.get()); + + return SQL_SUCCESS; + } + + return SQL_ERROR; + }); + } + + case SQL_HANDLE_DESC: { + using ODBC::ODBCConnection; + using ODBC::ODBCDescriptor; + + *result = SQL_NULL_HDESC; + + ODBCConnection* connection = reinterpret_cast(parent); + + return ODBCConnection::ExecuteWithDiagnostics(connection, SQL_ERROR, [=]() { + std::shared_ptr descriptor = connection->createDescriptor(); + + if (descriptor) { + *result = reinterpret_cast(descriptor.get()); + + return SQL_SUCCESS; + } + + return SQL_ERROR; + }); + } + + default: + break; + } + + return SQL_ERROR; +} + +SQLRETURN SQLFreeHandle(SQLSMALLINT type, SQLHANDLE handle) { + LOG_DEBUG("SQLFreeHandle called with type: {}, handle: {}", type, handle); + + switch (type) { + case SQL_HANDLE_ENV: { + using ODBC::ODBCEnvironment; + + ODBCEnvironment* environment = reinterpret_cast(handle); + + if (!environment) { + return SQL_INVALID_HANDLE; + } + + delete environment; + + return SQL_SUCCESS; + } + + case SQL_HANDLE_DBC: { + using ODBC::ODBCConnection; + + ODBCConnection* conn = reinterpret_cast(handle); + + if (!conn) { + return SQL_INVALID_HANDLE; + } + + conn->releaseConnection(); + + return SQL_SUCCESS; + } + + case SQL_HANDLE_STMT: { + using ODBC::ODBCStatement; + + ODBCStatement* statement = reinterpret_cast(handle); + + if (!statement) { + return SQL_INVALID_HANDLE; + } + + statement->releaseStatement(); + + return SQL_SUCCESS; + } + + case SQL_HANDLE_DESC: { + using ODBC::ODBCDescriptor; + + ODBCDescriptor* descriptor = reinterpret_cast(handle); + + if (!descriptor) { + return SQL_INVALID_HANDLE; + } + + descriptor->ReleaseDescriptor(); + + return SQL_SUCCESS; + } + + default: + break; + } + + return SQL_ERROR; +} + +SQLRETURN SQLFreeStmt(SQLHSTMT handle, SQLUSMALLINT option) { + switch (option) { + case SQL_CLOSE: { + using ODBC::ODBCStatement; + + return ODBCStatement::ExecuteWithDiagnostics(handle, SQL_ERROR, [=]() { + ODBCStatement* statement = reinterpret_cast(handle); + + // Close cursor with suppressErrors set to true + statement->closeCursor(true); + + return SQL_SUCCESS; + }); + } + + case SQL_DROP: { + return SQLFreeHandle(SQL_HANDLE_STMT, handle); + } + + case SQL_UNBIND: { + using ODBC::ODBCDescriptor; + using ODBC::ODBCStatement; + return ODBCStatement::ExecuteWithDiagnostics(handle, SQL_ERROR, [=]() { + ODBCStatement* statement = reinterpret_cast(handle); + ODBCDescriptor* ard = statement->GetARD(); + // Unbind columns + ard->SetHeaderField(SQL_DESC_COUNT, (void*)0, 0); + return SQL_SUCCESS; + }); + } + + // SQLBindParameter is not supported + case SQL_RESET_PARAMS: { + return SQL_SUCCESS; + } + } + + return SQL_ERROR; +} + +inline bool IsValidStringFieldArgs(SQLPOINTER diagInfoPtr, SQLSMALLINT bufferLength, + SQLSMALLINT* stringLengthPtr, bool isUnicode) { + const SQLSMALLINT charSize = isUnicode ? GetSqlWCharSize() : sizeof(char); + const bool hasValidBuffer = + diagInfoPtr && bufferLength >= 0 && bufferLength % charSize == 0; + + // regardless of capacity return false if invalid + if (diagInfoPtr && !hasValidBuffer) { + return false; + } + + return hasValidBuffer || stringLengthPtr; +} + +SQLRETURN SQLGetDiagField(SQLSMALLINT handleType, SQLHANDLE handle, SQLSMALLINT recNumber, + SQLSMALLINT diagIdentifier, SQLPOINTER diagInfoPtr, + SQLSMALLINT bufferLength, SQLSMALLINT* stringLengthPtr) { + // TODO: Implement additional fields types + // https://github.com/apache/arrow/issues/46573 + LOG_DEBUG( + "SQLGetDiagFieldW called with handleType: {}, handle: {}, recNumber: {}, " + "diagIdentifier: {}, diagInfoPtr: {}, bufferLength: {}, stringLengthPtr: {}", + handleType, handle, recNumber, diagIdentifier, diagInfoPtr, bufferLength, + fmt::ptr(stringLengthPtr)); + + using driver::odbcabstraction::Diagnostics; + using ODBC::GetStringAttribute; + using ODBC::ODBCConnection; + using ODBC::ODBCDescriptor; + using ODBC::ODBCEnvironment; + using ODBC::ODBCStatement; + + if (!handle) { + return SQL_INVALID_HANDLE; + } + + if (!diagInfoPtr && !stringLengthPtr) { + return SQL_ERROR; + } + + // If buffer length derived from null terminated string + if (diagInfoPtr && bufferLength == SQL_NTS) { + const wchar_t* str = reinterpret_cast(diagInfoPtr); + bufferLength = wcslen(str) * driver::odbcabstraction::GetSqlWCharSize(); + } + + // Set character type to be Unicode by default + const bool isUnicode = true; + Diagnostics* diagnostics = nullptr; + + switch (handleType) { + case SQL_HANDLE_ENV: { + ODBCEnvironment* environment = reinterpret_cast(handle); + diagnostics = &environment->GetDiagnostics(); + break; + } + + case SQL_HANDLE_DBC: { + ODBCConnection* connection = reinterpret_cast(handle); + diagnostics = &connection->GetDiagnostics(); + break; + } + + case SQL_HANDLE_DESC: { + ODBCDescriptor* descriptor = reinterpret_cast(handle); + diagnostics = &descriptor->GetDiagnostics(); + break; + } + + case SQL_HANDLE_STMT: { + ODBCStatement* statement = reinterpret_cast(handle); + diagnostics = &statement->GetDiagnostics(); + break; + } + + default: + return SQL_ERROR; + } + + if (!diagnostics) { + return SQL_ERROR; + } + + // Retrieve and return if header level diagnostics + switch (diagIdentifier) { + case SQL_DIAG_NUMBER: { + if (diagInfoPtr) { + *static_cast(diagInfoPtr) = + static_cast(diagnostics->GetRecordCount()); + } + + if (stringLengthPtr) { + *stringLengthPtr = sizeof(SQLINTEGER); + } + + return SQL_SUCCESS; + } + + // TODO implement return code function + case SQL_DIAG_RETURNCODE: { + return SQL_SUCCESS; + } + + case SQL_DIAG_CURSOR_ROW_COUNT: { + if (handleType == SQL_HANDLE_STMT) { + if (diagInfoPtr) { + // Will always be 0 if only SELECT supported + *static_cast(diagInfoPtr) = 0; + } + + if (stringLengthPtr) { + *stringLengthPtr = sizeof(SQLLEN); + } + + return SQL_SUCCESS; + } + + return SQL_ERROR; + } + + // Not supported + case SQL_DIAG_DYNAMIC_FUNCTION: + case SQL_DIAG_DYNAMIC_FUNCTION_CODE: { + if (handleType == SQL_HANDLE_STMT) { + return SQL_SUCCESS; + } + + return SQL_ERROR; + } + + case SQL_DIAG_ROW_COUNT: { + if (handleType == SQL_HANDLE_STMT) { + if (diagInfoPtr) { + // Will always be 0 if only SELECT is supported + *static_cast(diagInfoPtr) = 0; + } + + if (stringLengthPtr) { + *stringLengthPtr = sizeof(SQLLEN); + } + + return SQL_SUCCESS; + } + + return SQL_ERROR; + } + } + + // If not a diagnostic header field then the record number must be 1 or greater + if (recNumber < 1) { + return SQL_ERROR; + } + + // Retrieve record level diagnostics from specified 1 based record + const uint32_t recordIndex = static_cast(recNumber - 1); + if (!diagnostics->HasRecord(recordIndex)) { + return SQL_NO_DATA; + } + + // Retrieve record field data + switch (diagIdentifier) { + case SQL_DIAG_MESSAGE_TEXT: { + if (IsValidStringFieldArgs(diagInfoPtr, bufferLength, stringLengthPtr, isUnicode)) { + const std::string& message = diagnostics->GetMessageText(recordIndex); + return GetStringAttribute(isUnicode, message, true, diagInfoPtr, bufferLength, + stringLengthPtr, *diagnostics); + } + + return SQL_ERROR; + } + + case SQL_DIAG_NATIVE: { + if (diagInfoPtr) { + *static_cast(diagInfoPtr) = diagnostics->GetNativeError(recordIndex); + } + + if (stringLengthPtr) { + *stringLengthPtr = sizeof(SQLINTEGER); + } + + return SQL_SUCCESS; + } + + case SQL_DIAG_SERVER_NAME: { + if (IsValidStringFieldArgs(diagInfoPtr, bufferLength, stringLengthPtr, isUnicode)) { + switch (handleType) { + case SQL_HANDLE_DBC: { + ODBCConnection* connection = reinterpret_cast(handle); + std::string dsn = connection->GetDSN(); + return GetStringAttribute(isUnicode, dsn, true, diagInfoPtr, bufferLength, + stringLengthPtr, *diagnostics); + } + + case SQL_HANDLE_DESC: { + ODBCDescriptor* descriptor = reinterpret_cast(handle); + ODBCConnection* connection = &descriptor->GetConnection(); + std::string dsn = connection->GetDSN(); + return GetStringAttribute(isUnicode, dsn, true, diagInfoPtr, bufferLength, + stringLengthPtr, *diagnostics); + break; + } + + case SQL_HANDLE_STMT: { + ODBCStatement* statement = reinterpret_cast(handle); + ODBCConnection* connection = &statement->GetConnection(); + std::string dsn = connection->GetDSN(); + return GetStringAttribute(isUnicode, dsn, true, diagInfoPtr, bufferLength, + stringLengthPtr, *diagnostics); + } + + default: + return SQL_ERROR; + } + } + + return SQL_ERROR; + } + + case SQL_DIAG_SQLSTATE: { + if (IsValidStringFieldArgs(diagInfoPtr, bufferLength, stringLengthPtr, isUnicode)) { + const std::string& state = diagnostics->GetSQLState(recordIndex); + return GetStringAttribute(isUnicode, state, true, diagInfoPtr, bufferLength, + stringLengthPtr, *diagnostics); + } + + return SQL_ERROR; + } + + // Return valid dummy variable for unimplemented field + case SQL_DIAG_COLUMN_NUMBER: { + if (diagInfoPtr) { + *static_cast(diagInfoPtr) = SQL_NO_COLUMN_NUMBER; + } + + if (stringLengthPtr) { + *stringLengthPtr = sizeof(SQLINTEGER); + } + + return SQL_SUCCESS; + } + + // Return empty string dummy variable for unimplemented fields + case SQL_DIAG_CLASS_ORIGIN: + case SQL_DIAG_CONNECTION_NAME: + case SQL_DIAG_SUBCLASS_ORIGIN: { + if (IsValidStringFieldArgs(diagInfoPtr, bufferLength, stringLengthPtr, isUnicode)) { + return GetStringAttribute(isUnicode, "", true, diagInfoPtr, bufferLength, + stringLengthPtr, *diagnostics); + } + + return SQL_ERROR; + } + + // Return valid dummy variable for unimplemented field + case SQL_DIAG_ROW_NUMBER: { + if (diagInfoPtr) { + *static_cast(diagInfoPtr) = SQL_NO_ROW_NUMBER; + } + + if (stringLengthPtr) { + *stringLengthPtr = sizeof(SQLLEN); + } + + return SQL_SUCCESS; + } + + default: { + return SQL_ERROR; + } + } + + return SQL_ERROR; +} + +SQLRETURN SQLGetDiagRec(SQLSMALLINT handleType, SQLHANDLE handle, SQLSMALLINT recNumber, + SQLWCHAR* sqlState, SQLINTEGER* nativeErrorPtr, + SQLWCHAR* messageText, SQLSMALLINT bufferLength, + SQLSMALLINT* textLengthPtr) { + LOG_DEBUG( + "SQLGetDiagRecW called with handleType: {}, handle: {}, recNumber: {}, " + "sqlState: {}, nativeErrorPtr: {}, messageText: {}, bufferLength: {}, " + "textLengthPtr: {}", + handleType, handle, recNumber, fmt::ptr(sqlState), fmt::ptr(nativeErrorPtr), + fmt::ptr(messageText), bufferLength, fmt::ptr(textLengthPtr)); + + using driver::odbcabstraction::Diagnostics; + using ODBC::GetStringAttribute; + using ODBC::ODBCConnection; + using ODBC::ODBCDescriptor; + using ODBC::ODBCEnvironment; + using ODBC::ODBCStatement; + + if (!handle) { + return SQL_INVALID_HANDLE; + } + + // Record number must be greater or equal to 1 + if (recNumber < 1 || bufferLength < 0) { + return SQL_ERROR; + } + + // Set character type to be Unicode by default + const bool isUnicode = true; + Diagnostics* diagnostics = nullptr; + + switch (handleType) { + case SQL_HANDLE_ENV: { + auto* environment = ODBCEnvironment::of(handle); + diagnostics = &environment->GetDiagnostics(); + break; + } + + case SQL_HANDLE_DBC: { + auto* connection = ODBCConnection::of(handle); + diagnostics = &connection->GetDiagnostics(); + break; + } + + case SQL_HANDLE_DESC: { + auto* descriptor = ODBCDescriptor::of(handle); + diagnostics = &descriptor->GetDiagnostics(); + break; + } + + case SQL_HANDLE_STMT: { + auto* statement = ODBCStatement::of(handle); + diagnostics = &statement->GetDiagnostics(); + break; + } + + default: + return SQL_INVALID_HANDLE; + } + + if (!diagnostics) { + return SQL_ERROR; + } + + // Convert from ODBC 1 based record number to internal diagnostics 0 indexed storage + const size_t recordIndex = static_cast(recNumber - 1); + if (!diagnostics->HasRecord(recordIndex)) { + return SQL_NO_DATA; + } + + if (sqlState) { + // The length of the sql state is always 5 characters plus null + SQLSMALLINT size = 6; + const std::string& state = diagnostics->GetSQLState(recordIndex); + GetStringAttribute(isUnicode, state, false, sqlState, size, &size, *diagnostics); + } + + if (nativeErrorPtr) { + *nativeErrorPtr = diagnostics->GetNativeError(recordIndex); + } + + if (messageText || textLengthPtr) { + const std::string& message = diagnostics->GetMessageText(recordIndex); + return GetStringAttribute(isUnicode, message, false, messageText, bufferLength, + textLengthPtr, *diagnostics); + } + + return SQL_SUCCESS; +} + +SQLRETURN SQLGetEnvAttr(SQLHENV env, SQLINTEGER attr, SQLPOINTER valuePtr, + SQLINTEGER bufferLength, SQLINTEGER* strLenPtr) { + LOG_DEBUG( + "SQLGetEnvAttr called with env: {}, attr: {}, valuePtr: {}, " + "bufferLength: {}, strLenPtr: {}", + env, attr, valuePtr, bufferLength, fmt::ptr(strLenPtr)); + + using driver::odbcabstraction::DriverException; + using ODBC::ODBCEnvironment; + + ODBCEnvironment* environment = reinterpret_cast(env); + + return ODBCEnvironment::ExecuteWithDiagnostics(environment, SQL_ERROR, [=]() { + switch (attr) { + case SQL_ATTR_ODBC_VERSION: { + if (!valuePtr && !strLenPtr) { + throw DriverException("Invalid null pointer for attribute.", "HY000"); + } + + if (valuePtr) { + SQLINTEGER* value = reinterpret_cast(valuePtr); + *value = static_cast(environment->getODBCVersion()); + } + + if (strLenPtr) { + *strLenPtr = sizeof(SQLINTEGER); + } + + return SQL_SUCCESS; + } + + case SQL_ATTR_OUTPUT_NTS: { + if (!valuePtr && !strLenPtr) { + throw DriverException("Invalid null pointer for attribute.", "HY000"); + } + + if (valuePtr) { + // output nts always returns SQL_TRUE + SQLINTEGER* value = reinterpret_cast(valuePtr); + *value = SQL_TRUE; + } + + if (strLenPtr) { + *strLenPtr = sizeof(SQLINTEGER); + } + + return SQL_SUCCESS; + } + + case SQL_ATTR_CONNECTION_POOLING: { + throw DriverException("Optional feature not supported.", "HYC00"); + } + + default: { + throw DriverException("Invalid attribute", "HYC00"); + } + } + }); +} + +SQLRETURN SQLSetEnvAttr(SQLHENV env, SQLINTEGER attr, SQLPOINTER valuePtr, + SQLINTEGER strLen) { + LOG_DEBUG( + "SQLSetEnvAttr called with env: {}, attr: {}, valuePtr: {}, " + "strLen: {}", + env, attr, valuePtr, strLen); + + using driver::odbcabstraction::DriverException; + using ODBC::ODBCEnvironment; + + ODBCEnvironment* environment = reinterpret_cast(env); + + return ODBCEnvironment::ExecuteWithDiagnostics(environment, SQL_ERROR, [=]() { + if (!valuePtr) { + throw DriverException("Invalid null pointer for attribute.", "HY024"); + } + + switch (attr) { + case SQL_ATTR_ODBC_VERSION: { + SQLINTEGER version = + static_cast(reinterpret_cast(valuePtr)); + if (version == SQL_OV_ODBC2 || version == SQL_OV_ODBC3) { + environment->setODBCVersion(version); + + return SQL_SUCCESS; + } else { + throw DriverException("Invalid value for attribute", "HY024"); + } + } + + case SQL_ATTR_OUTPUT_NTS: { + // output nts can not be set to SQL_FALSE, is always SQL_TRUE + SQLINTEGER value = static_cast(reinterpret_cast(valuePtr)); + if (value == SQL_TRUE) { + return SQL_SUCCESS; + } else { + throw DriverException("Invalid value for attribute", "HY024"); + } + } + + case SQL_ATTR_CONNECTION_POOLING: { + throw DriverException("Optional feature not supported.", "HYC00"); + } + + default: { + throw DriverException("Invalid attribute", "HY092"); + } + } + }); +} + +SQLRETURN SQLGetConnectAttr(SQLHDBC conn, SQLINTEGER attribute, SQLPOINTER valuePtr, + SQLINTEGER bufferLength, SQLINTEGER* stringLengthPtr) { + LOG_DEBUG( + "SQLGetConnectAttrW called with conn: {}, attribute: {}, valuePtr: {}, " + "bufferLength: {}, stringLengthPtr: {}", + conn, attribute, valuePtr, bufferLength, fmt::ptr(stringLengthPtr)); + + using driver::odbcabstraction::Connection; + using ODBC::ODBCConnection; + + return ODBCConnection::ExecuteWithDiagnostics(conn, SQL_ERROR, [=]() { + const bool isUnicode = true; + ODBCConnection* connection = reinterpret_cast(conn); + return connection->GetConnectAttr(attribute, valuePtr, bufferLength, stringLengthPtr, + isUnicode); + }); +} + +SQLRETURN SQLSetConnectAttr(SQLHDBC conn, SQLINTEGER attr, SQLPOINTER valuePtr, + SQLINTEGER valueLen) { + LOG_DEBUG( + "SQLSetConnectAttrW called with conn: {}, attr: {}, valuePtr: {}, valueLen: {}", + conn, attr, valuePtr, valueLen); + + using driver::odbcabstraction::Connection; + using ODBC::ODBCConnection; + + return ODBCConnection::ExecuteWithDiagnostics(conn, SQL_ERROR, [=]() { + const bool isUnicode = true; + ODBCConnection* connection = reinterpret_cast(conn); + connection->SetConnectAttr(attr, valuePtr, valueLen, isUnicode); + return SQL_SUCCESS; + }); +} + +SQLRETURN SQLDriverConnect(SQLHDBC conn, SQLHWND windowHandle, + SQLWCHAR* inConnectionString, + SQLSMALLINT inConnectionStringLen, + SQLWCHAR* outConnectionString, + SQLSMALLINT outConnectionStringBufferLen, + SQLSMALLINT* outConnectionStringLen, + SQLUSMALLINT driverCompletion) { + LOG_DEBUG( + "SQLDriverConnectW called with conn: {}, windowHandle: {}, inConnectionString: {}, " + "inConnectionStringLen: {}, outConnectionString: {}, outConnectionStringBufferLen: " + "{}, outConnectionStringLen: {}, driverCompletion: {}", + conn, fmt::ptr(windowHandle), fmt::ptr(inConnectionString), inConnectionStringLen, + fmt::ptr(outConnectionString), outConnectionStringBufferLen, + fmt::ptr(outConnectionStringLen), driverCompletion); + + // TODO: Implement FILEDSN and SAVEFILE keywords according to the spec + // https://github.com/apache/arrow/issues/46449 + + // TODO: Copy connection string properly in SQLDriverConnect according to the + // spec https://github.com/apache/arrow/issues/46560 + + using driver::odbcabstraction::Connection; + using driver::odbcabstraction::DriverException; + using ODBC::ODBCConnection; + + return ODBCConnection::ExecuteWithDiagnostics(conn, SQL_ERROR, [=]() { + ODBCConnection* connection = reinterpret_cast(conn); + std::string connection_string = + ODBC::SqlWcharToString(inConnectionString, inConnectionStringLen); + Connection::ConnPropertyMap properties; + std::string dsn = + ODBCConnection::getPropertiesFromConnString(connection_string, properties); + + std::vector missing_properties; + + // TODO: Implement SQL_DRIVER_COMPLETE_REQUIRED in SQLDriverConnect according to the + // spec https://github.com/apache/arrow/issues/46448 +#if defined _WIN32 || defined _WIN64 + // Load the DSN window according to driverCompletion + if (driverCompletion == SQL_DRIVER_PROMPT) { + // Load DSN window before first attempt to connect + driver::flight_sql::config::Configuration config; + if (!DisplayConnectionWindow(windowHandle, config, properties)) { + return static_cast(SQL_NO_DATA); + } + connection->connect(dsn, properties, missing_properties); + } else if (driverCompletion == SQL_DRIVER_COMPLETE || + driverCompletion == SQL_DRIVER_COMPLETE_REQUIRED) { + try { + connection->connect(dsn, properties, missing_properties); + } catch (const DriverException&) { + // If first connection fails due to missing attributes, load + // the DSN window and try to connect again + if (!missing_properties.empty()) { + driver::flight_sql::config::Configuration config; + missing_properties.clear(); + + if (!DisplayConnectionWindow(windowHandle, config, properties)) { + return static_cast(SQL_NO_DATA); + } + connection->connect(dsn, properties, missing_properties); + } else { + throw; + } + } + } else { + // Default case: attempt connection without showing DSN window + connection->connect(dsn, properties, missing_properties); + } +#else + // Attempt connection without loading DSN window on macOS/Linux + connection->connect(dsn, properties, missing_properties); +#endif + // Copy connection string to outConnectionString after connection attempt + return ODBC::GetStringAttribute(true, connection_string, false, outConnectionString, + outConnectionStringBufferLen, outConnectionStringLen, + connection->GetDiagnostics()); + }); +} + +SQLRETURN SQLConnect(SQLHDBC conn, SQLWCHAR* dsnName, SQLSMALLINT dsnNameLen, + SQLWCHAR* userName, SQLSMALLINT userNameLen, SQLWCHAR* password, + SQLSMALLINT passwordLen) { + LOG_DEBUG( + "SQLConnectW called with conn: {}, dsnName: {}, dsnNameLen: {}, userName: {}, " + "userNameLen: {}, password: {}, passwordLen: {}", + conn, fmt::ptr(dsnName), dsnNameLen, fmt::ptr(userName), userNameLen, + fmt::ptr(password), passwordLen); + + using driver::flight_sql::FlightSqlConnection; + using driver::flight_sql::config::Configuration; + using ODBC::ODBCConnection; + + using ODBC::SqlWcharToString; + + return ODBCConnection::ExecuteWithDiagnostics(conn, SQL_ERROR, [=]() { + ODBCConnection* connection = reinterpret_cast(conn); + std::string dsn = SqlWcharToString(dsnName, dsnNameLen); + + Configuration config; + config.LoadDsn(dsn); + + if (userName) { + std::string uid = SqlWcharToString(userName, userNameLen); + config.Emplace(FlightSqlConnection::UID, std::move(uid)); + } + + if (password) { + std::string pwd = SqlWcharToString(password, passwordLen); + config.Emplace(FlightSqlConnection::PWD, std::move(pwd)); + } + + std::vector missing_properties; + + connection->connect(dsn, config.GetProperties(), missing_properties); + + return SQL_SUCCESS; + }); +} + +SQLRETURN SQLDisconnect(SQLHDBC conn) { + LOG_DEBUG("SQLDisconnect called with conn: {}", conn); + + using ODBC::ODBCConnection; + + return ODBCConnection::ExecuteWithDiagnostics(conn, SQL_ERROR, [=]() { + ODBCConnection* connection = reinterpret_cast(conn); + + connection->disconnect(); + + return SQL_SUCCESS; + }); +} + +SQLRETURN SQLGetInfo(SQLHDBC conn, SQLUSMALLINT infoType, SQLPOINTER infoValuePtr, + SQLSMALLINT bufLen, SQLSMALLINT* stringLengthPtr) { + LOG_DEBUG( + "SQLGetInfo called with conn: {}, infoType: {}, infoValuePtr: {}, bufLen: {}, " + "stringLengthPtr: {}", + conn, infoType, infoValuePtr, bufLen, fmt::ptr(stringLengthPtr)); + + using ODBC::ODBCConnection; + + return ODBCConnection::ExecuteWithDiagnostics(conn, SQL_ERROR, [=]() { + ODBCConnection* connection = reinterpret_cast(conn); + + // Set character type to be Unicode by default + const bool isUnicode = true; + + if (!infoValuePtr && !stringLengthPtr) { + return static_cast SQL_ERROR; + } + + return connection->GetInfo(infoType, infoValuePtr, bufLen, stringLengthPtr, + isUnicode); + }); +} + +SQLRETURN SQLGetStmtAttr(SQLHSTMT stmt, SQLINTEGER attribute, SQLPOINTER valuePtr, + SQLINTEGER bufferLength, SQLINTEGER* stringLengthPtr) { + LOG_DEBUG( + "SQLGetStmtAttrW called with stmt: {}, attribute: {}, valuePtr: {}, " + "bufferLength: {}, stringLengthPtr: {}", + stmt, attribute, valuePtr, bufferLength, fmt::ptr(stringLengthPtr)); + using ODBC::ODBCStatement; + + return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() { + ODBCStatement* statement = reinterpret_cast(stmt); + + bool isUnicode = true; + + statement->GetStmtAttr(attribute, valuePtr, bufferLength, stringLengthPtr, isUnicode); + + return SQL_SUCCESS; + }); +} + +SQLRETURN SQLSetStmtAttr(SQLHSTMT stmt, SQLINTEGER attribute, SQLPOINTER valuePtr, + SQLINTEGER stringLength) { + LOG_DEBUG( + "SQLSetStmtAttrW called with stmt: {}, attribute: {}, valuePtr: {}, " + "stringLength: {}", + stmt, attribute, valuePtr, stringLength); + using ODBC::ODBCStatement; + + return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() { + ODBCStatement* statement = reinterpret_cast(stmt); + + bool isUnicode = true; + + statement->SetStmtAttr(attribute, valuePtr, stringLength, isUnicode); + + return SQL_SUCCESS; + }); +} + +SQLRETURN SQLExecDirect(SQLHSTMT stmt, SQLWCHAR* queryText, SQLINTEGER textLength) { + LOG_DEBUG("SQLExecDirectW called with stmt: {}, queryText: {}, textLength: {}", stmt, + fmt::ptr(queryText), textLength); + using ODBC::ODBCStatement; + // The driver is built to handle SELECT statements only. + return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() { + ODBCStatement* statement = reinterpret_cast(stmt); + std::string query = ODBC::SqlWcharToString(queryText, textLength); + + statement->Prepare(query); + statement->ExecutePrepared(); + + return SQL_SUCCESS; + }); +} + +SQLRETURN SQLPrepare(SQLHSTMT stmt, SQLWCHAR* queryText, SQLINTEGER textLength) { + LOG_DEBUG("SQLPrepareW called with stmt: {}, queryText: {}, textLength: {}", stmt, + fmt::ptr(queryText), textLength); + using ODBC::ODBCStatement; + // The driver is built to handle SELECT statements only. + return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() { + ODBCStatement* statement = reinterpret_cast(stmt); + std::string query = ODBC::SqlWcharToString(queryText, textLength); + + statement->Prepare(query); + + return SQL_SUCCESS; + }); +} + +SQLRETURN SQLExecute(SQLHSTMT stmt) { + LOG_DEBUG("SQLExecute called with stmt: {}", stmt); + + using ODBC::ODBCStatement; + // The driver is built to handle SELECT statements only. + return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() { + ODBCStatement* statement = reinterpret_cast(stmt); + + statement->ExecutePrepared(); + + return SQL_SUCCESS; + }); +} + +SQLRETURN SQLFetch(SQLHSTMT stmt) { + LOG_DEBUG("SQLFetch called with stmt: {}", stmt); + + using ODBC::ODBCDescriptor; + using ODBC::ODBCStatement; + return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() { + ODBCStatement* statement = reinterpret_cast(stmt); + + // The SQL_ATTR_ROW_ARRAY_SIZE statement attribute specifies the number of rows in the + // rowset. + ODBCDescriptor* ard = statement->GetARD(); + size_t rows = static_cast(ard->GetArraySize()); + + if (statement->Fetch(rows)) { + return SQL_SUCCESS; + } else { + // Reached the end of rowset + return SQL_NO_DATA; + } + }); +} + +SQLRETURN SQLExtendedFetch(SQLHSTMT stmt, SQLUSMALLINT fetchOrientation, + SQLLEN fetchOffset, SQLULEN* rowCountPtr, + SQLUSMALLINT* rowStatusArray) { + // GH-47110: SQLExtendedFetch should return SQL_SUCCESS_WITH_INFO for certain diag + // states + LOG_DEBUG( + "SQLExtendedFetch called with stmt: {}, fetchOrientation: {}, fetchOffset: {}, " + "rowCountPtr: {}, rowStatusArray: {}", + stmt, fetchOrientation, fetchOffset, fmt::ptr(rowCountPtr), + fmt::ptr(rowStatusArray)); + using ODBC::ODBCDescriptor; + using ODBC::ODBCStatement; + return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() { + if (fetchOrientation != SQL_FETCH_NEXT) { + throw DriverException("Optional feature not supported.", "HYC00"); + } + // fetchOffset is ignored as only SQL_FETCH_NEXT is supported + + ODBCStatement* statement = reinterpret_cast(stmt); + + // The SQL_ROWSET_SIZE statement attribute specifies the number of rows in the + // rowset. + SQLULEN rowSetSize = statement->GetRowsetSize(); + LOG_DEBUG("SQL_ROWSET_SIZE value for SQLExtendedFetch: {}", rowSetSize); + if (statement->Fetch(static_cast(rowSetSize), rowCountPtr, rowStatusArray)) { + return SQL_SUCCESS; + } else { + // Reached the end of rowset + return SQL_NO_DATA; + } + }); +} + +SQLRETURN SQLFetchScroll(SQLHSTMT stmt, SQLSMALLINT fetchOrientation, + SQLLEN fetchOffset) { + LOG_DEBUG("SQLFetchScroll called with stmt: {}, fetchOrientation: {}, fetchOffset: {}", + stmt, fetchOrientation, fetchOffset); + using ODBC::ODBCDescriptor; + using ODBC::ODBCStatement; + return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() { + if (fetchOrientation != SQL_FETCH_NEXT) { + throw DriverException("Optional feature not supported.", "HYC00"); + } + // fetchOffset is ignored as only SQL_FETCH_NEXT is supported + + ODBCStatement* statement = reinterpret_cast(stmt); + + // The SQL_ATTR_ROW_ARRAY_SIZE statement attribute specifies the number of rows in the + // rowset. + ODBCDescriptor* ard = statement->GetARD(); + size_t rows = static_cast(ard->GetArraySize()); + if (statement->Fetch(rows)) { + return SQL_SUCCESS; + } else { + // Reached the end of rowset + return SQL_NO_DATA; + } + }); +} + +SQLRETURN SQLBindCol(SQLHSTMT stmt, SQLUSMALLINT recordNumber, SQLSMALLINT cType, + SQLPOINTER dataPtr, SQLLEN bufferLength, SQLLEN* indicatorPtr) { + LOG_DEBUG( + "SQLBindCol called with stmt: {}, recordNumber: {}, cType: {}, " + "dataPtr: {}, bufferLength: {}, strLen_or_IndPtr: {}", + stmt, recordNumber, cType, dataPtr, bufferLength, fmt::ptr(indicatorPtr)); + using ODBC::ODBCDescriptor; + using ODBC::ODBCStatement; + return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() { + // GH-47021: implement driver to return indicator value when data pointer is null + ODBCStatement* statement = reinterpret_cast(stmt); + ODBCDescriptor* ard = statement->GetARD(); + ard->BindCol(recordNumber, cType, dataPtr, bufferLength, indicatorPtr); + return SQL_SUCCESS; + }); +} + +SQLRETURN SQLCloseCursor(SQLHSTMT stmt) { + LOG_DEBUG("SQLCloseCursor called with stmt: {}", stmt); + using ODBC::ODBCStatement; + return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() { + ODBCStatement* statement = reinterpret_cast(stmt); + + // Close cursor with suppressErrors set to false + statement->closeCursor(false); + + return SQL_SUCCESS; + }); +} + +SQLRETURN SQLGetData(SQLHSTMT stmt, SQLUSMALLINT recordNumber, SQLSMALLINT cType, + SQLPOINTER dataPtr, SQLLEN bufferLength, SQLLEN* indicatorPtr) { + // GH-46979: support SQL_C_GUID data type + // GH-46980: support Interval data types + // GH-46985: return warning message instead of error on float truncation case + LOG_DEBUG( + "SQLGetData called with stmt: {}, recordNumber: {}, cType: {}, " + "dataPtr: {}, bufferLength: {}, indicatorPtr: {}", + stmt, recordNumber, cType, dataPtr, bufferLength, fmt::ptr(indicatorPtr)); + + using ODBC::ODBCStatement; + + return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() { + ODBCStatement* statement = reinterpret_cast(stmt); + return statement->GetData(recordNumber, cType, dataPtr, bufferLength, indicatorPtr); + }); +} + +SQLRETURN SQLMoreResults(SQLHSTMT stmt) { + LOG_DEBUG("SQLMoreResults called with stmt: {}", stmt); + using ODBC::ODBCStatement; + // Multiple result sets not supported. Return SQL_NO_DATA by default. + return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() { + ODBCStatement* statement = reinterpret_cast(stmt); + return statement->getMoreResults(); + }); +} + +SQLRETURN SQLNumResultCols(SQLHSTMT stmt, SQLSMALLINT* columnCountPtr) { + LOG_DEBUG("SQLNumResultCols called with stmt: {}, columnCountPtr: {}", stmt, + fmt::ptr(columnCountPtr)); + using ODBC::ODBCStatement; + return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() { + ODBCStatement* statement = reinterpret_cast(stmt); + statement->getColumnCount(columnCountPtr); + return SQL_SUCCESS; + }); +} + +SQLRETURN SQLRowCount(SQLHSTMT stmt, SQLLEN* rowCountPtr) { + LOG_DEBUG("SQLRowCount called with stmt: {}, columnCountPtr: {}", stmt, + fmt::ptr(rowCountPtr)); + using ODBC::ODBCStatement; + return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() { + ODBCStatement* statement = reinterpret_cast(stmt); + statement->getRowCount(rowCountPtr); + return SQL_SUCCESS; + }); +} + +SQLRETURN SQLTables(SQLHSTMT stmt, SQLWCHAR* catalogName, SQLSMALLINT catalogNameLength, + SQLWCHAR* schemaName, SQLSMALLINT schemaNameLength, + SQLWCHAR* tableName, SQLSMALLINT tableNameLength, SQLWCHAR* tableType, + SQLSMALLINT tableTypeLength) { + LOG_DEBUG( + "SQLTables called with stmt: {}, catalogName: {}, catalogNameLength: " + "{}, " + "schemaName: {}, schemaNameLength: {}, tableName: {}, tableNameLength: {}, " + "tableType: {}, " + "tableTypeLength: {}", + stmt, fmt::ptr(catalogName), catalogNameLength, fmt::ptr(schemaName), + schemaNameLength, fmt::ptr(tableName), tableNameLength, fmt::ptr(tableType), + tableTypeLength); + using ODBC::ODBCStatement; + using ODBC::SqlWcharToString; + + return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() { + ODBCStatement* statement = reinterpret_cast(stmt); + + std::string catalog = SqlWcharToString(catalogName, catalogNameLength); + std::string schema = SqlWcharToString(schemaName, schemaNameLength); + std::string table = SqlWcharToString(tableName, tableNameLength); + std::string type = SqlWcharToString(tableType, tableTypeLength); + + statement->GetTables(catalogName ? &catalog : nullptr, schemaName ? &schema : nullptr, + tableName ? &table : nullptr, tableType ? &type : nullptr); + + return SQL_SUCCESS; + }); +} + +SQLRETURN SQLColumns(SQLHSTMT stmt, SQLWCHAR* catalogName, SQLSMALLINT catalogNameLength, + SQLWCHAR* schemaName, SQLSMALLINT schemaNameLength, + SQLWCHAR* tableName, SQLSMALLINT tableNameLength, + SQLWCHAR* columnName, SQLSMALLINT columnNameLength) { + // GH-47159: Return NUM_PREC_RADIX based on whether COLUMN_SIZE contains number of + // digits or bits + LOG_DEBUG( + "SQLColumnsW called with stmt: {}, catalogName: {}, catalogNameLength: " + "{}, " + "schemaName: {}, schemaNameLength: {}, tableName: {}, tableNameLength: {}, " + "columnName: {}, " + "columnNameLength: {}", + stmt, fmt::ptr(catalogName), catalogNameLength, fmt::ptr(schemaName), + schemaNameLength, fmt::ptr(tableName), tableNameLength, fmt::ptr(columnName), + columnNameLength); + + using ODBC::ODBCStatement; + using ODBC::SqlWcharToString; + + return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() { + ODBCStatement* statement = reinterpret_cast(stmt); + + std::string catalog = SqlWcharToString(catalogName, catalogNameLength); + std::string schema = SqlWcharToString(schemaName, schemaNameLength); + std::string table = SqlWcharToString(tableName, tableNameLength); + std::string column = SqlWcharToString(columnName, columnNameLength); + + statement->GetColumns(catalogName ? &catalog : nullptr, + schemaName ? &schema : nullptr, tableName ? &table : nullptr, + columnName ? &column : nullptr); + + return SQL_SUCCESS; + }); +} + +SQLRETURN SQLColAttribute(SQLHSTMT stmt, SQLUSMALLINT recordNumber, + SQLUSMALLINT fieldIdentifier, SQLPOINTER characterAttributePtr, + SQLSMALLINT bufferLength, SQLSMALLINT* outputLength, + SQLLEN* numericAttributePtr) { + LOG_DEBUG( + "SQLColAttributeW called with stmt: {}, recordNumber: {}, " + "fieldIdentifier: {}, characterAttributePtr: {}, bufferLength: {}, " + "outputLength: {}, numericAttributePtr: {}", + stmt, recordNumber, fieldIdentifier, characterAttributePtr, bufferLength, + fmt::ptr(outputLength), fmt::ptr(numericAttributePtr)); + using ODBC::ODBCDescriptor; + using ODBC::ODBCStatement; + return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() { + ODBCStatement* statement = reinterpret_cast(stmt); + ODBCDescriptor* ird = statement->GetIRD(); + SQLINTEGER outputLengthInt; + switch (fieldIdentifier) { + // Numeric attributes + // internal is SQLLEN, no conversion is needed + case SQL_DESC_DISPLAY_SIZE: + case SQL_DESC_OCTET_LENGTH: { + ird->GetField(recordNumber, fieldIdentifier, numericAttributePtr, bufferLength, + &outputLengthInt); + break; + } + // internal is SQLULEN, conversion is needed. + case SQL_COLUMN_LENGTH: // ODBC 2.0 + case SQL_DESC_LENGTH: { + SQLULEN temp; + ird->GetField(recordNumber, fieldIdentifier, &temp, bufferLength, + &outputLengthInt); + if (numericAttributePtr) { + *numericAttributePtr = static_cast(temp); + } + break; + } + // internal is SQLINTEGER, conversion is needed. + case SQL_DESC_AUTO_UNIQUE_VALUE: + case SQL_DESC_CASE_SENSITIVE: + case SQL_DESC_NUM_PREC_RADIX: { + SQLINTEGER temp; + ird->GetField(recordNumber, fieldIdentifier, &temp, bufferLength, + &outputLengthInt); + if (numericAttributePtr) { + *numericAttributePtr = static_cast(temp); + } + break; + } + // internal is SQLSMALLINT, conversion is needed. + case SQL_DESC_CONCISE_TYPE: + case SQL_DESC_COUNT: + case SQL_DESC_FIXED_PREC_SCALE: + case SQL_DESC_TYPE: + case SQL_DESC_NULLABLE: + case SQL_COLUMN_PRECISION: // ODBC 2.0 + case SQL_DESC_PRECISION: + case SQL_COLUMN_SCALE: // ODBC 2.0 + case SQL_DESC_SCALE: + case SQL_DESC_SEARCHABLE: + case SQL_DESC_UNNAMED: + case SQL_DESC_UNSIGNED: + case SQL_DESC_UPDATABLE: { + SQLSMALLINT temp; + ird->GetField(recordNumber, fieldIdentifier, &temp, bufferLength, + &outputLengthInt); + if (numericAttributePtr) { + *numericAttributePtr = static_cast(temp); + } + break; + } + // Character attributes + case SQL_DESC_BASE_COLUMN_NAME: + case SQL_DESC_BASE_TABLE_NAME: + case SQL_DESC_CATALOG_NAME: + case SQL_DESC_LABEL: + case SQL_DESC_LITERAL_PREFIX: + case SQL_DESC_LITERAL_SUFFIX: + case SQL_DESC_LOCAL_TYPE_NAME: + case SQL_DESC_NAME: + case SQL_DESC_SCHEMA_NAME: + case SQL_DESC_TABLE_NAME: + case SQL_DESC_TYPE_NAME: + ird->GetField(recordNumber, fieldIdentifier, characterAttributePtr, bufferLength, + &outputLengthInt); + break; + default: + throw DriverException("Invalid descriptor field", "HY091"); + } + if (outputLength) { + *outputLength = static_cast(outputLengthInt); + } + return SQL_SUCCESS; + }); +} + +SQLRETURN SQLGetTypeInfo(SQLHSTMT stmt, SQLSMALLINT dataType) { + // GH-47237 return SQL_PRED_CHAR and SQL_PRED_BASIC for + // appropriate data types in `SEARCHABLE` field + LOG_DEBUG("SQLGetTypeInfoW called with stmt: {} dataType: {}", stmt, dataType); + using ODBC::ODBCStatement; + return ODBC::ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() { + ODBCStatement* statement = reinterpret_cast(stmt); + + switch (dataType) { + case SQL_ALL_TYPES: + case SQL_CHAR: + case SQL_VARCHAR: + case SQL_LONGVARCHAR: + case SQL_WCHAR: + case SQL_WVARCHAR: + case SQL_WLONGVARCHAR: + case SQL_BIT: + case SQL_BINARY: + case SQL_VARBINARY: + case SQL_LONGVARBINARY: + case SQL_TINYINT: + case SQL_SMALLINT: + case SQL_INTEGER: + case SQL_BIGINT: + case SQL_NUMERIC: + case SQL_DECIMAL: + case SQL_FLOAT: + case SQL_REAL: + case SQL_DOUBLE: + case SQL_GUID: + case SQL_DATE: + case SQL_TYPE_DATE: + case SQL_TIME: + case SQL_TYPE_TIME: + case SQL_TIMESTAMP: + case SQL_TYPE_TIMESTAMP: + case SQL_INTERVAL_DAY: + case SQL_INTERVAL_DAY_TO_HOUR: + case SQL_INTERVAL_DAY_TO_MINUTE: + case SQL_INTERVAL_DAY_TO_SECOND: + case SQL_INTERVAL_HOUR: + case SQL_INTERVAL_HOUR_TO_MINUTE: + case SQL_INTERVAL_HOUR_TO_SECOND: + case SQL_INTERVAL_MINUTE: + case SQL_INTERVAL_MINUTE_TO_SECOND: + case SQL_INTERVAL_SECOND: + case SQL_INTERVAL_YEAR: + case SQL_INTERVAL_YEAR_TO_MONTH: + case SQL_INTERVAL_MONTH: + statement->GetTypeInfo(dataType); + break; + default: + throw DriverException("Invalid SQL data type", "HY004"); + } + + return SQL_SUCCESS; + }); +} + +SQLRETURN SQLNativeSql(SQLHDBC connectionHandle, SQLWCHAR* inStatementText, + SQLINTEGER inStatementTextLength, SQLWCHAR* outStatementText, + SQLINTEGER bufferLength, SQLINTEGER* outStatementTextLength) { + LOG_DEBUG( + "SQLNativeSqlW called with connectionHandle: {}, inStatementText: {}, " + "inStatementTextLength: {}, outStatementText: {}, bufferLength: {}, " + "outStatementTextLength: {}", + connectionHandle, fmt::ptr(inStatementText), inStatementTextLength, + fmt::ptr(outStatementText), bufferLength, fmt::ptr(outStatementTextLength)); + + using driver::odbcabstraction::Diagnostics; + using ODBC::GetAttributeSQLWCHAR; + using ODBC::ODBCConnection; + using ODBC::SqlWcharToString; + + return ODBCConnection::ExecuteWithDiagnostics(connectionHandle, SQL_ERROR, [=]() { + const bool isLengthInBytes = false; + + ODBCConnection* connection = reinterpret_cast(connectionHandle); + Diagnostics& diagnostics = connection->GetDiagnostics(); + + std::string inStatementStr = SqlWcharToString(inStatementText, inStatementTextLength); + + return GetAttributeSQLWCHAR(inStatementStr, isLengthInBytes, outStatementText, + bufferLength, outStatementTextLength, diagnostics); + }); +} + +SQLRETURN SQLDescribeCol(SQLHSTMT stmt, SQLUSMALLINT columnNumber, SQLWCHAR* columnName, + SQLSMALLINT bufferLength, SQLSMALLINT* nameLengthPtr, + SQLSMALLINT* dataTypePtr, SQLULEN* columnSizePtr, + SQLSMALLINT* decimalDigitsPtr, SQLSMALLINT* nullablePtr) { + LOG_DEBUG( + "SQLDescribeColW called with stmt: {}, columnNumber: {}, " + "columnName: {}, bufferLength: {}, nameLengthPtr: {}, dataTypePtr: {}, " + "columnSizePtr: {}, decimalDigitsPtr: {}, nullablePtr: {}", + stmt, columnNumber, fmt::ptr(columnName), bufferLength, fmt::ptr(nameLengthPtr), + fmt::ptr(dataTypePtr), fmt::ptr(columnSizePtr), fmt::ptr(decimalDigitsPtr), + fmt::ptr(nullablePtr)); + using ODBC::ODBCDescriptor; + using ODBC::ODBCStatement; + + return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() { + ODBCStatement* statement = reinterpret_cast(stmt); + ODBCDescriptor* ird = statement->GetIRD(); + SQLINTEGER outputLengthInt; + SQLSMALLINT sqlType; + + // Column SQL Type + ird->GetField(columnNumber, SQL_DESC_CONCISE_TYPE, &sqlType, sizeof(SQLSMALLINT), + nullptr); + if (dataTypePtr) { + *dataTypePtr = sqlType; + } + + // Column Name + if (columnName || nameLengthPtr) { + ird->GetField(columnNumber, SQL_DESC_NAME, columnName, bufferLength, + &outputLengthInt); + if (nameLengthPtr) { + // returned length should be in characters + *nameLengthPtr = static_cast(outputLengthInt / GetSqlWCharSize()); + } + } + + // Column Size + if (columnSizePtr) { + switch (sqlType) { + // All numeric types + case SQL_DECIMAL: + case SQL_NUMERIC: + case SQL_TINYINT: + case SQL_SMALLINT: + case SQL_INTEGER: + case SQL_BIGINT: + case SQL_REAL: + case SQL_FLOAT: + case SQL_DOUBLE: { + ird->GetField(columnNumber, SQL_DESC_PRECISION, columnSizePtr, sizeof(SQLULEN), + nullptr); + break; + } + + default: { + ird->GetField(columnNumber, SQL_DESC_LENGTH, columnSizePtr, sizeof(SQLULEN), + nullptr); + } + } + } + + // Column Decimal Digits + if (decimalDigitsPtr) { + switch (sqlType) { + // All exact numeric types + case SQL_TINYINT: + case SQL_SMALLINT: + case SQL_INTEGER: + case SQL_BIGINT: + case SQL_DECIMAL: + case SQL_NUMERIC: { + ird->GetField(columnNumber, SQL_DESC_SCALE, decimalDigitsPtr, sizeof(SQLULEN), + nullptr); + break; + } + + // All datetime types (ODBC2) + case SQL_DATE: + case SQL_TIME: + case SQL_TIMESTAMP: + // All datetime types (ODBC3) + case SQL_TYPE_DATE: + case SQL_TYPE_TIME: + case SQL_TYPE_TIMESTAMP: + // All interval types with a seconds component + case SQL_INTERVAL_SECOND: + case SQL_INTERVAL_MINUTE_TO_SECOND: + case SQL_INTERVAL_HOUR_TO_SECOND: + case SQL_INTERVAL_DAY_TO_SECOND: { + ird->GetField(columnNumber, SQL_DESC_PRECISION, decimalDigitsPtr, + sizeof(SQLULEN), nullptr); + break; + } + + default: { + // All character and binary types + // SQL_BIT + // All approximate numeric types + // All interval types with no seconds component + *decimalDigitsPtr = static_cast(0); + } + } + } + + // Column Nullable + if (nullablePtr) { + ird->GetField(columnNumber, SQL_DESC_NULLABLE, nullablePtr, sizeof(SQLSMALLINT), + nullptr); + } + + return SQL_SUCCESS; + }); +} + +} // namespace arrow diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_api.h b/cpp/src/arrow/flight/sql/odbc/odbc_api.h new file mode 100644 index 00000000000..94a7dc0ec3e --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/odbc_api.h @@ -0,0 +1,104 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#ifdef _WIN32 +# include +#endif + +#include +#include +#include + +// @file odbc_api.h +// +// Define internal ODBC API function headers. +namespace arrow { +SQLRETURN SQLAllocHandle(SQLSMALLINT type, SQLHANDLE parent, SQLHANDLE* result); +SQLRETURN SQLFreeHandle(SQLSMALLINT type, SQLHANDLE handle); +SQLRETURN SQLFreeStmt(SQLHSTMT stmt, SQLUSMALLINT option); +SQLRETURN SQLGetDiagField(SQLSMALLINT handleType, SQLHANDLE handle, SQLSMALLINT recNumber, + SQLSMALLINT diagIdentifier, SQLPOINTER diagInfoPtr, + SQLSMALLINT bufferLength, SQLSMALLINT* stringLengthPtr); +SQLRETURN SQLGetDiagRec(SQLSMALLINT handleType, SQLHANDLE handle, SQLSMALLINT recNumber, + SQLWCHAR* sqlState, SQLINTEGER* nativeErrorPtr, + SQLWCHAR* messageText, SQLSMALLINT bufferLength, + SQLSMALLINT* textLengthPtr); +SQLRETURN SQLGetEnvAttr(SQLHENV env, SQLINTEGER attr, SQLPOINTER valuePtr, + SQLINTEGER bufferLen, SQLINTEGER* strLenPtr); +SQLRETURN SQLSetEnvAttr(SQLHENV env, SQLINTEGER attr, SQLPOINTER valuePtr, + SQLINTEGER strLen); +SQLRETURN SQLGetConnectAttr(SQLHDBC conn, SQLINTEGER attribute, SQLPOINTER valuePtr, + SQLINTEGER bufferLength, SQLINTEGER* stringLengthPtr); +SQLRETURN SQLSetConnectAttr(SQLHDBC conn, SQLINTEGER attr, SQLPOINTER value, + SQLINTEGER valueLen); +SQLRETURN SQLDriverConnect(SQLHDBC conn, SQLHWND windowHandle, + SQLWCHAR* inConnectionString, + SQLSMALLINT inConnectionStringLen, + SQLWCHAR* outConnectionString, + SQLSMALLINT outConnectionStringBufferLen, + SQLSMALLINT* outConnectionStringLen, + SQLUSMALLINT driverCompletion); +SQLRETURN SQLConnect(SQLHDBC conn, SQLWCHAR* dsnName, SQLSMALLINT dsnNameLen, + SQLWCHAR* userName, SQLSMALLINT userNameLen, SQLWCHAR* password, + SQLSMALLINT passwordLen); +SQLRETURN SQLDisconnect(SQLHDBC conn); +SQLRETURN SQLGetInfo(SQLHDBC conn, SQLUSMALLINT infoType, SQLPOINTER infoValuePtr, + SQLSMALLINT bufLen, SQLSMALLINT* length); +SQLRETURN SQLGetStmtAttr(SQLHSTMT stmt, SQLINTEGER attribute, SQLPOINTER valuePtr, + SQLINTEGER bufferLength, SQLINTEGER* stringLengthPtr); +SQLRETURN SQLSetStmtAttr(SQLHSTMT stmt, SQLINTEGER attribute, SQLPOINTER valuePtr, + SQLINTEGER stringLength); +SQLRETURN SQLExecDirect(SQLHSTMT stmt, SQLWCHAR* queryText, SQLINTEGER textLength); +SQLRETURN SQLPrepare(SQLHSTMT stmt, SQLWCHAR* queryText, SQLINTEGER textLength); +SQLRETURN SQLExecute(SQLHSTMT stmt); +SQLRETURN SQLFetch(SQLHSTMT stmt); +SQLRETURN SQLExtendedFetch(SQLHSTMT stmt, SQLUSMALLINT fetchOrientation, + SQLLEN fetchOffset, SQLULEN* rowCountPtr, + SQLUSMALLINT* rowStatusArray); +SQLRETURN SQLFetchScroll(SQLHSTMT stmt, SQLSMALLINT fetchOrientation, SQLLEN fetchOffset); +SQLRETURN SQLBindCol(SQLHSTMT stmt, SQLUSMALLINT recordNumber, SQLSMALLINT cType, + SQLPOINTER dataPtr, SQLLEN bufferLength, SQLLEN* indicatorPtr); +SQLRETURN SQLCloseCursor(SQLHSTMT stmt); +SQLRETURN SQLGetData(SQLHSTMT stmt, SQLUSMALLINT recordNumber, SQLSMALLINT cType, + SQLPOINTER dataPtr, SQLLEN bufferLength, SQLLEN* indicatorPtr); +SQLRETURN SQLMoreResults(SQLHSTMT stmt); +SQLRETURN SQLNumResultCols(SQLHSTMT stmt, SQLSMALLINT* columnCountPtr); +SQLRETURN SQLRowCount(SQLHSTMT stmt, SQLLEN* rowCountPtr); +SQLRETURN SQLTables(SQLHSTMT stmt, SQLWCHAR* catalogName, SQLSMALLINT catalogNameLength, + SQLWCHAR* schemaName, SQLSMALLINT schemaNameLength, + SQLWCHAR* tableName, SQLSMALLINT tableNameLength, SQLWCHAR* tableType, + SQLSMALLINT tableTypeLength); +SQLRETURN SQLColumns(SQLHSTMT stmt, SQLWCHAR* catalogName, SQLSMALLINT catalogNameLength, + SQLWCHAR* schemaName, SQLSMALLINT schemaNameLength, + SQLWCHAR* tableName, SQLSMALLINT tableNameLength, + SQLWCHAR* columnName, SQLSMALLINT columnNameLength); +SQLRETURN SQLColAttribute(SQLHSTMT stmt, SQLUSMALLINT recordNumber, + SQLUSMALLINT fieldIdentifier, SQLPOINTER characterAttributePtr, + SQLSMALLINT bufferLength, SQLSMALLINT* outputLength, + SQLLEN* numericAttributePtr); +SQLRETURN SQLGetTypeInfo(SQLHSTMT stmt, SQLSMALLINT dataType); +SQLRETURN SQLNativeSql(SQLHDBC connectionHandle, SQLWCHAR* inStatementText, + SQLINTEGER inStatementTextLength, SQLWCHAR* outStatementText, + SQLINTEGER bufferLength, SQLINTEGER* outStatementTextLength); +SQLRETURN SQLDescribeCol(SQLHSTMT statementHandle, SQLUSMALLINT columnNumber, + SQLWCHAR* columnName, SQLSMALLINT bufferLength, + SQLSMALLINT* nameLengthPtr, SQLSMALLINT* dataTypePtr, + SQLULEN* columnSizePtr, SQLSMALLINT* decimalDigitsPtr, + SQLSMALLINT* nullablePtr); +} // namespace arrow diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/CMakeLists.txt b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/CMakeLists.txt index c9614b88a5b..dd8b6dd2f1e 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/CMakeLists.txt +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/CMakeLists.txt @@ -17,9 +17,6 @@ include_directories(include) -# Ensure fmt is loaded as header only -add_compile_definitions(FMT_HEADER_ONLY) - add_library(odbcabstraction include/odbcabstraction/calendar_utils.h include/odbcabstraction/diagnostics.h @@ -66,19 +63,4 @@ set_target_properties(odbcabstraction RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/$/lib) -include(FetchContent) -fetchcontent_declare(spdlog - URL https://github.com/gabime/spdlog/archive/76fb40d95455f249bd70824ecfcae7a8f0930fa3.zip - CONFIGURE_COMMAND - "" - BUILD_COMMAND - "") -fetchcontent_getproperties(spdlog) -if(NOT spdlog_POPULATED) - fetchcontent_populate(spdlog) -endif() - -add_library(spdlog INTERFACE) -target_include_directories(spdlog INTERFACE ${spdlog_SOURCE_DIR}/include) - -target_link_libraries(odbcabstraction PUBLIC spdlog) +target_link_libraries(odbcabstraction PUBLIC spdlog::spdlog) diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/diagnostics.cc b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/diagnostics.cc index 8c94978ef99..78ca45ea2fe 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/diagnostics.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/diagnostics.cc @@ -15,9 +15,9 @@ // specific language governing permissions and limitations // under the License. -#include -#include -#include +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/diagnostics.h" +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/platform.h" +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/types.h" #include diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/encoding.cc b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/encoding.cc index 95dc920da78..00718cdbbe5 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/encoding.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/encoding.cc @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -#include +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/encoding.h" #if defined(__APPLE__) # include diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/exceptions.cc b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/exceptions.cc index fcd8163a500..242c85e5a28 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/exceptions.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/exceptions.cc @@ -15,8 +15,9 @@ // specific language governing permissions and limitations // under the License. -#include -#include +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/exceptions.h" +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/platform.h" + #include namespace driver { diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/diagnostics.h b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/diagnostics.h index f1c6efe4982..473411efd4f 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/diagnostics.h +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/diagnostics.h @@ -21,8 +21,8 @@ #include #include -#include -#include +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/exceptions.h" +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/types.h" namespace driver { namespace odbcabstraction { diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/exceptions.h b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/exceptions.h index 48a773e4f4d..82ffebedff6 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/exceptions.h +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/exceptions.h @@ -17,10 +17,10 @@ #pragma once -#include #include #include #include +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/error_codes.h" namespace driver { namespace odbcabstraction { diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/logger.h b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/logger.h index 5f8619cbb92..6249df98834 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/logger.h +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/logger.h @@ -18,9 +18,13 @@ #pragma once #include +#include #include -#include +#include + +// The logger using spdlog is deprecated and will be replaced. +// TODO: mirgate logging to use Arrow's internal logging system #define __LAZY_LOG(LEVEL, ...) \ do { \ diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/attribute_utils.h b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/attribute_utils.h index 9163e942ceb..d194ace237f 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/attribute_utils.h +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/attribute_utils.h @@ -17,16 +17,16 @@ #pragma once -#include -#include -#include #include #include #include #include #include +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/diagnostics.h" +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/exceptions.h" +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/platform.h" -#include +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/encoding_utils.h" namespace ODBC { using driver::odbcabstraction::WcsToUtf8; @@ -45,12 +45,12 @@ inline void GetAttribute(T attributeValue, SQLPOINTER output, O outputSize, } template -inline SQLRETURN GetAttributeUTF8(const std::string& attributeValue, SQLPOINTER output, - O outputSize, O* outputLenPtr) { +inline SQLRETURN GetAttributeUTF8(const std::string_view& attributeValue, + SQLPOINTER output, O outputSize, O* outputLenPtr) { if (output) { size_t outputLenBeforeNul = std::min(static_cast(attributeValue.size()), static_cast(outputSize - 1)); - memcpy(output, attributeValue.c_str(), outputLenBeforeNul); + memcpy(output, attributeValue.data(), outputLenBeforeNul); reinterpret_cast(output)[outputLenBeforeNul] = '\0'; } @@ -65,8 +65,8 @@ inline SQLRETURN GetAttributeUTF8(const std::string& attributeValue, SQLPOINTER } template -inline SQLRETURN GetAttributeUTF8(const std::string& attributeValue, SQLPOINTER output, - O outputSize, O* outputLenPtr, +inline SQLRETURN GetAttributeUTF8(const std::string_view& attributeValue, + SQLPOINTER output, O outputSize, O* outputLenPtr, driver::odbcabstraction::Diagnostics& diagnostics) { SQLRETURN result = GetAttributeUTF8(attributeValue, output, outputSize, outputLenPtr); if (SQL_SUCCESS_WITH_INFO == result) { @@ -76,26 +76,30 @@ inline SQLRETURN GetAttributeUTF8(const std::string& attributeValue, SQLPOINTER } template -inline SQLRETURN GetAttributeSQLWCHAR(const std::string& attributeValue, +inline SQLRETURN GetAttributeSQLWCHAR(const std::string_view& attributeValue, bool isLengthInBytes, SQLPOINTER output, O outputSize, O* outputLenPtr) { - size_t result = + size_t length = ConvertToSqlWChar(attributeValue, reinterpret_cast(output), isLengthInBytes ? outputSize : outputSize * GetSqlWCharSize()); + if (!isLengthInBytes) { + length = length / GetSqlWCharSize(); + } + if (outputLenPtr) { - *outputLenPtr = static_cast(isLengthInBytes ? result : result / GetSqlWCharSize()); + *outputLenPtr = static_cast(length); } if (output && - outputSize < static_cast(result + (isLengthInBytes ? GetSqlWCharSize() : 1))) { + outputSize < static_cast(length + (isLengthInBytes ? GetSqlWCharSize() : 1))) { return SQL_SUCCESS_WITH_INFO; } return SQL_SUCCESS; } template -inline SQLRETURN GetAttributeSQLWCHAR(const std::string& attributeValue, +inline SQLRETURN GetAttributeSQLWCHAR(const std::string_view& attributeValue, bool isLengthInBytes, SQLPOINTER output, O outputSize, O* outputLenPtr, driver::odbcabstraction::Diagnostics& diagnostics) { @@ -108,7 +112,8 @@ inline SQLRETURN GetAttributeSQLWCHAR(const std::string& attributeValue, } template -inline SQLRETURN GetStringAttribute(bool isUnicode, const std::string& attributeValue, +inline SQLRETURN GetStringAttribute(bool isUnicode, + const std::string_view& attributeValue, bool isLengthInBytes, SQLPOINTER output, O outputSize, O* outputLenPtr, driver::odbcabstraction::Diagnostics& diagnostics) { diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/encoding_utils.h b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/encoding_utils.h index 25619bb5555..94f4569ba89 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/encoding_utils.h +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/encoding_utils.h @@ -16,9 +16,9 @@ // under the License. #pragma once +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/encoding.h" +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/platform.h" -#include -#include #include #include #include @@ -34,10 +34,11 @@ namespace ODBC { using driver::odbcabstraction::DriverException; using driver::odbcabstraction::GetSqlWCharSize; using driver::odbcabstraction::Utf8ToWcs; +using driver::odbcabstraction::WcsToUtf8; // Return the number of bytes required for the conversion. template -inline size_t ConvertToSqlWChar(const std::string& str, SQLWCHAR* buffer, +inline size_t ConvertToSqlWChar(const std::string_view& str, SQLWCHAR* buffer, SQLLEN bufferSizeInBytes) { thread_local std::vector wstr; Utf8ToWcs(str.data(), str.size(), &wstr); @@ -63,7 +64,7 @@ inline size_t ConvertToSqlWChar(const std::string& str, SQLWCHAR* buffer, return valueLengthInBytes; } -inline size_t ConvertToSqlWChar(const std::string& str, SQLWCHAR* buffer, +inline size_t ConvertToSqlWChar(const std::string_view& str, SQLWCHAR* buffer, SQLLEN bufferSizeInBytes) { switch (GetSqlWCharSize()) { case sizeof(char16_t): @@ -77,4 +78,39 @@ inline size_t ConvertToSqlWChar(const std::string& str, SQLWCHAR* buffer, } } +/// \brief Convert buffer of SqlWchar to standard string +/// \param[in] wchar_msg SqlWchar to convert +/// \param[in] msg_len Number of characters in wchar_msg +/// \return wchar_msg in std::string format +inline std::string SqlWcharToString(SQLWCHAR* wchar_msg, SQLINTEGER msg_len = SQL_NTS) { + if (!wchar_msg || wchar_msg[0] == 0 || msg_len == 0) { + return std::string(); + } + + thread_local std::vector utf8_str; + + if (msg_len == SQL_NTS) { + WcsToUtf8((void*)wchar_msg, &utf8_str); + } else { + WcsToUtf8((void*)wchar_msg, msg_len, &utf8_str); + } + + return std::string(utf8_str.begin(), utf8_str.end()); +} + +inline std::string SqlStringToString(const unsigned char* sqlStr, + int32_t sqlStrLen = SQL_NTS) { + std::string res; + + const char* sqlStrC = reinterpret_cast(sqlStr); + + if (!sqlStr) return res; + + if (sqlStrLen == SQL_NTS) + res.assign(sqlStrC); + else if (sqlStrLen > 0) + res.assign(sqlStrC, sqlStrLen); + + return res; +} } // namespace ODBC diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/odbc_connection.h b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/odbc_connection.h index 6a01fe128d9..0e9498bcb8a 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/odbc_connection.h +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/odbc_connection.h @@ -17,9 +17,9 @@ #pragma once -#include +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/odbc_handle.h" +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/spi/connection.h" -#include #include #include #include @@ -41,6 +41,9 @@ class ODBCConnection : public ODBCHandle { ODBCConnection(const ODBCConnection&) = delete; ODBCConnection& operator=(const ODBCConnection&) = delete; + /// \brief Constructor for ODBCConnection. + /// \param[in] environment the parent environment. + /// \param[in] spiConnection the underlying spi connection. ODBCConnection(ODBCEnvironment& environment, std::shared_ptr spiConnection); @@ -48,16 +51,22 @@ class ODBCConnection : public ODBCHandle { const std::string& GetDSN() const; bool isConnected() const; + + /// \brief Connect to Arrow Flight SQL server. + /// \param[in] dsn the dsn name. + /// \param[in] properties the connection property map extracted from connection string. + /// \param[out] missing_properties report the properties that are missing void connect(std::string dsn, const driver::odbcabstraction::Connection::ConnPropertyMap& properties, std::vector& missing_properties); - void GetInfo(SQLUSMALLINT infoType, SQLPOINTER value, SQLSMALLINT bufferLength, - SQLSMALLINT* outputLength, bool isUnicode); + SQLRETURN GetInfo(SQLUSMALLINT infoType, SQLPOINTER value, SQLSMALLINT bufferLength, + SQLSMALLINT* outputLength, bool isUnicode); void SetConnectAttr(SQLINTEGER attribute, SQLPOINTER value, SQLINTEGER stringLength, bool isUnicode); - void GetConnectAttr(SQLINTEGER attribute, SQLPOINTER value, SQLINTEGER bufferLength, - SQLINTEGER* outputLength, bool isUnicode); + SQLRETURN GetConnectAttr(SQLINTEGER attribute, SQLPOINTER value, + SQLINTEGER bufferLength, SQLINTEGER* outputLength, + bool isUnicode); ~ODBCConnection() = default; diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/odbc_descriptor.h b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/odbc_descriptor.h index 092483f4719..e7656082c5c 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/odbc_descriptor.h +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/odbc_descriptor.h @@ -17,7 +17,7 @@ #pragma once -#include +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/odbc_handle.h" #include #include diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/odbc_handle.h b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/odbc_handle.h index c2428df394d..64257541a87 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/odbc_handle.h +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/odbc_handle.h @@ -17,12 +17,14 @@ #pragma once -#include -#include +// platform.h includes windows.h, so it needs to be included first +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/platform.h" + #include #include #include #include +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/diagnostics.h" /** * @brief An abstraction over a generic ODBC handle. @@ -47,7 +49,7 @@ class ODBCHandle { rc = function(); } catch (const driver::odbcabstraction::DriverException& ex) { GetDiagnostics().AddError(ex); - } catch (const std::bad_alloc& ex) { + } catch (const std::bad_alloc&) { GetDiagnostics().AddError(driver::odbcabstraction::DriverException( "A memory allocation error occurred.", "HY001")); } catch (const std::exception& ex) { diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/odbc_statement.h b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/odbc_statement.h index bbddfac4185..7fb8d5c5741 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/odbc_statement.h +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/odbc_statement.h @@ -17,9 +17,11 @@ #pragma once -#include +// platform.h platform.h includes windows.h so it needs to be included first +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/platform.h" + +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/odbc_handle.h" -#include #include #include #include @@ -64,8 +66,10 @@ class ODBCStatement : public ODBCHandle { /** * @brief Returns true if the number of rows fetch was greater than zero. + * rowCountPtr and rowStatusArray are optional arguments, they are only needed for + * SQLExtendedFetch */ - bool Fetch(size_t rows); + bool Fetch(size_t rows, SQLULEN* rowCountPtr = 0, SQLUSMALLINT* rowStatusArray = 0); bool isPrepared() const; void GetStmtAttr(SQLINTEGER statementAttribute, SQLPOINTER output, @@ -73,6 +77,11 @@ class ODBCStatement : public ODBCHandle { void SetStmtAttr(SQLINTEGER statementAttribute, SQLPOINTER value, SQLINTEGER bufferSize, bool isUnicode); + /** + * @brief Revert back to implicitly allocated internal descriptors. + * isApd as True indicates APD descritor is to be reverted. + * isApd as False indicates ARD descritor is to be reverted. + */ void RevertAppDescriptor(bool isApd); inline ODBCDescriptor* GetIRD() { return m_ird.get(); } @@ -81,8 +90,20 @@ class ODBCStatement : public ODBCHandle { inline SQLULEN GetRowsetSize() { return m_rowsetSize; } - bool GetData(SQLSMALLINT recordNumber, SQLSMALLINT cType, SQLPOINTER dataPtr, - SQLLEN bufferLength, SQLLEN* indicatorPtr); + SQLRETURN GetData(SQLSMALLINT recordNumber, SQLSMALLINT cType, SQLPOINTER dataPtr, + SQLLEN bufferLength, SQLLEN* indicatorPtr); + + SQLRETURN getMoreResults(); + + /** + * @brief Get number of columns from data set + */ + void getColumnCount(SQLSMALLINT* columnCountPtr); + + /** + * @brief Get number of rows affected by an UPDATE, INSERT, or DELETE statement + */ + void getRowCount(SQLLEN* rowCountPtr); /** * @brief Closes the cursor. This does _not_ un-prepare the statement or change diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/spi/connection.h b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/spi/connection.h index 792a52c1fad..ce86882c952 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/spi/connection.h +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/spi/connection.h @@ -25,21 +25,21 @@ #include #include -#include -#include +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/diagnostics.h" +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/types.h" namespace driver { namespace odbcabstraction { /// \brief Case insensitive comparator struct CaseInsensitiveComparator { - bool operator()(const std::string_view& s1, const std::string_view& s2) const { + bool operator()(const std::string& s1, const std::string& s2) const { return boost::lexicographical_compare(s1, s2, boost::is_iless()); } }; // PropertyMap is case-insensitive for keys. -typedef std::map PropertyMap; +typedef std::map PropertyMap; class Statement; diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/spi/driver.h b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/spi/driver.h index f13371bf2d5..61d570574c7 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/spi/driver.h +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/spi/driver.h @@ -19,8 +19,8 @@ #include -#include -#include +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/diagnostics.h" +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/types.h" namespace driver { namespace odbcabstraction { diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/spi/result_set.h b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/spi/result_set.h index 1b3f8eb96d8..c24c6424860 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/spi/result_set.h +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/spi/result_set.h @@ -20,9 +20,11 @@ #include #include -#include +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/platform.h" -#include +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/types.h" + +#include namespace driver { namespace odbcabstraction { @@ -88,10 +90,10 @@ class ResultSet { /// \param buffer Target buffer to be populated. /// \param buffer_length Target buffer length. /// \param strlen_buffer Buffer that holds the length of value being fetched. - /// \returns true if there is more data to fetch from the current cell; - /// false if the whole value was already fetched. - virtual bool GetData(int column, int16_t target_type, int precision, int scale, - void* buffer, size_t buffer_length, ssize_t* strlen_buffer) = 0; + /// \returns SQLRETURN for SQLGetData. + virtual SQLRETURN GetData(int column, int16_t target_type, int precision, int scale, + void* buffer, size_t buffer_length, + ssize_t* strlen_buffer) = 0; }; } // namespace odbcabstraction diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/spi/result_set_metadata.h b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/spi/result_set_metadata.h index f625a2598c1..636dce21e4a 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/spi/result_set_metadata.h +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/spi/result_set_metadata.h @@ -17,8 +17,8 @@ #pragma once -#include #include +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/types.h" namespace driver { namespace odbcabstraction { diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/types.h b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/types.h index e5d206a2ca7..8f16000daaa 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/types.h +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/types.h @@ -17,8 +17,8 @@ #pragma once -#include #include +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/platform.h" namespace driver { namespace odbcabstraction { diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/utils.h b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/utils.h index cc848baa0fd..6e1fe5739be 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/utils.h +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/utils.h @@ -17,10 +17,10 @@ #pragma once -#include -#include #include #include +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/logger.h" +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/spi/connection.h" namespace driver { namespace odbcabstraction { @@ -52,7 +52,8 @@ boost::optional AsInt32(int32_t min_value, const Connection::ConnPropertyMap& connPropertyMap, const std::string_view& property_name); -void ReadConfigFile(PropertyMap& properties, const std::string& configFileName); +void ReadConfigFile(PropertyMap& properties, const std::string& configPath, + const std::string& configFileName); } // namespace odbcabstraction } // namespace driver diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/logger.cc b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/logger.cc index edace64cf6a..8b105a2f0b6 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/logger.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/logger.cc @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -#include +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/logger.h" namespace driver { namespace odbcabstraction { diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/odbc_impl/odbc_connection.cc b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/odbc_impl/odbc_connection.cc index 0143976bb48..337951ede3a 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/odbc_impl/odbc_connection.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/odbc_impl/odbc_connection.cc @@ -17,6 +17,10 @@ #include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/odbc_connection.h" +#include "arrow/result.h" +#include "arrow/util/utf8.h" + +#include "arrow/flight/sql/odbc/flight_sql/include/flight_sql/config/configuration.h" #include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/exceptions.h" #include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/attribute_utils.h" #include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/odbc_descriptor.h" @@ -49,48 +53,16 @@ namespace { // characters such as semi-colons and equals signs. NOTE: This can be optimized to be // built statically. const boost::xpressive::sregex CONNECTION_STR_REGEX( - boost::xpressive::sregex::compile("([^=;]+)=({.+}|[^=;]+|[^;])")); + boost::xpressive::sregex::compile("([^=;]+)=({.+}|[^;]+|[^;])")); // Load properties from the given DSN. The properties loaded do _not_ overwrite existing // entries in the properties. void loadPropertiesFromDSN(const std::string& dsn, Connection::ConnPropertyMap& properties) { - const size_t BUFFER_SIZE = 1024 * 10; - std::vector outputBuffer; - outputBuffer.resize(BUFFER_SIZE, '\0'); - SQLSetConfigMode(ODBC_BOTH_DSN); - - SQLGetPrivateProfileString(dsn.c_str(), NULL, "", &outputBuffer[0], BUFFER_SIZE, - "odbc.ini"); - - // The output buffer holds the list of keys in a series of NUL-terminated strings. - // The series is terminated with an empty string (eg a NUL-terminator terminating the - // last key followed by a NUL terminator after). - std::vector keys; - size_t pos = 0; - while (pos < BUFFER_SIZE) { - std::string key(&outputBuffer[pos]); - if (key.empty()) { - break; - } - size_t len = key.size(); - - // Skip over Driver or DSN keys. - if (!boost::iequals(key, "DSN") && !boost::iequals(key, "Driver")) { - keys.emplace_back(std::move(key)); - } - pos += len + 1; - } - - for (auto& key : keys) { - outputBuffer.clear(); - outputBuffer.resize(BUFFER_SIZE, '\0'); - - std::string key_str = std::string(key); - SQLGetPrivateProfileString(dsn.c_str(), key_str.c_str(), "", &outputBuffer[0], - BUFFER_SIZE, "odbc.ini"); - - std::string value = std::string(&outputBuffer[0]); + driver::flight_sql::config::Configuration config; + config.LoadDsn(dsn); + Connection::ConnPropertyMap dsnProperties = config.GetProperties(); + for (auto& [key, value] : dsnProperties) { auto propIter = properties.find(key); if (propIter == properties.end()) { properties.emplace(std::make_pair(std::move(key), std::move(value))); @@ -131,152 +103,144 @@ void ODBCConnection::connect(std::string dsn, m_attributeTrackingStatement = std::make_shared(*this, spiStatement); } -void ODBCConnection::GetInfo(SQLUSMALLINT infoType, SQLPOINTER value, - SQLSMALLINT bufferLength, SQLSMALLINT* outputLength, - bool isUnicode) { +SQLRETURN ODBCConnection::GetInfo(SQLUSMALLINT infoType, SQLPOINTER value, + SQLSMALLINT bufferLength, SQLSMALLINT* outputLength, + bool isUnicode) { switch (infoType) { case SQL_ACTIVE_ENVIRONMENTS: GetAttribute(static_cast(0), value, bufferLength, outputLength); - break; + return SQL_SUCCESS; #ifdef SQL_ASYNC_DBC_FUNCTIONS case SQL_ASYNC_DBC_FUNCTIONS: GetAttribute(static_cast(SQL_ASYNC_DBC_NOT_CAPABLE), value, bufferLength, outputLength); - break; + return SQL_SUCCESS; #endif case SQL_ASYNC_MODE: GetAttribute(static_cast(SQL_AM_NONE), value, bufferLength, outputLength); - break; + return SQL_SUCCESS; #ifdef SQL_ASYNC_NOTIFICATION case SQL_ASYNC_NOTIFICATION: GetAttribute(static_cast(SQL_ASYNC_NOTIFICATION_NOT_CAPABLE), value, bufferLength, outputLength); - break; + return SQL_SUCCESS; #endif case SQL_BATCH_ROW_COUNT: GetAttribute(static_cast(0), value, bufferLength, outputLength); - break; + return SQL_SUCCESS; case SQL_BATCH_SUPPORT: GetAttribute(static_cast(0), value, bufferLength, outputLength); - break; + return SQL_SUCCESS; case SQL_DATA_SOURCE_NAME: - GetStringAttribute(isUnicode, m_dsn, true, value, bufferLength, outputLength, - GetDiagnostics()); - break; + return GetStringAttribute(isUnicode, m_dsn, true, value, bufferLength, outputLength, + GetDiagnostics()); case SQL_DRIVER_ODBC_VER: - GetStringAttribute(isUnicode, "03.80", true, value, bufferLength, outputLength, - GetDiagnostics()); - break; + return GetStringAttribute(isUnicode, "03.80", true, value, bufferLength, + outputLength, GetDiagnostics()); case SQL_DYNAMIC_CURSOR_ATTRIBUTES1: GetAttribute(static_cast(0), value, bufferLength, outputLength); - break; + return SQL_SUCCESS; case SQL_DYNAMIC_CURSOR_ATTRIBUTES2: GetAttribute(static_cast(0), value, bufferLength, outputLength); - break; + return SQL_SUCCESS; case SQL_FORWARD_ONLY_CURSOR_ATTRIBUTES1: GetAttribute(static_cast(SQL_CA1_NEXT), value, bufferLength, outputLength); - break; + return SQL_SUCCESS; case SQL_FORWARD_ONLY_CURSOR_ATTRIBUTES2: GetAttribute(static_cast(SQL_CA2_READ_ONLY_CONCURRENCY), value, bufferLength, outputLength); - break; + return SQL_SUCCESS; case SQL_FILE_USAGE: GetAttribute(static_cast(SQL_FILE_NOT_SUPPORTED), value, bufferLength, outputLength); - break; + return SQL_SUCCESS; case SQL_KEYSET_CURSOR_ATTRIBUTES1: GetAttribute(static_cast(0), value, bufferLength, outputLength); - break; + return SQL_SUCCESS; case SQL_KEYSET_CURSOR_ATTRIBUTES2: GetAttribute(static_cast(0), value, bufferLength, outputLength); - break; + return SQL_SUCCESS; case SQL_MAX_ASYNC_CONCURRENT_STATEMENTS: GetAttribute(static_cast(0), value, bufferLength, outputLength); - break; + return SQL_SUCCESS; case SQL_ODBC_INTERFACE_CONFORMANCE: GetAttribute(static_cast(SQL_OIC_CORE), value, bufferLength, outputLength); - break; + return SQL_SUCCESS; // case SQL_ODBC_STANDARD_CLI_CONFORMANCE: - mentioned in SQLGetInfo spec with no // description and there is no constant for this. case SQL_PARAM_ARRAY_ROW_COUNTS: GetAttribute(static_cast(SQL_PARC_NO_BATCH), value, bufferLength, outputLength); - break; + return SQL_SUCCESS; case SQL_PARAM_ARRAY_SELECTS: GetAttribute(static_cast(SQL_PAS_NO_SELECT), value, bufferLength, outputLength); - break; + return SQL_SUCCESS; case SQL_ROW_UPDATES: - GetStringAttribute(isUnicode, "N", true, value, bufferLength, outputLength, - GetDiagnostics()); - break; + return GetStringAttribute(isUnicode, "N", true, value, bufferLength, outputLength, + GetDiagnostics()); case SQL_SCROLL_OPTIONS: GetAttribute(static_cast(SQL_SO_FORWARD_ONLY), value, bufferLength, outputLength); - break; + return SQL_SUCCESS; case SQL_STATIC_CURSOR_ATTRIBUTES1: GetAttribute(static_cast(0), value, bufferLength, outputLength); - break; + return SQL_SUCCESS; case SQL_STATIC_CURSOR_ATTRIBUTES2: GetAttribute(static_cast(0), value, bufferLength, outputLength); - break; + return SQL_SUCCESS; case SQL_BOOKMARK_PERSISTENCE: GetAttribute(static_cast(0), value, bufferLength, outputLength); - break; + return SQL_SUCCESS; case SQL_DESCRIBE_PARAMETER: - GetStringAttribute(isUnicode, "N", true, value, bufferLength, outputLength, - GetDiagnostics()); - break; + return GetStringAttribute(isUnicode, "N", true, value, bufferLength, outputLength, + GetDiagnostics()); case SQL_MULT_RESULT_SETS: - GetStringAttribute(isUnicode, "N", true, value, bufferLength, outputLength, - GetDiagnostics()); - break; + return GetStringAttribute(isUnicode, "N", true, value, bufferLength, outputLength, + GetDiagnostics()); case SQL_MULTIPLE_ACTIVE_TXN: - GetStringAttribute(isUnicode, "N", true, value, bufferLength, outputLength, - GetDiagnostics()); - break; + return GetStringAttribute(isUnicode, "N", true, value, bufferLength, outputLength, + GetDiagnostics()); case SQL_NEED_LONG_DATA_LEN: - GetStringAttribute(isUnicode, "N", true, value, bufferLength, outputLength, - GetDiagnostics()); - break; + return GetStringAttribute(isUnicode, "N", true, value, bufferLength, outputLength, + GetDiagnostics()); case SQL_TXN_CAPABLE: GetAttribute(static_cast(SQL_TC_NONE), value, bufferLength, outputLength); - break; + return SQL_SUCCESS; case SQL_TXN_ISOLATION_OPTION: GetAttribute(static_cast(0), value, bufferLength, outputLength); - break; + return SQL_SUCCESS; case SQL_TABLE_TERM: - GetStringAttribute(isUnicode, "table", true, value, bufferLength, outputLength, - GetDiagnostics()); - break; + return GetStringAttribute(isUnicode, "table", true, value, bufferLength, + outputLength, GetDiagnostics()); // Deprecated ODBC 2.x fields required for backwards compatibility. case SQL_ODBC_API_CONFORMANCE: GetAttribute(static_cast(SQL_OAC_LEVEL1), value, bufferLength, outputLength); - break; + return SQL_SUCCESS; case SQL_FETCH_DIRECTION: GetAttribute(static_cast(SQL_FETCH_NEXT), value, bufferLength, outputLength); - break; + return SQL_SUCCESS; case SQL_LOCK_TYPES: GetAttribute(static_cast(0), value, bufferLength, outputLength); - break; + return SQL_SUCCESS; case SQL_POS_OPERATIONS: GetAttribute(static_cast(0), value, bufferLength, outputLength); - break; + return SQL_SUCCESS; case SQL_POSITIONED_STATEMENTS: GetAttribute(static_cast(0), value, bufferLength, outputLength); - break; + return SQL_SUCCESS; case SQL_SCROLL_CONCURRENCY: GetAttribute(static_cast(0), value, bufferLength, outputLength); - break; + return SQL_SUCCESS; case SQL_STATIC_SENSITIVITY: GetAttribute(static_cast(0), value, bufferLength, outputLength); - break; + return SQL_SUCCESS; // Driver-level string properties. case SQL_USER_NAME: @@ -311,9 +275,8 @@ void ODBCConnection::GetInfo(SQLUSMALLINT infoType, SQLPOINTER value, case SQL_XOPEN_CLI_YEAR: { const auto& info = m_spiConnection->GetInfo(infoType); const std::string& infoValue = boost::get(info); - GetStringAttribute(isUnicode, infoValue, true, value, bufferLength, outputLength, - GetDiagnostics()); - break; + return GetStringAttribute(isUnicode, infoValue, true, value, bufferLength, + outputLength, GetDiagnostics()); } // Driver-level 32-bit integer properties. @@ -403,7 +366,7 @@ void ODBCConnection::GetInfo(SQLUSMALLINT infoType, SQLPOINTER value, const auto& info = m_spiConnection->GetInfo(infoType); uint32_t infoValue = boost::get(info); GetAttribute(infoValue, value, bufferLength, outputLength); - break; + return SQL_SUCCESS; } // Driver-level 16-bit integer properties. @@ -438,7 +401,7 @@ void ODBCConnection::GetInfo(SQLUSMALLINT infoType, SQLPOINTER value, const auto& info = m_spiConnection->GetInfo(infoType); uint16_t infoValue = boost::get(info); GetAttribute(infoValue, value, bufferLength, outputLength); - break; + return SQL_SUCCESS; } // Special case - SQL_DATABASE_NAME is an alias for SQL_ATTR_CURRENT_CATALOG. @@ -448,13 +411,15 @@ void ODBCConnection::GetInfo(SQLUSMALLINT infoType, SQLPOINTER value, throw DriverException("Optional feature not supported.", "HYC00"); } const std::string& infoValue = boost::get(*attr); - GetStringAttribute(isUnicode, infoValue, true, value, bufferLength, outputLength, - GetDiagnostics()); - break; + return GetStringAttribute(isUnicode, infoValue, true, value, bufferLength, + outputLength, GetDiagnostics()); } default: - throw DriverException("Unknown SQLGetInfo type: " + std::to_string(infoType)); + throw DriverException("Unknown SQLGetInfo type: " + std::to_string(infoType), + "HY096"); } + + return SQL_ERROR; } void ODBCConnection::SetConnectAttr(SQLINTEGER attribute, SQLPOINTER value, @@ -463,7 +428,7 @@ void ODBCConnection::SetConnectAttr(SQLINTEGER attribute, SQLPOINTER value, bool successfully_written = false; switch (attribute) { // Internal connection attributes -#ifdef SQL_ATR_ASYNC_DBC_EVENT +#ifdef SQL_ATTR_ASYNC_DBC_EVENT case SQL_ATTR_ASYNC_DBC_EVENT: throw DriverException("Optional feature not supported.", "HYC00"); #endif @@ -471,7 +436,7 @@ void ODBCConnection::SetConnectAttr(SQLINTEGER attribute, SQLPOINTER value, case SQL_ATTR_ASYNC_DBC_FUNCTIONS_ENABLE: throw DriverException("Optional feature not supported.", "HYC00"); #endif -#ifdef SQL_ATTR_ASYNC_PCALLBACK +#ifdef SQL_ATTR_ASYNC_DBC_PCALLBACK case SQL_ATTR_ASYNC_DBC_PCALLBACK: throw DriverException("Optional feature not supported.", "HYC00"); #endif @@ -499,7 +464,7 @@ void ODBCConnection::SetConnectAttr(SQLINTEGER attribute, SQLPOINTER value, throw DriverException("Cannot set read-only attribute", "HY092"); case SQL_ATTR_TRACE: // DM-only throw DriverException("Cannot set read-only attribute", "HY092"); - case SQL_ATTR_TRACEFILE: + case SQL_ATTR_TRACEFILE: // DM-only throw DriverException("Optional feature not supported.", "HYC00"); case SQL_ATTR_TRANSLATE_LIB: throw DriverException("Optional feature not supported.", "HYC00"); @@ -573,59 +538,59 @@ void ODBCConnection::SetConnectAttr(SQLINTEGER attribute, SQLPOINTER value, } } -void ODBCConnection::GetConnectAttr(SQLINTEGER attribute, SQLPOINTER value, - SQLINTEGER bufferLength, SQLINTEGER* outputLength, - bool isUnicode) { +SQLRETURN ODBCConnection::GetConnectAttr(SQLINTEGER attribute, SQLPOINTER value, + SQLINTEGER bufferLength, + SQLINTEGER* outputLength, bool isUnicode) { using driver::odbcabstraction::Connection; boost::optional spiAttribute; switch (attribute) { // Internal connection attributes -#ifdef SQL_ATR_ASYNC_DBC_EVENT +#ifdef SQL_ATTR_ASYNC_DBC_EVENT case SQL_ATTR_ASYNC_DBC_EVENT: GetAttribute(static_cast(NULL), value, bufferLength, outputLength); - return; + return SQL_SUCCESS; #endif #ifdef SQL_ATTR_ASYNC_DBC_FUNCTIONS_ENABLE case SQL_ATTR_ASYNC_DBC_FUNCTIONS_ENABLE: GetAttribute(static_cast(SQL_ASYNC_DBC_ENABLE_OFF), value, bufferLength, outputLength); - return; + return SQL_SUCCESS; #endif -#ifdef SQL_ATTR_ASYNC_PCALLBACK +#ifdef SQL_ATTR_ASYNC_DBC_PCALLBACK case SQL_ATTR_ASYNC_DBC_PCALLBACK: GetAttribute(static_cast(NULL), value, bufferLength, outputLength); - return; + return SQL_SUCCESS; #endif #ifdef SQL_ATTR_ASYNC_DBC_PCONTEXT case SQL_ATTR_ASYNC_DBC_PCONTEXT: GetAttribute(static_cast(NULL), value, bufferLength, outputLength); - return; + return SQL_SUCCESS; #endif case SQL_ATTR_ASYNC_ENABLE: GetAttribute(static_cast(SQL_ASYNC_ENABLE_OFF), value, bufferLength, outputLength); - return; + return SQL_SUCCESS; case SQL_ATTR_AUTO_IPD: GetAttribute(static_cast(SQL_FALSE), value, bufferLength, outputLength); - return; + return SQL_SUCCESS; case SQL_ATTR_AUTOCOMMIT: GetAttribute(static_cast(SQL_AUTOCOMMIT_ON), value, bufferLength, outputLength); - return; + return SQL_SUCCESS; #ifdef SQL_ATTR_DBC_INFO_TOKEN case SQL_ATTR_DBC_INFO_TOKEN: throw DriverException("Cannot read set-only attribute", "HY092"); #endif case SQL_ATTR_ENLIST_IN_DTC: GetAttribute(static_cast(NULL), value, bufferLength, outputLength); - return; + return SQL_SUCCESS; case SQL_ATTR_ODBC_CURSORS: // DM-only. throw DriverException("Invalid attribute", "HY092"); case SQL_ATTR_QUIET_MODE: GetAttribute(static_cast(NULL), value, bufferLength, outputLength); - return; + return SQL_SUCCESS; case SQL_ATTR_TRACE: // DM-only throw DriverException("Invalid attribute", "HY092"); case SQL_ATTR_TRACEFILE: @@ -635,7 +600,7 @@ void ODBCConnection::GetConnectAttr(SQLINTEGER attribute, SQLPOINTER value, case SQL_ATTR_TRANSLATE_OPTION: throw DriverException("Optional feature not supported.", "HYC00"); case SQL_ATTR_TXN_ISOLATION: - throw DriverException("Optional feature not supported.", "HCY00"); + throw DriverException("Optional feature not supported.", "HYC00"); // ODBCAbstraction-level connection attributes. case SQL_ATTR_CURRENT_CATALOG: { @@ -644,9 +609,8 @@ void ODBCConnection::GetConnectAttr(SQLINTEGER attribute, SQLPOINTER value, throw DriverException("Optional feature not supported.", "HYC00"); } const std::string& infoValue = boost::get(*catalog); - GetStringAttribute(isUnicode, infoValue, true, value, bufferLength, outputLength, - GetDiagnostics()); - return; + return GetStringAttribute(isUnicode, infoValue, true, value, bufferLength, + outputLength, GetDiagnostics()); } // These all are uint32_t attributes. @@ -675,6 +639,7 @@ void ODBCConnection::GetConnectAttr(SQLINTEGER attribute, SQLPOINTER value, GetAttribute(static_cast(boost::get(*spiAttribute)), value, bufferLength, outputLength); + return SQL_SUCCESS; } void ODBCConnection::disconnect() { @@ -761,7 +726,6 @@ std::string ODBCConnection::getPropertiesFromConnString( if (!isDsnFirst) { isDriverFirst = true; } - continue; } // Strip wrapping curly braces. diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/odbc_impl/odbc_descriptor.cc b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/odbc_impl/odbc_descriptor.cc index b578bea3609..97b31bb550e 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/odbc_impl/odbc_descriptor.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/odbc_impl/odbc_descriptor.cc @@ -275,7 +275,9 @@ void ODBCDescriptor::GetHeaderField(SQLSMALLINT fieldIdentifier, SQLPOINTER valu GetAttribute(m_rowsProccessedPtr, value, bufferLength, outputLength); break; case SQL_DESC_COUNT: { - GetAttribute(m_highestOneBasedBoundRecord, value, bufferLength, outputLength); + // m_highestOneBasedBoundRecord equals number of records + 1 + GetAttribute(static_cast(m_highestOneBasedBoundRecord - 1), value, + bufferLength, outputLength); break; } default: @@ -311,52 +313,53 @@ void ODBCDescriptor::GetField(SQLSMALLINT recordNumber, SQLSMALLINT fieldIdentif // TODO: Restrict fields based on AppDescriptor IPD, and IRD. + bool lengthInBytes = true; SQLSMALLINT zeroBasedRecord = recordNumber - 1; const DescriptorRecord& record = m_records[zeroBasedRecord]; switch (fieldIdentifier) { case SQL_DESC_BASE_COLUMN_NAME: - GetAttributeUTF8(record.m_baseColumnName, value, bufferLength, outputLength, - GetDiagnostics()); + GetAttributeSQLWCHAR(record.m_baseColumnName, lengthInBytes, value, bufferLength, + outputLength, GetDiagnostics()); break; case SQL_DESC_BASE_TABLE_NAME: - GetAttributeUTF8(record.m_baseTableName, value, bufferLength, outputLength, - GetDiagnostics()); + GetAttributeSQLWCHAR(record.m_baseTableName, lengthInBytes, value, bufferLength, + outputLength, GetDiagnostics()); break; case SQL_DESC_CATALOG_NAME: - GetAttributeUTF8(record.m_catalogName, value, bufferLength, outputLength, - GetDiagnostics()); + GetAttributeSQLWCHAR(record.m_catalogName, lengthInBytes, value, bufferLength, + outputLength, GetDiagnostics()); break; case SQL_DESC_LABEL: - GetAttributeUTF8(record.m_label, value, bufferLength, outputLength, - GetDiagnostics()); + GetAttributeSQLWCHAR(record.m_label, lengthInBytes, value, bufferLength, + outputLength, GetDiagnostics()); break; case SQL_DESC_LITERAL_PREFIX: - GetAttributeUTF8(record.m_literalPrefix, value, bufferLength, outputLength, - GetDiagnostics()); + GetAttributeSQLWCHAR(record.m_literalPrefix, lengthInBytes, value, bufferLength, + outputLength, GetDiagnostics()); break; case SQL_DESC_LITERAL_SUFFIX: - GetAttributeUTF8(record.m_literalSuffix, value, bufferLength, outputLength, - GetDiagnostics()); + GetAttributeSQLWCHAR(record.m_literalSuffix, lengthInBytes, value, bufferLength, + outputLength, GetDiagnostics()); break; case SQL_DESC_LOCAL_TYPE_NAME: - GetAttributeUTF8(record.m_localTypeName, value, bufferLength, outputLength, - GetDiagnostics()); + GetAttributeSQLWCHAR(record.m_localTypeName, lengthInBytes, value, bufferLength, + outputLength, GetDiagnostics()); break; case SQL_DESC_NAME: - GetAttributeUTF8(record.m_name, value, bufferLength, outputLength, - GetDiagnostics()); + GetAttributeSQLWCHAR(record.m_name, lengthInBytes, value, bufferLength, + outputLength, GetDiagnostics()); break; case SQL_DESC_SCHEMA_NAME: - GetAttributeUTF8(record.m_schemaName, value, bufferLength, outputLength, - GetDiagnostics()); + GetAttributeSQLWCHAR(record.m_schemaName, lengthInBytes, value, bufferLength, + outputLength, GetDiagnostics()); break; case SQL_DESC_TABLE_NAME: - GetAttributeUTF8(record.m_tableName, value, bufferLength, outputLength, - GetDiagnostics()); + GetAttributeSQLWCHAR(record.m_tableName, lengthInBytes, value, bufferLength, + outputLength, GetDiagnostics()); break; case SQL_DESC_TYPE_NAME: - GetAttributeUTF8(record.m_typeName, value, bufferLength, outputLength, - GetDiagnostics()); + GetAttributeSQLWCHAR(record.m_typeName, lengthInBytes, value, bufferLength, + outputLength, GetDiagnostics()); break; case SQL_DESC_DATA_PTR: @@ -366,7 +369,7 @@ void ODBCDescriptor::GetField(SQLSMALLINT recordNumber, SQLSMALLINT fieldIdentif case SQL_DESC_OCTET_LENGTH_PTR: GetAttribute(record.m_indicatorPtr, value, bufferLength, outputLength); break; - + case SQL_COLUMN_LENGTH: // ODBC 2.0 case SQL_DESC_LENGTH: GetAttribute(record.m_length, value, bufferLength, outputLength); break; @@ -405,12 +408,14 @@ void ODBCDescriptor::GetField(SQLSMALLINT recordNumber, SQLSMALLINT fieldIdentif case SQL_DESC_PARAMETER_TYPE: GetAttribute(record.m_paramType, value, bufferLength, outputLength); break; + case SQL_COLUMN_PRECISION: // ODBC 2.0 case SQL_DESC_PRECISION: GetAttribute(record.m_precision, value, bufferLength, outputLength); break; case SQL_DESC_ROWVER: GetAttribute(record.m_rowVer, value, bufferLength, outputLength); break; + case SQL_COLUMN_SCALE: // ODBC 2.0 case SQL_DESC_SCALE: GetAttribute(record.m_scale, value, bufferLength, outputLength); break; @@ -500,7 +505,8 @@ void ODBCDescriptor::PopulateFromResultSetMetadata(ResultSetMetadata* rsmd) { m_records[i].m_caseSensitive = rsmd->IsCaseSensitive(oneBasedIndex) ? SQL_TRUE : SQL_FALSE; m_records[i].m_datetimeIntervalPrecision; // TODO - update when rsmd adds this - m_records[i].m_numPrecRadix = rsmd->GetNumPrecRadix(oneBasedIndex); + SQLINTEGER numPrecRadix = rsmd->GetNumPrecRadix(oneBasedIndex); + m_records[i].m_numPrecRadix = numPrecRadix > 0 ? numPrecRadix : 0; m_records[i].m_datetimeIntervalCode; // TODO m_records[i].m_fixedPrecScale = rsmd->IsFixedPrecScale(oneBasedIndex) ? SQL_TRUE : SQL_FALSE; @@ -510,8 +516,7 @@ void ODBCDescriptor::PopulateFromResultSetMetadata(ResultSetMetadata* rsmd) { m_records[i].m_rowVer = SQL_FALSE; m_records[i].m_scale = rsmd->GetScale(oneBasedIndex); m_records[i].m_searchable = rsmd->IsSearchable(oneBasedIndex); - m_records[i].m_type = - GetSqlTypeForODBCVersion(rsmd->GetDataType(oneBasedIndex), m_is2xConnection); + m_records[i].m_type = rsmd->GetDataType(oneBasedIndex); m_records[i].m_unnamed = m_records[i].m_name.empty() ? SQL_TRUE : SQL_FALSE; m_records[i].m_unsigned = rsmd->IsUnsigned(oneBasedIndex) ? SQL_TRUE : SQL_FALSE; m_records[i].m_updatable = rsmd->GetUpdatable(oneBasedIndex); diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/odbc_impl/odbc_environment.cc b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/odbc_impl/odbc_environment.cc index 7781235688f..9d7a8223591 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/odbc_impl/odbc_environment.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/odbc_impl/odbc_environment.cc @@ -15,12 +15,13 @@ // specific language governing permissions and limitations // under the License. -#include +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/odbc_environment.h" + +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/odbc_connection.h" +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/spi/connection.h" +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/spi/driver.h" +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/types.h" -#include -#include -#include -#include #include #include #include diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/odbc_impl/odbc_statement.cc b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/odbc_impl/odbc_statement.cc index a5db0cc25dd..bda30f4466c 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/odbc_impl/odbc_statement.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/odbc_impl/odbc_statement.cc @@ -15,16 +15,17 @@ // specific language governing permissions and limitations // under the License. -#include - -#include -#include -#include -#include -#include -#include -#include -#include +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/odbc_statement.h" + +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/exceptions.h" +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/attribute_utils.h" +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/odbc_connection.h" +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/odbc_descriptor.h" +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/spi/result_set.h" +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/spi/result_set_metadata.h" +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/spi/statement.h" +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/types.h" + #include #include #include @@ -128,6 +129,9 @@ SQLSMALLINT getCTypeForSQLType(const DescriptorRecord& record) { case SQL_WLONGVARCHAR: return SQL_C_WCHAR; + case SQL_BIT: + return SQL_C_BIT; + case SQL_BINARY: case SQL_VARBINARY: case SQL_LONGVARBINARY: @@ -145,13 +149,20 @@ SQLSMALLINT getCTypeForSQLType(const DescriptorRecord& record) { case SQL_BIGINT: return record.m_unsigned ? SQL_C_UBIGINT : SQL_C_SBIGINT; + case SQL_NUMERIC: + case SQL_DECIMAL: + return SQL_C_NUMERIC; + + case SQL_FLOAT: case SQL_REAL: return SQL_C_FLOAT; - case SQL_FLOAT: case SQL_DOUBLE: return SQL_C_DOUBLE; + case SQL_GUID: + return SQL_C_GUID; + case SQL_DATE: case SQL_TYPE_DATE: return SQL_C_TYPE_DATE; @@ -164,32 +175,32 @@ SQLSMALLINT getCTypeForSQLType(const DescriptorRecord& record) { case SQL_TYPE_TIMESTAMP: return SQL_C_TYPE_TIMESTAMP; - case SQL_C_INTERVAL_DAY: - return SQL_INTERVAL_DAY; - case SQL_C_INTERVAL_DAY_TO_HOUR: - return SQL_INTERVAL_DAY_TO_HOUR; - case SQL_C_INTERVAL_DAY_TO_MINUTE: - return SQL_INTERVAL_DAY_TO_MINUTE; - case SQL_C_INTERVAL_DAY_TO_SECOND: - return SQL_INTERVAL_DAY_TO_SECOND; - case SQL_C_INTERVAL_HOUR: - return SQL_INTERVAL_HOUR; - case SQL_C_INTERVAL_HOUR_TO_MINUTE: - return SQL_INTERVAL_HOUR_TO_MINUTE; - case SQL_C_INTERVAL_HOUR_TO_SECOND: - return SQL_INTERVAL_HOUR_TO_SECOND; - case SQL_C_INTERVAL_MINUTE: - return SQL_INTERVAL_MINUTE; - case SQL_C_INTERVAL_MINUTE_TO_SECOND: - return SQL_INTERVAL_MINUTE_TO_SECOND; - case SQL_C_INTERVAL_SECOND: - return SQL_INTERVAL_SECOND; - case SQL_C_INTERVAL_YEAR: - return SQL_INTERVAL_YEAR; - case SQL_C_INTERVAL_YEAR_TO_MONTH: - return SQL_INTERVAL_YEAR_TO_MONTH; - case SQL_C_INTERVAL_MONTH: - return SQL_INTERVAL_MONTH; + case SQL_INTERVAL_DAY: + return SQL_C_INTERVAL_DAY; + case SQL_INTERVAL_DAY_TO_HOUR: + return SQL_C_INTERVAL_DAY_TO_HOUR; + case SQL_INTERVAL_DAY_TO_MINUTE: + return SQL_C_INTERVAL_DAY_TO_MINUTE; + case SQL_INTERVAL_DAY_TO_SECOND: + return SQL_C_INTERVAL_DAY_TO_SECOND; + case SQL_INTERVAL_HOUR: + return SQL_C_INTERVAL_HOUR; + case SQL_INTERVAL_HOUR_TO_MINUTE: + return SQL_C_INTERVAL_HOUR_TO_MINUTE; + case SQL_INTERVAL_HOUR_TO_SECOND: + return SQL_C_INTERVAL_HOUR_TO_SECOND; + case SQL_INTERVAL_MINUTE: + return SQL_C_INTERVAL_MINUTE; + case SQL_INTERVAL_MINUTE_TO_SECOND: + return SQL_C_INTERVAL_MINUTE_TO_SECOND; + case SQL_INTERVAL_SECOND: + return SQL_C_INTERVAL_SECOND; + case SQL_INTERVAL_YEAR: + return SQL_C_INTERVAL_YEAR; + case SQL_INTERVAL_YEAR_TO_MONTH: + return SQL_C_INTERVAL_YEAR_TO_MONTH; + case SQL_INTERVAL_MONTH: + return SQL_C_INTERVAL_MONTH; default: throw DriverException("Unknown SQL type: " + std::to_string(record.m_conciseType), @@ -306,7 +317,8 @@ void ODBCStatement::ExecuteDirect(const std::string& query) { m_isPrepared = false; } -bool ODBCStatement::Fetch(size_t rows) { +bool ODBCStatement::Fetch(size_t rows, SQLULEN* rowCountPtr, + SQLUSMALLINT* rowStatusArray) { if (m_hasReachedEndOfResult) { m_ird->SetRowsProcessed(0); return false; @@ -339,11 +351,24 @@ bool ODBCStatement::Fetch(size_t rows) { m_currentArd->NotifyBindingsHavePropagated(); } - size_t rowsFetched = m_currenResult->Move(rows, m_currentArd->GetBindOffset(), - m_currentArd->GetBoundStructOffset(), - m_ird->GetArrayStatusPtr()); + uint16_t* arrayStatusPtr; + if (rowStatusArray) { + // For SQLExtendedFetch only + arrayStatusPtr = rowStatusArray; + } else { + arrayStatusPtr = m_ird->GetArrayStatusPtr(); + } + + size_t rowsFetched = + m_currenResult->Move(rows, m_currentArd->GetBindOffset(), + m_currentArd->GetBoundStructOffset(), arrayStatusPtr); m_ird->SetRowsProcessed(static_cast(rowsFetched)); + if (rowCountPtr) { + // For SQLExtendedFetch only + *rowCountPtr = rowsFetched; + } + m_rowNumber += rowsFetched; m_hasReachedEndOfResult = rowsFetched != rows; return rowsFetched != 0; @@ -580,6 +605,7 @@ void ODBCStatement::SetStmtAttr(SQLINTEGER statementAttribute, SQLPOINTER value, return; case SQL_ATTR_ASYNC_ENABLE: + throw DriverException("Unsupported attribute", "HYC00"); #ifdef SQL_ATTR_ASYNC_STMT_EVENT case SQL_ATTR_ASYNC_STMT_EVENT: throw DriverException("Unsupported attribute", "HYC00"); @@ -627,7 +653,7 @@ void ODBCStatement::SetStmtAttr(SQLINTEGER statementAttribute, SQLPOINTER value, CheckIfAttributeIsSetToOnlyValidValue(value, static_cast(SQL_UB_OFF)); return; case SQL_ATTR_RETRIEVE_DATA: - CheckIfAttributeIsSetToOnlyValidValue(value, static_cast(SQL_TRUE)); + CheckIfAttributeIsSetToOnlyValidValue(value, static_cast(SQL_RD_ON)); return; case SQL_ROWSET_SIZE: SetAttribute(value, m_rowsetSize); @@ -677,7 +703,7 @@ void ODBCStatement::RevertAppDescriptor(bool isApd) { void ODBCStatement::closeCursor(bool suppressErrors) { if (!suppressErrors && !m_currenResult) { - throw DriverException("Invalid cursor state", "28000"); + throw DriverException("Invalid cursor state", "24000"); } if (m_currenResult) { @@ -691,9 +717,9 @@ void ODBCStatement::closeCursor(bool suppressErrors) { m_hasReachedEndOfResult = false; } -bool ODBCStatement::GetData(SQLSMALLINT recordNumber, SQLSMALLINT cType, - SQLPOINTER dataPtr, SQLLEN bufferLength, - SQLLEN* indicatorPtr) { +SQLRETURN ODBCStatement::GetData(SQLSMALLINT recordNumber, SQLSMALLINT cType, + SQLPOINTER dataPtr, SQLLEN bufferLength, + SQLLEN* indicatorPtr) { if (recordNumber == 0) { throw DriverException("Bookmarks are not supported", "07009"); } else if (recordNumber > m_ird->GetRecords().size()) { @@ -735,6 +761,34 @@ bool ODBCStatement::GetData(SQLSMALLINT recordNumber, SQLSMALLINT cType, bufferLength, indicatorPtr); } +SQLRETURN ODBCStatement::getMoreResults() { + // Multiple result sets are not supported. + if (m_currenResult) { + return SQL_NO_DATA; + } else { + throw DriverException("Function sequence error", "HY010"); + } +} + +void ODBCStatement::getColumnCount(SQLSMALLINT* columnCountPtr) { + if (!columnCountPtr) { + // columnCountPtr is not valid, do nothing as ODBC spec does not mention this as an + // error + return; + } + size_t columnCount = m_ird->GetRecords().size(); + *columnCountPtr = static_cast(columnCount); +} + +void ODBCStatement::getRowCount(SQLLEN* rowCountPtr) { + if (!rowCountPtr) { + // rowCountPtr is not valid, do nothing as ODBC spec does not mention this as an error + return; + } + // Will always be -1 (number of rows unknown) if only SELECT is supported + *rowCountPtr = -1; +} + void ODBCStatement::releaseStatement() { closeCursor(true); m_connection.dropStatement(this); diff --git a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/utils.cc b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/utils.cc index f1d2d14744d..6feb7ff3be2 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbcabstraction/utils.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbcabstraction/utils.cc @@ -42,7 +42,7 @@ boost::optional AsBool(const std::string& value) { boost::optional AsBool(const Connection::ConnPropertyMap& connPropertyMap, const std::string_view& property_name) { - auto extracted_property = connPropertyMap.find(property_name); + auto extracted_property = connPropertyMap.find(std::string(property_name)); if (extracted_property != connPropertyMap.end()) { return AsBool(extracted_property->second); @@ -54,7 +54,7 @@ boost::optional AsBool(const Connection::ConnPropertyMap& connPropertyMap, boost::optional AsInt32(int32_t min_value, const Connection::ConnPropertyMap& connPropertyMap, const std::string_view& property_name) { - auto extracted_property = connPropertyMap.find(property_name); + auto extracted_property = connPropertyMap.find(std::string(property_name)); if (extracted_property != connPropertyMap.end()) { const int32_t stringColumnLength = std::stoi(extracted_property->second); @@ -81,9 +81,8 @@ std::string GetModulePath() { return std::string(path.begin(), path.begin() + dirname_length); } -void ReadConfigFile(PropertyMap& properties, const std::string& config_file_name) { - auto config_path = GetModulePath(); - +void ReadConfigFile(PropertyMap& properties, const std::string& config_path, + const std::string& config_file_name) { std::ifstream config_file; auto config_file_path = config_path + "/" + config_file_name; config_file.open(config_file_path); diff --git a/cpp/src/arrow/flight/sql/odbc/performance_tests/TEST_RUNS/001CompareQueries/CompareQueriesOutput.txt b/cpp/src/arrow/flight/sql/odbc/performance_tests/TEST_RUNS/001CompareQueries/CompareQueriesOutput.txt new file mode 100644 index 00000000000..269667c25d6 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/performance_tests/TEST_RUNS/001CompareQueries/CompareQueriesOutput.txt @@ -0,0 +1,410 @@ +PS C:\Users\Administrator\GitHub\arrow\cpp\src\arrow\flight\sql\odbc\performance_tests> python compare_queries_bar.py --driver_a "Arrow Flight SQL ODBC Driver" --driver_b "Apache Arrow Flight SQL ODBC Driver" --iterations 5 --outfile compare_results.csv --plotfile compare_plot.png + +Running Limit... +Iteration 1: 34765.16 ms (rows returned: 1000) +Iteration 2: 33567.84 ms (rows returned: 1000) +Iteration 3: 33591.92 ms (rows returned: 1000) +Iteration 4: 33771.59 ms (rows returned: 1000) +Iteration 5: 33719.72 ms (rows returned: 1000) +{ + "driver": "Arrow Flight SQL ODBC Driver", + "iterations": 5, + "limit": null, + "query": "SELECT * FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\" LIMIT 1000", + "avg_ms": 33883.24728012085, + "min_ms": 33567.84391403198, + "max_ms": 34765.1641368866, + "all_runs_ms": [ + 34765.1641368866, + 33567.84391403198, + 33591.91560745239, + 33771.59285545349, + 33719.719886779785 + ], + "rows_returned": [ + 1000, + 1000, + 1000, + 1000, + 1000 + ] +} +Iteration 1: 36034.89 ms (rows returned: 1000) +Iteration 2: 31752.09 ms (rows returned: 1000) +Iteration 3: 35200.15 ms (rows returned: 1000) +Iteration 4: 34950.92 ms (rows returned: 1000) +Iteration 5: 35200.88 ms (rows returned: 1000) +{ + "driver": "Apache Arrow Flight SQL ODBC Driver", + "iterations": 5, + "limit": null, + "query": "SELECT * FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\" LIMIT 1000", + "avg_ms": 34627.785539627075, + "min_ms": 31752.08568572998, + "max_ms": 36034.88802909851, + "all_runs_ms": [ + 36034.88802909851, + 31752.08568572998, + 35200.14786720276, + 34950.92225074768, + 35200.883865356445 + ], + "rows_returned": [ + 1000, + 1000, + 1000, + 1000, + 1000 + ] +} + +Running AvgGroupBy... +Iteration 1: 17156.91 ms (rows returned: 13) +Iteration 2: 16672.24 ms (rows returned: 13) +Iteration 3: 16697.97 ms (rows returned: 13) +Iteration 4: 16869.58 ms (rows returned: 13) +Iteration 5: 16784.30 ms (rows returned: 13) +{ + "driver": "Arrow Flight SQL ODBC Driver", + "iterations": 5, + "limit": null, + "query": "SELECT passenger_count, AVG(fare_amount) FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\" GROUP BY passenger_count", + "avg_ms": 16836.19828224182, + "min_ms": 16672.237634658813, + "max_ms": 17156.906366348267, + "all_runs_ms": [ + 17156.906366348267, + 16672.237634658813, + 16697.968244552612, + 16869.58408355713, + 16784.295082092285 + ], + "rows_returned": [ + 13, + 13, + 13, + 13, + 13 + ] +} +Iteration 1: 16770.03 ms (rows returned: 13) +Iteration 2: 17078.22 ms (rows returned: 13) +Iteration 3: 16742.22 ms (rows returned: 13) +Iteration 4: 16740.14 ms (rows returned: 13) +Iteration 5: 16666.99 ms (rows returned: 13) +{ + "driver": "Apache Arrow Flight SQL ODBC Driver", + "iterations": 5, + "limit": null, + "query": "SELECT passenger_count, AVG(fare_amount) FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\" GROUP BY passenger_count", + "avg_ms": 16799.520874023438, + "min_ms": 16666.994333267212, + "max_ms": 17078.22322845459, + "all_runs_ms": [ + 16770.0252532959, + 17078.22322845459, + 16742.223978042603, + 16740.137577056885, + 16666.994333267212 + ], + "rows_returned": [ + 13, + 13, + 13, + 13, + 13 + ] +} + +Running SumGroupBy... +Iteration 1: 16453.79 ms (rows returned: 13) +Iteration 2: 16430.96 ms (rows returned: 13) +Iteration 3: 16337.73 ms (rows returned: 13) +Iteration 4: 16466.14 ms (rows returned: 13) +Iteration 5: 16312.47 ms (rows returned: 13) +{ + "driver": "Arrow Flight SQL ODBC Driver", + "iterations": 5, + "limit": null, + "query": "SELECT passenger_count, SUM(fare_amount) FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\" GROUP BY passenger_count", + "avg_ms": 16400.220680236816, + "min_ms": 16312.472820281982, + "max_ms": 16466.143131256104, + "all_runs_ms": [ + 16453.79376411438, + 16430.964708328247, + 16337.72897720337, + 16466.143131256104, + 16312.472820281982 + ], + "rows_returned": [ + 13, + 13, + 13, + 13, + 13 + ] +} +Iteration 1: 16365.31 ms (rows returned: 13) +Iteration 2: 16315.78 ms (rows returned: 13) +Iteration 3: 16266.77 ms (rows returned: 13) +Iteration 4: 16450.32 ms (rows returned: 13) +Iteration 5: 16380.79 ms (rows returned: 13) +{ + "driver": "Apache Arrow Flight SQL ODBC Driver", + "iterations": 5, + "limit": null, + "query": "SELECT passenger_count, SUM(fare_amount) FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\" GROUP BY passenger_count", + "avg_ms": 16355.792951583862, + "min_ms": 16266.765356063843, + "max_ms": 16450.316429138184, + "all_runs_ms": [ + 16365.306615829468, + 16315.782070159912, + 16266.765356063843, + 16450.316429138184, + 16380.794286727905 + ], + "rows_returned": [ + 13, + 13, + 13, + 13, + 13 + ] +} + +Running GreaterThan... +Iteration 1: 16431.14 ms (rows returned: 500) +Iteration 2: 16581.48 ms (rows returned: 500) +Iteration 3: 17463.49 ms (rows returned: 500) +Iteration 4: 16401.79 ms (rows returned: 500) +Iteration 5: 16410.06 ms (rows returned: 500) +{ + "driver": "Arrow Flight SQL ODBC Driver", + "iterations": 5, + "limit": null, + "query": "SELECT * FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\" WHERE fare_amount > 50 LIMIT 500", + "avg_ms": 16657.591247558594, + "min_ms": 16401.78942680359, + "max_ms": 17463.488817214966, + "all_runs_ms": [ + 16431.142330169678, + 16581.480026245117, + 17463.488817214966, + 16401.78942680359, + 16410.05563735962 + ], + "rows_returned": [ + 500, + 500, + 500, + 500, + 500 + ] +} +Iteration 1: 16492.02 ms (rows returned: 500) +Iteration 2: 16609.02 ms (rows returned: 500) +Iteration 3: 16921.74 ms (rows returned: 500) +Iteration 4: 19105.62 ms (rows returned: 500) +Iteration 5: 17268.11 ms (rows returned: 500) +{ + "driver": "Apache Arrow Flight SQL ODBC Driver", + "iterations": 5, + "limit": null, + "query": "SELECT * FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\" WHERE fare_amount > 50 LIMIT 500", + "avg_ms": 17279.300785064697, + "min_ms": 16492.016792297363, + "max_ms": 19105.62229156494, + "all_runs_ms": [ + 16492.016792297363, + 16609.01641845703, + 16921.739101409912, + 19105.62229156494, + 17268.10932159424 + ], + "rows_returned": [ + 500, + 500, + 500, + 500, + 500 + ] +} + +Running OrderBy... +Iteration 1: 53609.11 ms (rows returned: 100) +Iteration 2: 51697.37 ms (rows returned: 100) +Iteration 3: 52442.85 ms (rows returned: 100) +Iteration 4: 51915.98 ms (rows returned: 100) +Iteration 5: 52732.26 ms (rows returned: 100) +{ + "driver": "Arrow Flight SQL ODBC Driver", + "iterations": 5, + "limit": null, + "query": "SELECT * FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\" ORDER BY fare_amount DESC LIMIT 100", + "avg_ms": 52479.51292991638, + "min_ms": 51697.365045547485, + "max_ms": 53609.11321640015, + "all_runs_ms": [ + 53609.11321640015, + 51697.365045547485, + 52442.848205566406, + 51915.97771644592, + 52732.26046562195 + ], + "rows_returned": [ + 100, + 100, + 100, + 100, + 100 + ] +} +Iteration 1: 51679.21 ms (rows returned: 100) +Iteration 2: 51655.93 ms (rows returned: 100) +Iteration 3: 52763.46 ms (rows returned: 100) +Iteration 4: 52017.11 ms (rows returned: 100) +Iteration 5: 51290.61 ms (rows returned: 100) +{ + "driver": "Apache Arrow Flight SQL ODBC Driver", + "iterations": 5, + "limit": null, + "query": "SELECT * FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\" ORDER BY fare_amount DESC LIMIT 100", + "avg_ms": 51881.261920928955, + "min_ms": 51290.605545043945, + "max_ms": 52763.455629348755, + "all_runs_ms": [ + 51679.210901260376, + 51655.92646598816, + 52763.455629348755, + 52017.11106300354, + 51290.605545043945 + ], + "rows_returned": [ + 100, + 100, + 100, + 100, + 100 + ] +} + +Running SingleCountGroupBy... +Iteration 1: 12628.46 ms (rows returned: 13) +Iteration 2: 11824.79 ms (rows returned: 13) +Iteration 3: 11797.74 ms (rows returned: 13) +Iteration 4: 11769.57 ms (rows returned: 13) +Iteration 5: 12707.33 ms (rows returned: 13) +{ + "driver": "Arrow Flight SQL ODBC Driver", + "iterations": 5, + "limit": null, + "query": "SELECT passenger_count, COUNT(*) FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\" GROUP BY passenger_count", + "avg_ms": 12145.577430725098, + "min_ms": 11769.570350646973, + "max_ms": 12707.332849502563, + "all_runs_ms": [ + 12628.455400466919, + 11824.785709381104, + 11797.74284362793, + 11769.570350646973, + 12707.332849502563 + ], + "rows_returned": [ + 13, + 13, + 13, + 13, + 13 + ] +} +Iteration 1: 11914.47 ms (rows returned: 13) +Iteration 2: 11833.41 ms (rows returned: 13) +Iteration 3: 11851.74 ms (rows returned: 13) +Iteration 4: 11873.59 ms (rows returned: 13) +Iteration 5: 11799.07 ms (rows returned: 13) +{ + "driver": "Apache Arrow Flight SQL ODBC Driver", + "iterations": 5, + "limit": null, + "query": "SELECT passenger_count, COUNT(*) FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\" GROUP BY passenger_count", + "avg_ms": 11854.455614089966, + "min_ms": 11799.065351486206, + "max_ms": 11914.469003677368, + "all_runs_ms": [ + 11914.469003677368, + 11833.407640457153, + 11851.743936538696, + 11873.592138290405, + 11799.065351486206 + ], + "rows_returned": [ + 13, + 13, + 13, + 13, + 13 + ] +} + +Running MultiCountGroupBy... +Iteration 1: 291068.74 ms (rows returned: 14545) +Iteration 2: 276341.55 ms (rows returned: 14545) +Iteration 3: 297687.92 ms (rows returned: 14545) +Iteration 4: 357626.56 ms (rows returned: 14545) +Iteration 5: 333281.98 ms (rows returned: 14545) +{ + "driver": "Arrow Flight SQL ODBC Driver", + "iterations": 5, + "limit": null, + "query": "SELECT passenger_count, fare_amount, COUNT(*) FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\" GROUP BY passenger_count, fare_amount", + "avg_ms": 311201.34949684143, + "min_ms": 276341.55106544495, + "max_ms": 357626.5640258789, + "all_runs_ms": [ + 291068.74084472656, + 276341.55106544495, + 297687.9153251648, + 357626.5640258789, + 333281.97622299194 + ], + "rows_returned": [ + 14545, + 14545, + 14545, + 14545, + 14545 + ] +} +Iteration 1: 308960.54 ms (rows returned: 14545) +Iteration 2: 291849.08 ms (rows returned: 14545) +Iteration 3: 289264.62 ms (rows returned: 14545) +Iteration 4: 287014.85 ms (rows returned: 14545) +Iteration 5: 284922.38 ms (rows returned: 14545) +{ + "driver": "Apache Arrow Flight SQL ODBC Driver", + "iterations": 5, + "limit": null, + "query": "SELECT passenger_count, fare_amount, COUNT(*) FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\" GROUP BY passenger_count, fare_amount", + "avg_ms": 292402.29506492615, + "min_ms": 284922.3837852478, + "max_ms": 308960.5448246002, + "all_runs_ms": [ + 308960.5448246002, + 291849.0788936615, + 289264.61839675903, + 287014.8494243622, + 284922.3837852478 + ], + "rows_returned": [ + 14545, + 14545, + 14545, + 14545, + 14545 + ] +} + +Results written to compare_results.csv +Plot saved to compare_plot.png \ No newline at end of file diff --git a/cpp/src/arrow/flight/sql/odbc/performance_tests/TEST_RUNS/001CompareQueries/queries_plot.png b/cpp/src/arrow/flight/sql/odbc/performance_tests/TEST_RUNS/001CompareQueries/queries_plot.png new file mode 100644 index 00000000000..888488f98f1 Binary files /dev/null and b/cpp/src/arrow/flight/sql/odbc/performance_tests/TEST_RUNS/001CompareQueries/queries_plot.png differ diff --git a/cpp/src/arrow/flight/sql/odbc/performance_tests/TEST_RUNS/001CompareQueries/queries_results.csv b/cpp/src/arrow/flight/sql/odbc/performance_tests/TEST_RUNS/001CompareQueries/queries_results.csv new file mode 100644 index 00000000000..1c8855d4216 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/performance_tests/TEST_RUNS/001CompareQueries/queries_results.csv @@ -0,0 +1,8 @@ +query,driver_a,a_avg,a_min,a_max,driver_b,b_avg,b_min,b_max +Limit,Arrow Flight SQL ODBC Driver,33883.24728012085,33567.84391403198,34765.1641368866,Apache Arrow Flight SQL ODBC Driver,34627.785539627075,31752.08568572998,36034.88802909851 +AvgGroupBy,Arrow Flight SQL ODBC Driver,16836.19828224182,16672.237634658813,17156.906366348267,Apache Arrow Flight SQL ODBC Driver,16799.520874023438,16666.994333267212,17078.22322845459 +SumGroupBy,Arrow Flight SQL ODBC Driver,16400.220680236816,16312.472820281982,16466.143131256104,Apache Arrow Flight SQL ODBC Driver,16355.792951583862,16266.765356063843,16450.316429138184 +GreaterThan,Arrow Flight SQL ODBC Driver,16657.591247558594,16401.78942680359,17463.488817214966,Apache Arrow Flight SQL ODBC Driver,17279.300785064697,16492.016792297363,19105.62229156494 +OrderBy,Arrow Flight SQL ODBC Driver,52479.51292991638,51697.365045547485,53609.11321640015,Apache Arrow Flight SQL ODBC Driver,51881.261920928955,51290.605545043945,52763.455629348755 +SingleCountGroupBy,Arrow Flight SQL ODBC Driver,12145.577430725098,11769.570350646973,12707.332849502563,Apache Arrow Flight SQL ODBC Driver,11854.455614089966,11799.065351486206,11914.469003677368 +MultiCountGroupBy,Arrow Flight SQL ODBC Driver,311201.34949684143,276341.55106544495,357626.5640258789,Apache Arrow Flight SQL ODBC Driver,292402.29506492615,284922.3837852478,308960.5448246002 diff --git a/cpp/src/arrow/flight/sql/odbc/performance_tests/TEST_RUNS/002CompareLimits100K/CompareLimitsOutput.txt b/cpp/src/arrow/flight/sql/odbc/performance_tests/TEST_RUNS/002CompareLimits100K/CompareLimitsOutput.txt new file mode 100644 index 00000000000..0197052a818 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/performance_tests/TEST_RUNS/002CompareLimits100K/CompareLimitsOutput.txt @@ -0,0 +1,315 @@ +PS C:\Users\Administrator\GitHub\arrow\cpp\src\arrow\flight\sql\odbc\performance_tests> python compare_limits_plot.py --driver_a "Arrow Flight SQL ODBC Driver" --driver_b "Apache Arrow Flight SQL ODBC Driver" --iterations 5 --limits 1 10 100 1000 10000 100000 --outfile limits_results_to_100K.csv --plotfile limits_plot_to_100K.png + +Running Arrow Flight SQL ODBC Driver with LIMIT=1... +Intermediate JSON for Arrow Flight SQL ODBC Driver LIMIT=1: +{ + "driver": "Arrow Flight SQL ODBC Driver", + "iterations": 5, + "limit": 1, + "query": "SELECT * FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\" LIMIT 1", + "avg_ms": 363.750696182251, + "min_ms": 322.5698471069336, + "max_ms": 447.0713138580322, + "all_runs_ms": [ + 447.0713138580322, + 390.6853199005127, + 333.9381217956543, + 324.48887825012207, + 322.5698471069336 + ], + "rows_returned": [ + 1, + 1, + 1, + 1, + 1 + ] +} + +Running Arrow Flight SQL ODBC Driver with LIMIT=10... +Intermediate JSON for Arrow Flight SQL ODBC Driver LIMIT=10: +{ + "driver": "Arrow Flight SQL ODBC Driver", + "iterations": 5, + "limit": 10, + "query": "SELECT * FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\" LIMIT 10", + "avg_ms": 611.6148471832275, + "min_ms": 562.8888607025146, + "max_ms": 669.5723533630371, + "all_runs_ms": [ + 562.8888607025146, + 628.2343864440918, + 669.5723533630371, + 595.0183868408203, + 602.3602485656738 + ], + "rows_returned": [ + 10, + 10, + 10, + 10, + 10 + ] +} + +Running Arrow Flight SQL ODBC Driver with LIMIT=100... +Intermediate JSON for Arrow Flight SQL ODBC Driver LIMIT=100: +{ + "driver": "Arrow Flight SQL ODBC Driver", + "iterations": 5, + "limit": 100, + "query": "SELECT * FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\" LIMIT 100", + "avg_ms": 3315.7440662384033, + "min_ms": 3253.4024715423584, + "max_ms": 3411.7939472198486, + "all_runs_ms": [ + 3253.4024715423584, + 3411.7939472198486, + 3285.139322280884, + 3296.3390350341797, + 3332.045555114746 + ], + "rows_returned": [ + 100, + 100, + 100, + 100, + 100 + ] +} + +Running Arrow Flight SQL ODBC Driver with LIMIT=1000... +Intermediate JSON for Arrow Flight SQL ODBC Driver LIMIT=1000: +{ + "driver": "Arrow Flight SQL ODBC Driver", + "iterations": 5, + "limit": 1000, + "query": "SELECT * FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\" LIMIT 1000", + "avg_ms": 30438.987493515015, + "min_ms": 29623.3491897583, + "max_ms": 31157.08327293396, + "all_runs_ms": [ + 31029.332160949707, + 29623.3491897583, + 31157.08327293396, + 29719.9490070343, + 30665.223836898804 + ], + "rows_returned": [ + 1000, + 1000, + 1000, + 1000, + 1000 + ] +} + +Running Arrow Flight SQL ODBC Driver with LIMIT=10000... +Intermediate JSON for Arrow Flight SQL ODBC Driver LIMIT=10000: +{ + "driver": "Arrow Flight SQL ODBC Driver", + "iterations": 5, + "limit": 10000, + "query": "SELECT * FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\" LIMIT 10000", + "avg_ms": 300643.31245422363, + "min_ms": 290261.10315322876, + "max_ms": 306783.6825847626, + "all_runs_ms": [ + 302672.1565723419, + 300335.0327014923, + 303164.5872592926, + 290261.10315322876, + 306783.6825847626 + ], + "rows_returned": [ + 10000, + 10000, + 10000, + 10000, + 10000 + ] +} + +Running Arrow Flight SQL ODBC Driver with LIMIT=100000... +Intermediate JSON for Arrow Flight SQL ODBC Driver LIMIT=100000: +{ + "driver": "Arrow Flight SQL ODBC Driver", + "iterations": 5, + "limit": 100000, + "query": "SELECT * FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\" LIMIT 100000", + "avg_ms": 3036322.395181656, + "min_ms": 2996272.8984355927, + "max_ms": 3068238.5454177856, + "all_runs_ms": [ + 3038236.1080646515, + 3056891.6964530945, + 3068238.5454177856, + 3021972.727537155, + 2996272.8984355927 + ], + "rows_returned": [ + 100000, + 100000, + 100000, + 100000, + 100000 + ] +} + +Running Apache Arrow Flight SQL ODBC Driver with LIMIT=1... +Intermediate JSON for Apache Arrow Flight SQL ODBC Driver LIMIT=1: +{ + "driver": "Apache Arrow Flight SQL ODBC Driver", + "iterations": 5, + "limit": 1, + "query": "SELECT * FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\" LIMIT 1", + "avg_ms": 361.22307777404785, + "min_ms": 329.59818840026855, + "max_ms": 457.34381675720215, + "all_runs_ms": [ + 457.34381675720215, + 341.2652015686035, + 340.3174877166748, + 337.59069442749023, + 329.59818840026855 + ], + "rows_returned": [ + 1, + 1, + 1, + 1, + 1 + ] +} + +Running Apache Arrow Flight SQL ODBC Driver with LIMIT=10... +Intermediate JSON for Apache Arrow Flight SQL ODBC Driver LIMIT=10: +{ + "driver": "Apache Arrow Flight SQL ODBC Driver", + "iterations": 5, + "limit": 10, + "query": "SELECT * FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\" LIMIT 10", + "avg_ms": 609.0284824371338, + "min_ms": 596.207857131958, + "max_ms": 619.4398403167725, + "all_runs_ms": [ + 619.4398403167725, + 609.3175411224365, + 596.207857131958, + 615.5171394348145, + 604.6600341796875 + ], + "rows_returned": [ + 10, + 10, + 10, + 10, + 10 + ] +} + +Running Apache Arrow Flight SQL ODBC Driver with LIMIT=100... +Intermediate JSON for Apache Arrow Flight SQL ODBC Driver LIMIT=100: +{ + "driver": "Apache Arrow Flight SQL ODBC Driver", + "iterations": 5, + "limit": 100, + "query": "SELECT * FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\" LIMIT 100", + "avg_ms": 3322.842216491699, + "min_ms": 3236.36794090271, + "max_ms": 3433.699607849121, + "all_runs_ms": [ + 3396.2011337280273, + 3236.36794090271, + 3247.5204467773438, + 3300.421953201294, + 3433.699607849121 + ], + "rows_returned": [ + 100, + 100, + 100, + 100, + 100 + ] +} + +Running Apache Arrow Flight SQL ODBC Driver with LIMIT=1000... +Intermediate JSON for Apache Arrow Flight SQL ODBC Driver LIMIT=1000: +{ + "driver": "Apache Arrow Flight SQL ODBC Driver", + "iterations": 5, + "limit": 1000, + "query": "SELECT * FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\" LIMIT 1000", + "avg_ms": 30824.00050163269, + "min_ms": 29915.04693031311, + "max_ms": 32180.66430091858, + "all_runs_ms": [ + 32180.66430091858, + 29915.04693031311, + 30911.401987075806, + 29925.776720046997, + 31187.11256980896 + ], + "rows_returned": [ + 1000, + 1000, + 1000, + 1000, + 1000 + ] +} + +Running Apache Arrow Flight SQL ODBC Driver with LIMIT=10000... +Intermediate JSON for Apache Arrow Flight SQL ODBC Driver LIMIT=10000: +{ + "driver": "Apache Arrow Flight SQL ODBC Driver", + "iterations": 5, + "limit": 10000, + "query": "SELECT * FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\" LIMIT 10000", + "avg_ms": 303659.582901001, + "min_ms": 300044.9962615967, + "max_ms": 309345.0765609741, + "all_runs_ms": [ + 309345.0765609741, + 304755.68890571594, + 300044.9962615967, + 301692.6624774933, + 302459.49029922485 + ], + "rows_returned": [ + 10000, + 10000, + 10000, + 10000, + 10000 + ] +} + +Running Apache Arrow Flight SQL ODBC Driver with LIMIT=100000... +Intermediate JSON for Apache Arrow Flight SQL ODBC Driver LIMIT=100000: +{ + "driver": "Apache Arrow Flight SQL ODBC Driver", + "iterations": 5, + "limit": 100000, + "query": "SELECT * FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\" LIMIT 100000", + "avg_ms": 3027773.0962753296, + "min_ms": 3013334.800004959, + "max_ms": 3070387.4740600586, + "all_runs_ms": [ + 3013334.800004959, + 3013481.8575382233, + 3024003.721475601, + 3070387.4740600586, + 3017657.628297806 + ], + "rows_returned": [ + 100000, + 100000, + 100000, + 100000, + 100000 + ] +} +Results saved to limits_results_to_100K.csv +Plot saved to limits_plot_to_100K.png \ No newline at end of file diff --git a/cpp/src/arrow/flight/sql/odbc/performance_tests/TEST_RUNS/002CompareLimits100K/limits_plot_to_100K.png b/cpp/src/arrow/flight/sql/odbc/performance_tests/TEST_RUNS/002CompareLimits100K/limits_plot_to_100K.png new file mode 100644 index 00000000000..ed4bfb46a5e Binary files /dev/null and b/cpp/src/arrow/flight/sql/odbc/performance_tests/TEST_RUNS/002CompareLimits100K/limits_plot_to_100K.png differ diff --git a/cpp/src/arrow/flight/sql/odbc/performance_tests/TEST_RUNS/002CompareLimits100K/limits_results_to_100K.csv b/cpp/src/arrow/flight/sql/odbc/performance_tests/TEST_RUNS/002CompareLimits100K/limits_results_to_100K.csv new file mode 100644 index 00000000000..c560eadf11e --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/performance_tests/TEST_RUNS/002CompareLimits100K/limits_results_to_100K.csv @@ -0,0 +1,13 @@ +Driver,Limit,Avg_ms,Min_ms,Max_ms +Arrow Flight SQL ODBC Driver,1,363.750696182251,322.5698471069336,447.0713138580322 +Arrow Flight SQL ODBC Driver,10,611.6148471832275,562.8888607025146,669.5723533630371 +Arrow Flight SQL ODBC Driver,100,3315.7440662384033,3253.4024715423584,3411.7939472198486 +Arrow Flight SQL ODBC Driver,1000,30438.987493515015,29623.3491897583,31157.08327293396 +Arrow Flight SQL ODBC Driver,10000,300643.31245422363,290261.10315322876,306783.6825847626 +Arrow Flight SQL ODBC Driver,100000,3036322.395181656,2996272.8984355927,3068238.5454177856 +Apache Arrow Flight SQL ODBC Driver,1,361.22307777404785,329.59818840026855,457.34381675720215 +Apache Arrow Flight SQL ODBC Driver,10,609.0284824371338,596.207857131958,619.4398403167725 +Apache Arrow Flight SQL ODBC Driver,100,3322.842216491699,3236.36794090271,3433.699607849121 +Apache Arrow Flight SQL ODBC Driver,1000,30824.00050163269,29915.04693031311,32180.66430091858 +Apache Arrow Flight SQL ODBC Driver,10000,303659.582901001,300044.9962615967,309345.0765609741 +Apache Arrow Flight SQL ODBC Driver,100000,3027773.0962753296,3013334.800004959,3070387.4740600586 diff --git a/cpp/src/arrow/flight/sql/odbc/performance_tests/TEST_RUNS/003CompareQueries20/CompareQueriesOutput.txt b/cpp/src/arrow/flight/sql/odbc/performance_tests/TEST_RUNS/003CompareQueries20/CompareQueriesOutput.txt new file mode 100644 index 00000000000..ab8733d62a7 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/performance_tests/TEST_RUNS/003CompareQueries20/CompareQueriesOutput.txt @@ -0,0 +1,1205 @@ +PS C:\Users\Administrator\GitHub\arrow\cpp\src\arrow\flight\sql\odbc\performance_tests> python compare_queries_bar.py --driver_a "Arrow Flight SQL ODBC Driver" --driver_b "Apache Arrow Flight SQL ODBC Driver" --schema "Samples.samples.dremio.com" --table "NYC-taxi-trips-iceberg" --iterations 5 --outfile compare_20_queries.csv --plotfile results_20_queries.png + +Running Limit100... +Iteration 1: 3258.13 ms (rows returned: 100) +Iteration 2: 3221.18 ms (rows returned: 100) +Iteration 3: 3090.82 ms (rows returned: 100) +Iteration 4: 3084.55 ms (rows returned: 100) +Iteration 5: 3089.00 ms (rows returned: 100) +{ + "driver": "Arrow Flight SQL ODBC Driver", + "schema": "Samples.samples.dremio.com", + "table": "NYC-taxi-trips-iceberg", + "iterations": 5, + "query": "SELECT * FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\" LIMIT 100", + "avg_ms": 3148.7356185913086, + "min_ms": 3084.550380706787, + "max_ms": 3258.1324577331543, + "all_runs_ms": [ + 3258.1324577331543, + 3221.1766242980957, + 3090.8217430114746, + 3084.550380706787, + 3088.9968872070312 + ], + "rows_returned": [ + 100, + 100, + 100, + 100, + 100 + ] +} +Iteration 1: 3078.05 ms (rows returned: 100) +Iteration 2: 3135.37 ms (rows returned: 100) +Iteration 3: 3114.44 ms (rows returned: 100) +Iteration 4: 3110.02 ms (rows returned: 100) +Iteration 5: 3073.05 ms (rows returned: 100) +{ + "driver": "Apache Arrow Flight SQL ODBC Driver", + "schema": "Samples.samples.dremio.com", + "table": "NYC-taxi-trips-iceberg", + "iterations": 5, + "query": "SELECT * FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\" LIMIT 100", + "avg_ms": 3102.1849632263184, + "min_ms": 3073.0504989624023, + "max_ms": 3135.3671550750732, + "all_runs_ms": [ + 3078.051805496216, + 3135.3671550750732, + 3114.4394874572754, + 3110.015869140625, + 3073.0504989624023 + ], + "rows_returned": [ + 100, + 100, + 100, + 100, + 100 + ] +} + +Running Limit1000... +Iteration 1: 28295.13 ms (rows returned: 1000) +Iteration 2: 28336.46 ms (rows returned: 1000) +Iteration 3: 28279.17 ms (rows returned: 1000) +Iteration 4: 28276.08 ms (rows returned: 1000) +Iteration 5: 28206.97 ms (rows returned: 1000) +{ + "driver": "Arrow Flight SQL ODBC Driver", + "schema": "Samples.samples.dremio.com", + "table": "NYC-taxi-trips-iceberg", + "iterations": 5, + "query": "SELECT * FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\" LIMIT 1000", + "avg_ms": 28278.762102127075, + "min_ms": 28206.9673538208, + "max_ms": 28336.461544036865, + "all_runs_ms": [ + 28295.12882232666, + 28336.461544036865, + 28279.168844223022, + 28276.083946228027, + 28206.9673538208 + ], + "rows_returned": [ + 1000, + 1000, + 1000, + 1000, + 1000 + ] +} +Iteration 1: 28357.74 ms (rows returned: 1000) +Iteration 2: 28239.28 ms (rows returned: 1000) +Iteration 3: 28254.01 ms (rows returned: 1000) +Iteration 4: 28299.25 ms (rows returned: 1000) +Iteration 5: 28472.38 ms (rows returned: 1000) +{ + "driver": "Apache Arrow Flight SQL ODBC Driver", + "schema": "Samples.samples.dremio.com", + "table": "NYC-taxi-trips-iceberg", + "iterations": 5, + "query": "SELECT * FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\" LIMIT 1000", + "avg_ms": 28324.5304107666, + "min_ms": 28239.282846450806, + "max_ms": 28472.376108169556, + "all_runs_ms": [ + 28357.738256454468, + 28239.282846450806, + 28254.00972366333, + 28299.24511909485, + 28472.376108169556 + ], + "rows_returned": [ + 1000, + 1000, + 1000, + 1000, + 1000 + ] +} + +Running AvgFareByPassenger... +Iteration 1: 18290.48 ms (rows returned: 13) +Iteration 2: 17481.93 ms (rows returned: 13) +Iteration 3: 17523.29 ms (rows returned: 13) +Iteration 4: 17411.47 ms (rows returned: 13) +Iteration 5: 17423.48 ms (rows returned: 13) +{ + "driver": "Arrow Flight SQL ODBC Driver", + "schema": "Samples.samples.dremio.com", + "table": "NYC-taxi-trips-iceberg", + "iterations": 5, + "query": "SELECT passenger_count, AVG(fare_amount) AS avg_fare FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\" GROUP BY passenger_count", + "avg_ms": 17626.128387451172, + "min_ms": 17411.471366882324, + "max_ms": 18290.47703742981, + "all_runs_ms": [ + 18290.47703742981, + 17481.927156448364, + 17523.289442062378, + 17411.471366882324, + 17423.476934432983 + ], + "rows_returned": [ + 13, + 13, + 13, + 13, + 13 + ] +} +Iteration 1: 17427.29 ms (rows returned: 13) +Iteration 2: 17412.03 ms (rows returned: 13) +Iteration 3: 17631.00 ms (rows returned: 13) +Iteration 4: 17411.21 ms (rows returned: 13) +Iteration 5: 17423.88 ms (rows returned: 13) +{ + "driver": "Apache Arrow Flight SQL ODBC Driver", + "schema": "Samples.samples.dremio.com", + "table": "NYC-taxi-trips-iceberg", + "iterations": 5, + "query": "SELECT passenger_count, AVG(fare_amount) AS avg_fare FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\" GROUP BY passenger_count", + "avg_ms": 17461.08193397522, + "min_ms": 17411.21244430542, + "max_ms": 17630.999088287354, + "all_runs_ms": [ + 17427.2882938385, + 17412.02735900879, + 17630.999088287354, + 17411.21244430542, + 17423.882484436035 + ], + "rows_returned": [ + 13, + 13, + 13, + 13, + 13 + ] +} + +Running SumFareByPassenger... +Iteration 1: 17029.83 ms (rows returned: 13) +Iteration 2: 17219.08 ms (rows returned: 13) +Iteration 3: 17086.83 ms (rows returned: 13) +Iteration 4: 17048.02 ms (rows returned: 13) +Iteration 5: 17191.18 ms (rows returned: 13) +{ + "driver": "Arrow Flight SQL ODBC Driver", + "schema": "Samples.samples.dremio.com", + "table": "NYC-taxi-trips-iceberg", + "iterations": 5, + "query": "SELECT passenger_count, SUM(fare_amount) AS total_fare FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\" GROUP BY passenger_count", + "avg_ms": 17114.98770713806, + "min_ms": 17029.832124710083, + "max_ms": 17219.08450126648, + "all_runs_ms": [ + 17029.832124710083, + 17219.08450126648, + 17086.82870864868, + 17048.015594482422, + 17191.17760658264 + ], + "rows_returned": [ + 13, + 13, + 13, + 13, + 13 + ] +} +Iteration 1: 16788.63 ms (rows returned: 13) +Iteration 2: 16720.72 ms (rows returned: 13) +Iteration 3: 16698.86 ms (rows returned: 13) +Iteration 4: 16762.84 ms (rows returned: 13) +Iteration 5: 16722.51 ms (rows returned: 13) +{ + "driver": "Apache Arrow Flight SQL ODBC Driver", + "schema": "Samples.samples.dremio.com", + "table": "NYC-taxi-trips-iceberg", + "iterations": 5, + "query": "SELECT passenger_count, SUM(fare_amount) AS total_fare FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\" GROUP BY passenger_count", + "avg_ms": 16738.711404800415, + "min_ms": 16698.858499526978, + "max_ms": 16788.630723953247, + "all_runs_ms": [ + 16788.630723953247, + 16720.719814300537, + 16698.858499526978, + 16762.84146308899, + 16722.506523132324 + ], + "rows_returned": [ + 13, + 13, + 13, + 13, + 13 + ] +} + +Running AvgTipByPassenger... +Iteration 1: 17996.39 ms (rows returned: 13) +Iteration 2: 17289.82 ms (rows returned: 13) +Iteration 3: 17209.62 ms (rows returned: 13) +Iteration 4: 17048.29 ms (rows returned: 13) +Iteration 5: 17133.30 ms (rows returned: 13) +{ + "driver": "Arrow Flight SQL ODBC Driver", + "schema": "Samples.samples.dremio.com", + "table": "NYC-taxi-trips-iceberg", + "iterations": 5, + "query": "SELECT passenger_count, AVG(tip_amount) AS avg_tip FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\" GROUP BY passenger_count", + "avg_ms": 17335.482931137085, + "min_ms": 17048.28643798828, + "max_ms": 17996.389627456665, + "all_runs_ms": [ + 17996.389627456665, + 17289.820194244385, + 17209.61618423462, + 17048.28643798828, + 17133.302211761475 + ], + "rows_returned": [ + 13, + 13, + 13, + 13, + 13 + ] +} +Iteration 1: 17100.37 ms (rows returned: 13) +Iteration 2: 17144.93 ms (rows returned: 13) +Iteration 3: 17263.51 ms (rows returned: 13) +Iteration 4: 17126.49 ms (rows returned: 13) +Iteration 5: 17094.25 ms (rows returned: 13) +{ + "driver": "Apache Arrow Flight SQL ODBC Driver", + "schema": "Samples.samples.dremio.com", + "table": "NYC-taxi-trips-iceberg", + "iterations": 5, + "query": "SELECT passenger_count, AVG(tip_amount) AS avg_tip FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\" GROUP BY passenger_count", + "avg_ms": 17145.911836624146, + "min_ms": 17094.25401687622, + "max_ms": 17263.51237297058, + "all_runs_ms": [ + 17100.367307662964, + 17144.932985305786, + 17263.51237297058, + 17126.492500305176, + 17094.25401687622 + ], + "rows_returned": [ + 13, + 13, + 13, + 13, + 13 + ] +} + +Running TotalAmountByPassenger... +Iteration 1: 18626.53 ms (rows returned: 13) +Iteration 2: 17192.35 ms (rows returned: 13) +Iteration 3: 17178.07 ms (rows returned: 13) +Iteration 4: 17203.18 ms (rows returned: 13) +Iteration 5: 17205.09 ms (rows returned: 13) +{ + "driver": "Arrow Flight SQL ODBC Driver", + "schema": "Samples.samples.dremio.com", + "table": "NYC-taxi-trips-iceberg", + "iterations": 5, + "query": "SELECT passenger_count, SUM(total_amount) AS total_amount FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\" GROUP BY passenger_count", + "avg_ms": 17481.04419708252, + "min_ms": 17178.06601524353, + "max_ms": 18626.532077789307, + "all_runs_ms": [ + 18626.532077789307, + 17192.352533340454, + 17178.06601524353, + 17203.18055152893, + 17205.089807510376 + ], + "rows_returned": [ + 13, + 13, + 13, + 13, + 13 + ] +} +Iteration 1: 17208.52 ms (rows returned: 13) +Iteration 2: 17161.80 ms (rows returned: 13) +Iteration 3: 17121.03 ms (rows returned: 13) +Iteration 4: 17118.08 ms (rows returned: 13) +Iteration 5: 17151.74 ms (rows returned: 13) +{ + "driver": "Apache Arrow Flight SQL ODBC Driver", + "schema": "Samples.samples.dremio.com", + "table": "NYC-taxi-trips-iceberg", + "iterations": 5, + "query": "SELECT passenger_count, SUM(total_amount) AS total_amount FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\" GROUP BY passenger_count", + "avg_ms": 17152.232027053833, + "min_ms": 17118.079662322998, + "max_ms": 17208.51707458496, + "all_runs_ms": [ + 17208.51707458496, + 17161.800146102905, + 17121.02508544922, + 17118.079662322998, + 17151.738166809082 + ], + "rows_returned": [ + 13, + 13, + 13, + 13, + 13 + ] +} + +Running FareGreater50... +Iteration 1: 14541.40 ms (rows returned: 500) +Iteration 2: 14376.75 ms (rows returned: 500) +Iteration 3: 14375.08 ms (rows returned: 500) +Iteration 4: 14287.94 ms (rows returned: 500) +Iteration 5: 14298.50 ms (rows returned: 500) +{ + "driver": "Arrow Flight SQL ODBC Driver", + "schema": "Samples.samples.dremio.com", + "table": "NYC-taxi-trips-iceberg", + "iterations": 5, + "query": "SELECT * FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\" WHERE fare_amount > 50 LIMIT 500", + "avg_ms": 14375.934505462646, + "min_ms": 14287.942886352539, + "max_ms": 14541.397094726562, + "all_runs_ms": [ + 14541.397094726562, + 14376.750469207764, + 14375.08249282837, + 14287.942886352539, + 14298.499584197998 + ], + "rows_returned": [ + 500, + 500, + 500, + 500, + 500 + ] +} +Iteration 1: 14336.94 ms (rows returned: 500) +Iteration 2: 14282.82 ms (rows returned: 500) +Iteration 3: 14338.50 ms (rows returned: 500) +Iteration 4: 14307.37 ms (rows returned: 500) +Iteration 5: 14303.22 ms (rows returned: 500) +{ + "driver": "Apache Arrow Flight SQL ODBC Driver", + "schema": "Samples.samples.dremio.com", + "table": "NYC-taxi-trips-iceberg", + "iterations": 5, + "query": "SELECT * FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\" WHERE fare_amount > 50 LIMIT 500", + "avg_ms": 14313.771390914917, + "min_ms": 14282.81831741333, + "max_ms": 14338.501691818237, + "all_runs_ms": [ + 14336.944103240967, + 14282.81831741333, + 14338.501691818237, + 14307.369232177734, + 14303.223609924316 + ], + "rows_returned": [ + 500, + 500, + 500, + 500, + 500 + ] +} + +Running TripDistance5to10... +Iteration 1: 14320.81 ms (rows returned: 500) +Iteration 2: 14312.74 ms (rows returned: 500) +Iteration 3: 14323.37 ms (rows returned: 500) +Iteration 4: 14342.22 ms (rows returned: 500) +Iteration 5: 14334.44 ms (rows returned: 500) +{ + "driver": "Arrow Flight SQL ODBC Driver", + "schema": "Samples.samples.dremio.com", + "table": "NYC-taxi-trips-iceberg", + "iterations": 5, + "query": "SELECT * FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\" WHERE trip_distance_mi BETWEEN 5 AND 10 LIMIT 500", + "avg_ms": 14326.717376708984, + "min_ms": 14312.744140625, + "max_ms": 14342.220067977905, + "all_runs_ms": [ + 14320.80602645874, + 14312.744140625, + 14323.374509811401, + 14342.220067977905, + 14334.442138671875 + ], + "rows_returned": [ + 500, + 500, + 500, + 500, + 500 + ] +} +Iteration 1: 14335.08 ms (rows returned: 500) +Iteration 2: 14351.85 ms (rows returned: 500) +Iteration 3: 14409.90 ms (rows returned: 500) +Iteration 4: 14369.25 ms (rows returned: 500) +Iteration 5: 14333.16 ms (rows returned: 500) +{ + "driver": "Apache Arrow Flight SQL ODBC Driver", + "schema": "Samples.samples.dremio.com", + "table": "NYC-taxi-trips-iceberg", + "iterations": 5, + "query": "SELECT * FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\" WHERE trip_distance_mi BETWEEN 5 AND 10 LIMIT 500", + "avg_ms": 14359.848976135254, + "min_ms": 14333.15634727478, + "max_ms": 14409.904479980469, + "all_runs_ms": [ + 14335.081577301025, + 14351.84907913208, + 14409.904479980469, + 14369.253396987915, + 14333.15634727478 + ], + "rows_returned": [ + 500, + 500, + 500, + 500, + 500 + ] +} + +Running PassengerCount2... +Iteration 1: 14400.87 ms (rows returned: 500) +Iteration 2: 14330.71 ms (rows returned: 500) +Iteration 3: 14326.70 ms (rows returned: 500) +Iteration 4: 14347.51 ms (rows returned: 500) +Iteration 5: 14340.93 ms (rows returned: 500) +{ + "driver": "Arrow Flight SQL ODBC Driver", + "schema": "Samples.samples.dremio.com", + "table": "NYC-taxi-trips-iceberg", + "iterations": 5, + "query": "SELECT * FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\" WHERE passenger_count = 2 LIMIT 500", + "avg_ms": 14349.34492111206, + "min_ms": 14326.699256896973, + "max_ms": 14400.872230529785, + "all_runs_ms": [ + 14400.872230529785, + 14330.707311630249, + 14326.699256896973, + 14347.514152526855, + 14340.93165397644 + ], + "rows_returned": [ + 500, + 500, + 500, + 500, + 500 + ] +} +Iteration 1: 14389.24 ms (rows returned: 500) +Iteration 2: 14426.34 ms (rows returned: 500) +Iteration 3: 14356.14 ms (rows returned: 500) +Iteration 4: 14306.25 ms (rows returned: 500) +Iteration 5: 14299.65 ms (rows returned: 500) +{ + "driver": "Apache Arrow Flight SQL ODBC Driver", + "schema": "Samples.samples.dremio.com", + "table": "NYC-taxi-trips-iceberg", + "iterations": 5, + "query": "SELECT * FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\" WHERE passenger_count = 2 LIMIT 500", + "avg_ms": 14355.523681640625, + "min_ms": 14299.64542388916, + "max_ms": 14426.344871520996, + "all_runs_ms": [ + 14389.240503311157, + 14426.344871520996, + 14356.136083602905, + 14306.251525878906, + 14299.64542388916 + ], + "rows_returned": [ + 500, + 500, + 500, + 500, + 500 + ] +} + +Running TipGreater10... +Iteration 1: 14351.80 ms (rows returned: 500) +Iteration 2: 14319.90 ms (rows returned: 500) +Iteration 3: 14294.81 ms (rows returned: 500) +Iteration 4: 14296.91 ms (rows returned: 500) +Iteration 5: 14397.65 ms (rows returned: 500) +{ + "driver": "Arrow Flight SQL ODBC Driver", + "schema": "Samples.samples.dremio.com", + "table": "NYC-taxi-trips-iceberg", + "iterations": 5, + "query": "SELECT * FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\" WHERE tip_amount > 10 LIMIT 500", + "avg_ms": 14332.215118408203, + "min_ms": 14294.811964035034, + "max_ms": 14397.649049758911, + "all_runs_ms": [ + 14351.803302764893, + 14319.896459579468, + 14294.811964035034, + 14296.91481590271, + 14397.649049758911 + ], + "rows_returned": [ + 500, + 500, + 500, + 500, + 500 + ] +} +Iteration 1: 14357.85 ms (rows returned: 500) +Iteration 2: 14396.18 ms (rows returned: 500) +Iteration 3: 14342.35 ms (rows returned: 500) +Iteration 4: 14400.37 ms (rows returned: 500) +Iteration 5: 14313.89 ms (rows returned: 500) +{ + "driver": "Apache Arrow Flight SQL ODBC Driver", + "schema": "Samples.samples.dremio.com", + "table": "NYC-taxi-trips-iceberg", + "iterations": 5, + "query": "SELECT * FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\" WHERE tip_amount > 10 LIMIT 500", + "avg_ms": 14362.126731872559, + "min_ms": 14313.886642456055, + "max_ms": 14400.365591049194, + "all_runs_ms": [ + 14357.853174209595, + 14396.17919921875, + 14342.3490524292, + 14400.365591049194, + 14313.886642456055 + ], + "rows_returned": [ + 500, + 500, + 500, + 500, + 500 + ] +} + +Running OrderByFareDesc... +Iteration 1: 54312.90 ms (rows returned: 100) +Iteration 2: 53397.82 ms (rows returned: 100) +Iteration 3: 53488.77 ms (rows returned: 100) +Iteration 4: 53777.30 ms (rows returned: 100) +Iteration 5: 54228.48 ms (rows returned: 100) +{ + "driver": "Arrow Flight SQL ODBC Driver", + "schema": "Samples.samples.dremio.com", + "table": "NYC-taxi-trips-iceberg", + "iterations": 5, + "query": "SELECT * FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\" ORDER BY fare_amount DESC LIMIT 100", + "avg_ms": 53841.05463027954, + "min_ms": 53397.82190322876, + "max_ms": 54312.901735305786, + "all_runs_ms": [ + 54312.901735305786, + 53397.82190322876, + 53488.770961761475, + 53777.30345726013, + 54228.47509384155 + ], + "rows_returned": [ + 100, + 100, + 100, + 100, + 100 + ] +} +Iteration 1: 54100.68 ms (rows returned: 100) +Iteration 2: 53510.49 ms (rows returned: 100) +Iteration 3: 53884.82 ms (rows returned: 100) +Iteration 4: 53539.88 ms (rows returned: 100) +Iteration 5: 54196.78 ms (rows returned: 100) +{ + "driver": "Apache Arrow Flight SQL ODBC Driver", + "schema": "Samples.samples.dremio.com", + "table": "NYC-taxi-trips-iceberg", + "iterations": 5, + "query": "SELECT * FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\" ORDER BY fare_amount DESC LIMIT 100", + "avg_ms": 53846.53010368347, + "min_ms": 53510.48803329468, + "max_ms": 54196.778535842896, + "all_runs_ms": [ + 54100.68130493164, + 53510.48803329468, + 53884.81903076172, + 53539.883613586426, + 54196.778535842896 + ], + "rows_returned": [ + 100, + 100, + 100, + 100, + 100 + ] +} + +Running OrderByTripDistance... +Iteration 1: 54123.02 ms (rows returned: 100) +Iteration 2: 54615.67 ms (rows returned: 100) +Iteration 3: 54261.92 ms (rows returned: 100) +Iteration 4: 53612.34 ms (rows returned: 100) +Iteration 5: 54194.30 ms (rows returned: 100) +{ + "driver": "Arrow Flight SQL ODBC Driver", + "schema": "Samples.samples.dremio.com", + "table": "NYC-taxi-trips-iceberg", + "iterations": 5, + "query": "SELECT * FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\" ORDER BY trip_distance_mi DESC LIMIT 100", + "avg_ms": 54161.45000457764, + "min_ms": 53612.33949661255, + "max_ms": 54615.665674209595, + "all_runs_ms": [ + 54123.01993370056, + 54615.665674209595, + 54261.92402839661, + 53612.33949661255, + 54194.30088996887 + ], + "rows_returned": [ + 100, + 100, + 100, + 100, + 100 + ] +} +Iteration 1: 53885.21 ms (rows returned: 100) +Iteration 2: 53874.51 ms (rows returned: 100) +Iteration 3: 53506.34 ms (rows returned: 100) +Iteration 4: 53779.29 ms (rows returned: 100) +Iteration 5: 53635.18 ms (rows returned: 100) +{ + "driver": "Apache Arrow Flight SQL ODBC Driver", + "schema": "Samples.samples.dremio.com", + "table": "NYC-taxi-trips-iceberg", + "iterations": 5, + "query": "SELECT * FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\" ORDER BY trip_distance_mi DESC LIMIT 100", + "avg_ms": 53736.1065864563, + "min_ms": 53506.33525848389, + "max_ms": 53885.21337509155, + "all_runs_ms": [ + 53885.21337509155, + 53874.51457977295, + 53506.33525848389, + 53779.29162979126, + 53635.178089141846 + ], + "rows_returned": [ + 100, + 100, + 100, + 100, + 100 + ] +} + +Running TotalRowCount... +Iteration 1: 5569.19 ms (rows returned: 1) +Iteration 2: 4449.97 ms (rows returned: 1) +Iteration 3: 4016.96 ms (rows returned: 1) +Iteration 4: 4170.72 ms (rows returned: 1) +Iteration 5: 3826.20 ms (rows returned: 1) +{ + "driver": "Arrow Flight SQL ODBC Driver", + "schema": "Samples.samples.dremio.com", + "table": "NYC-taxi-trips-iceberg", + "iterations": 5, + "query": "SELECT COUNT(*) AS total_trips FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\"", + "avg_ms": 4406.608629226685, + "min_ms": 3826.195478439331, + "max_ms": 5569.1938400268555, + "all_runs_ms": [ + 5569.1938400268555, + 4449.97239112854, + 4016.963243484497, + 4170.718193054199, + 3826.195478439331 + ], + "rows_returned": [ + 1, + 1, + 1, + 1, + 1 + ] +} +Iteration 1: 3837.28 ms (rows returned: 1) +Iteration 2: 3836.22 ms (rows returned: 1) +Iteration 3: 3894.01 ms (rows returned: 1) +Iteration 4: 3891.32 ms (rows returned: 1) +Iteration 5: 3926.83 ms (rows returned: 1) +{ + "driver": "Apache Arrow Flight SQL ODBC Driver", + "schema": "Samples.samples.dremio.com", + "table": "NYC-taxi-trips-iceberg", + "iterations": 5, + "query": "SELECT COUNT(*) AS total_trips FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\"", + "avg_ms": 3877.1294116973877, + "min_ms": 3836.2209796905518, + "max_ms": 3926.828145980835, + "all_runs_ms": [ + 3837.275981903076, + 3836.2209796905518, + 3894.005537033081, + 3891.3164138793945, + 3926.828145980835 + ], + "rows_returned": [ + 1, + 1, + 1, + 1, + 1 + ] +} + +Running MaxFare... +Iteration 1: 9391.77 ms (rows returned: 1) +Iteration 2: 9165.55 ms (rows returned: 1) +Iteration 3: 9127.03 ms (rows returned: 1) +Iteration 4: 9082.91 ms (rows returned: 1) +Iteration 5: 9111.72 ms (rows returned: 1) +{ + "driver": "Arrow Flight SQL ODBC Driver", + "schema": "Samples.samples.dremio.com", + "table": "NYC-taxi-trips-iceberg", + "iterations": 5, + "query": "SELECT MAX(fare_amount) AS max_fare FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\"", + "avg_ms": 9175.795316696167, + "min_ms": 9082.907915115356, + "max_ms": 9391.76869392395, + "all_runs_ms": [ + 9391.76869392395, + 9165.551900863647, + 9127.027034759521, + 9082.907915115356, + 9111.72103881836 + ], + "rows_returned": [ + 1, + 1, + 1, + 1, + 1 + ] +} +Iteration 1: 9277.34 ms (rows returned: 1) +Iteration 2: 9252.20 ms (rows returned: 1) +Iteration 3: 9093.64 ms (rows returned: 1) +Iteration 4: 9075.61 ms (rows returned: 1) +Iteration 5: 9082.72 ms (rows returned: 1) +{ + "driver": "Apache Arrow Flight SQL ODBC Driver", + "schema": "Samples.samples.dremio.com", + "table": "NYC-taxi-trips-iceberg", + "iterations": 5, + "query": "SELECT MAX(fare_amount) AS max_fare FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\"", + "avg_ms": 9156.304168701172, + "min_ms": 9075.613260269165, + "max_ms": 9277.344703674316, + "all_runs_ms": [ + 9277.344703674316, + 9252.196550369263, + 9093.641996383667, + 9075.613260269165, + 9082.724332809448 + ], + "rows_returned": [ + 1, + 1, + 1, + 1, + 1 + ] +} + +Running MinFare... +Iteration 1: 9504.61 ms (rows returned: 1) +Iteration 2: 9677.66 ms (rows returned: 1) +Iteration 3: 9393.83 ms (rows returned: 1) +Iteration 4: 9294.61 ms (rows returned: 1) +Iteration 5: 9556.50 ms (rows returned: 1) +{ + "driver": "Arrow Flight SQL ODBC Driver", + "schema": "Samples.samples.dremio.com", + "table": "NYC-taxi-trips-iceberg", + "iterations": 5, + "query": "SELECT MIN(fare_amount) AS min_fare FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\"", + "avg_ms": 9485.442781448364, + "min_ms": 9294.612646102905, + "max_ms": 9677.664518356323, + "all_runs_ms": [ + 9504.612684249878, + 9677.664518356323, + 9393.826484680176, + 9294.612646102905, + 9556.497573852539 + ], + "rows_returned": [ + 1, + 1, + 1, + 1, + 1 + ] +} +Iteration 1: 9221.09 ms (rows returned: 1) +Iteration 2: 9285.73 ms (rows returned: 1) +Iteration 3: 9213.77 ms (rows returned: 1) +Iteration 4: 9380.58 ms (rows returned: 1) +Iteration 5: 9288.95 ms (rows returned: 1) +{ + "driver": "Apache Arrow Flight SQL ODBC Driver", + "schema": "Samples.samples.dremio.com", + "table": "NYC-taxi-trips-iceberg", + "iterations": 5, + "query": "SELECT MIN(fare_amount) AS min_fare FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\"", + "avg_ms": 9278.022623062134, + "min_ms": 9213.771343231201, + "max_ms": 9380.575180053711, + "all_runs_ms": [ + 9221.086978912354, + 9285.732746124268, + 9213.771343231201, + 9380.575180053711, + 9288.946866989136 + ], + "rows_returned": [ + 1, + 1, + 1, + 1, + 1 + ] +} + +Running MaxTripDistance... +Iteration 1: 10542.72 ms (rows returned: 1) +Iteration 2: 10567.05 ms (rows returned: 1) +Iteration 3: 10554.66 ms (rows returned: 1) +Iteration 4: 10329.56 ms (rows returned: 1) +Iteration 5: 10375.75 ms (rows returned: 1) +{ + "driver": "Arrow Flight SQL ODBC Driver", + "schema": "Samples.samples.dremio.com", + "table": "NYC-taxi-trips-iceberg", + "iterations": 5, + "query": "SELECT MAX(trip_distance_mi) AS max_distance FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\"", + "avg_ms": 10473.949575424194, + "min_ms": 10329.558372497559, + "max_ms": 10567.053079605103, + "all_runs_ms": [ + 10542.723655700684, + 10567.053079605103, + 10554.66103553772, + 10329.558372497559, + 10375.751733779907 + ], + "rows_returned": [ + 1, + 1, + 1, + 1, + 1 + ] +} +Iteration 1: 10339.24 ms (rows returned: 1) +Iteration 2: 10281.62 ms (rows returned: 1) +Iteration 3: 11202.08 ms (rows returned: 1) +Iteration 4: 10792.55 ms (rows returned: 1) +Iteration 5: 10545.28 ms (rows returned: 1) +{ + "driver": "Apache Arrow Flight SQL ODBC Driver", + "schema": "Samples.samples.dremio.com", + "table": "NYC-taxi-trips-iceberg", + "iterations": 5, + "query": "SELECT MAX(trip_distance_mi) AS max_distance FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\"", + "avg_ms": 10632.155895233154, + "min_ms": 10281.619787216187, + "max_ms": 11202.081203460693, + "all_runs_ms": [ + 10339.242935180664, + 10281.619787216187, + 11202.081203460693, + 10792.552709579468, + 10545.28284072876 + ], + "rows_returned": [ + 1, + 1, + 1, + 1, + 1 + ] +} + +Running TripsByYear... +Iteration 1: 24763.79 ms (rows returned: 2) +Iteration 2: 24189.83 ms (rows returned: 2) +Iteration 3: 24179.34 ms (rows returned: 2) +Iteration 4: 24288.34 ms (rows returned: 2) +Iteration 5: 24510.13 ms (rows returned: 2) +{ + "driver": "Arrow Flight SQL ODBC Driver", + "schema": "Samples.samples.dremio.com", + "table": "NYC-taxi-trips-iceberg", + "iterations": 5, + "query": "SELECT EXTRACT(YEAR FROM pickup_datetime) \"year\", COUNT(*) \"trips\" FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\" GROUP BY \"year\" ORDER BY \"year\"", + "avg_ms": 24386.285734176636, + "min_ms": 24179.344177246094, + "max_ms": 24763.78583908081, + "all_runs_ms": [ + 24763.78583908081, + 24189.826726913452, + 24179.344177246094, + 24288.337230682373, + 24510.13469696045 + ], + "rows_returned": [ + 2, + 2, + 2, + 2, + 2 + ] +} +Iteration 1: 24203.49 ms (rows returned: 2) +Iteration 2: 23960.76 ms (rows returned: 2) +Iteration 3: 24465.03 ms (rows returned: 2) +Iteration 4: 24695.98 ms (rows returned: 2) +Iteration 5: 24106.13 ms (rows returned: 2) +{ + "driver": "Apache Arrow Flight SQL ODBC Driver", + "schema": "Samples.samples.dremio.com", + "table": "NYC-taxi-trips-iceberg", + "iterations": 5, + "query": "SELECT EXTRACT(YEAR FROM pickup_datetime) \"year\", COUNT(*) \"trips\" FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\" GROUP BY \"year\" ORDER BY \"year\"", + "avg_ms": 24286.277627944946, + "min_ms": 23960.764408111572, + "max_ms": 24695.979833602905, + "all_runs_ms": [ + 24203.487396240234, + 23960.764408111572, + 24465.02923965454, + 24695.979833602905, + 24106.12726211548 + ], + "rows_returned": [ + 2, + 2, + 2, + 2, + 2 + ] +} + +Running TripsByMonth... +Iteration 1: 24291.11 ms (rows returned: 12) +Iteration 2: 24134.81 ms (rows returned: 12) +Iteration 3: 24545.46 ms (rows returned: 12) +Iteration 4: 24278.44 ms (rows returned: 12) +Iteration 5: 24340.06 ms (rows returned: 12) +{ + "driver": "Arrow Flight SQL ODBC Driver", + "schema": "Samples.samples.dremio.com", + "table": "NYC-taxi-trips-iceberg", + "iterations": 5, + "query": "SELECT EXTRACT(MONTH FROM pickup_datetime) \"month\", COUNT(*) \"trips\" FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\" GROUP BY \"month\" ORDER BY \"month\"", + "avg_ms": 24317.972707748413, + "min_ms": 24134.80544090271, + "max_ms": 24545.45521736145, + "all_runs_ms": [ + 24291.110515594482, + 24134.80544090271, + 24545.45521736145, + 24278.4366607666, + 24340.05570411682 + ], + "rows_returned": [ + 12, + 12, + 12, + 12, + 12 + ] +} +Iteration 1: 24174.38 ms (rows returned: 12) +Iteration 2: 24303.72 ms (rows returned: 12) +Iteration 3: 24389.76 ms (rows returned: 12) +Iteration 4: 24270.72 ms (rows returned: 12) +Iteration 5: 24196.28 ms (rows returned: 12) +{ + "driver": "Apache Arrow Flight SQL ODBC Driver", + "schema": "Samples.samples.dremio.com", + "table": "NYC-taxi-trips-iceberg", + "iterations": 5, + "query": "SELECT EXTRACT(MONTH FROM pickup_datetime) \"month\", COUNT(*) \"trips\" FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\" GROUP BY \"month\" ORDER BY \"month\"", + "avg_ms": 24266.971921920776, + "min_ms": 24174.379348754883, + "max_ms": 24389.7647857666, + "all_runs_ms": [ + 24174.379348754883, + 24303.71594429016, + 24389.7647857666, + 24270.715713500977, + 24196.28381729126 + ], + "rows_returned": [ + 12, + 12, + 12, + 12, + 12 + ] +} + +Running TopFaresRanked... +Iteration 1: 223123.71 ms (rows returned: 100) +Iteration 2: 225523.15 ms (rows returned: 100) +Iteration 3: 229951.83 ms (rows returned: 100) +Iteration 4: 222796.05 ms (rows returned: 100) +Iteration 5: 221088.02 ms (rows returned: 100) +{ + "driver": "Arrow Flight SQL ODBC Driver", + "schema": "Samples.samples.dremio.com", + "table": "NYC-taxi-trips-iceberg", + "iterations": 5, + "query": "SELECT fare_amount, rank_col FROM (SELECT fare_amount, ROW_NUMBER() OVER (ORDER BY fare_amount DESC) AS rank_col FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\") t LIMIT 100", + "avg_ms": 224496.55227661133, + "min_ms": 221088.01984786987, + "max_ms": 229951.82847976685, + "all_runs_ms": [ + 223123.71277809143, + 225523.15068244934, + 229951.82847976685, + 222796.04959487915, + 221088.01984786987 + ], + "rows_returned": [ + 100, + 100, + 100, + 100, + 100 + ] +} +Iteration 1: 234645.46 ms (rows returned: 100) +Iteration 2: 224257.77 ms (rows returned: 100) +Iteration 3: 225608.64 ms (rows returned: 100) +Iteration 4: 221247.26 ms (rows returned: 100) +Iteration 5: 218252.25 ms (rows returned: 100) +{ + "driver": "Apache Arrow Flight SQL ODBC Driver", + "schema": "Samples.samples.dremio.com", + "table": "NYC-taxi-trips-iceberg", + "iterations": 5, + "query": "SELECT fare_amount, rank_col FROM (SELECT fare_amount, ROW_NUMBER() OVER (ORDER BY fare_amount DESC) AS rank_col FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\") t LIMIT 100", + "avg_ms": 224802.27527618408, + "min_ms": 218252.2463798523, + "max_ms": 234645.4563140869, + "all_runs_ms": [ + 234645.4563140869, + 224257.76839256287, + 225608.64353179932, + 221247.26176261902, + 218252.2463798523 + ], + "rows_returned": [ + 100, + 100, + 100, + 100, + 100 + ] +} + +Running AvgFareByPassengerWindow... +Iteration 1: 193259.05 ms (rows returned: 500) +Iteration 2: 192692.28 ms (rows returned: 500) +Iteration 3: 194490.74 ms (rows returned: 500) +Iteration 4: 191565.32 ms (rows returned: 500) +Iteration 5: 191919.91 ms (rows returned: 500) +{ + "driver": "Arrow Flight SQL ODBC Driver", + "schema": "Samples.samples.dremio.com", + "table": "NYC-taxi-trips-iceberg", + "iterations": 5, + "query": "SELECT passenger_count, AVG(fare_amount) OVER (PARTITION BY passenger_count) AS avg_fare FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\" LIMIT 500", + "avg_ms": 192785.4600906372, + "min_ms": 191565.31715393066, + "max_ms": 194490.74482917786, + "all_runs_ms": [ + 193259.05179977417, + 192692.2812461853, + 194490.74482917786, + 191565.31715393066, + 191919.90542411804 + ], + "rows_returned": [ + 500, + 500, + 500, + 500, + 500 + ] +} +Iteration 1: 193995.94 ms (rows returned: 500) +Iteration 2: 191825.51 ms (rows returned: 500) +Iteration 3: 192794.67 ms (rows returned: 500) +Iteration 4: 192621.31 ms (rows returned: 500) +Iteration 5: 192819.00 ms (rows returned: 500) +{ + "driver": "Apache Arrow Flight SQL ODBC Driver", + "schema": "Samples.samples.dremio.com", + "table": "NYC-taxi-trips-iceberg", + "iterations": 5, + "query": "SELECT passenger_count, AVG(fare_amount) OVER (PARTITION BY passenger_count) AS avg_fare FROM \"Samples.samples.dremio.com\".\"NYC-taxi-trips-iceberg\" LIMIT 500", + "avg_ms": 192811.2874031067, + "min_ms": 191825.510263443, + "max_ms": 193995.94283103943, + "all_runs_ms": [ + 193995.94283103943, + 191825.510263443, + 192794.67058181763, + 192621.31428718567, + 192818.99905204773 + ], + "rows_returned": [ + 500, + 500, + 500, + 500, + 500 + ] +} + +Results written to compare_20_queries.csv +Plot saved to results_20_queries.png +PS C:\Users\Administrator\GitHub\arrow\cpp\src\arrow\flight\sql\odbc\performance_tests> \ No newline at end of file diff --git a/cpp/src/arrow/flight/sql/odbc/performance_tests/TEST_RUNS/003CompareQueries20/compare_20_queries.csv b/cpp/src/arrow/flight/sql/odbc/performance_tests/TEST_RUNS/003CompareQueries20/compare_20_queries.csv new file mode 100644 index 00000000000..c642f84467c --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/performance_tests/TEST_RUNS/003CompareQueries20/compare_20_queries.csv @@ -0,0 +1,21 @@ +query,driver_a,a_avg,a_min,a_max,driver_b,b_avg,b_min,b_max +Limit100,Arrow Flight SQL ODBC Driver,3148.7356185913086,3084.550380706787,3258.1324577331543,Apache Arrow Flight SQL ODBC Driver,3102.1849632263184,3073.0504989624023,3135.3671550750732 +Limit1000,Arrow Flight SQL ODBC Driver,28278.762102127075,28206.9673538208,28336.461544036865,Apache Arrow Flight SQL ODBC Driver,28324.5304107666,28239.282846450806,28472.376108169556 +AvgFareByPassenger,Arrow Flight SQL ODBC Driver,17626.128387451172,17411.471366882324,18290.47703742981,Apache Arrow Flight SQL ODBC Driver,17461.08193397522,17411.21244430542,17630.999088287354 +SumFareByPassenger,Arrow Flight SQL ODBC Driver,17114.98770713806,17029.832124710083,17219.08450126648,Apache Arrow Flight SQL ODBC Driver,16738.711404800415,16698.858499526978,16788.630723953247 +AvgTipByPassenger,Arrow Flight SQL ODBC Driver,17335.482931137085,17048.28643798828,17996.389627456665,Apache Arrow Flight SQL ODBC Driver,17145.911836624146,17094.25401687622,17263.51237297058 +TotalAmountByPassenger,Arrow Flight SQL ODBC Driver,17481.04419708252,17178.06601524353,18626.532077789307,Apache Arrow Flight SQL ODBC Driver,17152.232027053833,17118.079662322998,17208.51707458496 +FareGreater50,Arrow Flight SQL ODBC Driver,14375.934505462646,14287.942886352539,14541.397094726562,Apache Arrow Flight SQL ODBC Driver,14313.771390914917,14282.81831741333,14338.501691818237 +TripDistance5to10,Arrow Flight SQL ODBC Driver,14326.717376708984,14312.744140625,14342.220067977905,Apache Arrow Flight SQL ODBC Driver,14359.848976135254,14333.15634727478,14409.904479980469 +PassengerCount2,Arrow Flight SQL ODBC Driver,14349.34492111206,14326.699256896973,14400.872230529785,Apache Arrow Flight SQL ODBC Driver,14355.523681640625,14299.64542388916,14426.344871520996 +TipGreater10,Arrow Flight SQL ODBC Driver,14332.215118408203,14294.811964035034,14397.649049758911,Apache Arrow Flight SQL ODBC Driver,14362.126731872559,14313.886642456055,14400.365591049194 +OrderByFareDesc,Arrow Flight SQL ODBC Driver,53841.05463027954,53397.82190322876,54312.901735305786,Apache Arrow Flight SQL ODBC Driver,53846.53010368347,53510.48803329468,54196.778535842896 +OrderByTripDistance,Arrow Flight SQL ODBC Driver,54161.45000457764,53612.33949661255,54615.665674209595,Apache Arrow Flight SQL ODBC Driver,53736.1065864563,53506.33525848389,53885.21337509155 +TotalRowCount,Arrow Flight SQL ODBC Driver,4406.608629226685,3826.195478439331,5569.1938400268555,Apache Arrow Flight SQL ODBC Driver,3877.1294116973877,3836.2209796905518,3926.828145980835 +MaxFare,Arrow Flight SQL ODBC Driver,9175.795316696167,9082.907915115356,9391.76869392395,Apache Arrow Flight SQL ODBC Driver,9156.304168701172,9075.613260269165,9277.344703674316 +MinFare,Arrow Flight SQL ODBC Driver,9485.442781448364,9294.612646102905,9677.664518356323,Apache Arrow Flight SQL ODBC Driver,9278.022623062134,9213.771343231201,9380.575180053711 +MaxTripDistance,Arrow Flight SQL ODBC Driver,10473.949575424194,10329.558372497559,10567.053079605103,Apache Arrow Flight SQL ODBC Driver,10632.155895233154,10281.619787216187,11202.081203460693 +TripsByYear,Arrow Flight SQL ODBC Driver,24386.285734176636,24179.344177246094,24763.78583908081,Apache Arrow Flight SQL ODBC Driver,24286.277627944946,23960.764408111572,24695.979833602905 +TripsByMonth,Arrow Flight SQL ODBC Driver,24317.972707748413,24134.80544090271,24545.45521736145,Apache Arrow Flight SQL ODBC Driver,24266.971921920776,24174.379348754883,24389.7647857666 +TopFaresRanked,Arrow Flight SQL ODBC Driver,224496.55227661133,221088.01984786987,229951.82847976685,Apache Arrow Flight SQL ODBC Driver,224802.27527618408,218252.2463798523,234645.4563140869 +AvgFareByPassengerWindow,Arrow Flight SQL ODBC Driver,192785.4600906372,191565.31715393066,194490.74482917786,Apache Arrow Flight SQL ODBC Driver,192811.2874031067,191825.510263443,193995.94283103943 diff --git a/cpp/src/arrow/flight/sql/odbc/performance_tests/TEST_RUNS/003CompareQueries20/results_20_queries.png b/cpp/src/arrow/flight/sql/odbc/performance_tests/TEST_RUNS/003CompareQueries20/results_20_queries.png new file mode 100644 index 00000000000..242f2619d58 Binary files /dev/null and b/cpp/src/arrow/flight/sql/odbc/performance_tests/TEST_RUNS/003CompareQueries20/results_20_queries.png differ diff --git a/cpp/src/arrow/flight/sql/odbc/performance_tests/compare_limits_plot.py b/cpp/src/arrow/flight/sql/odbc/performance_tests/compare_limits_plot.py new file mode 100644 index 00000000000..de314376fc6 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/performance_tests/compare_limits_plot.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python3 +import subprocess +import argparse +import json +import matplotlib.pyplot as plt +import csv + +# ---------------------------- +# ARGUMENTS +# ---------------------------- +parser = argparse.ArgumentParser(description="Compare two ODBC Flight SQL drivers using JSON output.") +parser.add_argument("--driver_a", required=True, help="First driver name") +parser.add_argument("--driver_b", required=True, help="Second driver name") +parser.add_argument("--iterations", type=int, default=5, help="Number of iterations per limit") +parser.add_argument("--limits", type=int, nargs="+", default=[100, 1000, 10000], help="List of LIMIT values") +parser.add_argument("--outfile", default="perf_results.csv", help="CSV output filename") +parser.add_argument("--plotfile", default="perf_plot.png", help="Plot output filename") +args = parser.parse_args() + +# ---------------------------- +# HELPER FUNCTION TO RUN DRIVER +# ---------------------------- +def run_driver(driver_name): + results = {} + for limit in args.limits: + cmd = [ + "python", "table_read_test.py", + "--driver", driver_name, + "--limit", str(limit), + "--iterations", str(args.iterations), + "--json" + ] + print(f"\nRunning {driver_name} with LIMIT={limit}...") + + try: + proc = subprocess.run(cmd, capture_output=True, text=True, check=True) + output = proc.stdout.strip() + # Extract JSON from output + json_start = output.find("{") + json_end = output.rfind("}") + 1 + json_text = output[json_start:json_end] + data = json.loads(json_text) + + # --- PRINT INTERMEDIATE JSON RESULTS --- + print(f"Intermediate JSON for {driver_name} LIMIT={limit}:") + print(json.dumps(data, indent=2)) + + avg_ms = data["avg_ms"] + min_ms = data["min_ms"] + max_ms = data["max_ms"] + results[limit] = (avg_ms, min_ms, max_ms) + except subprocess.CalledProcessError as e: + print(f"Error running {driver_name} with LIMIT={limit}:\n{e.stdout}\n{e.stderr}") + results[limit] = (None, None, None) + except Exception as e: + print(f"Failed to parse JSON for {driver_name} with LIMIT={limit}: {e}") + results[limit] = (None, None, None) + + return results + +# ---------------------------- +# RUN DRIVERS +# ---------------------------- +driver_a_results = run_driver(args.driver_a) +driver_b_results = run_driver(args.driver_b) + +# ---------------------------- +# SAVE RESULTS TO CSV +# ---------------------------- +with open(args.outfile, "w", newline="") as f: + writer = csv.writer(f) + writer.writerow(["Driver", "Limit", "Avg_ms", "Min_ms", "Max_ms"]) + for limit, (avg, min_, max_) in driver_a_results.items(): + writer.writerow([args.driver_a, limit, avg, min_, max_]) + for limit, (avg, min_, max_) in driver_b_results.items(): + writer.writerow([args.driver_b, limit, avg, min_, max_]) +print(f"Results saved to {args.outfile}") + +# ---------------------------- +# PLOT RESULTS WITH ERROR BARS +# ---------------------------- +limits = args.limits +avg_a = [driver_a_results[l][0] for l in limits] +avg_b = [driver_b_results[l][0] for l in limits] + +err_a = [[avg_a[i] - driver_a_results[l][1] if driver_a_results[l][1] else 0 for i, l in enumerate(limits)], + [driver_a_results[l][2] - avg_a[i] if driver_a_results[l][2] else 0 for i, l in enumerate(limits)]] + +err_b = [[avg_b[i] - driver_b_results[l][1] if driver_b_results[l][1] else 0 for i, l in enumerate(limits)], + [driver_b_results[l][2] - avg_b[i] if driver_b_results[l][2] else 0 for i, l in enumerate(limits)]] + +plt.figure(figsize=(10,6)) +plt.errorbar(limits, avg_a, yerr=err_a, fmt='o-', capsize=5, label=args.driver_a) +plt.errorbar(limits, avg_b, yerr=err_b, fmt='s-', capsize=5, label=args.driver_b) +plt.xlabel("LIMIT value") +plt.ylabel("Query time (ms)") +plt.title("ODBC Flight SQL Driver Performance Comparison") +plt.legend() +plt.grid(True) +plt.xscale("log") +plt.yscale("log") +plt.savefig(args.plotfile) +print(f"Plot saved to {args.plotfile}") +plt.show() diff --git a/cpp/src/arrow/flight/sql/odbc/performance_tests/compare_queries_bar.py b/cpp/src/arrow/flight/sql/odbc/performance_tests/compare_queries_bar.py new file mode 100644 index 00000000000..6bc009792ca --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/performance_tests/compare_queries_bar.py @@ -0,0 +1,190 @@ +#!/usr/bin/env python3 +import argparse +import subprocess +import csv +import json +import matplotlib.pyplot as plt + +def run_query(driver, schema, table, query, iterations, label): + cmd = [ + "python", "table_read_test.py", + "--driver", driver, + "--schema", schema, + "--table", table, + "--iterations", str(iterations), + "--json", + "--query", query + ] + try: + proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True) + full_output = [] + json_lines = [] + inside_json = False + for line in proc.stdout: + line = line.rstrip() + full_output.append(line) + print(line) + if line.startswith("{"): + inside_json = True + if inside_json: + json_lines.append(line) + if line.endswith("}"): + inside_json = False + proc.wait() + + if not json_lines: + print(f"\n--- FULL CHILD OUTPUT ({label}, {driver}) ---") + print("\n".join(full_output)) + print("--- END CHILD OUTPUT ---\n") + return None + + json_text = "\n".join(json_lines) + return json.loads(json_text) + + except Exception as e: + print(f"Error running {label} for {driver}: {e}") + return None + +def main(): + parser = argparse.ArgumentParser(description="Compare two ODBC Flight SQL drivers on multiple queries") + parser.add_argument("--driver_a", required=True, help="Driver A name") + parser.add_argument("--driver_b", required=True, help="Driver B name") + parser.add_argument("--schema", required=True, help="Schema name") + parser.add_argument("--table", required=True, help="Table name") + parser.add_argument("--iterations", type=int, default=5, help="Number of iterations per query") + parser.add_argument("--outfile", default="compare_results.csv", help="CSV output file") + parser.add_argument("--plotfile", default="compare_plot.png", help="Plot output file") + args = parser.parse_args() + + # ---------------------------- + # 20 Queries + # ---------------------------- + queries = { + # Simple selects + "Limit100": f'SELECT * FROM "{args.schema}"."{args.table}" LIMIT 100', + "Limit1000": f'SELECT * FROM "{args.schema}"."{args.table}" LIMIT 1000', + + # Aggregations & Group By + "AvgFareByPassenger": f'SELECT passenger_count, AVG(fare_amount) AS avg_fare FROM "{args.schema}"."{args.table}" GROUP BY passenger_count', + "SumFareByPassenger": f'SELECT passenger_count, SUM(fare_amount) AS total_fare FROM "{args.schema}"."{args.table}" GROUP BY passenger_count', + "AvgTipByPassenger": f'SELECT passenger_count, AVG(tip_amount) AS avg_tip FROM "{args.schema}"."{args.table}" GROUP BY passenger_count', + "TotalAmountByPassenger": f'SELECT passenger_count, SUM(total_amount) AS total_amount FROM "{args.schema}"."{args.table}" GROUP BY passenger_count', + + # Filters + "FareGreater50": f'SELECT * FROM "{args.schema}"."{args.table}" WHERE fare_amount > 50 LIMIT 500', + "TripDistance5to10": f'SELECT * FROM "{args.schema}"."{args.table}" WHERE trip_distance_mi BETWEEN 5 AND 10 LIMIT 500', + "PassengerCount2": f'SELECT * FROM "{args.schema}"."{args.table}" WHERE passenger_count = 2 LIMIT 500', + "TipGreater10": f'SELECT * FROM "{args.schema}"."{args.table}" WHERE tip_amount > 10 LIMIT 500', + + # Ordering + "OrderByFareDesc": f'SELECT * FROM "{args.schema}"."{args.table}" ORDER BY fare_amount DESC LIMIT 100', + "OrderByTripDistance": f'SELECT * FROM "{args.schema}"."{args.table}" ORDER BY trip_distance_mi DESC LIMIT 100', + + # Aggregates without group by + "TotalRowCount": f'SELECT COUNT(*) AS total_trips FROM "{args.schema}"."{args.table}"', + "MaxFare": f'SELECT MAX(fare_amount) AS max_fare FROM "{args.schema}"."{args.table}"', + "MinFare": f'SELECT MIN(fare_amount) AS min_fare FROM "{args.schema}"."{args.table}"', + "MaxTripDistance": f'SELECT MAX(trip_distance_mi) AS max_distance FROM "{args.schema}"."{args.table}"', + + # Date/time functions + "TripsByYear": f'SELECT EXTRACT(YEAR FROM pickup_datetime) "year", COUNT(*) "trips" FROM "{args.schema}"."{args.table}" GROUP BY "year" ORDER BY "year"', + + "TripsByMonth": f'SELECT EXTRACT(MONTH FROM pickup_datetime) "month", COUNT(*) "trips" FROM "{args.schema}"."{args.table}" GROUP BY "month" ORDER BY "month"', + + # Window functions + "TopFaresRanked": f'SELECT fare_amount, rank_col FROM (SELECT fare_amount, ROW_NUMBER() OVER (ORDER BY fare_amount DESC) AS rank_col FROM "{args.schema}"."{args.table}") t LIMIT 100', + "AvgFareByPassengerWindow": f'SELECT passenger_count, AVG(fare_amount) OVER (PARTITION BY passenger_count) AS avg_fare FROM "{args.schema}"."{args.table}" LIMIT 500' + } + + results = [] + + for label, query in queries.items(): + print(f"\nRunning {label}...") + result_a = run_query(args.driver_a, args.schema, args.table, query, args.iterations, label) + result_b = run_query(args.driver_b, args.schema, args.table, query, args.iterations, label) + + row = {"query": label} + if result_a: + row.update({ + "driver_a": args.driver_a, + "a_avg": result_a["avg_ms"], + "a_min": result_a["min_ms"], + "a_max": result_a["max_ms"], + }) + else: + row.update({ + "driver_a": args.driver_a, + "a_avg": "N/A", + "a_min": "N/A", + "a_max": "N/A", + }) + + if result_b: + row.update({ + "driver_b": args.driver_b, + "b_avg": result_b["avg_ms"], + "b_min": result_b["min_ms"], + "b_max": result_b["max_ms"], + }) + else: + row.update({ + "driver_b": args.driver_b, + "b_avg": "N/A", + "b_min": "N/A", + "b_max": "N/A", + }) + + results.append(row) + + # ---------------------------- + # Write results to CSV + # ---------------------------- + with open(args.outfile, "w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=[ + "query", "driver_a", "a_avg", "a_min", "a_max", + "driver_b", "b_avg", "b_min", "b_max" + ]) + writer.writeheader() + for row in results: + writer.writerow(row) + print(f"\nResults written to {args.outfile}") + + # ---------------------------- + # Plot results + # ---------------------------- + queries_labels = [r["query"] for r in results] + x = range(len(queries_labels)) + width = 0.35 + fig, ax = plt.subplots(figsize=(14, 6)) + + driver_colors = { + "a": "#1f77b4", + "b": "#ff7f0e" + } + + for i, r in enumerate(results): + if r["a_avg"] != "N/A": + a_mean = r["a_avg"] + a_err = [[a_mean - r["a_min"]], [r["a_max"] - a_mean]] + ax.bar(i - width/2, a_mean, width, yerr=a_err, capsize=5, + color=driver_colors["a"], alpha=0.8, label=args.driver_a if i == 0 else "") + if r["b_avg"] != "N/A": + b_mean = r["b_avg"] + b_err = [[b_mean - r["b_min"]], [r["b_max"] - b_mean]] + ax.bar(i + width/2, b_mean, width, yerr=b_err, capsize=5, + color=driver_colors["b"], alpha=0.8, label=args.driver_b if i == 0 else "") + + ax.set_ylabel("Execution time (ms)") + ax.set_title("Query Performance Comparison: Driver A vs Driver B") + ax.set_xticks(x) + ax.set_xticklabels(queries_labels, rotation=45, ha="right") + ax.legend(title="Driver") + + plt.tight_layout() + plt.savefig(args.plotfile) + print(f"Plot saved to {args.plotfile}") + plt.show() + + +if __name__ == "__main__": + main() diff --git a/cpp/src/arrow/flight/sql/odbc/performance_tests/table_read_test.py b/cpp/src/arrow/flight/sql/odbc/performance_tests/table_read_test.py new file mode 100644 index 00000000000..f75c82083ba --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/performance_tests/table_read_test.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python3 +import argparse +import pyodbc +import time +import json +import csv +import sys +from os import environ + +def main(): + parser = argparse.ArgumentParser( + description="Benchmark ODBC Flight SQL Driver query performance." + ) + parser.add_argument("--driver", required=True, help="ODBC driver name") + parser.add_argument("--schema", required=True, help="Schema name") + parser.add_argument("--table", required=True, help="Table name") + parser.add_argument("--limit", type=int, help="LIMIT for SELECT query (ignored if --query provided)") + parser.add_argument("--iterations", type=int, default=1, help="Number of iterations to run") + parser.add_argument("--json", action="store_true", help="Output JSON") + parser.add_argument("--csv", help="Output CSV file") + parser.add_argument("--query", type=str, help="Optional SQL query to run") + args = parser.parse_args() + + token = environ.get("token") + if not token: + raise RuntimeError("Environment variable 'token' must be set.") + + # Build SQL query + if args.query: + sql = args.query + else: + limit = args.limit if args.limit else 100 + sql = f'SELECT * FROM "{args.schema}"."{args.table}" LIMIT {limit}' + + # Connection string for Arrow Flight SQL ODBC + conn_str = ( + f"Driver={{{args.driver}}};" + "ConnectionType=Direct;" + "HOST=dremio-clients-demo.test.drem.io;" + "PORT=32010;" + "AuthenticationType=Plain;" + f"UID=improving;PWD={token};ssl=true;" + "DisableCertificateVerification=1;" + ) + + # Connect + try: + conn = pyodbc.connect(conn_str, autocommit=True) + except pyodbc.Error as e: + print(f"Failed to connect using driver '{args.driver}': {e}") + sys.exit(1) + + # UTF-8 decoding for CHAR columns + conn.setdecoding(pyodbc.SQL_CHAR, encoding="utf-8") + conn.setdecoding(pyodbc.SQL_WCHAR, encoding="utf-16le") + cursor = conn.cursor() + + # Run benchmark + timings = [] + rows_returned = [] + for i in range(args.iterations): + start = time.time() + try: + cursor.execute(sql) + rows = cursor.fetchall() + elapsed_ms = (time.time() - start) * 1000 + timings.append(elapsed_ms) + rows_returned.append(len(rows)) + print(f"Iteration {i+1}: {elapsed_ms:.2f} ms (rows returned: {len(rows)})") + except pyodbc.Error as e: + print(f"Query execution failed: {e}") + timings.append(None) + rows_returned.append(None) + + valid_timings = [t for t in timings if t is not None] + avg_ms = sum(valid_timings)/len(valid_timings) if valid_timings else None + min_ms = min(valid_timings) if valid_timings else None + max_ms = max(valid_timings) if valid_timings else None + + result = { + "driver": args.driver, + "schema": args.schema, + "table": args.table, + "iterations": args.iterations, + "query": sql, + "avg_ms": avg_ms, + "min_ms": min_ms, + "max_ms": max_ms, + "all_runs_ms": timings, + "rows_returned": rows_returned, + } + + # JSON output + if args.json: + print(json.dumps(result, indent=2)) + + # CSV output + if args.csv: + with open(args.csv, "w", newline="") as f: + writer = csv.writer(f) + writer.writerow(result.keys()) + writer.writerow(result.values()) + + cursor.close() + conn.close() + +if __name__ == "__main__": + main() diff --git a/cpp/src/arrow/flight/sql/odbc/tests/CMakeLists.txt b/cpp/src/arrow/flight/sql/odbc/tests/CMakeLists.txt new file mode 100644 index 00000000000..2dc719fa05e --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/tests/CMakeLists.txt @@ -0,0 +1,56 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +add_custom_target(tests) + +include_directories(${ODBC_INCLUDE_DIRS}) + +find_package(SQLite3Alt REQUIRED) + +set(ARROW_FLIGHT_SQL_MOCK_SERVER_SRCS + ../../example/sqlite_sql_info.cc + ../../example/sqlite_type_info.cc + ../../example/sqlite_statement.cc + ../../example/sqlite_statement_batch_reader.cc + ../../example/sqlite_server.cc + ../../example/sqlite_tables_schema_batch_reader.cc) + +add_arrow_test(flight_sql_odbc_test + SOURCES + columns_test.cc + connection_attr_test.cc + connection_info_test.cc + errors_test.cc + get_functions_test.cc + statement_attr_test.cc + statement_test.cc + tables_test.cc + type_info_test.cc + # Connection test needs to be put last to resolve segfault issue + connection_test.cc + odbc_test_suite.cc + odbc_test_suite.h + # Enable Protobuf cleanup after test execution + # GH-46889: move protobuf_test_util to a more common location + ../../../../engine/substrait/protobuf_test_util.cc + ${ARROW_FLIGHT_SQL_MOCK_SERVER_SRCS} + EXTRA_LINK_LIBS + ${ODBC_LIBRARIES} + ${ODBCINST} + ${SQLite3_LIBRARIES} + arrow_odbc_spi_impl + odbcabstraction) diff --git a/cpp/src/arrow/flight/sql/odbc/tests/README b/cpp/src/arrow/flight/sql/odbc/tests/README new file mode 100644 index 00000000000..8e43296edff --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/tests/README @@ -0,0 +1,4 @@ +Prior to running the tests, set environment variable `ARROW_FLIGHT_SQL_ODBC_CONN` +to a valid connection string. +A valid connection string looks like: +driver={Apache Arrow Flight SQL ODBC Driver};HOST=localhost;port=32010;pwd=myPassword;uid=myName;useEncryption=false; diff --git a/cpp/src/arrow/flight/sql/odbc/tests/columns_test.cc b/cpp/src/arrow/flight/sql/odbc/tests/columns_test.cc new file mode 100644 index 00000000000..60a8e251576 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/tests/columns_test.cc @@ -0,0 +1,2887 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +#include "arrow/flight/sql/odbc/tests/odbc_test_suite.h" + +#ifdef _WIN32 +# include +#endif + +#include +#include +#include + +#include "gtest/gtest.h" + +namespace arrow::flight::sql::odbc { +// Helper functions +void checkSQLColumns( + SQLHSTMT stmt, const std::wstring& expectedTable, const std::wstring& expectedColumn, + const SQLINTEGER& expectedDataType, const std::wstring& expectedTypeName, + const SQLINTEGER& expectedColumnSize, const SQLINTEGER& expectedBufferLength, + const SQLSMALLINT& expectedDecimalDigits, const SQLSMALLINT& expectedNumPrecRadix, + const SQLSMALLINT& expectedNullable, const SQLSMALLINT& expectedSqlDataType, + const SQLSMALLINT& expectedDateTimeSub, const SQLINTEGER& expectedOctetCharLength, + const SQLINTEGER& expectedOrdinalPosition, const std::wstring& expectedIsNullable) { + CheckStringColumnW(stmt, 3, expectedTable); // table name + CheckStringColumnW(stmt, 4, expectedColumn); // column name + + CheckIntColumn(stmt, 5, expectedDataType); // data type + + CheckStringColumnW(stmt, 6, expectedTypeName); // type name + + CheckIntColumn(stmt, 7, expectedColumnSize); // column size + CheckIntColumn(stmt, 8, expectedBufferLength); // buffer length + + CheckSmallIntColumn(stmt, 9, expectedDecimalDigits); // decimal digits + CheckSmallIntColumn(stmt, 10, expectedNumPrecRadix); // num prec radix + CheckSmallIntColumn(stmt, 11, + expectedNullable); // nullable + + CheckNullColumnW(stmt, 12); // remarks + CheckNullColumnW(stmt, 13); // column def + + CheckSmallIntColumn(stmt, 14, expectedSqlDataType); // sql data type + CheckSmallIntColumn(stmt, 15, expectedDateTimeSub); // sql date type sub + CheckIntColumn(stmt, 16, expectedOctetCharLength); // char octet length + CheckIntColumn(stmt, 17, + expectedOrdinalPosition); // oridinal position + + CheckStringColumnW(stmt, 18, expectedIsNullable); // is nullable +} + +void checkMockSQLColumns( + SQLHSTMT stmt, const std::wstring& expectedCatalog, const std::wstring& expectedTable, + const std::wstring& expectedColumn, const SQLINTEGER& expectedDataType, + const std::wstring& expectedTypeName, const SQLINTEGER& expectedColumnSize, + const SQLINTEGER& expectedBufferLength, const SQLSMALLINT& expectedDecimalDigits, + const SQLSMALLINT& expectedNumPrecRadix, const SQLSMALLINT& expectedNullable, + const SQLSMALLINT& expectedSqlDataType, const SQLSMALLINT& expectedDateTimeSub, + const SQLINTEGER& expectedOctetCharLength, const SQLINTEGER& expectedOrdinalPosition, + const std::wstring& expectedIsNullable) { + CheckStringColumnW(stmt, 1, expectedCatalog); // catalog + CheckNullColumnW(stmt, 2); // schema + + checkSQLColumns(stmt, expectedTable, expectedColumn, expectedDataType, expectedTypeName, + expectedColumnSize, expectedBufferLength, expectedDecimalDigits, + expectedNumPrecRadix, expectedNullable, expectedSqlDataType, + expectedDateTimeSub, expectedOctetCharLength, expectedOrdinalPosition, + expectedIsNullable); +} + +void checkRemoteSQLColumns( + SQLHSTMT stmt, const std::wstring& expectedSchema, const std::wstring& expectedTable, + const std::wstring& expectedColumn, const SQLINTEGER& expectedDataType, + const std::wstring& expectedTypeName, const SQLINTEGER& expectedColumnSize, + const SQLINTEGER& expectedBufferLength, const SQLSMALLINT& expectedDecimalDigits, + const SQLSMALLINT& expectedNumPrecRadix, const SQLSMALLINT& expectedNullable, + const SQLSMALLINT& expectedSqlDataType, const SQLSMALLINT& expectedDateTimeSub, + const SQLINTEGER& expectedOctetCharLength, const SQLINTEGER& expectedOrdinalPosition, + const std::wstring& expectedIsNullable) { + CheckNullColumnW(stmt, 1); // catalog + CheckStringColumnW(stmt, 2, expectedSchema); // schema + checkSQLColumns(stmt, expectedTable, expectedColumn, expectedDataType, expectedTypeName, + expectedColumnSize, expectedBufferLength, expectedDecimalDigits, + expectedNumPrecRadix, expectedNullable, expectedSqlDataType, + expectedDateTimeSub, expectedOctetCharLength, expectedOrdinalPosition, + expectedIsNullable); +} + +void checkSQLColAttribute(SQLHSTMT stmt, SQLUSMALLINT idx, + const std::wstring& expectedColmnName, SQLLEN expectedDataType, + SQLLEN expectedConciseType, SQLLEN expectedDisplaySize, + SQLLEN expectedPrecScale, SQLLEN expectedLength, + const std::wstring& expectedLiteralPrefix, + const std::wstring& expectedLiteralSuffix, + SQLLEN expectedColumnSize, SQLLEN expectedColumnScale, + SQLLEN expectedColumnNullability, SQLLEN expectedNumPrecRadix, + SQLLEN expectedOctetLength, SQLLEN expectedSearchable, + SQLLEN expectedUnsignedColumn) { + std::vector name(ODBC_BUFFER_SIZE); + SQLSMALLINT nameLen = 0; + std::vector baseColumnName(ODBC_BUFFER_SIZE); + SQLSMALLINT columnNameLen = 0; + std::vector label(ODBC_BUFFER_SIZE); + SQLSMALLINT labelLen = 0; + std::vector prefix(ODBC_BUFFER_SIZE); + SQLSMALLINT prefixLen = 0; + std::vector suffix(ODBC_BUFFER_SIZE); + SQLSMALLINT suffixLen = 0; + SQLLEN dataType = 0; + SQLLEN conciseType = 0; + SQLLEN displaySize = 0; + SQLLEN precScale = 0; + SQLLEN length = 0; + SQLLEN size = 0; + SQLLEN scale = 0; + SQLLEN nullability = 0; + SQLLEN numPrecRadix = 0; + SQLLEN octetLength = 0; + SQLLEN searchable = 0; + SQLLEN unsignedCol = 0; + + SQLRETURN ret = SQLColAttribute(stmt, idx, SQL_DESC_NAME, &name[0], + (SQLSMALLINT)name.size(), &nameLen, 0); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLColAttribute(stmt, idx, SQL_DESC_BASE_COLUMN_NAME, &baseColumnName[0], + (SQLSMALLINT)baseColumnName.size(), &columnNameLen, 0); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLColAttribute(stmt, idx, SQL_DESC_LABEL, &label[0], (SQLSMALLINT)label.size(), + &labelLen, 0); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLColAttribute(stmt, idx, SQL_DESC_TYPE, 0, 0, 0, &dataType); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLColAttribute(stmt, idx, SQL_DESC_CONCISE_TYPE, 0, 0, 0, &conciseType); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLColAttribute(stmt, idx, SQL_DESC_DISPLAY_SIZE, 0, 0, 0, &displaySize); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLColAttribute(stmt, idx, SQL_DESC_FIXED_PREC_SCALE, 0, 0, 0, &precScale); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLColAttribute(stmt, idx, SQL_DESC_LENGTH, 0, 0, 0, &length); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLColAttribute(stmt, idx, SQL_DESC_LITERAL_PREFIX, &prefix[0], + (SQLSMALLINT)prefix.size(), &prefixLen, 0); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLColAttribute(stmt, idx, SQL_DESC_LITERAL_SUFFIX, &suffix[0], + (SQLSMALLINT)suffix.size(), &suffixLen, 0); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLColAttribute(stmt, idx, SQL_DESC_PRECISION, 0, 0, 0, &size); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLColAttribute(stmt, idx, SQL_DESC_SCALE, 0, 0, 0, &scale); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLColAttribute(stmt, idx, SQL_DESC_NULLABLE, 0, 0, 0, &nullability); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLColAttribute(stmt, idx, SQL_DESC_NUM_PREC_RADIX, 0, 0, 0, &numPrecRadix); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLColAttribute(stmt, idx, SQL_DESC_OCTET_LENGTH, 0, 0, 0, &octetLength); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLColAttribute(stmt, idx, SQL_DESC_SEARCHABLE, 0, 0, 0, &searchable); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLColAttribute(stmt, idx, SQL_DESC_UNSIGNED, 0, 0, 0, &unsignedCol); + EXPECT_EQ(ret, SQL_SUCCESS); + + std::wstring nameStr = ConvertToWString(name, nameLen); + std::wstring baseColumnNameStr = ConvertToWString(baseColumnName, columnNameLen); + std::wstring labelStr = ConvertToWString(label, labelLen); + std::wstring prefixStr = ConvertToWString(prefix, prefixLen); + + // Assume column name, base column name, and label are equivalent in the result set + EXPECT_EQ(nameStr, expectedColmnName); + EXPECT_EQ(baseColumnNameStr, expectedColmnName); + EXPECT_EQ(labelStr, expectedColmnName); + EXPECT_EQ(dataType, expectedDataType); + EXPECT_EQ(conciseType, expectedConciseType); + EXPECT_EQ(displaySize, expectedDisplaySize); + EXPECT_EQ(precScale, expectedPrecScale); + EXPECT_EQ(length, expectedLength); + EXPECT_EQ(prefixStr, expectedLiteralPrefix); + EXPECT_EQ(size, expectedColumnSize); + EXPECT_EQ(scale, expectedColumnScale); + EXPECT_EQ(nullability, expectedColumnNullability); + EXPECT_EQ(numPrecRadix, expectedNumPrecRadix); + EXPECT_EQ(octetLength, expectedOctetLength); + EXPECT_EQ(searchable, expectedSearchable); + EXPECT_EQ(unsignedCol, expectedUnsignedColumn); +} + +void checkSQLColAttributes(SQLHSTMT stmt, SQLUSMALLINT idx, + const std::wstring& expectedColmnName, SQLLEN expectedDataType, + SQLLEN expectedDisplaySize, SQLLEN expectedPrecScale, + SQLLEN expectedLength, SQLLEN expectedColumnSize, + SQLLEN expectedColumnScale, SQLLEN expectedColumnNullability, + SQLLEN expectedSearchable, SQLLEN expectedUnsignedColumn) { + std::vector name(ODBC_BUFFER_SIZE); + SQLSMALLINT nameLen = 0; + std::vector label(ODBC_BUFFER_SIZE); + SQLSMALLINT labelLen = 0; + SQLLEN dataType = 0; + SQLLEN displaySize = 0; + SQLLEN precScale = 0; + SQLLEN length = 0; + SQLLEN size = 0; + SQLLEN scale = 0; + SQLLEN nullability = 0; + SQLLEN searchable = 0; + SQLLEN unsignedCol = 0; + + SQLRETURN ret = SQLColAttributes(stmt, idx, SQL_COLUMN_NAME, &name[0], + (SQLSMALLINT)name.size(), &nameLen, 0); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLColAttributes(stmt, idx, SQL_COLUMN_LABEL, &label[0], + (SQLSMALLINT)label.size(), &labelLen, 0); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLColAttributes(stmt, idx, SQL_COLUMN_TYPE, 0, 0, 0, &dataType); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLColAttributes(stmt, idx, SQL_COLUMN_DISPLAY_SIZE, 0, 0, 0, &displaySize); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLColAttribute(stmt, idx, SQL_COLUMN_MONEY, 0, 0, 0, &precScale); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLColAttributes(stmt, idx, SQL_COLUMN_LENGTH, 0, 0, 0, &length); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLColAttributes(stmt, idx, SQL_COLUMN_PRECISION, 0, 0, 0, &size); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLColAttributes(stmt, idx, SQL_COLUMN_SCALE, 0, 0, 0, &scale); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLColAttributes(stmt, idx, SQL_COLUMN_NULLABLE, 0, 0, 0, &nullability); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLColAttributes(stmt, idx, SQL_COLUMN_SEARCHABLE, 0, 0, 0, &searchable); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLColAttributes(stmt, idx, SQL_COLUMN_UNSIGNED, 0, 0, 0, &unsignedCol); + EXPECT_EQ(ret, SQL_SUCCESS); + + std::wstring nameStr = ConvertToWString(name, nameLen); + std::wstring labelStr = ConvertToWString(label, labelLen); + + EXPECT_EQ(nameStr, expectedColmnName); + EXPECT_EQ(labelStr, expectedColmnName); + EXPECT_EQ(dataType, expectedDataType); + EXPECT_EQ(displaySize, expectedDisplaySize); + EXPECT_EQ(length, expectedLength); + EXPECT_EQ(size, expectedColumnSize); + EXPECT_EQ(scale, expectedColumnScale); + EXPECT_EQ(nullability, expectedColumnNullability); + EXPECT_EQ(searchable, expectedSearchable); + EXPECT_EQ(unsignedCol, expectedUnsignedColumn); +} + +void checkSQLColAttributeString(SQLHSTMT stmt, const std::wstring& wsql, SQLUSMALLINT idx, + SQLUSMALLINT fieldIdentifier, + const std::wstring& expectedAttrString) { + // Execute query and check SQLColAttribute string attribute + std::vector sql0(wsql.begin(), wsql.end()); + SQLRETURN ret = SQLExecDirect(stmt, &sql0[0], static_cast(sql0.size())); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLFetch(stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + std::vector strVal(ODBC_BUFFER_SIZE); + SQLSMALLINT strLen = 0; + + ret = SQLColAttribute(stmt, idx, fieldIdentifier, &strVal[0], + (SQLSMALLINT)strVal.size(), &strLen, 0); + EXPECT_EQ(ret, SQL_SUCCESS); + + std::wstring attrStr = ConvertToWString(strVal, strLen); + EXPECT_EQ(attrStr, expectedAttrString); +} + +void checkSQLColAttributeNumeric(SQLHSTMT stmt, const std::wstring& wsql, + SQLUSMALLINT idx, SQLUSMALLINT fieldIdentifier, + SQLLEN expectedAttrNumeric) { + // Execute query and check SQLColAttribute numeric attribute + std::vector sql0(wsql.begin(), wsql.end()); + SQLRETURN ret = SQLExecDirect(stmt, &sql0[0], static_cast(sql0.size())); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLFetch(stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + SQLLEN numVal = 0; + ret = SQLColAttribute(stmt, idx, fieldIdentifier, 0, 0, 0, &numVal); + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_EQ(numVal, expectedAttrNumeric); +} + +void checkSQLColAttributesString(SQLHSTMT stmt, const std::wstring& wsql, + SQLUSMALLINT idx, SQLUSMALLINT fieldIdentifier, + const std::wstring& expectedAttrString) { + // Execute query and check ODBC 2.0 API SQLColAttributes string attribute + std::vector sql0(wsql.begin(), wsql.end()); + SQLRETURN ret = SQLExecDirect(stmt, &sql0[0], static_cast(sql0.size())); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLFetch(stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + std::vector strVal(ODBC_BUFFER_SIZE); + SQLSMALLINT strLen = 0; + + ret = SQLColAttributes(stmt, idx, fieldIdentifier, &strVal[0], + (SQLSMALLINT)strVal.size(), &strLen, 0); + EXPECT_EQ(ret, SQL_SUCCESS); + + std::wstring attrStr = ConvertToWString(strVal, strLen); + EXPECT_EQ(attrStr, expectedAttrString); +} + +void checkSQLColAttributesNumeric(SQLHSTMT stmt, const std::wstring& wsql, + SQLUSMALLINT idx, SQLUSMALLINT fieldIdentifier, + SQLLEN expectedAttrNumeric) { + // Execute query and check ODBC 2.0 API SQLColAttributes numeric attribute + std::vector sql0(wsql.begin(), wsql.end()); + SQLRETURN ret = SQLExecDirect(stmt, &sql0[0], static_cast(sql0.size())); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLFetch(stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + SQLLEN numVal = 0; + ret = SQLColAttributes(stmt, idx, fieldIdentifier, 0, 0, 0, &numVal); + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_EQ(numVal, expectedAttrNumeric); +} + +TYPED_TEST(FlightSQLODBCTestBase, SQLColumnsTestInputData) { + this->connect(); + + SQLWCHAR catalogName[] = L""; + SQLWCHAR schemaName[] = L""; + SQLWCHAR tableName[] = L""; + SQLWCHAR columnName[] = L""; + + // All values populated + SQLRETURN ret = SQLColumns(this->stmt, catalogName, sizeof(catalogName), schemaName, + sizeof(schemaName), tableName, sizeof(tableName), columnName, + sizeof(columnName)); + + EXPECT_EQ(ret, SQL_SUCCESS); + + ValidateFetch(this->stmt, SQL_NO_DATA); + + // Sizes are nulls + ret = + SQLColumns(this->stmt, catalogName, 0, schemaName, 0, tableName, 0, columnName, 0); + + EXPECT_EQ(ret, SQL_SUCCESS); + + ValidateFetch(this->stmt, SQL_NO_DATA); + + // Values are nulls + ret = SQLColumns(this->stmt, 0, sizeof(catalogName), 0, sizeof(schemaName), 0, + sizeof(tableName), 0, sizeof(columnName)); + + EXPECT_EQ(ret, SQL_SUCCESS); + + ValidateFetch(this->stmt, SQL_SUCCESS); + // Close statement cursor to avoid leaving in an invalid state + SQLFreeStmt(this->stmt, SQL_CLOSE); + + // All values and sizes are nulls + ret = SQLColumns(this->stmt, 0, 0, 0, 0, 0, 0, 0, 0); + + EXPECT_EQ(ret, SQL_SUCCESS); + + ValidateFetch(this->stmt, SQL_SUCCESS); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLColumnsAllColumns) { + // Check table pattern and column pattern returns all columns + this->connect(); + + // Attempt to get all columns + SQLWCHAR tablePattern[] = L"%"; + SQLWCHAR columnPattern[] = L"%"; + + SQLRETURN ret = SQLColumns(this->stmt, nullptr, SQL_NTS, nullptr, SQL_NTS, tablePattern, + SQL_NTS, columnPattern, SQL_NTS); + + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + // mock limitation: SQLite mock server returns 10 for bigint size when spec indicates + // should be 19 + // DECIMAL_DIGITS should be 0 for bigint type since it is exact + // mock limitation: SQLite mock server returns 10 for bigint decimal digits when spec + // indicates should be 0 + checkMockSQLColumns(this->stmt, + std::wstring(L"main"), // expectedCatalog + std::wstring(L"foreignTable"), // expectedTable + std::wstring(L"id"), // expectedColumn + SQL_BIGINT, // expectedDataType + std::wstring(L"BIGINT"), // expectedTypeName + 10, // expectedColumnSize (mock returns 10 instead of 19) + 8, // expectedBufferLength + 15, // expectedDecimalDigits (mock returns 15 instead of 0) + 10, // expectedNumPrecRadix + SQL_NULLABLE, // expectedNullable + SQL_BIGINT, // expectedSqlDataType + NULL, // expectedDateTimeSub + 8, // expectedOctetCharLength + 1, // expectedOrdinalPosition + std::wstring(L"YES")); // expectedIsNullable + + // Check 2nd Column + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkMockSQLColumns(this->stmt, + std::wstring(L"main"), // expectedCatalog + std::wstring(L"foreignTable"), // expectedTable + std::wstring(L"foreignName"), // expectedColumn + SQL_WVARCHAR, // expectedDataType + std::wstring(L"WVARCHAR"), // expectedTypeName + 0, // expectedColumnSize (mock server limitation: returns 0 for + // varchar(100), the ODBC spec expects 100) + 0, // expectedBufferLength + 15, // expectedDecimalDigits + 0, // expectedNumPrecRadix + SQL_NULLABLE, // expectedNullable + SQL_WVARCHAR, // expectedSqlDataType + NULL, // expectedDateTimeSub + 0, // expectedOctetCharLength + 2, // expectedOrdinalPosition + std::wstring(L"YES")); // expectedIsNullable + + // Check 3rd Column + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkMockSQLColumns(this->stmt, + std::wstring(L"main"), // expectedCatalog + std::wstring(L"foreignTable"), // expectedTable + std::wstring(L"value"), // expectedColumn + SQL_BIGINT, // expectedDataType + std::wstring(L"BIGINT"), // expectedTypeName + 10, // expectedColumnSize (mock returns 10 instead of 19) + 8, // expectedBufferLength + 15, // expectedDecimalDigits (mock returns 15 instead of 0) + 10, // expectedNumPrecRadix + SQL_NULLABLE, // expectedNullable + SQL_BIGINT, // expectedSqlDataType + NULL, // expectedDateTimeSub + 8, // expectedOctetCharLength + 3, // expectedOrdinalPosition + std::wstring(L"YES")); // expectedIsNullable + + // Check 4th Column + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkMockSQLColumns(this->stmt, + std::wstring(L"main"), // expectedCatalog + std::wstring(L"intTable"), // expectedTable + std::wstring(L"id"), // expectedColumn + SQL_BIGINT, // expectedDataType + std::wstring(L"BIGINT"), // expectedTypeName + 10, // expectedColumnSize (mock returns 10 instead of 19) + 8, // expectedBufferLength + 15, // expectedDecimalDigits (mock returns 15 instead of 0) + 10, // expectedNumPrecRadix + SQL_NULLABLE, // expectedNullable + SQL_BIGINT, // expectedSqlDataType + NULL, // expectedDateTimeSub + 8, // expectedOctetCharLength + 1, // expectedOrdinalPosition + std::wstring(L"YES")); // expectedIsNullable + + // Check 5th Column + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkMockSQLColumns(this->stmt, + std::wstring(L"main"), // expectedCatalog + std::wstring(L"intTable"), // expectedTable + std::wstring(L"keyName"), // expectedColumn + SQL_WVARCHAR, // expectedDataType + std::wstring(L"WVARCHAR"), // expectedTypeName + 0, // expectedColumnSize (mock server limitation: returns 0 for + // varchar(100), the ODBC spec expects 100) + 0, // expectedBufferLength + 15, // expectedDecimalDigits + 0, // expectedNumPrecRadix + SQL_NULLABLE, // expectedNullable + SQL_WVARCHAR, // expectedSqlDataType + NULL, // expectedDateTimeSub + 0, // expectedOctetCharLength + 2, // expectedOrdinalPosition + std::wstring(L"YES")); // expectedIsNullable + + // Check 6th Column + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkMockSQLColumns(this->stmt, + std::wstring(L"main"), // expectedCatalog + std::wstring(L"intTable"), // expectedTable + std::wstring(L"value"), // expectedColumn + SQL_BIGINT, // expectedDataType + std::wstring(L"BIGINT"), // expectedTypeName + 10, // expectedColumnSize (mock returns 10 instead of 19) + 8, // expectedBufferLength + 15, // expectedDecimalDigits (mock returns 15 instead of 0) + 10, // expectedNumPrecRadix + SQL_NULLABLE, // expectedNullable + SQL_BIGINT, // expectedSqlDataType + NULL, // expectedDateTimeSub + 8, // expectedOctetCharLength + 3, // expectedOrdinalPosition + std::wstring(L"YES")); // expectedIsNullable + + // Check 7th Column + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkMockSQLColumns(this->stmt, + std::wstring(L"main"), // expectedCatalog + std::wstring(L"intTable"), // expectedTable + std::wstring(L"foreignId"), // expectedColumn + SQL_BIGINT, // expectedDataType + std::wstring(L"BIGINT"), // expectedTypeName + 10, // expectedColumnSize (mock returns 10 instead of 19) + 8, // expectedBufferLength + 15, // expectedDecimalDigits (mock returns 15 instead of 0) + 10, // expectedNumPrecRadix + SQL_NULLABLE, // expectedNullable + SQL_BIGINT, // expectedSqlDataType + NULL, // expectedDateTimeSub + 8, // expectedOctetCharLength + 4, // expectedOrdinalPosition + std::wstring(L"YES")); // expectedIsNullable + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLColumnsAllTypes) { + // Limitation: Mock server returns incorrect values for column size for some columns. + // For character and binary type columns, the driver calculates buffer length and char + // octet length from column size. + + // Checks filtering table with table name pattern + this->connect(); + this->CreateTableAllDataType(); + + // Attempt to get all columns from AllTypesTable + SQLWCHAR tablePattern[] = L"AllTypesTable"; + SQLWCHAR columnPattern[] = L"%"; + + SQLRETURN ret = SQLColumns(this->stmt, nullptr, SQL_NTS, nullptr, SQL_NTS, tablePattern, + SQL_NTS, columnPattern, SQL_NTS); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Fetch SQLColumn data for 1st column in AllTypesTable + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkMockSQLColumns(this->stmt, + std::wstring(L"main"), // expectedCatalog + std::wstring(L"AllTypesTable"), // expectedTable + std::wstring(L"bigint_col"), // expectedColumn + SQL_BIGINT, // expectedDataType + std::wstring(L"BIGINT"), // expectedTypeName + 10, // expectedColumnSize (mock server limitation: returns 10, + // the ODBC spec expects 19) + 8, // expectedBufferLength + 15, // expectedDecimalDigits (mock server limitation: returns 15, + // the ODBC spec expects 0) + 10, // expectedNumPrecRadix + SQL_NULLABLE, // expectedNullable + SQL_BIGINT, // expectedSqlDataType + NULL, // expectedDateTimeSub + 8, // expectedOctetCharLength + 1, // expectedOrdinalPosition + std::wstring(L"YES")); // expectedIsNullable + + // Check SQLColumn data for 2nd column in AllTypesTable + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkMockSQLColumns(this->stmt, + std::wstring(L"main"), // expectedCatalog + std::wstring(L"AllTypesTable"), // expectedTable + std::wstring(L"char_col"), // expectedColumn + SQL_WVARCHAR, // expectedDataType + std::wstring(L"WVARCHAR"), // expectedTypeName + 0, // expectedColumnSize (mock server limitation: returns 0 for + // varchar(100), the ODBC spec expects 100) + 0, // expectedBufferLength + 15, // expectedDecimalDigits + 0, // expectedNumPrecRadix + SQL_NULLABLE, // expectedNullable + SQL_WVARCHAR, // expectedSqlDataType + NULL, // expectedDateTimeSub + 0, // expectedOctetCharLength + 2, // expectedOrdinalPosition + std::wstring(L"YES")); // expectedIsNullable + + // Check SQLColumn data for 3rd column in AllTypesTable + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkMockSQLColumns(this->stmt, + std::wstring(L"main"), // expectedCatalog + std::wstring(L"AllTypesTable"), // expectedTable + std::wstring(L"varbinary_col"), // expectedColumn + SQL_BINARY, // expectedDataType + std::wstring(L"BINARY"), // expectedTypeName + 0, // expectedColumnSize (mock server limitation: returns 0 for + // BLOB column, spec expects binary data limit) + 0, // expectedBufferLength + 15, // expectedDecimalDigits + 0, // expectedNumPrecRadix + SQL_NULLABLE, // expectedNullable + SQL_BINARY, // expectedSqlDataType + NULL, // expectedDateTimeSub + 0, // expectedOctetCharLength + 3, // expectedOrdinalPosition + std::wstring(L"YES")); // expectedIsNullable + + // Check SQLColumn data for 4th column in AllTypesTable + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkMockSQLColumns(this->stmt, + std::wstring(L"main"), // expectedCatalog + std::wstring(L"AllTypesTable"), // expectedTable + std::wstring(L"double_col"), // expectedColumn + SQL_DOUBLE, // expectedDataType + std::wstring(L"DOUBLE"), // expectedTypeName + 15, // expectedColumnSize + 8, // expectedBufferLength + 15, // expectedDecimalDigits + 2, // expectedNumPrecRadix + SQL_NULLABLE, // expectedNullable + SQL_DOUBLE, // expectedSqlDataType + NULL, // expectedDateTimeSub + 8, // expectedOctetCharLength + 4, // expectedOrdinalPosition + std::wstring(L"YES")); // expectedIsNullable + + // There should be no more column data + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_NO_DATA); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLColumnsUnicode) { + // Limitation: Mock server returns incorrect values for column size for some columns. + // For character and binary type columns, the driver calculates buffer length and char + // octet length from column size. + this->connect(); + this->CreateUnicodeTable(); + + // Attempt to get all columns + SQLWCHAR tablePattern[] = L"数据"; + SQLWCHAR columnPattern[] = L"%"; + + SQLRETURN ret = SQLColumns(this->stmt, nullptr, SQL_NTS, nullptr, SQL_NTS, tablePattern, + SQL_NTS, columnPattern, SQL_NTS); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Check SQLColumn data for 1st column + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkMockSQLColumns(this->stmt, + std::wstring(L"main"), // expectedCatalog + std::wstring(L"数据"), // expectedTable + std::wstring(L"资料"), // expectedColumn + SQL_WVARCHAR, // expectedDataType + std::wstring(L"WVARCHAR"), // expectedTypeName + 0, // expectedColumnSize (mock server limitation: returns 0 for + // varchar(100), spec expects 100) + 0, // expectedBufferLength + 15, // expectedDecimalDigits + 0, // expectedNumPrecRadix + SQL_NULLABLE, // expectedNullable + SQL_WVARCHAR, // expectedSqlDataType + NULL, // expectedDateTimeSub + 0, // expectedOctetCharLength + 1, // expectedOrdinalPosition + std::wstring(L"YES")); // expectedIsNullable + + // There should be no more column data + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_NO_DATA); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCRemoteTestBase, TestSQLColumnsAllTypes) { + // GH-47159: Return NUM_PREC_RADIX based on whether COLUMN_SIZE contains number of + // digits or bits + this->connect(); + + SQLWCHAR tablePattern[] = L"ODBCTest"; + SQLWCHAR columnPattern[] = L"%"; + + SQLRETURN ret = SQLColumns(this->stmt, nullptr, SQL_NTS, nullptr, SQL_NTS, tablePattern, + SQL_NTS, columnPattern, SQL_NTS); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Check 1st Column + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkRemoteSQLColumns(this->stmt, + std::wstring(L"$scratch"), // expectedSchema + std::wstring(L"ODBCTest"), // expectedTable + std::wstring(L"sinteger_max"), // expectedColumn + SQL_INTEGER, // expectedDataType + std::wstring(L"INTEGER"), // expectedTypeName + 32, // expectedColumnSize (remote server returns number of bits) + 4, // expectedBufferLength + 0, // expectedDecimalDigits + 10, // expectedNumPrecRadix + SQL_NULLABLE, // expectedNullable + SQL_INTEGER, // expectedSqlDataType + NULL, // expectedDateTimeSub + 4, // expectedOctetCharLength + 1, // expectedOrdinalPosition + std::wstring(L"YES")); // expectedIsNullable + + // Check 2nd Column + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkRemoteSQLColumns(this->stmt, + std::wstring(L"$scratch"), // expectedSchema + std::wstring(L"ODBCTest"), // expectedTable + std::wstring(L"sbigint_max"), // expectedColumn + SQL_BIGINT, // expectedDataType + std::wstring(L"BIGINT"), // expectedTypeName + 64, // expectedColumnSize (remote server returns number of bits) + 8, // expectedBufferLength + 0, // expectedDecimalDigits + 10, // expectedNumPrecRadix + SQL_NULLABLE, // expectedNullable + SQL_BIGINT, // expectedSqlDataType + NULL, // expectedDateTimeSub + 8, // expectedOctetCharLength + 2, // expectedOrdinalPosition + std::wstring(L"YES")); // expectedIsNullable + + // Check 3rd Column + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkRemoteSQLColumns(this->stmt, + std::wstring(L"$scratch"), // expectedSchema + std::wstring(L"ODBCTest"), // expectedTable + std::wstring(L"decimal_positive"), // expectedColumn + SQL_DECIMAL, // expectedDataType + std::wstring(L"DECIMAL"), // expectedTypeName + 38, // expectedColumnSize + 19, // expectedBufferLength + 0, // expectedDecimalDigits + 10, // expectedNumPrecRadix + SQL_NULLABLE, // expectedNullable + SQL_DECIMAL, // expectedSqlDataType + NULL, // expectedDateTimeSub + 2, // expectedOctetCharLength + 3, // expectedOrdinalPosition + std::wstring(L"YES")); // expectedIsNullable + + // Check 4th Column + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkRemoteSQLColumns(this->stmt, + std::wstring(L"$scratch"), // expectedSchema + std::wstring(L"ODBCTest"), // expectedTable + std::wstring(L"float_max"), // expectedColumn + SQL_FLOAT, // expectedDataType + std::wstring(L"FLOAT"), // expectedTypeName + 24, // expectedColumnSize (precision bits from IEEE 754) + 8, // expectedBufferLength + 0, // expectedDecimalDigits + 2, // expectedNumPrecRadix + SQL_NULLABLE, // expectedNullable + SQL_FLOAT, // expectedSqlDataType + NULL, // expectedDateTimeSub + 8, // expectedOctetCharLength + 4, // expectedOrdinalPosition + std::wstring(L"YES")); // expectedIsNullable + + // Check 5th Column + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkRemoteSQLColumns(this->stmt, + std::wstring(L"$scratch"), // expectedSchema + std::wstring(L"ODBCTest"), // expectedTable + std::wstring(L"double_max"), // expectedColumn + SQL_DOUBLE, // expectedDataType + std::wstring(L"DOUBLE"), // expectedTypeName + 53, // expectedColumnSize (precision bits from IEEE 754) + 8, // expectedBufferLength + 0, // expectedDecimalDigits + 2, // expectedNumPrecRadix + SQL_NULLABLE, // expectedNullable + SQL_DOUBLE, // expectedSqlDataType + NULL, // expectedDateTimeSub + 8, // expectedOctetCharLength + 5, // expectedOrdinalPosition + std::wstring(L"YES")); // expectedIsNullable + + // Check 6th Column + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkRemoteSQLColumns(this->stmt, + std::wstring(L"$scratch"), // expectedSchema + std::wstring(L"ODBCTest"), // expectedTable + std::wstring(L"bit_true"), // expectedColumn + SQL_BIT, // expectedDataType + std::wstring(L"BOOLEAN"), // expectedTypeName + 0, // expectedColumnSize (limitation: remote server remote server + // returns 0, should be 1) + 1, // expectedBufferLength + 0, // expectedDecimalDigits + 0, // expectedNumPrecRadix + SQL_NULLABLE, // expectedNullable + SQL_BIT, // expectedSqlDataType + NULL, // expectedDateTimeSub + 1, // expectedOctetCharLength + 6, // expectedOrdinalPosition + std::wstring(L"YES")); // expectedIsNullable + + // ODBC ver 3 returns SQL_TYPE_DATE, SQL_TYPE_TIME, and SQL_TYPE_TIMESTAMP in the + // DATA_TYPE field + + // Check 7th Column + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkRemoteSQLColumns( + this->stmt, + std::wstring(L"$scratch"), // expectedSchema + std::wstring(L"ODBCTest"), // expectedTable + std::wstring(L"date_max"), // expectedColumn + SQL_TYPE_DATE, // expectedDataType + std::wstring(L"DATE"), // expectedTypeName + 0, // expectedColumnSize (limitation: remote server returns 0, should be 10) + 10, // expectedBufferLength + 0, // expectedDecimalDigits + 0, // expectedNumPrecRadix + SQL_NULLABLE, // expectedNullable + SQL_DATETIME, // expectedSqlDataType + SQL_CODE_DATE, // expectedDateTimeSub + 6, // expectedOctetCharLength + 7, // expectedOrdinalPosition + std::wstring(L"YES")); // expectedIsNullable + + // Check 8th Column + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkRemoteSQLColumns( + this->stmt, + std::wstring(L"$scratch"), // expectedSchema + std::wstring(L"ODBCTest"), // expectedTable + std::wstring(L"time_max"), // expectedColumn + SQL_TYPE_TIME, // expectedDataType + std::wstring(L"TIME"), // expectedTypeName + 3, // expectedColumnSize (limitation: should be 9+fractional digits) + 12, // expectedBufferLength + 0, // expectedDecimalDigits + 0, // expectedNumPrecRadix + SQL_NULLABLE, // expectedNullable + SQL_DATETIME, // expectedSqlDataType + SQL_CODE_TIME, // expectedDateTimeSub + 6, // expectedOctetCharLength + 8, // expectedOrdinalPosition + std::wstring(L"YES")); // expectedIsNullable + + // Check 9th Column + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkRemoteSQLColumns( + this->stmt, + std::wstring(L"$scratch"), // expectedSchema + std::wstring(L"ODBCTest"), // expectedTable + std::wstring(L"timestamp_max"), // expectedColumn + SQL_TYPE_TIMESTAMP, // expectedDataType + std::wstring(L"TIMESTAMP"), // expectedTypeName + 3, // expectedColumnSize (limitation: should be 20+fractional digits) + 23, // expectedBufferLength + 0, // expectedDecimalDigits + 0, // expectedNumPrecRadix + SQL_NULLABLE, // expectedNullable + SQL_DATETIME, // expectedSqlDataType + SQL_CODE_TIMESTAMP, // expectedDateTimeSub + 16, // expectedOctetCharLength + 9, // expectedOrdinalPosition + std::wstring(L"YES")); // expectedIsNullable + + // There is no more column + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_NO_DATA); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCRemoteTestBase, TestSQLColumnsAllTypesODBCVer2) { + // GH-47159: Return NUM_PREC_RADIX based on whether COLUMN_SIZE contains number of + // digits or bits + this->connect(SQL_OV_ODBC2); + + SQLWCHAR tablePattern[] = L"ODBCTest"; + SQLWCHAR columnPattern[] = L"%"; + + SQLRETURN ret = SQLColumns(this->stmt, nullptr, SQL_NTS, nullptr, SQL_NTS, tablePattern, + SQL_NTS, columnPattern, SQL_NTS); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Check 1st Column + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkRemoteSQLColumns(this->stmt, + std::wstring(L"$scratch"), // expectedSchema + std::wstring(L"ODBCTest"), // expectedTable + std::wstring(L"sinteger_max"), // expectedColumn + SQL_INTEGER, // expectedDataType + std::wstring(L"INTEGER"), // expectedTypeName + 32, // expectedColumnSize (remote server returns number of bits) + 4, // expectedBufferLength + 0, // expectedDecimalDigits + 10, // expectedNumPrecRadix + SQL_NULLABLE, // expectedNullable + SQL_INTEGER, // expectedSqlDataType + NULL, // expectedDateTimeSub + 4, // expectedOctetCharLength + 1, // expectedOrdinalPosition + std::wstring(L"YES")); // expectedIsNullable + + // Check 2nd Column + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkRemoteSQLColumns(this->stmt, + std::wstring(L"$scratch"), // expectedSchema + std::wstring(L"ODBCTest"), // expectedTable + std::wstring(L"sbigint_max"), // expectedColumn + SQL_BIGINT, // expectedDataType + std::wstring(L"BIGINT"), // expectedTypeName + 64, // expectedColumnSize (remote server returns number of bits) + 8, // expectedBufferLength + 0, // expectedDecimalDigits + 10, // expectedNumPrecRadix + SQL_NULLABLE, // expectedNullable + SQL_BIGINT, // expectedSqlDataType + NULL, // expectedDateTimeSub + 8, // expectedOctetCharLength + 2, // expectedOrdinalPosition + std::wstring(L"YES")); // expectedIsNullable + + // Check 3rd Column + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkRemoteSQLColumns(this->stmt, + std::wstring(L"$scratch"), // expectedSchema + std::wstring(L"ODBCTest"), // expectedTable + std::wstring(L"decimal_positive"), // expectedColumn + SQL_DECIMAL, // expectedDataType + std::wstring(L"DECIMAL"), // expectedTypeName + 38, // expectedColumnSize + 19, // expectedBufferLength + 0, // expectedDecimalDigits + 10, // expectedNumPrecRadix + SQL_NULLABLE, // expectedNullable + SQL_DECIMAL, // expectedSqlDataType + NULL, // expectedDateTimeSub + 2, // expectedOctetCharLength + 3, // expectedOrdinalPosition + std::wstring(L"YES")); // expectedIsNullable + + // Check 4th Column + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkRemoteSQLColumns(this->stmt, + std::wstring(L"$scratch"), // expectedSchema + std::wstring(L"ODBCTest"), // expectedTable + std::wstring(L"float_max"), // expectedColumn + SQL_FLOAT, // expectedDataType + std::wstring(L"FLOAT"), // expectedTypeName + 24, // expectedColumnSize (precision bits from IEEE 754) + 8, // expectedBufferLength + 0, // expectedDecimalDigits + 2, // expectedNumPrecRadix + SQL_NULLABLE, // expectedNullable + SQL_FLOAT, // expectedSqlDataType + NULL, // expectedDateTimeSub + 8, // expectedOctetCharLength + 4, // expectedOrdinalPosition + std::wstring(L"YES")); // expectedIsNullable + + // Check 5th Column + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkRemoteSQLColumns(this->stmt, + std::wstring(L"$scratch"), // expectedSchema + std::wstring(L"ODBCTest"), // expectedTable + std::wstring(L"double_max"), // expectedColumn + SQL_DOUBLE, // expectedDataType + std::wstring(L"DOUBLE"), // expectedTypeName + 53, // expectedColumnSize (precision bits from IEEE 754) + 8, // expectedBufferLength + 0, // expectedDecimalDigits + 2, // expectedNumPrecRadix + SQL_NULLABLE, // expectedNullable + SQL_DOUBLE, // expectedSqlDataType + NULL, // expectedDateTimeSub + 8, // expectedOctetCharLength + 5, // expectedOrdinalPosition + std::wstring(L"YES")); // expectedIsNullable + + // Check 6th Column + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkRemoteSQLColumns(this->stmt, + std::wstring(L"$scratch"), // expectedSchema + std::wstring(L"ODBCTest"), // expectedTable + std::wstring(L"bit_true"), // expectedColumn + SQL_BIT, // expectedDataType + std::wstring(L"BOOLEAN"), // expectedTypeName + 0, // expectedColumnSize (limitation: remote server remote server + // returns 0, should be 1) + 1, // expectedBufferLength + 0, // expectedDecimalDigits + 0, // expectedNumPrecRadix + SQL_NULLABLE, // expectedNullable + SQL_BIT, // expectedSqlDataType + NULL, // expectedDateTimeSub + 1, // expectedOctetCharLength + 6, // expectedOrdinalPosition + std::wstring(L"YES")); // expectedIsNullable + + // ODBC ver 2 returns SQL_DATE, SQL_TIME, and SQL_TIMESTAMP in the DATA_TYPE field + + // Check 7th Column + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkRemoteSQLColumns( + this->stmt, + std::wstring(L"$scratch"), // expectedSchema + std::wstring(L"ODBCTest"), // expectedTable + std::wstring(L"date_max"), // expectedColumn + SQL_DATE, // expectedDataType + std::wstring(L"DATE"), // expectedTypeName + 0, // expectedColumnSize (limitation: remote server returns 0, should be 10) + 10, // expectedBufferLength + 0, // expectedDecimalDigits + 0, // expectedNumPrecRadix + SQL_NULLABLE, // expectedNullable + SQL_DATETIME, // expectedSqlDataType + SQL_CODE_DATE, // expectedDateTimeSub + 6, // expectedOctetCharLength + 7, // expectedOrdinalPosition + std::wstring(L"YES")); // expectedIsNullable + + // Check 8th Column + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkRemoteSQLColumns( + this->stmt, + std::wstring(L"$scratch"), // expectedSchema + std::wstring(L"ODBCTest"), // expectedTable + std::wstring(L"time_max"), // expectedColumn + SQL_TIME, // expectedDataType + std::wstring(L"TIME"), // expectedTypeName + 3, // expectedColumnSize (limitation: should be 9+fractional digits) + 12, // expectedBufferLength + 0, // expectedDecimalDigits + 0, // expectedNumPrecRadix + SQL_NULLABLE, // expectedNullable + SQL_DATETIME, // expectedSqlDataType + SQL_CODE_TIME, // expectedDateTimeSub + 6, // expectedOctetCharLength + 8, // expectedOrdinalPosition + std::wstring(L"YES")); // expectedIsNullable + + // Check 9th Column + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkRemoteSQLColumns( + this->stmt, + std::wstring(L"$scratch"), // expectedSchema + std::wstring(L"ODBCTest"), // expectedTable + std::wstring(L"timestamp_max"), // expectedColumn + SQL_TIMESTAMP, // expectedDataType + std::wstring(L"TIMESTAMP"), // expectedTypeName + 3, // expectedColumnSize (limitation: should be 20+fractional digits) + 23, // expectedBufferLength + 0, // expectedDecimalDigits + 0, // expectedNumPrecRadix + SQL_NULLABLE, // expectedNullable + SQL_DATETIME, // expectedSqlDataType + SQL_CODE_TIMESTAMP, // expectedDateTimeSub + 16, // expectedOctetCharLength + 9, // expectedOrdinalPosition + std::wstring(L"YES")); // expectedIsNullable + + // There is no more column + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_NO_DATA); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLColumnsColumnPattern) { + // Checks filtering table with column name pattern. + // Only check table and column name + this->connect(); + + SQLWCHAR tablePattern[] = L"%"; + SQLWCHAR columnPattern[] = L"id"; + + SQLRETURN ret = SQLColumns(this->stmt, nullptr, SQL_NTS, nullptr, SQL_NTS, tablePattern, + SQL_NTS, columnPattern, SQL_NTS); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Check 1st Column + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkMockSQLColumns(this->stmt, + std::wstring(L"main"), // expectedCatalog + std::wstring(L"foreignTable"), // expectedTable + std::wstring(L"id"), // expectedColumn + SQL_BIGINT, // expectedDataType + std::wstring(L"BIGINT"), // expectedTypeName + 10, // expectedColumnSize (mock returns 10 instead of 19) + 8, // expectedBufferLength + 15, // expectedDecimalDigits (mock returns 15 instead of 0) + 10, // expectedNumPrecRadix + SQL_NULLABLE, // expectedNullable + SQL_BIGINT, // expectedSqlDataType + NULL, // expectedDateTimeSub + 8, // expectedOctetCharLength + 1, // expectedOrdinalPosition + std::wstring(L"YES")); // expectedIsNullable + + // Check 2nd Column + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkMockSQLColumns(this->stmt, + std::wstring(L"main"), // expectedCatalog + std::wstring(L"intTable"), // expectedTable + std::wstring(L"id"), // expectedColumn + SQL_BIGINT, // expectedDataType + std::wstring(L"BIGINT"), // expectedTypeName + 10, // expectedColumnSize (mock returns 10 instead of 19) + 8, // expectedBufferLength + 15, // expectedDecimalDigits (mock returns 15 instead of 0) + 10, // expectedNumPrecRadix + SQL_NULLABLE, // expectedNullable + SQL_BIGINT, // expectedSqlDataType + NULL, // expectedDateTimeSub + 8, // expectedOctetCharLength + 1, // expectedOrdinalPosition + std::wstring(L"YES")); // expectedIsNullable + + // There is no more column + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_NO_DATA); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLColumnsTableColumnPattern) { + // Checks filtering table with table and column name pattern. + // Only check table and column name + this->connect(); + + SQLWCHAR tablePattern[] = L"foreignTable"; + SQLWCHAR columnPattern[] = L"id"; + + SQLRETURN ret = SQLColumns(this->stmt, nullptr, SQL_NTS, nullptr, SQL_NTS, tablePattern, + SQL_NTS, columnPattern, SQL_NTS); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Check 1st Column + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkMockSQLColumns(this->stmt, + std::wstring(L"main"), // expectedCatalog + std::wstring(L"foreignTable"), // expectedTable + std::wstring(L"id"), // expectedColumn + SQL_BIGINT, // expectedDataType + std::wstring(L"BIGINT"), // expectedTypeName + 10, // expectedColumnSize (mock returns 10 instead of 19) + 8, // expectedBufferLength + 15, // expectedDecimalDigits (mock returns 15 instead of 0) + 10, // expectedNumPrecRadix + SQL_NULLABLE, // expectedNullable + SQL_BIGINT, // expectedSqlDataType + NULL, // expectedDateTimeSub + 8, // expectedOctetCharLength + 1, // expectedOrdinalPosition + std::wstring(L"YES")); // expectedIsNullable + + // There is no more column + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_NO_DATA); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLColumnsInvalidTablePattern) { + this->connect(); + + SQLWCHAR tablePattern[] = L"non-existent-table"; + SQLWCHAR columnPattern[] = L"%"; + + SQLRETURN ret = SQLColumns(this->stmt, nullptr, SQL_NTS, nullptr, SQL_NTS, tablePattern, + SQL_NTS, columnPattern, SQL_NTS); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // There is no column from filter + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_NO_DATA); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, SQLColAttributeTestInputData) { + this->connect(); + + std::wstring wsql = L"SELECT 1 as col1;"; + std::vector sql0(wsql.begin(), wsql.end()); + + SQLRETURN ret = + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size())); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + SQLUSMALLINT idx = 1; + std::vector characterAttr(ODBC_BUFFER_SIZE); + SQLSMALLINT characterAttrLen = 0; + SQLLEN numericAttr = 0; + + // All character values populated + ret = SQLColAttribute(this->stmt, idx, SQL_DESC_NAME, &characterAttr[0], + (SQLSMALLINT)characterAttr.size(), &characterAttrLen, 0); + EXPECT_EQ(ret, SQL_SUCCESS); + + // All numeric values populated + ret = SQLColAttribute(this->stmt, idx, SQL_DESC_COUNT, 0, 0, 0, &numericAttr); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Pass null values, driver should not throw error + ret = SQLColAttribute(this->stmt, idx, SQL_COLUMN_TABLE_NAME, 0, 0, 0, 0); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLColAttribute(this->stmt, idx, SQL_DESC_COUNT, 0, 0, 0, 0); + EXPECT_EQ(ret, SQL_SUCCESS); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, SQLColAttributeGetCharacterLen) { + this->connect(); + + std::wstring wsql = L"SELECT 1 as col1;"; + std::vector sql0(wsql.begin(), wsql.end()); + + SQLRETURN ret = + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size())); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + SQLSMALLINT characterAttrLen = 0; + + // Check length of character attribute + ret = SQLColAttribute(this->stmt, 1, SQL_DESC_BASE_COLUMN_NAME, 0, 0, &characterAttrLen, + 0); + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_EQ(characterAttrLen, 4 * ODBC::GetSqlWCharSize()); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, SQLColAttributeInvalidFieldId) { + this->connect(); + + std::wstring wsql = L"SELECT 1 as col1;"; + std::vector sql0(wsql.begin(), wsql.end()); + + SQLRETURN ret = + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size())); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + SQLUSMALLINT invalidFieldId = -100; + SQLUSMALLINT idx = 1; + std::vector characterAttr(ODBC_BUFFER_SIZE); + SQLSMALLINT characterAttrLen = 0; + SQLLEN numericAttr = 0; + + ret = SQLColAttribute(this->stmt, idx, invalidFieldId, &characterAttr[0], + (SQLSMALLINT)characterAttr.size(), &characterAttrLen, 0); + EXPECT_EQ(ret, SQL_ERROR); + // Verify invalid descriptor field identifier error state is returned + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, error_state_HY091); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, SQLColAttributeInvalidColId) { + this->connect(); + + std::wstring wsql = L"SELECT 1 as col1;"; + std::vector sql0(wsql.begin(), wsql.end()); + + SQLRETURN ret = + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size())); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + SQLUSMALLINT invalidColId = 2; + std::vector characterAttr(ODBC_BUFFER_SIZE); + SQLSMALLINT characterAttrLen = 0; + SQLLEN numericAttr = 0; + + ret = SQLColAttribute(this->stmt, invalidColId, SQL_DESC_BASE_COLUMN_NAME, + &characterAttr[0], (SQLSMALLINT)characterAttr.size(), + &characterAttrLen, 0); + EXPECT_EQ(ret, SQL_ERROR); + // Verify invalid descriptor index error state is returned + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, error_state_07009); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLColAttributeAllTypes) { + this->connect(); + this->CreateTableAllDataType(); + + std::wstring wsql = L"SELECT * from AllTypesTable;"; + std::vector sql0(wsql.begin(), wsql.end()); + + SQLRETURN ret = + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size())); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkSQLColAttribute(this->stmt, 1, + std::wstring(L"bigint_col"), // expectedColmnName + SQL_BIGINT, // expectedDataType + SQL_BIGINT, // expectedConciseType + 20, // expectedDisplaySize + SQL_FALSE, // expectedPrecScale + 8, // expectedLength + std::wstring(L""), // expectedLiteralPrefix + std::wstring(L""), // expectedLiteralSuffix + 8, // expectedColumnSize + 0, // expectedColumnScale + SQL_NULLABLE, // expectedColumnNullability + 10, // expectedNumPrecRadix + 8, // expectedOctetLength + SQL_PRED_NONE, // expectedSearchable + SQL_FALSE); // expectedUnsignedColumn + + checkSQLColAttribute(this->stmt, 2, + std::wstring(L"char_col"), // expectedColmnName + SQL_WVARCHAR, // expectedDataType + SQL_WVARCHAR, // expectedConciseType + 0, // expectedDisplaySize + SQL_FALSE, // expectedPrecScale + 0, // expectedLength + std::wstring(L""), // expectedLiteralPrefix + std::wstring(L""), // expectedLiteralSuffix + 0, // expectedColumnSize + 0, // expectedColumnScale + SQL_NULLABLE, // expectedColumnNullability + 0, // expectedNumPrecRadix + 0, // expectedOctetLength + SQL_PRED_NONE, // expectedSearchable + SQL_TRUE); // expectedUnsignedColumn + + checkSQLColAttribute(this->stmt, 3, + std::wstring(L"varbinary_col"), // expectedColmnName + SQL_BINARY, // expectedDataType + SQL_BINARY, // expectedConciseType + 0, // expectedDisplaySize + SQL_FALSE, // expectedPrecScale + 0, // expectedLength + std::wstring(L""), // expectedLiteralPrefix + std::wstring(L""), // expectedLiteralSuffix + 0, // expectedColumnSize + 0, // expectedColumnScale + SQL_NULLABLE, // expectedColumnNullability + 0, // expectedNumPrecRadix + 0, // expectedOctetLength + SQL_PRED_NONE, // expectedSearchable + SQL_TRUE); // expectedUnsignedColumn + + checkSQLColAttribute(this->stmt, 4, + std::wstring(L"double_col"), // expectedColmnName + SQL_DOUBLE, // expectedDataType + SQL_DOUBLE, // expectedConciseType + 24, // expectedDisplaySize + SQL_FALSE, // expectedPrecScale + 8, // expectedLength + std::wstring(L""), // expectedLiteralPrefix + std::wstring(L""), // expectedLiteralSuffix + 8, // expectedColumnSize + 0, // expectedColumnScale + SQL_NULLABLE, // expectedColumnNullability + 2, // expectedNumPrecRadix + 8, // expectedOctetLength + SQL_PRED_NONE, // expectedSearchable + SQL_FALSE); // expectedUnsignedColumn + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLColAttributesAllTypesODBCVer2) { + // Tests ODBC 2.0 API SQLColAttributes + this->connect(SQL_OV_ODBC2); + this->CreateTableAllDataType(); + + std::wstring wsql = L"SELECT * from AllTypesTable;"; + std::vector sql0(wsql.begin(), wsql.end()); + + SQLRETURN ret = + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size())); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + checkSQLColAttributes(this->stmt, 1, + std::wstring(L"bigint_col"), // expectedColmnName + SQL_BIGINT, // expectedDataType + 20, // expectedDisplaySize + SQL_FALSE, // expectedPrecScale + 8, // expectedLength + 8, // expectedColumnSize + 0, // expectedColumnScale + SQL_NULLABLE, // expectedColumnNullability + SQL_PRED_NONE, // expectedSearchable + SQL_FALSE); // expectedUnsignedColumn + + checkSQLColAttributes(this->stmt, 2, + std::wstring(L"char_col"), // expectedColmnName + SQL_WVARCHAR, // expectedDataType + 0, // expectedDisplaySize + SQL_FALSE, // expectedPrecScale + 0, // expectedLength + 0, // expectedColumnSize + 0, // expectedColumnScale + SQL_NULLABLE, // expectedColumnNullability + SQL_PRED_NONE, // expectedSearchable + SQL_TRUE); // expectedUnsignedColumn + + checkSQLColAttributes(this->stmt, 3, + std::wstring(L"varbinary_col"), // expectedColmnName + SQL_BINARY, // expectedDataType + 0, // expectedDisplaySize + SQL_FALSE, // expectedPrecScale + 0, // expectedLength + 0, // expectedColumnSize + 0, // expectedColumnScale + SQL_NULLABLE, // expectedColumnNullability + SQL_PRED_NONE, // expectedSearchable + SQL_TRUE); // expectedUnsignedColumn + + checkSQLColAttributes(this->stmt, 4, + std::wstring(L"double_col"), // expectedColmnName + SQL_DOUBLE, // expectedDataType + 24, // expectedDisplaySize + SQL_FALSE, // expectedPrecScale + 8, // expectedLength + 8, // expectedColumnSize + 0, // expectedColumnScale + SQL_NULLABLE, // expectedColumnNullability + SQL_PRED_NONE, // expectedSearchable + SQL_FALSE); // expectedUnsignedColumn + + this->disconnect(); +} + +TEST_F(FlightSQLODBCRemoteTestBase, TestSQLColAttributeAllTypes) { + // Test assumes there is a table $scratch.ODBCTest in remote server + this->connect(); + + std::wstring wsql = L"SELECT * from $scratch.ODBCTest;"; + std::vector sql0(wsql.begin(), wsql.end()); + + SQLRETURN ret = + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size())); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkSQLColAttribute(this->stmt, 1, + std::wstring(L"sinteger_max"), // expectedColmnName + SQL_INTEGER, // expectedDataType + SQL_INTEGER, // expectedConciseType + 11, // expectedDisplaySize + SQL_FALSE, // expectedPrecScale + 4, // expectedLength + std::wstring(L""), // expectedLiteralPrefix + std::wstring(L""), // expectedLiteralSuffix + 4, // expectedColumnSize + 0, // expectedColumnScale + SQL_NULLABLE, // expectedColumnNullability + 10, // expectedNumPrecRadix + 4, // expectedOctetLength + SQL_SEARCHABLE, // expectedSearchable + SQL_FALSE); // expectedUnsignedColumn + + checkSQLColAttribute(this->stmt, 2, + std::wstring(L"sbigint_max"), // expectedColmnName + SQL_BIGINT, // expectedDataType + SQL_BIGINT, // expectedConciseType + 20, // expectedDisplaySize + SQL_FALSE, // expectedPrecScale + 8, // expectedLength + std::wstring(L""), // expectedLiteralPrefix + std::wstring(L""), // expectedLiteralSuffix + 8, // expectedColumnSize + 0, // expectedColumnScale + SQL_NULLABLE, // expectedColumnNullability + 10, // expectedNumPrecRadix + 8, // expectedOctetLength + SQL_SEARCHABLE, // expectedSearchable + SQL_FALSE); // expectedUnsignedColumn + + checkSQLColAttribute(this->stmt, 3, + std::wstring(L"decimal_positive"), // expectedColmnName + SQL_DECIMAL, // expectedDataType + SQL_DECIMAL, // expectedConciseType + 40, // expectedDisplaySize + SQL_FALSE, // expectedPrecScale + 19, // expectedLength + std::wstring(L""), // expectedLiteralPrefix + std::wstring(L""), // expectedLiteralSuffix + 19, // expectedColumnSize + 0, // expectedColumnScale + SQL_NULLABLE, // expectedColumnNullability + 10, // expectedNumPrecRadix + 40, // expectedOctetLength + SQL_SEARCHABLE, // expectedSearchable + SQL_FALSE); // expectedUnsignedColumn + + checkSQLColAttribute(this->stmt, 4, + std::wstring(L"float_max"), // expectedColmnName + SQL_FLOAT, // expectedDataType + SQL_FLOAT, // expectedConciseType + 24, // expectedDisplaySize + SQL_FALSE, // expectedPrecScale + 8, // expectedLength + std::wstring(L""), // expectedLiteralPrefix + std::wstring(L""), // expectedLiteralSuffix + 8, // expectedColumnSize + 0, // expectedColumnScale + SQL_NULLABLE, // expectedColumnNullability + 2, // expectedNumPrecRadix + 8, // expectedOctetLength + SQL_SEARCHABLE, // expectedSearchable + SQL_FALSE); // expectedUnsignedColumn + + checkSQLColAttribute(this->stmt, 5, + std::wstring(L"double_max"), // expectedColmnName + SQL_DOUBLE, // expectedDataType + SQL_DOUBLE, // expectedConciseType + 24, // expectedDisplaySize + SQL_FALSE, // expectedPrecScale + 8, // expectedLength + std::wstring(L""), // expectedLiteralPrefix + std::wstring(L""), // expectedLiteralSuffix + 8, // expectedColumnSize + 0, // expectedColumnScale + SQL_NULLABLE, // expectedColumnNullability + 2, // expectedNumPrecRadix + 8, // expectedOctetLength + SQL_SEARCHABLE, // expectedSearchable + SQL_FALSE); // expectedUnsignedColumn + + checkSQLColAttribute(this->stmt, 6, + std::wstring(L"bit_true"), // expectedColmnName + SQL_BIT, // expectedDataType + SQL_BIT, // expectedConciseType + 1, // expectedDisplaySize + SQL_FALSE, // expectedPrecScale + 1, // expectedLength + std::wstring(L""), // expectedLiteralPrefix + std::wstring(L""), // expectedLiteralSuffix + 1, // expectedColumnSize + 0, // expectedColumnScale + SQL_NULLABLE, // expectedColumnNullability + 0, // expectedNumPrecRadix + 1, // expectedOctetLength + SQL_SEARCHABLE, // expectedSearchable + SQL_TRUE); // expectedUnsignedColumn + + checkSQLColAttribute(this->stmt, 7, + std::wstring(L"date_max"), // expectedColmnName + SQL_DATETIME, // expectedDataType + SQL_TYPE_DATE, // expectedConciseType + 10, // expectedDisplaySize + SQL_FALSE, // expectedPrecScale + 10, // expectedLength + std::wstring(L""), // expectedLiteralPrefix + std::wstring(L""), // expectedLiteralSuffix + 10, // expectedColumnSize + 0, // expectedColumnScale + SQL_NULLABLE, // expectedColumnNullability + 0, // expectedNumPrecRadix + 6, // expectedOctetLength + SQL_SEARCHABLE, // expectedSearchable + SQL_TRUE); // expectedUnsignedColumn + + checkSQLColAttribute(this->stmt, 8, + std::wstring(L"time_max"), // expectedColmnName + SQL_DATETIME, // expectedDataType + SQL_TYPE_TIME, // expectedConciseType + 12, // expectedDisplaySize + SQL_FALSE, // expectedPrecScale + 12, // expectedLength + std::wstring(L""), // expectedLiteralPrefix + std::wstring(L""), // expectedLiteralSuffix + 12, // expectedColumnSize + 3, // expectedColumnScale + SQL_NULLABLE, // expectedColumnNullability + 0, // expectedNumPrecRadix + 6, // expectedOctetLength + SQL_SEARCHABLE, // expectedSearchable + SQL_TRUE); // expectedUnsignedColumn + + checkSQLColAttribute(this->stmt, 9, + std::wstring(L"timestamp_max"), // expectedColmnName + SQL_DATETIME, // expectedDataType + SQL_TYPE_TIMESTAMP, // expectedConciseType + 23, // expectedDisplaySize + SQL_FALSE, // expectedPrecScale + 23, // expectedLength + std::wstring(L""), // expectedLiteralPrefix + std::wstring(L""), // expectedLiteralSuffix + 23, // expectedColumnSize + 3, // expectedColumnScale + SQL_NULLABLE, // expectedColumnNullability + 0, // expectedNumPrecRadix + 16, // expectedOctetLength + SQL_SEARCHABLE, // expectedSearchable + SQL_TRUE); // expectedUnsignedColumn + + this->disconnect(); +} + +TEST_F(FlightSQLODBCRemoteTestBase, TestSQLColAttributeAllTypesODBCVer2) { + // Test assumes there is a table $scratch.ODBCTest in remote server + this->connect(SQL_OV_ODBC2); + + std::wstring wsql = L"SELECT * from $scratch.ODBCTest;"; + std::vector sql0(wsql.begin(), wsql.end()); + + SQLRETURN ret = + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size())); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkSQLColAttribute(this->stmt, 1, + std::wstring(L"sinteger_max"), // expectedColmnName + SQL_INTEGER, // expectedDataType + SQL_INTEGER, // expectedConciseType + 11, // expectedDisplaySize + SQL_FALSE, // expectedPrecScale + 4, // expectedLength + std::wstring(L""), // expectedLiteralPrefix + std::wstring(L""), // expectedLiteralSuffix + 4, // expectedColumnSize + 0, // expectedColumnScale + SQL_NULLABLE, // expectedColumnNullability + 10, // expectedNumPrecRadix + 4, // expectedOctetLength + SQL_SEARCHABLE, // expectedSearchable + SQL_FALSE); // expectedUnsignedColumn + + checkSQLColAttribute(this->stmt, 2, + std::wstring(L"sbigint_max"), // expectedColmnName + SQL_BIGINT, // expectedDataType + SQL_BIGINT, // expectedConciseType + 20, // expectedDisplaySize + SQL_FALSE, // expectedPrecScale + 8, // expectedLength + std::wstring(L""), // expectedLiteralPrefix + std::wstring(L""), // expectedLiteralSuffix + 8, // expectedColumnSize + 0, // expectedColumnScale + SQL_NULLABLE, // expectedColumnNullability + 10, // expectedNumPrecRadix + 8, // expectedOctetLength + SQL_SEARCHABLE, // expectedSearchable + SQL_FALSE); // expectedUnsignedColumn + + checkSQLColAttribute(this->stmt, 3, + std::wstring(L"decimal_positive"), // expectedColmnName + SQL_DECIMAL, // expectedDataType + SQL_DECIMAL, // expectedConciseType + 40, // expectedDisplaySize + SQL_FALSE, // expectedPrecScale + 19, // expectedLength + std::wstring(L""), // expectedLiteralPrefix + std::wstring(L""), // expectedLiteralSuffix + 19, // expectedColumnSize + 0, // expectedColumnScale + SQL_NULLABLE, // expectedColumnNullability + 10, // expectedNumPrecRadix + 40, // expectedOctetLength + SQL_SEARCHABLE, // expectedSearchable + SQL_FALSE); // expectedUnsignedColumn + + checkSQLColAttribute(this->stmt, 4, + std::wstring(L"float_max"), // expectedColmnName + SQL_FLOAT, // expectedDataType + SQL_FLOAT, // expectedConciseType + 24, // expectedDisplaySize + SQL_FALSE, // expectedPrecScale + 8, // expectedLength + std::wstring(L""), // expectedLiteralPrefix + std::wstring(L""), // expectedLiteralSuffix + 8, // expectedColumnSize + 0, // expectedColumnScale + SQL_NULLABLE, // expectedColumnNullability + 2, // expectedNumPrecRadix + 8, // expectedOctetLength + SQL_SEARCHABLE, // expectedSearchable + SQL_FALSE); // expectedUnsignedColumn + + checkSQLColAttribute(this->stmt, 5, + std::wstring(L"double_max"), // expectedColmnName + SQL_DOUBLE, // expectedDataType + SQL_DOUBLE, // expectedConciseType + 24, // expectedDisplaySize + SQL_FALSE, // expectedPrecScale + 8, // expectedLength + std::wstring(L""), // expectedLiteralPrefix + std::wstring(L""), // expectedLiteralSuffix + 8, // expectedColumnSize + 0, // expectedColumnScale + SQL_NULLABLE, // expectedColumnNullability + 2, // expectedNumPrecRadix + 8, // expectedOctetLength + SQL_SEARCHABLE, // expectedSearchable + SQL_FALSE); // expectedUnsignedColumn + + checkSQLColAttribute(this->stmt, 6, + std::wstring(L"bit_true"), // expectedColmnName + SQL_BIT, // expectedDataType + SQL_BIT, // expectedConciseType + 1, // expectedDisplaySize + SQL_FALSE, // expectedPrecScale + 1, // expectedLength + std::wstring(L""), // expectedLiteralPrefix + std::wstring(L""), // expectedLiteralSuffix + 1, // expectedColumnSize + 0, // expectedColumnScale + SQL_NULLABLE, // expectedColumnNullability + 0, // expectedNumPrecRadix + 1, // expectedOctetLength + SQL_SEARCHABLE, // expectedSearchable + SQL_TRUE); // expectedUnsignedColumn + + checkSQLColAttribute(this->stmt, 7, + std::wstring(L"date_max"), // expectedColmnName + SQL_DATETIME, // expectedDataType + SQL_DATE, // expectedConciseType + 10, // expectedDisplaySize + SQL_FALSE, // expectedPrecScale + 10, // expectedLength + std::wstring(L""), // expectedLiteralPrefix + std::wstring(L""), // expectedLiteralSuffix + 10, // expectedColumnSize + 0, // expectedColumnScale + SQL_NULLABLE, // expectedColumnNullability + 0, // expectedNumPrecRadix + 6, // expectedOctetLength + SQL_SEARCHABLE, // expectedSearchable + SQL_TRUE); // expectedUnsignedColumn + + checkSQLColAttribute(this->stmt, 8, + std::wstring(L"time_max"), // expectedColmnName + SQL_DATETIME, // expectedDataType + SQL_TIME, // expectedConciseType + 12, // expectedDisplaySize + SQL_FALSE, // expectedPrecScale + 12, // expectedLength + std::wstring(L""), // expectedLiteralPrefix + std::wstring(L""), // expectedLiteralSuffix + 12, // expectedColumnSize + 3, // expectedColumnScale + SQL_NULLABLE, // expectedColumnNullability + 0, // expectedNumPrecRadix + 6, // expectedOctetLength + SQL_SEARCHABLE, // expectedSearchable + SQL_TRUE); // expectedUnsignedColumn + + checkSQLColAttribute(this->stmt, 9, + std::wstring(L"timestamp_max"), // expectedColmnName + SQL_DATETIME, // expectedDataType + SQL_TIMESTAMP, // expectedConciseType + 23, // expectedDisplaySize + SQL_FALSE, // expectedPrecScale + 23, // expectedLength + std::wstring(L""), // expectedLiteralPrefix + std::wstring(L""), // expectedLiteralSuffix + 23, // expectedColumnSize + 3, // expectedColumnScale + SQL_NULLABLE, // expectedColumnNullability + 0, // expectedNumPrecRadix + 16, // expectedOctetLength + SQL_SEARCHABLE, // expectedSearchable + SQL_TRUE); // expectedUnsignedColumn + + this->disconnect(); +} + +TEST_F(FlightSQLODBCRemoteTestBase, TestSQLColAttributesAllTypesODBCVer2) { + // Tests ODBC 2.0 API SQLColAttributes + // Test assumes there is a table $scratch.ODBCTest in remote server + this->connect(SQL_OV_ODBC2); + + std::wstring wsql = L"SELECT * from $scratch.ODBCTest;"; + std::vector sql0(wsql.begin(), wsql.end()); + + SQLRETURN ret = + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size())); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkSQLColAttributes(this->stmt, 1, + std::wstring(L"sinteger_max"), // expectedColmnName + SQL_INTEGER, // expectedDataType + 11, // expectedDisplaySize + SQL_FALSE, // expectedPrecScale + 4, // expectedLength + 4, // expectedColumnSize + 0, // expectedColumnScale + SQL_NULLABLE, // expectedColumnNullability + SQL_SEARCHABLE, // expectedSearchable + SQL_FALSE); // expectedUnsignedColumn + + checkSQLColAttributes(this->stmt, 2, + std::wstring(L"sbigint_max"), // expectedColmnName + SQL_BIGINT, // expectedDataType + 20, // expectedDisplaySize + SQL_FALSE, // expectedPrecScale + 8, // expectedLength + 8, // expectedColumnSize + 0, // expectedColumnScale + SQL_NULLABLE, // expectedColumnNullability + SQL_SEARCHABLE, // expectedSearchable + SQL_FALSE); // expectedUnsignedColumn + + checkSQLColAttributes(this->stmt, 3, + std::wstring(L"decimal_positive"), // expectedColmnName + SQL_DECIMAL, // expectedDataType + 40, // expectedDisplaySize + SQL_FALSE, // expectedPrecScale + 19, // expectedLength + 19, // expectedColumnSize + 0, // expectedColumnScale + SQL_NULLABLE, // expectedColumnNullability + SQL_SEARCHABLE, // expectedSearchable + SQL_FALSE); // expectedUnsignedColumn + + checkSQLColAttributes(this->stmt, 4, + std::wstring(L"float_max"), // expectedColmnName + SQL_FLOAT, // expectedDataType + 24, // expectedDisplaySize + SQL_FALSE, // expectedPrecScale + 8, // expectedLength + 8, // expectedColumnSize + 0, // expectedColumnScale + SQL_NULLABLE, // expectedColumnNullability + SQL_SEARCHABLE, // expectedSearchable + SQL_FALSE); // expectedUnsignedColumn + + checkSQLColAttributes(this->stmt, 5, + std::wstring(L"double_max"), // expectedColmnName + SQL_DOUBLE, // expectedDataType + 24, // expectedDisplaySize + SQL_FALSE, // expectedPrecScale + 8, // expectedLength + 8, // expectedColumnSize + 0, // expectedColumnScale + SQL_NULLABLE, // expectedColumnNullability + SQL_SEARCHABLE, // expectedSearchable + SQL_FALSE); // expectedUnsignedColumn + + checkSQLColAttributes(this->stmt, 6, + std::wstring(L"bit_true"), // expectedColmnName + SQL_BIT, // expectedDataType + 1, // expectedDisplaySize + SQL_FALSE, // expectedPrecScale + 1, // expectedLength + 1, // expectedColumnSize + 0, // expectedColumnScale + SQL_NULLABLE, // expectedColumnNullability + SQL_SEARCHABLE, // expectedSearchable + SQL_TRUE); // expectedUnsignedColumn + + checkSQLColAttributes(this->stmt, 7, + std::wstring(L"date_max"), // expectedColmnName + SQL_DATE, // expectedDataType + 10, // expectedDisplaySize + SQL_FALSE, // expectedPrecScale + 10, // expectedLength + 10, // expectedColumnSize + 0, // expectedColumnScale + SQL_NULLABLE, // expectedColumnNullability + SQL_SEARCHABLE, // expectedSearchable + SQL_TRUE); // expectedUnsignedColumn + + checkSQLColAttributes(this->stmt, 8, + std::wstring(L"time_max"), // expectedColmnName + SQL_TIME, // expectedDataType + 12, // expectedDisplaySize + SQL_FALSE, // expectedPrecScale + 12, // expectedLength + 12, // expectedColumnSize + 3, // expectedColumnScale + SQL_NULLABLE, // expectedColumnNullability + SQL_SEARCHABLE, // expectedSearchable + SQL_TRUE); // expectedUnsignedColumn + + checkSQLColAttributes(this->stmt, 9, + std::wstring(L"timestamp_max"), // expectedColmnName + SQL_TIMESTAMP, // expectedDataType + 23, // expectedDisplaySize + SQL_FALSE, // expectedPrecScale + 23, // expectedLength + 23, // expectedColumnSize + 3, // expectedColumnScale + SQL_NULLABLE, // expectedColumnNullability + SQL_SEARCHABLE, // expectedSearchable + SQL_TRUE); // expectedUnsignedColumn + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLColAttributeCaseSensitive) { + // Arrow limitation: returns SQL_FALSE for case sensitive column + this->connect(); + + std::wstring wsql = this->getQueryAllDataTypes(); + // Int column + checkSQLColAttributeNumeric(this->stmt, wsql, 1, SQL_DESC_CASE_SENSITIVE, SQL_FALSE); + SQLFreeStmt(this->stmt, SQL_CLOSE); + // Varchar column + checkSQLColAttributeNumeric(this->stmt, wsql, 28, SQL_DESC_CASE_SENSITIVE, SQL_FALSE); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLColAttributesCaseSensitive) { + // Arrow limitation: returns SQL_FALSE for case sensitive column + // Tests ODBC 2.0 API SQLColAttributes + this->connect(SQL_OV_ODBC2); + + std::wstring wsql = this->getQueryAllDataTypes(); + // Int column + checkSQLColAttributesNumeric(this->stmt, wsql, 1, SQL_COLUMN_CASE_SENSITIVE, SQL_FALSE); + SQLFreeStmt(this->stmt, SQL_CLOSE); + // Varchar column + checkSQLColAttributesNumeric(this->stmt, wsql, 28, SQL_COLUMN_CASE_SENSITIVE, + SQL_FALSE); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLColAttributeUniqueValue) { + // Mock server limitation: returns false for auto-increment column + this->connect(); + this->CreateTableAllDataType(); + + std::wstring wsql = L"SELECT * from AllTypesTable;"; + checkSQLColAttributeNumeric(this->stmt, wsql, 1, SQL_DESC_AUTO_UNIQUE_VALUE, SQL_FALSE); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLColAttributesAutoIncrement) { + // Tests ODBC 2.0 API SQLColAttributes + // Mock server limitation: returns false for auto-increment column + this->connect(SQL_OV_ODBC2); + this->CreateTableAllDataType(); + + std::wstring wsql = L"SELECT * from AllTypesTable;"; + checkSQLColAttributeNumeric(this->stmt, wsql, 1, SQL_COLUMN_AUTO_INCREMENT, SQL_FALSE); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLColAttributeBaseTableName) { + this->connect(); + this->CreateTableAllDataType(); + + std::wstring wsql = L"SELECT * from AllTypesTable;"; + checkSQLColAttributeString(this->stmt, wsql, 1, SQL_DESC_BASE_TABLE_NAME, + std::wstring(L"AllTypesTable")); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLColAttributesTableName) { + // Tests ODBC 2.0 API SQLColAttributes + this->connect(SQL_OV_ODBC2); + this->CreateTableAllDataType(); + + std::wstring wsql = L"SELECT * from AllTypesTable;"; + checkSQLColAttributesString(this->stmt, wsql, 1, SQL_COLUMN_TABLE_NAME, + std::wstring(L"AllTypesTable")); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLColAttributeCatalogName) { + // Mock server limitattion: mock doesn't return catalog for result metadata, + // and the defautl catalog should be 'main' + this->connect(); + this->CreateTableAllDataType(); + + std::wstring wsql = L"SELECT * from AllTypesTable;"; + checkSQLColAttributeString(this->stmt, wsql, 1, SQL_DESC_CATALOG_NAME, + std::wstring(L"")); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCRemoteTestBase, TestSQLColAttributeCatalogName) { + // Remote server does not have catalogs + this->connect(); + + std::wstring wsql = L"SELECT * from $scratch.ODBCTest;"; + checkSQLColAttributeString(this->stmt, wsql, 1, SQL_DESC_CATALOG_NAME, + std::wstring(L"")); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLColAttributesQualifierName) { + // Mock server limitattion: mock doesn't return catalog for result metadata, + // and the defautl catalog should be 'main' + // Tests ODBC 2.0 API SQLColAttributes + this->connect(SQL_OV_ODBC2); + this->CreateTableAllDataType(); + + std::wstring wsql = L"SELECT * from AllTypesTable;"; + checkSQLColAttributeString(this->stmt, wsql, 1, SQL_COLUMN_QUALIFIER_NAME, + std::wstring(L"")); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCRemoteTestBase, TestSQLColAttributesQualifierName) { + // Remote server does not have catalogs + // Tests ODBC 2.0 API SQLColAttributes + this->connect(SQL_OV_ODBC2); + + std::wstring wsql = L"SELECT * from $scratch.ODBCTest;"; + checkSQLColAttributeString(this->stmt, wsql, 1, SQL_COLUMN_QUALIFIER_NAME, + std::wstring(L"")); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLColAttributeCount) { + this->connect(); + + std::wstring wsql = this->getQueryAllDataTypes(); + // Pass 0 as column number, driver should ignore it + checkSQLColAttributeNumeric(this->stmt, wsql, 0, SQL_DESC_COUNT, 32); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLColAttributeLocalTypeName) { + this->connect(); + + std::wstring wsql = this->getQueryAllDataTypes(); + // Mock server doesn't have local type name + checkSQLColAttributeString(this->stmt, wsql, 1, SQL_DESC_LOCAL_TYPE_NAME, + std::wstring(L"")); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCRemoteTestBase, TestSQLColAttributeLocalTypeName) { + this->connect(); + + std::wstring wsql = this->getQueryAllDataTypes(); + checkSQLColAttributeString(this->stmt, wsql, 1, SQL_DESC_LOCAL_TYPE_NAME, + std::wstring(L"INTEGER")); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLColAttributeSchemaName) { + this->connect(); + this->CreateTableAllDataType(); + + std::wstring wsql = L"SELECT * from AllTypesTable;"; + // Mock server doesn't have schemas + checkSQLColAttributeString(this->stmt, wsql, 1, SQL_DESC_SCHEMA_NAME, + std::wstring(L"")); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCRemoteTestBase, TestSQLColAttributeSchemaName) { + // Test assumes there is a table $scratch.ODBCTest in remote server + this->connect(); + + std::wstring wsql = L"SELECT * from $scratch.ODBCTest;"; + // Remote server limitation: doesn't return schema name, expected schema name is + // $scratch + checkSQLColAttributeString(this->stmt, wsql, 1, SQL_DESC_SCHEMA_NAME, + std::wstring(L"")); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLColAttributesOwnerName) { + // Tests ODBC 2.0 API SQLColAttributes + this->connect(SQL_OV_ODBC2); + this->CreateTableAllDataType(); + + std::wstring wsql = L"SELECT * from AllTypesTable;"; + // Mock server doesn't have schemas + checkSQLColAttributesString(this->stmt, wsql, 1, SQL_COLUMN_OWNER_NAME, + std::wstring(L"")); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCRemoteTestBase, TestSQLColAttributesOwnerName) { + // Test assumes there is a table $scratch.ODBCTest in remote server + // Tests ODBC 2.0 API SQLColAttributes + this->connect(SQL_OV_ODBC2); + + std::wstring wsql = L"SELECT * from $scratch.ODBCTest;"; + // Remote server limitation: doesn't return schema name, expected schema name is + // $scratch + checkSQLColAttributesString(this->stmt, wsql, 1, SQL_COLUMN_OWNER_NAME, + std::wstring(L"")); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLColAttributeTableName) { + this->connect(); + this->CreateTableAllDataType(); + + std::wstring wsql = L"SELECT * from AllTypesTable;"; + checkSQLColAttributeString(this->stmt, wsql, 1, SQL_DESC_TABLE_NAME, + std::wstring(L"AllTypesTable")); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLColAttributeTypeName) { + this->connect(); + this->CreateTableAllDataType(); + + std::wstring wsql = L"SELECT * from AllTypesTable;"; + // Mock server doesn't return data source-dependent data type name + checkSQLColAttributeString(this->stmt, wsql, 1, SQL_DESC_TYPE_NAME, std::wstring(L"")); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCRemoteTestBase, TestSQLColAttributeTypeName) { + this->connect(); + + std::wstring wsql = L"SELECT * from $scratch.ODBCTest;"; + checkSQLColAttributeString(this->stmt, wsql, 1, SQL_DESC_TYPE_NAME, + std::wstring(L"INTEGER")); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLColAttributesTypeName) { + // Tests ODBC 2.0 API SQLColAttributes + this->connect(SQL_OV_ODBC2); + this->CreateTableAllDataType(); + + std::wstring wsql = L"SELECT * from AllTypesTable;"; + // Mock server doesn't return data source-dependent data type name + checkSQLColAttributesString(this->stmt, wsql, 1, SQL_COLUMN_TYPE_NAME, + std::wstring(L"")); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCRemoteTestBase, TestSQLColAttributesTypeName) { + // Tests ODBC 2.0 API SQLColAttributes + this->connect(SQL_OV_ODBC2); + + std::wstring wsql = L"SELECT * from $scratch.ODBCTest;"; + checkSQLColAttributesString(this->stmt, wsql, 1, SQL_COLUMN_TYPE_NAME, + std::wstring(L"INTEGER")); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLColAttributeUnnamed) { + this->connect(); + + std::wstring wsql = this->getQueryAllDataTypes(); + checkSQLColAttributeNumeric(this->stmt, wsql, 1, SQL_DESC_UNNAMED, SQL_NAMED); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLColAttributeUpdatable) { + this->connect(); + + std::wstring wsql = this->getQueryAllDataTypes(); + // Mock server and remote server do not return updatable information + checkSQLColAttributeNumeric(this->stmt, wsql, 1, SQL_DESC_UPDATABLE, + SQL_ATTR_READWRITE_UNKNOWN); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLColAttributesUpdatable) { + // Tests ODBC 2.0 API SQLColAttributes + this->connect(SQL_OV_ODBC2); + + std::wstring wsql = this->getQueryAllDataTypes(); + // Mock server and remote server do not return updatable information + checkSQLColAttributesNumeric(this->stmt, wsql, 1, SQL_COLUMN_UPDATABLE, + SQL_ATTR_READWRITE_UNKNOWN); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, SQLDescribeColValidateInput) { + this->connect(); + this->CreateTestTables(); + + SQLSMALLINT columnCount = 0; + SQLSMALLINT expectedValue = 3; + SQLWCHAR sqlQuery[] = L"SELECT * FROM TestTable LIMIT 1;"; + SQLINTEGER queryLength = static_cast(wcslen(sqlQuery)); + + SQLUSMALLINT bookmarkColumn = 0; + SQLUSMALLINT validColumn = 1; + SQLUSMALLINT outOfRangeColumn = 4; + SQLUSMALLINT negativeColumn = -1; + SQLWCHAR columnName[1024]; + constexpr SQLINTEGER bufCharLen = sizeof(columnName) / ODBC::GetSqlWCharSize(); + SQLSMALLINT nameLength = 0; + SQLSMALLINT dataType = 0; + SQLULEN columnSize = 0; + SQLSMALLINT decimalDigits = 0; + SQLSMALLINT nullable = 0; + + SQLRETURN ret = SQLExecDirect(this->stmt, sqlQuery, queryLength); + + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLFetch(this->stmt); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Invalid descriptor index - Bookmarks are not supported + ret = SQLDescribeCol(this->stmt, bookmarkColumn, columnName, bufCharLen, &nameLength, + &dataType, &columnSize, &decimalDigits, &nullable); + + EXPECT_EQ(ret, SQL_ERROR); + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, error_state_07009); + + // Invalid descriptor index - index out of range + ret = SQLDescribeCol(this->stmt, outOfRangeColumn, columnName, bufCharLen, &nameLength, + &dataType, &columnSize, &decimalDigits, &nullable); + + EXPECT_EQ(ret, SQL_ERROR); + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, error_state_07009); + + // Invalid descriptor index - index out of range + ret = SQLDescribeCol(this->stmt, negativeColumn, columnName, bufCharLen, &nameLength, + &dataType, &columnSize, &decimalDigits, &nullable); + + EXPECT_EQ(ret, SQL_ERROR); + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, error_state_07009); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, SQLDescribeColQueryAllDataTypesMetadata) { + // Mock server has a limitation where only SQL_WVARCHAR column type values are returned + // from SELECT AS queries + this->connect(); + + SQLWCHAR columnName[1024]; + constexpr SQLINTEGER bufCharLen = sizeof(columnName) / ODBC::GetSqlWCharSize(); + SQLSMALLINT nameLength = 0; + SQLSMALLINT columnDataType = 0; + SQLULEN columnSize = 0; + SQLSMALLINT decimalDigits = 0; + SQLSMALLINT nullable = 0; + size_t columnIndex = 0; + + std::wstring wsql = this->getQueryAllDataTypes(); + std::vector sql0(wsql.begin(), wsql.end()); + + SQLWCHAR* columnNames[] = { + (SQLWCHAR*)L"stiny_int_min", (SQLWCHAR*)L"stiny_int_max", + (SQLWCHAR*)L"utiny_int_min", (SQLWCHAR*)L"utiny_int_max", + (SQLWCHAR*)L"ssmall_int_min", (SQLWCHAR*)L"ssmall_int_max", + (SQLWCHAR*)L"usmall_int_min", (SQLWCHAR*)L"usmall_int_max", + (SQLWCHAR*)L"sinteger_min", (SQLWCHAR*)L"sinteger_max", + (SQLWCHAR*)L"uinteger_min", (SQLWCHAR*)L"uinteger_max", + (SQLWCHAR*)L"sbigint_min", (SQLWCHAR*)L"sbigint_max", + (SQLWCHAR*)L"ubigint_min", (SQLWCHAR*)L"ubigint_max", + (SQLWCHAR*)L"decimal_negative", (SQLWCHAR*)L"decimal_positive", + (SQLWCHAR*)L"float_min", (SQLWCHAR*)L"float_max", + (SQLWCHAR*)L"double_min", (SQLWCHAR*)L"double_max", + (SQLWCHAR*)L"bit_false", (SQLWCHAR*)L"bit_true", + (SQLWCHAR*)L"c_char", (SQLWCHAR*)L"c_wchar", + (SQLWCHAR*)L"c_wvarchar", (SQLWCHAR*)L"c_varchar", + (SQLWCHAR*)L"date_min", (SQLWCHAR*)L"date_max", + (SQLWCHAR*)L"timestamp_min", (SQLWCHAR*)L"timestamp_max"}; + SQLSMALLINT columnDataTypes[] = { + SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, + SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, + SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, + SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, + SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, + SQL_WVARCHAR, SQL_WVARCHAR}; + + SQLRETURN ret = + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size())); + + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLFetch(this->stmt); + + EXPECT_EQ(ret, SQL_SUCCESS); + + for (size_t i = 0; i < sizeof(columnNames) / sizeof(*columnNames); ++i) { + columnIndex = i + 1; + + ret = SQLDescribeCol(this->stmt, columnIndex, columnName, bufCharLen, &nameLength, + &columnDataType, &columnSize, &decimalDigits, &nullable); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_EQ(nameLength, wcslen(columnNames[i])); + + std::wstring returned(columnName, columnName + nameLength); + EXPECT_EQ(returned, columnNames[i]); + EXPECT_EQ(columnDataType, columnDataTypes[i]); + EXPECT_EQ(columnSize, 1024); + EXPECT_EQ(decimalDigits, 0); + EXPECT_EQ(nullable, SQL_NULLABLE); + + nameLength = 0; + columnDataType = 0; + columnSize = 0; + decimalDigits = 0; + nullable = 0; + } + + this->disconnect(); +} + +TEST_F(FlightSQLODBCRemoteTestBase, SQLDescribeColQueryAllDataTypesMetadata) { + this->connect(); + + SQLWCHAR columnName[1024]; + constexpr SQLINTEGER bufCharLen = sizeof(columnName) / ODBC::GetSqlWCharSize(); + SQLSMALLINT nameLength = 0; + SQLSMALLINT columnDataType = 0; + SQLULEN columnSize = 0; + SQLSMALLINT decimalDigits = 0; + SQLSMALLINT nullable = 0; + size_t columnIndex = 0; + + std::wstring wsql = this->getQueryAllDataTypes(); + std::vector sql0(wsql.begin(), wsql.end()); + + SQLWCHAR* columnNames[] = { + (SQLWCHAR*)L"stiny_int_min", (SQLWCHAR*)L"stiny_int_max", + (SQLWCHAR*)L"utiny_int_min", (SQLWCHAR*)L"utiny_int_max", + (SQLWCHAR*)L"ssmall_int_min", (SQLWCHAR*)L"ssmall_int_max", + (SQLWCHAR*)L"usmall_int_min", (SQLWCHAR*)L"usmall_int_max", + (SQLWCHAR*)L"sinteger_min", (SQLWCHAR*)L"sinteger_max", + (SQLWCHAR*)L"uinteger_min", (SQLWCHAR*)L"uinteger_max", + (SQLWCHAR*)L"sbigint_min", (SQLWCHAR*)L"sbigint_max", + (SQLWCHAR*)L"ubigint_min", (SQLWCHAR*)L"ubigint_max", + (SQLWCHAR*)L"decimal_negative", (SQLWCHAR*)L"decimal_positive", + (SQLWCHAR*)L"float_min", (SQLWCHAR*)L"float_max", + (SQLWCHAR*)L"double_min", (SQLWCHAR*)L"double_max", + (SQLWCHAR*)L"bit_false", (SQLWCHAR*)L"bit_true", + (SQLWCHAR*)L"c_char", (SQLWCHAR*)L"c_wchar", + (SQLWCHAR*)L"c_wvarchar", (SQLWCHAR*)L"c_varchar", + (SQLWCHAR*)L"date_min", (SQLWCHAR*)L"date_max", + (SQLWCHAR*)L"timestamp_min", (SQLWCHAR*)L"timestamp_max"}; + SQLSMALLINT columnDataTypes[] = { + SQL_INTEGER, SQL_INTEGER, SQL_INTEGER, SQL_INTEGER, SQL_INTEGER, + SQL_INTEGER, SQL_INTEGER, SQL_INTEGER, SQL_INTEGER, SQL_INTEGER, + SQL_BIGINT, SQL_BIGINT, SQL_BIGINT, SQL_BIGINT, SQL_BIGINT, + SQL_WVARCHAR, SQL_DECIMAL, SQL_DECIMAL, SQL_FLOAT, SQL_FLOAT, + SQL_DOUBLE, SQL_DOUBLE, SQL_BIT, SQL_BIT, SQL_WVARCHAR, + SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, SQL_TYPE_DATE, SQL_TYPE_DATE, + SQL_TYPE_TIMESTAMP, SQL_TYPE_TIMESTAMP}; + SQLULEN columnSizes[] = {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 8, + 8, 8, 8, 8, 65536, 19, 19, 8, 8, 8, 8, + 1, 1, 65536, 65536, 65536, 65536, 10, 10, 23, 23}; + SQLULEN columnDecimalDigits[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 10, 23, 23}; + + SQLRETURN ret = + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size())); + + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLFetch(this->stmt); + + EXPECT_EQ(ret, SQL_SUCCESS); + + for (size_t i = 0; i < sizeof(columnNames) / sizeof(*columnNames); ++i) { + columnIndex = i + 1; + + ret = SQLDescribeCol(this->stmt, columnIndex, columnName, bufCharLen, &nameLength, + &columnDataType, &columnSize, &decimalDigits, &nullable); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_EQ(nameLength, wcslen(columnNames[i])); + + std::wstring returned(columnName, columnName + nameLength); + EXPECT_EQ(returned, columnNames[i]); + EXPECT_EQ(columnDataType, columnDataTypes[i]); + EXPECT_EQ(columnSize, columnSizes[i]); + EXPECT_EQ(decimalDigits, columnDecimalDigits[i]); + EXPECT_EQ(nullable, SQL_NULLABLE); + + nameLength = 0; + columnDataType = 0; + columnSize = 0; + decimalDigits = 0; + nullable = 0; + } + + this->disconnect(); +} + +TEST_F(FlightSQLODBCRemoteTestBase, SQLDescribeColODBCTestTableMetadata) { + // Test assumes there is a table $scratch.ODBCTest in remote server + this->connect(); + + SQLWCHAR columnName[1024]; + constexpr SQLINTEGER bufCharLen = sizeof(columnName) / ODBC::GetSqlWCharSize(); + SQLSMALLINT nameLength = 0; + SQLSMALLINT columnDataType = 0; + SQLULEN columnSize = 0; + SQLSMALLINT decimalDigits = 0; + SQLSMALLINT nullable = 0; + size_t columnIndex = 0; + + SQLWCHAR sqlQuery[] = L"SELECT * from $scratch.ODBCTest LIMIT 1;"; + SQLINTEGER queryLength = static_cast(wcslen(sqlQuery)); + + SQLWCHAR* columnNames[] = {(SQLWCHAR*)L"sinteger_max", (SQLWCHAR*)L"sbigint_max", + (SQLWCHAR*)L"decimal_positive", (SQLWCHAR*)L"float_max", + (SQLWCHAR*)L"double_max", (SQLWCHAR*)L"bit_true", + (SQLWCHAR*)L"date_max", (SQLWCHAR*)L"time_max", + (SQLWCHAR*)L"timestamp_max"}; + SQLSMALLINT columnDataTypes[] = {SQL_INTEGER, SQL_BIGINT, SQL_DECIMAL, + SQL_FLOAT, SQL_DOUBLE, SQL_BIT, + SQL_TYPE_DATE, SQL_TYPE_TIME, SQL_TYPE_TIMESTAMP}; + SQLULEN columnSizes[] = {4, 8, 19, 8, 8, 1, 10, 12, 23}; + SQLULEN columnDecimalDigits[] = {0, 0, 0, 0, 0, 0, 10, 12, 23}; + + SQLRETURN ret = SQLExecDirect(this->stmt, sqlQuery, queryLength); + + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLFetch(this->stmt); + + EXPECT_EQ(ret, SQL_SUCCESS); + + for (size_t i = 0; i < sizeof(columnNames) / sizeof(*columnNames); ++i) { + columnIndex = i + 1; + + ret = SQLDescribeCol(this->stmt, columnIndex, columnName, bufCharLen, &nameLength, + &columnDataType, &columnSize, &decimalDigits, &nullable); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_EQ(nameLength, wcslen(columnNames[i])); + + std::wstring returned(columnName, columnName + nameLength); + EXPECT_EQ(returned, columnNames[i]); + EXPECT_EQ(columnDataType, columnDataTypes[i]); + EXPECT_EQ(columnSize, columnSizes[i]); + EXPECT_EQ(decimalDigits, columnDecimalDigits[i]); + EXPECT_EQ(nullable, SQL_NULLABLE); + + nameLength = 0; + columnDataType = 0; + columnSize = 0; + decimalDigits = 0; + nullable = 0; + } + + this->disconnect(); +} + +TEST_F(FlightSQLODBCRemoteTestBase, SQLDescribeColODBCTestTableMetadataODBC2) { + // Test assumes there is a table $scratch.ODBCTest in remote server + this->connect(SQL_OV_ODBC2); + + SQLWCHAR columnName[1024]; + constexpr SQLINTEGER bufCharLen = sizeof(columnName) / ODBC::GetSqlWCharSize(); + SQLSMALLINT nameLength = 0; + SQLSMALLINT columnDataType = 0; + SQLULEN columnSize = 0; + SQLSMALLINT decimalDigits = 0; + SQLSMALLINT nullable = 0; + size_t columnIndex = 0; + + SQLWCHAR sqlQuery[] = L"SELECT * from $scratch.ODBCTest LIMIT 1;"; + SQLINTEGER queryLength = static_cast(wcslen(sqlQuery)); + + SQLWCHAR* columnNames[] = {(SQLWCHAR*)L"sinteger_max", (SQLWCHAR*)L"sbigint_max", + (SQLWCHAR*)L"decimal_positive", (SQLWCHAR*)L"float_max", + (SQLWCHAR*)L"double_max", (SQLWCHAR*)L"bit_true", + (SQLWCHAR*)L"date_max", (SQLWCHAR*)L"time_max", + (SQLWCHAR*)L"timestamp_max"}; + SQLSMALLINT columnDataTypes[] = {SQL_INTEGER, SQL_BIGINT, SQL_DECIMAL, + SQL_FLOAT, SQL_DOUBLE, SQL_BIT, + SQL_DATE, SQL_TIME, SQL_TIMESTAMP}; + SQLULEN columnSizes[] = {4, 8, 19, 8, 8, 1, 10, 12, 23}; + SQLULEN columnDecimalDigits[] = {0, 0, 0, 0, 0, 0, 10, 12, 23}; + + SQLRETURN ret = SQLExecDirect(this->stmt, sqlQuery, queryLength); + + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLFetch(this->stmt); + + EXPECT_EQ(ret, SQL_SUCCESS); + + for (size_t i = 0; i < sizeof(columnNames) / sizeof(*columnNames); ++i) { + columnIndex = i + 1; + + ret = SQLDescribeCol(this->stmt, columnIndex, columnName, bufCharLen, &nameLength, + &columnDataType, &columnSize, &decimalDigits, &nullable); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_EQ(nameLength, wcslen(columnNames[i])); + + std::wstring returned(columnName, columnName + nameLength); + EXPECT_EQ(returned, columnNames[i]); + EXPECT_EQ(columnDataType, columnDataTypes[i]); + EXPECT_EQ(columnSize, columnSizes[i]); + EXPECT_EQ(decimalDigits, columnDecimalDigits[i]); + EXPECT_EQ(nullable, SQL_NULLABLE); + + nameLength = 0; + columnDataType = 0; + columnSize = 0; + decimalDigits = 0; + nullable = 0; + } + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, SQLDescribeColAllTypesTableMetadata) { + this->connect(); + this->CreateTableAllDataType(); + + SQLWCHAR columnName[1024]; + constexpr SQLINTEGER bufCharLen = sizeof(columnName) / ODBC::GetSqlWCharSize(); + SQLSMALLINT nameLength = 0; + SQLSMALLINT columnDataType = 0; + SQLULEN columnSize = 0; + SQLSMALLINT decimalDigits = 0; + SQLSMALLINT nullable = 0; + size_t columnIndex = 0; + + SQLWCHAR sqlQuery[] = L"SELECT * from AllTypesTable LIMIT 1;"; + SQLINTEGER queryLength = static_cast(wcslen(sqlQuery)); + + SQLWCHAR* columnNames[] = {(SQLWCHAR*)L"bigint_col", (SQLWCHAR*)L"char_col", + (SQLWCHAR*)L"varbinary_col", (SQLWCHAR*)L"double_col"}; + SQLSMALLINT columnDataTypes[] = {SQL_BIGINT, SQL_WVARCHAR, SQL_BINARY, SQL_DOUBLE}; + SQLULEN columnSizes[] = {8, 0, 0, 8}; + + SQLRETURN ret = SQLExecDirect(this->stmt, sqlQuery, queryLength); + + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLFetch(this->stmt); + + EXPECT_EQ(ret, SQL_SUCCESS); + + for (size_t i = 0; i < sizeof(columnNames) / sizeof(*columnNames); ++i) { + columnIndex = i + 1; + + ret = SQLDescribeCol(this->stmt, columnIndex, columnName, bufCharLen, &nameLength, + &columnDataType, &columnSize, &decimalDigits, &nullable); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_EQ(nameLength, wcslen(columnNames[i])); + + std::wstring returned(columnName, columnName + nameLength); + EXPECT_EQ(returned, columnNames[i]); + EXPECT_EQ(columnDataType, columnDataTypes[i]); + EXPECT_EQ(columnSize, columnSizes[i]); + EXPECT_EQ(decimalDigits, 0); + EXPECT_EQ(nullable, SQL_NULLABLE); + + nameLength = 0; + columnDataType = 0; + columnSize = 0; + decimalDigits = 0; + nullable = 0; + } + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, SQLDescribeColUnicodeTableMetadata) { + this->connect(); + this->CreateUnicodeTable(); + + SQLWCHAR columnName[1024]; + constexpr SQLINTEGER bufCharLen = sizeof(columnName) / ODBC::GetSqlWCharSize(); + SQLSMALLINT nameLength = 0; + SQLSMALLINT columnDataType = 0; + SQLULEN columnSize = 0; + SQLSMALLINT decimalDigits = 0; + SQLSMALLINT nullable = 0; + size_t columnIndex = 1; + + SQLWCHAR sqlQuery[] = L"SELECT * from 数据 LIMIT 1;"; + SQLINTEGER queryLength = static_cast(wcslen(sqlQuery)); + + SQLWCHAR expectedColumnName[] = L"资料"; + SQLSMALLINT expectedColumnDataType = SQL_WVARCHAR; + SQLULEN expectedColumnSize = 0; + + SQLRETURN ret = SQLExecDirect(this->stmt, sqlQuery, queryLength); + + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLFetch(this->stmt); + + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLDescribeCol(this->stmt, columnIndex, columnName, bufCharLen, &nameLength, + &columnDataType, &columnSize, &decimalDigits, &nullable); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_EQ(nameLength, wcslen(expectedColumnName)); + + std::wstring returned(columnName, columnName + nameLength); + EXPECT_EQ(returned, expectedColumnName); + EXPECT_EQ(columnDataType, expectedColumnDataType); + EXPECT_EQ(columnSize, expectedColumnSize); + EXPECT_EQ(decimalDigits, 0); + EXPECT_EQ(nullable, SQL_NULLABLE); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, SQLColumnsGetMetadataBySQLDescribeCol) { + this->connect(); + + SQLWCHAR columnName[1024]; + constexpr SQLINTEGER bufCharLen = sizeof(columnName) / ODBC::GetSqlWCharSize(); + SQLSMALLINT nameLength = 0; + SQLSMALLINT columnDataType = 0; + SQLULEN columnSize = 0; + SQLSMALLINT decimalDigits = 0; + SQLSMALLINT nullable = 0; + size_t columnIndex = 0; + + SQLWCHAR* columnNames[] = { + (SQLWCHAR*)L"TABLE_CAT", (SQLWCHAR*)L"TABLE_SCHEM", + (SQLWCHAR*)L"TABLE_NAME", (SQLWCHAR*)L"COLUMN_NAME", + (SQLWCHAR*)L"DATA_TYPE", (SQLWCHAR*)L"TYPE_NAME", + (SQLWCHAR*)L"COLUMN_SIZE", (SQLWCHAR*)L"BUFFER_LENGTH", + (SQLWCHAR*)L"DECIMAL_DIGITS", (SQLWCHAR*)L"NUM_PREC_RADIX", + (SQLWCHAR*)L"NULLABLE", (SQLWCHAR*)L"REMARKS", + (SQLWCHAR*)L"COLUMN_DEF", (SQLWCHAR*)L"SQL_DATA_TYPE", + (SQLWCHAR*)L"SQL_DATETIME_SUB", (SQLWCHAR*)L"CHAR_OCTET_LENGTH", + (SQLWCHAR*)L"ORDINAL_POSITION", (SQLWCHAR*)L"IS_NULLABLE"}; + SQLSMALLINT columnDataTypes[] = { + SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, SQL_SMALLINT, SQL_WVARCHAR, + SQL_INTEGER, SQL_INTEGER, SQL_SMALLINT, SQL_SMALLINT, SQL_SMALLINT, SQL_WVARCHAR, + SQL_WVARCHAR, SQL_SMALLINT, SQL_SMALLINT, SQL_INTEGER, SQL_INTEGER, SQL_WVARCHAR}; + SQLULEN columnSizes[] = {1024, 1024, 1024, 1024, 2, 1024, 4, 4, 2, + 2, 2, 1024, 1024, 2, 2, 4, 4, 1024}; + + SQLRETURN ret = SQLColumns(this->stmt, nullptr, 0, nullptr, 0, nullptr, 0, nullptr, 0); + + EXPECT_EQ(ret, SQL_SUCCESS); + + for (size_t i = 0; i < sizeof(columnNames) / sizeof(*columnNames); ++i) { + columnIndex = i + 1; + + ret = SQLDescribeCol(this->stmt, columnIndex, columnName, bufCharLen, &nameLength, + &columnDataType, &columnSize, &decimalDigits, &nullable); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_EQ(nameLength, wcslen(columnNames[i])); + + std::wstring returned(columnName, columnName + nameLength); + EXPECT_EQ(returned, columnNames[i]); + EXPECT_EQ(columnDataType, columnDataTypes[i]); + EXPECT_EQ(columnSize, columnSizes[i]); + EXPECT_EQ(decimalDigits, 0); + EXPECT_EQ(nullable, SQL_NULLABLE); + + nameLength = 0; + columnDataType = 0; + columnSize = 0; + decimalDigits = 0; + nullable = 0; + } + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, SQLColumnsGetMetadataBySQLDescribeColODBC2) { + this->connect(SQL_OV_ODBC2); + + SQLWCHAR columnName[1024]; + constexpr SQLINTEGER bufCharLen = sizeof(columnName) / ODBC::GetSqlWCharSize(); + SQLSMALLINT nameLength = 0; + SQLSMALLINT columnDataType = 0; + SQLULEN columnSize = 0; + SQLSMALLINT decimalDigits = 0; + SQLSMALLINT nullable = 0; + size_t columnIndex = 0; + + SQLWCHAR* columnNames[] = {(SQLWCHAR*)L"TABLE_QUALIFIER", + (SQLWCHAR*)L"TABLE_OWNER", + (SQLWCHAR*)L"TABLE_NAME", + (SQLWCHAR*)L"COLUMN_NAME", + (SQLWCHAR*)L"DATA_TYPE", + (SQLWCHAR*)L"TYPE_NAME", + (SQLWCHAR*)L"PRECISION", + (SQLWCHAR*)L"LENGTH", + (SQLWCHAR*)L"SCALE", + (SQLWCHAR*)L"RADIX", + (SQLWCHAR*)L"NULLABLE", + (SQLWCHAR*)L"REMARKS", + (SQLWCHAR*)L"COLUMN_DEF", + (SQLWCHAR*)L"SQL_DATA_TYPE", + (SQLWCHAR*)L"SQL_DATETIME_SUB", + (SQLWCHAR*)L"CHAR_OCTET_LENGTH", + (SQLWCHAR*)L"ORDINAL_POSITION", + (SQLWCHAR*)L"IS_NULLABLE"}; + SQLSMALLINT columnDataTypes[] = { + SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, SQL_SMALLINT, SQL_WVARCHAR, + SQL_INTEGER, SQL_INTEGER, SQL_SMALLINT, SQL_SMALLINT, SQL_SMALLINT, SQL_WVARCHAR, + SQL_WVARCHAR, SQL_SMALLINT, SQL_SMALLINT, SQL_INTEGER, SQL_INTEGER, SQL_WVARCHAR}; + SQLULEN columnSizes[] = {1024, 1024, 1024, 1024, 2, 1024, 4, 4, 2, + 2, 2, 1024, 1024, 2, 2, 4, 4, 1024}; + + SQLRETURN ret = SQLColumns(this->stmt, nullptr, 0, nullptr, 0, nullptr, 0, nullptr, 0); + + EXPECT_EQ(ret, SQL_SUCCESS); + + for (size_t i = 0; i < sizeof(columnNames) / sizeof(*columnNames); ++i) { + columnIndex = i + 1; + + ret = SQLDescribeCol(this->stmt, columnIndex, columnName, bufCharLen, &nameLength, + &columnDataType, &columnSize, &decimalDigits, &nullable); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_EQ(nameLength, wcslen(columnNames[i])); + + std::wstring returned(columnName, columnName + nameLength); + EXPECT_EQ(returned, columnNames[i]); + EXPECT_EQ(columnDataType, columnDataTypes[i]); + EXPECT_EQ(columnSize, columnSizes[i]); + EXPECT_EQ(decimalDigits, 0); + EXPECT_EQ(nullable, SQL_NULLABLE); + + nameLength = 0; + columnDataType = 0; + columnSize = 0; + decimalDigits = 0; + nullable = 0; + } + + this->disconnect(); +} +} // namespace arrow::flight::sql::odbc diff --git a/cpp/src/arrow/flight/sql/odbc/tests/connection_attr_test.cc b/cpp/src/arrow/flight/sql/odbc/tests/connection_attr_test.cc new file mode 100644 index 00000000000..b725711f56e --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/tests/connection_attr_test.cc @@ -0,0 +1,565 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +#include "arrow/flight/sql/odbc/tests/odbc_test_suite.h" + +#ifdef _WIN32 +# include +#endif + +#include +#include +#include + +#include "gtest/gtest.h" + +namespace arrow::flight::sql::odbc { + +#ifdef SQL_ATTR_ASYNC_DBC_EVENT +TYPED_TEST(FlightSQLODBCTestBase, TestSQLSetConnectAttrAsyncDbcEventUnsupported) { + this->connect(); + + SQLRETURN ret = SQLSetConnectAttr(this->conn, SQL_ATTR_ASYNC_DBC_EVENT, 0, 0); + + EXPECT_EQ(ret, SQL_ERROR); + // Driver Manager on Windows returns error code HY118 + VerifyOdbcErrorState(SQL_HANDLE_DBC, this->conn, error_state_HY118); + + this->disconnect(); +} +#endif + +#ifdef SQL_ATTR_ASYNC_ENABLE +TYPED_TEST(FlightSQLODBCTestBase, TestSQLSetConnectAttrAyncEnableUnsupported) { + this->connect(); + + SQLRETURN ret = SQLSetConnectAttr(this->conn, SQL_ATTR_ASYNC_ENABLE, 0, 0); + + EXPECT_EQ(ret, SQL_ERROR); + VerifyOdbcErrorState(SQL_HANDLE_DBC, this->conn, error_state_HYC00); + + this->disconnect(); +} +#endif + +#ifdef SQL_ATTR_ASYNC_DBC_PCALLBACK +TYPED_TEST(FlightSQLODBCTestBase, TestSQLSetConnectAttrAyncDbcPcCallbackUnsupported) { + this->connect(); + + SQLRETURN ret = SQLSetConnectAttr(this->conn, SQL_ATTR_ASYNC_DBC_PCALLBACK, 0, 0); + + EXPECT_EQ(ret, SQL_ERROR); + VerifyOdbcErrorState(SQL_HANDLE_DBC, this->conn, error_state_HYC00); + + this->disconnect(); +} +#endif + +#ifdef SQL_ATTR_ASYNC_DBC_PCONTEXT +TYPED_TEST(FlightSQLODBCTestBase, TestSQLSetConnectAttrAyncDbcPcContextUnsupported) { + this->connect(); + + SQLRETURN ret = SQLSetConnectAttr(this->conn, SQL_ATTR_ASYNC_DBC_PCONTEXT, 0, 0); + + EXPECT_EQ(ret, SQL_ERROR); + VerifyOdbcErrorState(SQL_HANDLE_DBC, this->conn, error_state_HYC00); + + this->disconnect(); +} +#endif + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLSetConnectAttrAutoIpdReadOnly) { + this->connect(); + + // Verify read-only attribute cannot be set + SQLRETURN ret = SQLSetConnectAttr(this->conn, SQL_ATTR_AUTO_IPD, 0, 0); + + EXPECT_EQ(ret, SQL_ERROR); + VerifyOdbcErrorState(SQL_HANDLE_DBC, this->conn, error_state_HY092); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLSetConnectAttrConnectionDeadReadOnly) { + this->connect(); + + // Verify read-only attribute cannot be set + SQLRETURN ret = SQLSetConnectAttr(this->conn, SQL_ATTR_CONNECTION_DEAD, 0, 0); + + EXPECT_EQ(ret, SQL_ERROR); + VerifyOdbcErrorState(SQL_HANDLE_DBC, this->conn, error_state_HY092); + + this->disconnect(); +} + +#ifdef SQL_ATTR_DBC_INFO_TOKEN +TYPED_TEST(FlightSQLODBCTestBase, TestSQLSetConnectAttrDbcInfoTokenUnsupported) { + this->connect(); + + SQLRETURN ret = SQLSetConnectAttr(this->conn, SQL_ATTR_DBC_INFO_TOKEN, 0, 0); + + EXPECT_EQ(ret, SQL_ERROR); + VerifyOdbcErrorState(SQL_HANDLE_DBC, this->conn, error_state_HYC00); + + this->disconnect(); +} +#endif + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLSetConnectAttrEnlistInDtcUnsupported) { + this->connect(); + + SQLRETURN ret = SQLSetConnectAttr(this->conn, SQL_ATTR_ENLIST_IN_DTC, 0, 0); + + EXPECT_EQ(ret, SQL_ERROR); + VerifyOdbcErrorState(SQL_HANDLE_DBC, this->conn, error_state_HYC00); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLSetConnectAttrOdbcCursorsDMOnly) { + this->allocEnvConnHandles(); + + // Verify DM-only attribute is settable via Driver Manager + SQLRETURN ret = SQLSetConnectAttr(this->conn, SQL_ATTR_ODBC_CURSORS, + reinterpret_cast(SQL_CUR_USE_DRIVER), 0); + + EXPECT_EQ(ret, SQL_SUCCESS); + + std::string connect_str = this->getConnectionString(); + this->connectWithString(connect_str); + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLSetConnectAttrQuietModeReadOnly) { + this->connect(); + + // Verify read-only attribute cannot be set + SQLRETURN ret = SQLSetConnectAttr(this->conn, SQL_ATTR_QUIET_MODE, 0, 0); + + EXPECT_EQ(ret, SQL_ERROR); + VerifyOdbcErrorState(SQL_HANDLE_DBC, this->conn, error_state_HY092); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLSetConnectAttrTraceDMOnly) { + this->connect(); + + // Verify DM-only attribute is settable via Driver Manager + SQLRETURN ret = SQLSetConnectAttr(this->conn, SQL_ATTR_TRACE, + reinterpret_cast(SQL_OPT_TRACE_OFF), 0); + EXPECT_EQ(ret, SQL_SUCCESS); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLSetConnectAttrTracefileDMOnly) { + this->connect(); + + // Verify DM-only attribute is handled by Driver Manager + + // Use placeholder value as we want the call to fail, or else + // the driver manager will produce a trace file. + std::wstring trace_file = L"invalid/file/path"; + std::vector trace_file0(trace_file.begin(), trace_file.end()); + SQLRETURN ret = SQLSetConnectAttr(this->conn, SQL_ATTR_TRACEFILE, &trace_file0[0], + static_cast(trace_file0.size())); + EXPECT_EQ(ret, SQL_ERROR); + VerifyOdbcErrorState(SQL_HANDLE_DBC, this->conn, error_state_HY000); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLSetConnectAttrTranslateLabDMOnly) { + this->connect(); + + // Verify DM-only attribute is handled by Driver Manager + SQLRETURN ret = SQLSetConnectAttr(this->conn, SQL_ATTR_TRANSLATE_LIB, 0, 0); + EXPECT_EQ(ret, SQL_ERROR); + // Checks for invalid argument return error + VerifyOdbcErrorState(SQL_HANDLE_DBC, this->conn, error_state_HY024); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLSetConnectAttrTranslateOptionUnsupported) { + this->connect(); + + SQLRETURN ret = SQLSetConnectAttr(this->conn, SQL_ATTR_TRANSLATE_OPTION, 0, 0); + + EXPECT_EQ(ret, SQL_ERROR); + VerifyOdbcErrorState(SQL_HANDLE_DBC, this->conn, error_state_HYC00); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLSetConnectAttrTxnIsolationUnsupported) { + this->connect(); + + SQLRETURN ret = + SQLSetConnectAttr(this->conn, SQL_ATTR_TXN_ISOLATION, + reinterpret_cast(SQL_TXN_READ_UNCOMMITTED), 0); + EXPECT_EQ(ret, SQL_ERROR); + VerifyOdbcErrorState(SQL_HANDLE_DBC, this->conn, error_state_HYC00); + + this->disconnect(); +} + +#ifdef SQL_ATTR_DBC_INFO_TOKEN +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetConnectAttrDbcInfoTokenSetOnly) { + this->connect(); + + // Verify that set-only attribute cannot be read + SQLPOINTER ptr = NULL; + SQLRETURN ret = SQLGetConnectAttr(this->conn, SQL_ATTR_DBC_INFO_TOKEN, ptr, 0, 0); + + EXPECT_EQ(ret, SQL_ERROR); + VerifyOdbcErrorState(SQL_HANDLE_DBC, this->conn, error_state_HY092); + + this->disconnect(); +} +#endif + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetConnectAttrOdbcCursorsDMOnly) { + this->connect(); + + // Verify that DM-only attribute is handled by driver manager + SQLULEN cursor_attr; + SQLRETURN ret = + SQLGetConnectAttr(this->conn, SQL_ATTR_ODBC_CURSORS, &cursor_attr, 0, 0); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(cursor_attr, SQL_CUR_USE_DRIVER); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetConnectAttrTraceDMOnly) { + this->connect(); + + // Verify that DM-only attribute is handled by driver manager + SQLUINTEGER trace; + SQLRETURN ret = SQLGetConnectAttr(this->conn, SQL_ATTR_TRACE, &trace, 0, 0); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(trace, SQL_OPT_TRACE_OFF); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetConnectAttrTraceFileDMOnly) { + this->connect(); + + // Verify that DM-only attribute is handled by driver manager + SQLWCHAR outstr[ODBC_BUFFER_SIZE]; + SQLINTEGER outstrlen; + SQLRETURN ret = SQLGetConnectAttr(this->conn, SQL_ATTR_TRACEFILE, outstr, + ODBC_BUFFER_SIZE, &outstrlen); + + EXPECT_EQ(ret, SQL_SUCCESS); + // Length is returned in bytes for SQLGetConnectAttr, + // we want the number of characters + outstrlen /= driver::odbcabstraction::GetSqlWCharSize(); + std::string out_connection_string = + ODBC::SqlWcharToString(outstr, static_cast(outstrlen)); + EXPECT_TRUE(!out_connection_string.empty()); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetConnectAttrTranslateLibUnsupported) { + this->connect(); + + SQLWCHAR outstr[ODBC_BUFFER_SIZE]; + SQLINTEGER outstrlen; + SQLRETURN ret = SQLGetConnectAttr(this->conn, SQL_ATTR_TRANSLATE_LIB, outstr, + ODBC_BUFFER_SIZE, &outstrlen); + + EXPECT_EQ(ret, SQL_ERROR); + VerifyOdbcErrorState(SQL_HANDLE_DBC, this->conn, error_state_HYC00); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetConnectAttrTranslateOptionUnsupported) { + this->connect(); + + SQLINTEGER option; + SQLRETURN ret = SQLGetConnectAttr(this->conn, SQL_ATTR_TRANSLATE_OPTION, &option, 0, 0); + + EXPECT_EQ(ret, SQL_ERROR); + VerifyOdbcErrorState(SQL_HANDLE_DBC, this->conn, error_state_HYC00); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetConnectAttrTxnIsolationUnsupported) { + this->connect(); + + SQLINTEGER isolation; + SQLRETURN ret = SQLGetConnectAttr(this->conn, SQL_ATTR_TXN_ISOLATION, &isolation, 0, 0); + + EXPECT_EQ(ret, SQL_ERROR); + VerifyOdbcErrorState(SQL_HANDLE_DBC, this->conn, error_state_HYC00); + + this->disconnect(); +} + +#ifdef SQL_ATTR_ASYNC_DBC_FUNCTIONS_ENABLE +TYPED_TEST(FlightSQLODBCTestBase, + TestSQLGetConnectAttrAsyncDbcFunctionsEnableUnsupported) { + this->connect(); + + // Verifies that the Windows driver manager returns HY114 for unsupported functionality + SQLUINTEGER enable; + SQLRETURN ret = + SQLGetConnectAttr(this->conn, SQL_ATTR_ASYNC_DBC_FUNCTIONS_ENABLE, &enable, 0, 0); + + EXPECT_EQ(ret, SQL_ERROR); + VerifyOdbcErrorState(SQL_HANDLE_DBC, this->conn, error_state_HY114); + + this->disconnect(); +} +#endif + +// Tests for supported attributes + +#ifdef SQL_ATTR_ASYNC_DBC_EVENT +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetConnectAttrAsyncDbcEventDefault) { + this->connect(); + + SQLPOINTER ptr = NULL; + SQLRETURN ret = SQLGetConnectAttr(this->conn, SQL_ATTR_ASYNC_DBC_EVENT, ptr, 0, 0); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(ptr, reinterpret_cast(NULL)); + + this->disconnect(); +} +#endif + +#ifdef SQL_ATTR_ASYNC_DBC_PCALLBACK +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetConnectAttrAsyncDbcPcallbackDefault) { + this->connect(); + + SQLPOINTER ptr = NULL; + SQLRETURN ret = SQLGetConnectAttr(this->conn, SQL_ATTR_ASYNC_DBC_PCALLBACK, ptr, 0, 0); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(ptr, reinterpret_cast(NULL)); + + this->disconnect(); +} +#endif + +#ifdef SQL_ATTR_ASYNC_DBC_PCONTEXT +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetConnectAttrAsyncDbcPcontextDefault) { + this->connect(); + + SQLPOINTER ptr = NULL; + SQLRETURN ret = SQLGetConnectAttr(this->conn, SQL_ATTR_ASYNC_DBC_PCONTEXT, ptr, 0, 0); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(ptr, reinterpret_cast(NULL)); + + this->disconnect(); +} +#endif + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetConnectAttrAsyncEnableDefault) { + this->connect(); + + SQLULEN enable; + SQLRETURN ret = SQLGetConnectAttr(this->conn, SQL_ATTR_ASYNC_ENABLE, &enable, 0, 0); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(enable, SQL_ASYNC_ENABLE_OFF); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetConnectAttrAutoIpdDefault) { + this->connect(); + + SQLUINTEGER ipd; + SQLRETURN ret = SQLGetConnectAttr(this->conn, SQL_ATTR_AUTO_IPD, &ipd, 0, 0); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(ipd, static_cast(SQL_FALSE)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetConnectAttrAutocommitDefault) { + this->connect(); + + SQLUINTEGER auto_commit; + SQLRETURN ret = SQLGetConnectAttr(this->conn, SQL_ATTR_AUTOCOMMIT, &auto_commit, 0, 0); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(auto_commit, SQL_AUTOCOMMIT_ON); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetConnectAttrEnlistInDtcDefault) { + this->connect(); + + SQLPOINTER ptr = NULL; + SQLRETURN ret = SQLGetConnectAttr(this->conn, SQL_ATTR_ENLIST_IN_DTC, ptr, 0, 0); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(ptr, reinterpret_cast(NULL)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetConnectAttrQuietModeDefault) { + this->connect(); + + HWND ptr = NULL; + SQLRETURN ret = SQLGetConnectAttr(this->conn, SQL_ATTR_QUIET_MODE, ptr, 0, 0); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(ptr, reinterpret_cast(NULL)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLSetConnectAttrAccessModeValid) { + this->connect(); + + // The driver always returns SQL_MODE_READ_WRITE + + // Check default value first + SQLUINTEGER mode = -1; + SQLRETURN ret = SQLGetConnectAttr(this->conn, SQL_ATTR_ACCESS_MODE, &mode, 0, 0); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(mode, SQL_MODE_READ_WRITE); + + ret = SQLSetConnectAttr(this->conn, SQL_ATTR_ACCESS_MODE, + reinterpret_cast(SQL_MODE_READ_WRITE), 0); + + EXPECT_EQ(ret, SQL_SUCCESS); + + mode = -1; + + ret = SQLGetConnectAttr(this->conn, SQL_ATTR_ACCESS_MODE, &mode, 0, 0); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(mode, SQL_MODE_READ_WRITE); + + // Attempt to set to SQL_MODE_READ_ONLY, driver should return warning and not error + ret = SQLSetConnectAttr(this->conn, SQL_ATTR_ACCESS_MODE, + reinterpret_cast(SQL_MODE_READ_ONLY), 0); + + EXPECT_EQ(ret, SQL_SUCCESS_WITH_INFO); + + // Verify warning status + VerifyOdbcErrorState(SQL_HANDLE_DBC, this->conn, error_state_01S02); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLSetConnectAttrConnectionTimeoutValid) { + this->connect(); + + // Check default value first + SQLUINTEGER timeout = -1; + SQLRETURN ret = + SQLGetConnectAttr(this->conn, SQL_ATTR_CONNECTION_TIMEOUT, &timeout, 0, 0); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(timeout, 0); + + ret = SQLSetConnectAttr(this->conn, SQL_ATTR_CONNECTION_TIMEOUT, + reinterpret_cast(42), 0); + + EXPECT_EQ(ret, SQL_SUCCESS); + + timeout = -1; + + ret = SQLGetConnectAttr(this->conn, SQL_ATTR_CONNECTION_TIMEOUT, &timeout, 0, 0); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(timeout, 42); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLSetConnectAttrLoginTimeoutValid) { + this->connect(); + + // Check default value first + SQLUINTEGER timeout = -1; + SQLRETURN ret = SQLGetConnectAttr(this->conn, SQL_ATTR_LOGIN_TIMEOUT, &timeout, 0, 0); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(timeout, 0); + + ret = SQLSetConnectAttr(this->conn, SQL_ATTR_LOGIN_TIMEOUT, + reinterpret_cast(42), 0); + + EXPECT_EQ(ret, SQL_SUCCESS); + + timeout = -1; + + ret = SQLGetConnectAttr(this->conn, SQL_ATTR_LOGIN_TIMEOUT, &timeout, 0, 0); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(timeout, 42); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLSetConnectAttrPacketSizeValid) { + this->connect(); + + // The driver always returns 0. PACKET_SIZE value is unused by the driver. + + // Check default value first + SQLUINTEGER size = -1; + SQLRETURN ret = SQLGetConnectAttr(this->conn, SQL_ATTR_PACKET_SIZE, &size, 0, 0); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(size, 0); + + ret = SQLSetConnectAttr(this->conn, SQL_ATTR_PACKET_SIZE, + reinterpret_cast(0), 0); + + EXPECT_EQ(ret, SQL_SUCCESS); + + size = -1; + + ret = SQLGetConnectAttr(this->conn, SQL_ATTR_PACKET_SIZE, &size, 0, 0); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(size, 0); + + // Attempt to set to non-zero value, driver should return warning and not error + ret = SQLSetConnectAttr(this->conn, SQL_ATTR_PACKET_SIZE, + reinterpret_cast(2), 0); + + EXPECT_EQ(ret, SQL_SUCCESS_WITH_INFO); + + // Verify warning status + VerifyOdbcErrorState(SQL_HANDLE_DBC, this->conn, error_state_01S02); + + this->disconnect(); +} + +} // namespace arrow::flight::sql::odbc diff --git a/cpp/src/arrow/flight/sql/odbc/tests/connection_info_test.cc b/cpp/src/arrow/flight/sql/odbc/tests/connection_info_test.cc new file mode 100644 index 00000000000..39bf7e1440b --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/tests/connection_info_test.cc @@ -0,0 +1,1487 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +#include "arrow/flight/sql/odbc/tests/odbc_test_suite.h" + +#ifdef _WIN32 +# include +#endif + +#include +#include +#include + +#include "gtest/gtest.h" + +namespace arrow::flight::sql::odbc { + +// Helper Functions + +// Validate unsigned short SQLUSMALLINT return value +void validate(SQLHDBC connection, SQLUSMALLINT infoType, SQLUSMALLINT expected_value) { + SQLUSMALLINT info_value; + SQLSMALLINT message_length; + + SQLRETURN ret = SQLGetInfo(connection, infoType, &info_value, 0, &message_length); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_EQ(info_value, expected_value); +} + +// Validate unsigned long SQLUINTEGER return value +void validate(SQLHDBC connection, SQLUSMALLINT infoType, SQLUINTEGER expected_value) { + SQLUINTEGER info_value; + SQLSMALLINT message_length; + + SQLRETURN ret = SQLGetInfo(connection, infoType, &info_value, 0, &message_length); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_EQ(info_value, expected_value); +} + +// Validate unsigned length SQLULEN return value +void validate(SQLHDBC connection, SQLUSMALLINT infoType, SQLULEN expected_value) { + SQLULEN info_value; + SQLSMALLINT message_length; + + SQLRETURN ret = SQLGetInfo(connection, infoType, &info_value, 0, &message_length); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_EQ(info_value, expected_value); +} + +// Validate wchar string SQLWCHAR return value +void validate(SQLHDBC connection, SQLUSMALLINT infoType, SQLWCHAR* expected_value) { + SQLWCHAR info_value[ODBC_BUFFER_SIZE] = L""; + SQLSMALLINT message_length; + + SQLRETURN ret = + SQLGetInfo(connection, infoType, info_value, ODBC_BUFFER_SIZE, &message_length); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_EQ(*info_value, *expected_value); +} + +// Validate unsigned long SQLUINTEGER return value is greater than +void validateGreaterThan(SQLHDBC connection, SQLUSMALLINT infoType, + SQLUINTEGER compared_value) { + SQLUINTEGER info_value; + SQLSMALLINT message_length; + + SQLRETURN ret = SQLGetInfo(connection, infoType, &info_value, 0, &message_length); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_GT(info_value, compared_value); +} + +// Validate unsigned length SQLULEN return value is greater than +void validateGreaterThan(SQLHDBC connection, SQLUSMALLINT infoType, + SQLULEN compared_value) { + SQLULEN info_value; + SQLSMALLINT message_length; + + SQLRETURN ret = SQLGetInfo(connection, infoType, &info_value, 0, &message_length); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_GT(info_value, compared_value); +} + +// Validate wchar string SQLWCHAR return value is not empty +void validateNotEmptySQLWCHAR(SQLHDBC connection, SQLUSMALLINT infoType, + bool allowTruncation) { + SQLWCHAR info_value[ODBC_BUFFER_SIZE] = L""; + SQLSMALLINT message_length; + + SQLRETURN ret = + SQLGetInfo(connection, infoType, info_value, ODBC_BUFFER_SIZE, &message_length); + + if (allowTruncation && ret == SQL_SUCCESS_WITH_INFO) { + EXPECT_EQ(ret, SQL_SUCCESS_WITH_INFO); + } else { + EXPECT_EQ(ret, SQL_SUCCESS); + } + + EXPECT_GT(wcslen(info_value), 0); +} + +// Driver Information + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoActiveEnvironments) { + this->connect(); + + validate(this->conn, SQL_ACTIVE_ENVIRONMENTS, static_cast(0)); + + this->disconnect(); +} + +#ifdef SQL_ASYNC_DBC_FUNCTIONS +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoAsyncDbcFunctions) { + this->connect(); + + validate(this->conn, SQL_ASYNC_DBC_FUNCTIONS, + static_cast(SQL_ASYNC_DBC_NOT_CAPABLE)); + + this->disconnect(); +} +#endif + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoAsyncMode) { + this->connect(); + + validate(this->conn, SQL_ASYNC_MODE, static_cast(SQL_AM_NONE)); + + this->disconnect(); +} + +#ifdef SQL_ASYNC_NOTIFICATION +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoAsyncNotification) { + this->connect(); + + validate(this->conn, SQL_ASYNC_NOTIFICATION, + static_cast(SQL_ASYNC_NOTIFICATION_NOT_CAPABLE)); + + this->disconnect(); +} +#endif + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoBatchRowCount) { + this->connect(); + + validate(this->conn, SQL_BATCH_ROW_COUNT, static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoBatchSupport) { + this->connect(); + + validate(this->conn, SQL_BATCH_SUPPORT, static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoDataSourceName) { + this->connect(); + + validate(this->conn, SQL_DATA_SOURCE_NAME, (SQLWCHAR*)L""); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoDriverAwarePoolingSupported) { + // A driver does not need to implement SQL_DRIVER_AWARE_POOLING_SUPPORTED and the + // Driver Manager will not honor to the driver's return value. + this->connect(); + + validate(this->conn, SQL_DRIVER_AWARE_POOLING_SUPPORTED, + static_cast(SQL_DRIVER_AWARE_POOLING_NOT_CAPABLE)); + + this->disconnect(); +} + +// These information types are implemented by the Driver Manager alone. +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoDriverHdbc) { + this->connect(); + + // Value returned from driver manager is the connection address + validateGreaterThan(this->conn, SQL_DRIVER_HDBC, static_cast(0)); + + this->disconnect(); +} + +// These information types are implemented by the Driver Manager alone. +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoDriverHdesc) { + // TODO This is failing due to no descriptor being created + // enable after SQL_HANDLE_DESC is supported + GTEST_SKIP(); + this->connect(); + + validate(this->conn, SQL_DRIVER_HDESC, static_cast(0)); + + this->disconnect(); +} + +// These information types are implemented by the Driver Manager alone. +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoDriverHenv) { + this->connect(); + + // Value returned from driver manager is the env address + validateGreaterThan(this->conn, SQL_DRIVER_HENV, static_cast(0)); + + this->disconnect(); +} + +// These information types are implemented by the Driver Manager alone. +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoDriverHlib) { + this->connect(); + + validateGreaterThan(this->conn, SQL_DRIVER_HLIB, static_cast(0)); + + this->disconnect(); +} + +// These information types are implemented by the Driver Manager alone. +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoDriverHstmt) { + this->connect(); + + // Value returned from driver manager is the stmt address + SQLHSTMT local_stmt = this->stmt; + SQLRETURN ret = SQLGetInfo(this->conn, SQL_DRIVER_HSTMT, &local_stmt, 0, 0); + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_GT(local_stmt, static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoDriverName) { + this->connect(); + + validate(this->conn, SQL_DRIVER_NAME, (SQLWCHAR*)L"Arrow Flight ODBC Driver"); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoDriverOdbcVer) { + this->connect(); + + validate(this->conn, SQL_DRIVER_ODBC_VER, (SQLWCHAR*)L"03.80"); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoDriverVer) { + this->connect(); + + validate(this->conn, SQL_DRIVER_VER, (SQLWCHAR*)L"00.09.0000.0"); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoDynamicCursorAttributes1) { + this->connect(); + + validate(this->conn, SQL_DYNAMIC_CURSOR_ATTRIBUTES1, static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoDynamicCursorAttributes2) { + this->connect(); + + validate(this->conn, SQL_DYNAMIC_CURSOR_ATTRIBUTES2, static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoForwardOnlyCursorAttributes1) { + this->connect(); + + validate(this->conn, SQL_FORWARD_ONLY_CURSOR_ATTRIBUTES1, + static_cast(SQL_CA1_NEXT)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoForwardOnlyCursorAttributes2) { + this->connect(); + + validate(this->conn, SQL_FORWARD_ONLY_CURSOR_ATTRIBUTES2, + static_cast(SQL_CA2_READ_ONLY_CONCURRENCY)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoFileUsage) { + this->connect(); + + validate(this->conn, SQL_FILE_USAGE, static_cast(SQL_FILE_NOT_SUPPORTED)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoGetDataExtensions) { + this->connect(); + + validate(this->conn, SQL_GETDATA_EXTENSIONS, + static_cast(SQL_GD_ANY_COLUMN | SQL_GD_ANY_ORDER)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoSchemaViews) { + this->connect(); + + validate(this->conn, SQL_INFO_SCHEMA_VIEWS, + static_cast(SQL_ISV_TABLES | SQL_ISV_COLUMNS | SQL_ISV_VIEWS)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoKeysetCursorAttributes1) { + this->connect(); + + validate(this->conn, SQL_KEYSET_CURSOR_ATTRIBUTES1, static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoKeysetCursorAttributes2) { + this->connect(); + + validate(this->conn, SQL_KEYSET_CURSOR_ATTRIBUTES2, static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoMaxAsyncConcurrentStatements) { + this->connect(); + + validate(this->conn, SQL_MAX_ASYNC_CONCURRENT_STATEMENTS, static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoMaxConcurrentActivities) { + this->connect(); + + validate(this->conn, SQL_MAX_CONCURRENT_ACTIVITIES, static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoMaxDriverConnections) { + this->connect(); + + validate(this->conn, SQL_MAX_DRIVER_CONNECTIONS, static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoOdbcInterfaceConformance) { + this->connect(); + + validate(this->conn, SQL_ODBC_INTERFACE_CONFORMANCE, + static_cast(SQL_OIC_CORE)); + + this->disconnect(); +} + +// case SQL_ODBC_STANDARD_CLI_CONFORMANCE: - mentioned in SQLGetInfo spec with no +// description and there is no constant for this. +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoOdbcStandardCliConformance) { + // Type commented out in odbc_connection.cc + GTEST_SKIP(); + this->connect(); + + // Type does not exist in sql.h + // validate(this->conn, SQL_ODBC_STANDARD_CLI_CONFORMANCE, + // static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoOdbcVer) { + // This is implemented only in the Driver Manager. + this->connect(); + + validate(this->conn, SQL_ODBC_VER, (SQLWCHAR*)L"03.80.0000"); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoParamArrayRowCounts) { + this->connect(); + + validate(this->conn, SQL_PARAM_ARRAY_ROW_COUNTS, + static_cast(SQL_PARC_NO_BATCH)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoParamArraySelects) { + this->connect(); + + validate(this->conn, SQL_PARAM_ARRAY_SELECTS, + static_cast(SQL_PAS_NO_SELECT)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoRowUpdates) { + this->connect(); + + validate(this->conn, SQL_ROW_UPDATES, (SQLWCHAR*)L"N"); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoSearchPatternEscape) { + this->connect(); + + validate(this->conn, SQL_SEARCH_PATTERN_ESCAPE, (SQLWCHAR*)L"\\"); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoServerName) { + this->connect(); + + validateNotEmptySQLWCHAR(this->conn, SQL_SERVER_NAME, false); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoStaticCursorAttributes1) { + this->connect(); + + validate(this->conn, SQL_STATIC_CURSOR_ATTRIBUTES1, static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoStaticCursorAttributes2) { + this->connect(); + + validate(this->conn, SQL_STATIC_CURSOR_ATTRIBUTES2, static_cast(0)); + + this->disconnect(); +} + +// DBMS Product Information + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoDatabaseName) { + this->connect(); + + validate(this->conn, SQL_DATABASE_NAME, (SQLWCHAR*)L""); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoDbmsName) { + this->connect(); + + validateNotEmptySQLWCHAR(this->conn, SQL_DBMS_NAME, false); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoDbmsVer) { + this->connect(); + + validateNotEmptySQLWCHAR(this->conn, SQL_DBMS_VER, false); + + this->disconnect(); +} + +// Data Source Information + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoAccessibleProcedures) { + this->connect(); + + validate(this->conn, SQL_ACCESSIBLE_PROCEDURES, (SQLWCHAR*)L"N"); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoAccessibleTables) { + this->connect(); + + validate(this->conn, SQL_ACCESSIBLE_TABLES, (SQLWCHAR*)L"Y"); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoBookmarkPersistence) { + this->connect(); + + validate(this->conn, SQL_BOOKMARK_PERSISTENCE, static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoCatalogTerm) { + this->connect(); + + validate(this->conn, SQL_CATALOG_TERM, (SQLWCHAR*)L""); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoCollationSeq) { + this->connect(); + + validate(this->conn, SQL_COLLATION_SEQ, (SQLWCHAR*)L""); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoConcatNullBehavior) { + this->connect(); + + validate(this->conn, SQL_CONCAT_NULL_BEHAVIOR, static_cast(SQL_CB_NULL)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoCursorCommitBehavior) { + this->connect(); + + validate(this->conn, SQL_CURSOR_COMMIT_BEHAVIOR, + static_cast(SQL_CB_CLOSE)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoCursorRollbackBehavior) { + this->connect(); + + validate(this->conn, SQL_CURSOR_ROLLBACK_BEHAVIOR, + static_cast(SQL_CB_CLOSE)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoCursorSensitivity) { + this->connect(); + + validate(this->conn, SQL_CURSOR_SENSITIVITY, static_cast(SQL_UNSPECIFIED)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoDataSourceReadOnly) { + this->connect(); + + validate(this->conn, SQL_DATA_SOURCE_READ_ONLY, (SQLWCHAR*)L"N"); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoDefaultTxnIsolation) { + this->connect(); + + validate(this->conn, SQL_DEFAULT_TXN_ISOLATION, static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoDescribeParameter) { + this->connect(); + + validate(this->conn, SQL_DESCRIBE_PARAMETER, (SQLWCHAR*)L"N"); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoMultResultSets) { + this->connect(); + + validate(this->conn, SQL_MULT_RESULT_SETS, (SQLWCHAR*)L"N"); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoMultipleActiveTxn) { + this->connect(); + + validate(this->conn, SQL_MULTIPLE_ACTIVE_TXN, (SQLWCHAR*)L"N"); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoNeedLongDataLen) { + this->connect(); + + validate(this->conn, SQL_NEED_LONG_DATA_LEN, (SQLWCHAR*)L"N"); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetInfoNullCollation) { + this->connect(); + + validate(this->conn, SQL_NULL_COLLATION, static_cast(SQL_NC_START)); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetInfoProcedureTerm) { + this->connect(); + + validate(this->conn, SQL_PROCEDURE_TERM, (SQLWCHAR*)L""); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoSchemaTerm) { + this->connect(); + + validate(this->conn, SQL_SCHEMA_TERM, (SQLWCHAR*)L"schema"); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoScrollOptions) { + this->connect(); + + validate(this->conn, SQL_SCROLL_OPTIONS, static_cast(SQL_SO_FORWARD_ONLY)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoTableTerm) { + this->connect(); + + validate(this->conn, SQL_TABLE_TERM, (SQLWCHAR*)L"table"); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoTxnCapable) { + this->connect(); + + validate(this->conn, SQL_TXN_CAPABLE, static_cast(SQL_TC_NONE)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoTxnIsolationOption) { + this->connect(); + + validate(this->conn, SQL_TXN_ISOLATION_OPTION, static_cast(0)); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetInfoUserName) { + this->connect(); + + validate(this->conn, SQL_USER_NAME, (SQLWCHAR*)L""); + + this->disconnect(); +} + +// Supported SQL + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoAggregateFunctions) { + this->connect(); + + validate( + this->conn, SQL_AGGREGATE_FUNCTIONS, + static_cast(SQL_AF_ALL | SQL_AF_AVG | SQL_AF_COUNT | SQL_AF_DISTINCT | + SQL_AF_MAX | SQL_AF_MIN | SQL_AF_SUM)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoAlterDomain) { + this->connect(); + + validate(this->conn, SQL_ALTER_DOMAIN, static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoAlterSchema) { + // Type commented out in odbc_connection.cc + GTEST_SKIP(); + this->connect(); + + // Type does not exist in sql.h + // validate(this->conn, SQL_ALTER_SCHEMA, static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoAlterTable) { + this->connect(); + + validate(this->conn, SQL_ALTER_TABLE, static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoAnsiSqlDatetimeLiterals) { + // Type commented out in odbc_connection.cc + GTEST_SKIP(); + this->connect(); + + // Type does not exist in sql.h + // validate(this->conn, SQL_ANSI_SQL_DATETIME_LITERALS, (SQLWCHAR*)L""); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoCatalogLocation) { + this->connect(); + + validate(this->conn, SQL_CATALOG_LOCATION, static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoCatalogName) { + this->connect(); + + validate(this->conn, SQL_CATALOG_NAME, (SQLWCHAR*)L"N"); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoCatalogNameSeparator) { + this->connect(); + + validate(this->conn, SQL_CATALOG_NAME_SEPARATOR, (SQLWCHAR*)L""); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetInfoCatalogUsage) { + this->connect(); + + validate(this->conn, SQL_CATALOG_USAGE, static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoColumnAlias) { + this->connect(); + + validate(this->conn, SQL_COLUMN_ALIAS, (SQLWCHAR*)L"Y"); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetInfoCorrelationName) { + this->connect(); + + validate(this->conn, SQL_CORRELATION_NAME, static_cast(SQL_CN_NONE)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoCreateAssertion) { + this->connect(); + + validate(this->conn, SQL_CREATE_ASSERTION, static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoCreateCharacterSet) { + this->connect(); + + validate(this->conn, SQL_CREATE_CHARACTER_SET, static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoCreateCollation) { + this->connect(); + + validate(this->conn, SQL_CREATE_COLLATION, static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoCreateDomain) { + this->connect(); + + validate(this->conn, SQL_CREATE_DOMAIN, static_cast(0)); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetInfoCreateSchema) { + this->connect(); + + validate(this->conn, SQL_CREATE_SCHEMA, static_cast(1)); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetInfoCreateTable) { + this->connect(); + + validate(this->conn, SQL_CREATE_TABLE, static_cast(1)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoCreateTranslation) { + this->connect(); + + validate(this->conn, SQL_CREATE_TRANSLATION, static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoDdlIndex) { + this->connect(); + + validate(this->conn, SQL_DDL_INDEX, static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoDropAssertion) { + this->connect(); + + validate(this->conn, SQL_DROP_ASSERTION, static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoDropCharacterSet) { + this->connect(); + + validate(this->conn, SQL_DROP_CHARACTER_SET, static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoDropCollation) { + this->connect(); + + validate(this->conn, SQL_DROP_COLLATION, static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoDropDomain) { + this->connect(); + + validate(this->conn, SQL_DROP_DOMAIN, static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoDropSchema) { + this->connect(); + + validate(this->conn, SQL_DROP_SCHEMA, static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoDropTable) { + this->connect(); + + validate(this->conn, SQL_DROP_TABLE, static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoDropTranslation) { + this->connect(); + + validate(this->conn, SQL_DROP_TRANSLATION, static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoDropView) { + this->connect(); + + validate(this->conn, SQL_DROP_VIEW, static_cast(0)); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetInfoExpressionsInOrderby) { + this->connect(); + + validate(this->conn, SQL_EXPRESSIONS_IN_ORDERBY, (SQLWCHAR*)L"N"); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetInfoGroupBy) { + this->connect(); + + validate(this->conn, SQL_GROUP_BY, + static_cast(SQL_GB_GROUP_BY_CONTAINS_SELECT)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoIdentifierCase) { + this->connect(); + + validate(this->conn, SQL_IDENTIFIER_CASE, static_cast(SQL_IC_MIXED)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoIdentifierQuoteChar) { + this->connect(); + + validate(this->conn, SQL_IDENTIFIER_QUOTE_CHAR, (SQLWCHAR*)L"\""); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoIndexKeywords) { + this->connect(); + + validate(this->conn, SQL_INDEX_KEYWORDS, static_cast(SQL_IK_NONE)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoInsertStatement) { + this->connect(); + + validate(this->conn, SQL_INSERT_STATEMENT, + static_cast(SQL_IS_INSERT_LITERALS | SQL_IS_INSERT_SEARCHED | + SQL_IS_SELECT_INTO)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoIntegrity) { + this->connect(); + + validate(this->conn, SQL_INTEGRITY, (SQLWCHAR*)L"N"); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoKeywords) { + this->connect(); + + validateNotEmptySQLWCHAR(this->conn, SQL_KEYWORDS, true); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoLikeEscapeClause) { + this->connect(); + + validate(this->conn, SQL_LIKE_ESCAPE_CLAUSE, (SQLWCHAR*)L"Y"); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetInfoNonNullableColumns) { + this->connect(); + + validate(this->conn, SQL_NON_NULLABLE_COLUMNS, static_cast(SQL_NNC_NULL)); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetInfoOjCapabilities) { + this->connect(); + + validate(this->conn, SQL_OJ_CAPABILITIES, + static_cast(SQL_OJ_LEFT | SQL_OJ_RIGHT | SQL_OJ_FULL)); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetInfoOrderByColumnsInSelect) { + this->connect(); + + validate(this->conn, SQL_ORDER_BY_COLUMNS_IN_SELECT, (SQLWCHAR*)L"Y"); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetInfoOuterJoins) { + this->connect(); + + validate(this->conn, SQL_OUTER_JOINS, (SQLWCHAR*)L"N"); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoProcedures) { + this->connect(); + + validate(this->conn, SQL_PROCEDURES, (SQLWCHAR*)L"N"); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetInfoQuotedIdentifierCase) { + this->connect(); + + validate(this->conn, SQL_QUOTED_IDENTIFIER_CASE, + static_cast(SQL_IC_MIXED)); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetInfoSchemaUsage) { + this->connect(); + + validate(this->conn, SQL_SCHEMA_USAGE, static_cast(SQL_SU_DML_STATEMENTS)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoSpecialCharacters) { + this->connect(); + + validate(this->conn, SQL_SPECIAL_CHARACTERS, (SQLWCHAR*)L""); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoSqlConformance) { + this->connect(); + + validate(this->conn, SQL_SQL_CONFORMANCE, static_cast(SQL_SC_SQL92_ENTRY)); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetInfoSubqueries) { + this->connect(); + + validate(this->conn, SQL_SUBQUERIES, + static_cast(SQL_SQ_CORRELATED_SUBQUERIES | SQL_SQ_COMPARISON | + SQL_SQ_EXISTS | SQL_SQ_IN | SQL_SQ_QUANTIFIED)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoUnion) { + this->connect(); + + validate(this->conn, SQL_UNION, + static_cast(SQL_U_UNION | SQL_U_UNION_ALL)); + + this->disconnect(); +} + +// SQL Limits + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoMaxBinaryLiteralLen) { + this->connect(); + + validate(this->conn, SQL_MAX_BINARY_LITERAL_LEN, static_cast(0)); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetInfoMaxCatalogNameLen) { + this->connect(); + + validate(this->conn, SQL_MAX_CATALOG_NAME_LEN, static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoMaxCharLiteralLen) { + this->connect(); + + validate(this->conn, SQL_MAX_CHAR_LITERAL_LEN, static_cast(0)); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetInfoMaxColumnNameLen) { + this->connect(); + + validate(this->conn, SQL_MAX_COLUMN_NAME_LEN, static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoMaxColumnsInGroupBy) { + this->connect(); + + validate(this->conn, SQL_MAX_COLUMNS_IN_GROUP_BY, static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoMaxColumnsInIndex) { + this->connect(); + + validate(this->conn, SQL_MAX_COLUMNS_IN_INDEX, static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoMaxColumnsInOrderBy) { + this->connect(); + + validate(this->conn, SQL_MAX_COLUMNS_IN_ORDER_BY, static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoMaxColumnsInSelect) { + this->connect(); + + validate(this->conn, SQL_MAX_COLUMNS_IN_SELECT, static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoMaxColumnsInTable) { + this->connect(); + + validate(this->conn, SQL_MAX_COLUMNS_IN_TABLE, static_cast(0)); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetInfoMaxCursorNameLen) { + this->connect(); + + validate(this->conn, SQL_MAX_CURSOR_NAME_LEN, static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoMaxIdentifierLen) { + this->connect(); + + validate(this->conn, SQL_MAX_IDENTIFIER_LEN, static_cast(65535)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoMaxIndexSize) { + this->connect(); + + validate(this->conn, SQL_MAX_INDEX_SIZE, static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoMaxProcedureNameLen) { + this->connect(); + + validate(this->conn, SQL_MAX_PROCEDURE_NAME_LEN, static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoMaxRowSize) { + this->connect(); + + validate(this->conn, SQL_MAX_ROW_SIZE, (SQLWCHAR*)L""); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetInfoMaxRowSizeIncludesLong) { + this->connect(); + + validate(this->conn, SQL_MAX_ROW_SIZE_INCLUDES_LONG, (SQLWCHAR*)L"N"); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetInfoMaxSchemaNameLen) { + this->connect(); + + validate(this->conn, SQL_MAX_SCHEMA_NAME_LEN, static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoMaxStatementLen) { + this->connect(); + + validate(this->conn, SQL_MAX_STATEMENT_LEN, static_cast(0)); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetInfoMaxTableNameLen) { + this->connect(); + + validate(this->conn, SQL_MAX_TABLE_NAME_LEN, static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoMaxTablesInSelect) { + this->connect(); + + validate(this->conn, SQL_MAX_TABLES_IN_SELECT, static_cast(0)); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetInfoMaxUserNameLen) { + this->connect(); + + validate(this->conn, SQL_MAX_USER_NAME_LEN, static_cast(0)); + + this->disconnect(); +} + +// Scalar Function Information + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetInfoConvertFunctions) { + this->connect(); + + validate(this->conn, SQL_CONVERT_FUNCTIONS, static_cast(0)); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetInfoNumericFunctions) { + this->connect(); + + validate(this->conn, SQL_NUMERIC_FUNCTIONS, static_cast(4058942)); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetInfoStringFunctions) { + this->connect(); + + validate(this->conn, SQL_STRING_FUNCTIONS, + static_cast(SQL_FN_STR_LTRIM | SQL_FN_STR_LENGTH | + SQL_FN_STR_REPLACE | SQL_FN_STR_RTRIM)); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetInfoSystemFunctions) { + this->connect(); + + validate(this->conn, SQL_SYSTEM_FUNCTIONS, + static_cast(SQL_FN_SYS_IFNULL | SQL_FN_SYS_USERNAME)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoTimedateAddIntervals) { + this->connect(); + + validate(this->conn, SQL_TIMEDATE_ADD_INTERVALS, + static_cast(SQL_FN_TSI_FRAC_SECOND | SQL_FN_TSI_SECOND | + SQL_FN_TSI_MINUTE | SQL_FN_TSI_HOUR | SQL_FN_TSI_DAY | + SQL_FN_TSI_WEEK | SQL_FN_TSI_MONTH | + SQL_FN_TSI_QUARTER | SQL_FN_TSI_YEAR)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoTimedateDiffIntervals) { + this->connect(); + + validate(this->conn, SQL_TIMEDATE_DIFF_INTERVALS, + static_cast(SQL_FN_TSI_FRAC_SECOND | SQL_FN_TSI_SECOND | + SQL_FN_TSI_MINUTE | SQL_FN_TSI_HOUR | SQL_FN_TSI_DAY | + SQL_FN_TSI_WEEK | SQL_FN_TSI_MONTH | + SQL_FN_TSI_QUARTER | SQL_FN_TSI_YEAR)); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetInfoTimedateFunctions) { + this->connect(); + + validate(this->conn, SQL_TIMEDATE_FUNCTIONS, + static_cast( + SQL_FN_TD_CURRENT_DATE | SQL_FN_TD_CURRENT_TIME | + SQL_FN_TD_CURRENT_TIMESTAMP | SQL_FN_TD_CURDATE | SQL_FN_TD_CURTIME | + SQL_FN_TD_DAYNAME | SQL_FN_TD_DAYOFMONTH | SQL_FN_TD_DAYOFWEEK | + SQL_FN_TD_DAYOFYEAR | SQL_FN_TD_EXTRACT | SQL_FN_TD_HOUR | + SQL_FN_TD_MINUTE | SQL_FN_TD_MONTH | SQL_FN_TD_MONTHNAME | SQL_FN_TD_NOW | + SQL_FN_TD_QUARTER | SQL_FN_TD_SECOND | SQL_FN_TD_TIMESTAMPADD | + SQL_FN_TD_TIMESTAMPDIFF | SQL_FN_TD_WEEK | SQL_FN_TD_YEAR)); + + this->disconnect(); +} + +// Conversion Information + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetInfoConvertBigint) { + this->connect(); + + validate(this->conn, SQL_CONVERT_BIGINT, static_cast(8)); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetInfoConvertBinary) { + this->connect(); + + validate(this->conn, SQL_CONVERT_BINARY, static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoConvertBit) { + this->connect(); + + validate(this->conn, SQL_CONVERT_BIT, static_cast(0)); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetInfoConvertChar) { + this->connect(); + + validate(this->conn, SQL_CONVERT_CHAR, static_cast(0)); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetInfoConvertDate) { + this->connect(); + + validate(this->conn, SQL_CONVERT_DATE, static_cast(0)); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetInfoConvertDecimal) { + this->connect(); + + validate(this->conn, SQL_CONVERT_DECIMAL, static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoConvertDouble) { + this->connect(); + + validate(this->conn, SQL_CONVERT_DOUBLE, static_cast(0)); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetInfoConvertFloat) { + this->connect(); + + validate(this->conn, SQL_CONVERT_FLOAT, static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoConvertInteger) { + this->connect(); + + validate(this->conn, SQL_CONVERT_INTEGER, static_cast(0)); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetInfoConvertIntervalDayTime) { + this->connect(); + + validate(this->conn, SQL_CONVERT_INTERVAL_DAY_TIME, static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoConvertIntervalYearMonth) { + this->connect(); + + validate(this->conn, SQL_CONVERT_INTERVAL_YEAR_MONTH, static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoConvertLongvarbinary) { + this->connect(); + + validate(this->conn, SQL_CONVERT_LONGVARBINARY, static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoConvertLongvarchar) { + this->connect(); + + validate(this->conn, SQL_CONVERT_LONGVARCHAR, static_cast(0)); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetInfoConvertNumeric) { + this->connect(); + + validate(this->conn, SQL_CONVERT_NUMERIC, static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoConvertReal) { + this->connect(); + + validate(this->conn, SQL_CONVERT_REAL, static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoConvertSmallint) { + this->connect(); + + validate(this->conn, SQL_CONVERT_SMALLINT, static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoConvertTime) { + this->connect(); + + validate(this->conn, SQL_CONVERT_TIME, static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoConvertTimestamp) { + this->connect(); + + validate(this->conn, SQL_CONVERT_TIMESTAMP, static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoConvertTinyint) { + this->connect(); + + validate(this->conn, SQL_CONVERT_TINYINT, static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoConvertVarbinary) { + this->connect(); + + validate(this->conn, SQL_CONVERT_VARBINARY, static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetInfoConvertVarchar) { + this->connect(); + + validate(this->conn, SQL_CONVERT_VARCHAR, static_cast(0)); + + this->disconnect(); +} + +} // namespace arrow::flight::sql::odbc diff --git a/cpp/src/arrow/flight/sql/odbc/tests/connection_test.cc b/cpp/src/arrow/flight/sql/odbc/tests/connection_test.cc new file mode 100644 index 00000000000..5bf737d5518 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/tests/connection_test.cc @@ -0,0 +1,1026 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +#include "arrow/flight/sql/odbc/tests/odbc_test_suite.h" + +#ifdef _WIN32 +# include +#endif + +#include +#include +#include + +#include "google/protobuf/message_lite.h" +#include "gtest/gtest.h" + +namespace arrow::flight::sql::odbc { + +TEST(SQLAllocHandle, TestSQLAllocHandleEnv) { + // ODBC Environment + SQLHENV env; + + // Allocate an environment handle + SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &env); + + EXPECT_TRUE(env != NULL); +} + +TEST(SQLAllocEnv, TestSQLAllocEnv) { + // ODBC Environment + SQLHENV env; + + // Allocate an environment handle + SQLRETURN return_value = SQLAllocEnv(&env); + + EXPECT_TRUE(return_value == SQL_SUCCESS); +} + +TEST(SQLAllocHandle, TestSQLAllocHandleConnect) { + // ODBC Environment + SQLHENV env; + SQLHDBC conn; + + // Allocate an environment handle + SQLRETURN return_value = SQLAllocEnv(&env); + + EXPECT_TRUE(return_value == SQL_SUCCESS); + + // Allocate a connection using alloc handle + SQLRETURN return_alloc_handle = SQLAllocHandle(SQL_HANDLE_DBC, env, &conn); + + EXPECT_TRUE(return_alloc_handle == SQL_SUCCESS); +} + +TEST(SQLAllocConnect, TestSQLAllocHandleConnect) { + // ODBC Environment + SQLHENV env; + SQLHDBC conn; + + // Allocate an environment handle + SQLRETURN return_value = SQLAllocEnv(&env); + + EXPECT_TRUE(return_value == SQL_SUCCESS); + + // Allocate a connection using alloc handle + SQLRETURN return_alloc_connect = SQLAllocConnect(env, &conn); + + EXPECT_TRUE(return_alloc_connect == SQL_SUCCESS); +} + +TEST(SQLFreeHandle, TestSQLFreeHandleEnv) { + // ODBC Environment + SQLHENV env; + + // Allocate an environment handle + SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &env); + + // Free an environment handle + SQLRETURN return_value = SQLFreeHandle(SQL_HANDLE_ENV, env); + + EXPECT_TRUE(return_value == SQL_SUCCESS); +} + +TEST(SQLFreeEnv, TestSQLFreeEnv) { + // ODBC Environment + SQLHENV env; + + // Allocate an environment handle + SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &env); + + // Free an environment handle + SQLRETURN return_value = SQLFreeEnv(env); + + EXPECT_TRUE(return_value == SQL_SUCCESS); +} + +TEST(SQLFreeHandle, TestSQLFreeHandleConnect) { + // ODBC Environment + SQLHENV env; + SQLHDBC conn; + + // Allocate an environment handle + SQLRETURN return_value = SQLAllocEnv(&env); + + EXPECT_TRUE(return_value == SQL_SUCCESS); + + // Allocate a connection using alloc handle + SQLRETURN return_alloc_handle = SQLAllocHandle(SQL_HANDLE_DBC, env, &conn); + + EXPECT_TRUE(return_alloc_handle == SQL_SUCCESS); + + // Free the created connection using free handle + SQLRETURN return_free_handle = SQLFreeHandle(SQL_HANDLE_DBC, conn); + + EXPECT_TRUE(return_free_handle == SQL_SUCCESS); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestFreeNullHandles) { + // Verifies attempt to free invalid handle does not cause segfault + // Attempt to free null statement handle + SQLRETURN ret = SQLFreeHandle(SQL_HANDLE_STMT, this->stmt); + + EXPECT_EQ(ret, SQL_INVALID_HANDLE); + + // Attempt to free null connection handle + ret = SQLFreeHandle(SQL_HANDLE_DBC, this->conn); + + EXPECT_EQ(ret, SQL_INVALID_HANDLE); + + // Attempt to free null environment handle + ret = SQLFreeHandle(SQL_HANDLE_ENV, this->env); + + EXPECT_EQ(ret, SQL_INVALID_HANDLE); +} + +TEST(SQLFreeConnect, TestSQLFreeConnect) { + // ODBC Environment + SQLHENV env; + SQLHDBC conn; + + // Allocate an environment handle + SQLRETURN return_env = SQLAllocEnv(&env); + + EXPECT_TRUE(return_env == SQL_SUCCESS); + + // Allocate a connection using alloc handle + SQLRETURN return_alloc_handle = SQLAllocHandle(SQL_HANDLE_DBC, env, &conn); + + EXPECT_TRUE(return_alloc_handle == SQL_SUCCESS); + + // Free the created connection using free connect + SQLRETURN return_free_connect = SQLFreeConnect(conn); + + EXPECT_TRUE(return_free_connect == SQL_SUCCESS); +} + +TEST(SQLGetEnvAttr, TestSQLGetEnvAttrODBCVersion) { + // ODBC Environment + SQLHENV env; + + SQLINTEGER version; + + // Allocate an environment handle + SQLRETURN return_env = SQLAllocEnv(&env); + + EXPECT_TRUE(return_env == SQL_SUCCESS); + + SQLRETURN return_get = SQLGetEnvAttr(env, SQL_ATTR_ODBC_VERSION, &version, 0, 0); + + EXPECT_TRUE(return_get == SQL_SUCCESS); + + EXPECT_EQ(version, SQL_OV_ODBC2); +} + +TEST(SQLSetEnvAttr, TestSQLSetEnvAttrODBCVersionValid) { + // ODBC Environment + SQLHENV env; + + // Allocate an environment handle + SQLRETURN return_env = SQLAllocEnv(&env); + + EXPECT_TRUE(return_env == SQL_SUCCESS); + + // Attempt to set to unsupported version + SQLRETURN return_set = + SQLSetEnvAttr(env, SQL_ATTR_ODBC_VERSION, reinterpret_cast(SQL_OV_ODBC2), 0); + + EXPECT_TRUE(return_set == SQL_SUCCESS); +} + +TEST(SQLSetEnvAttr, TestSQLSetEnvAttrODBCVersionInvalid) { + // ODBC Environment + SQLHENV env; + + // Allocate an environment handle + SQLRETURN return_env = SQLAllocEnv(&env); + + EXPECT_TRUE(return_env == SQL_SUCCESS); + + // Attempt to set to unsupported version + SQLRETURN return_set = + SQLSetEnvAttr(env, SQL_ATTR_ODBC_VERSION, reinterpret_cast(1), 0); + + EXPECT_TRUE(return_set == SQL_ERROR); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetEnvAttrOutputNTS) { + this->connect(); + + SQLINTEGER output_nts; + + SQLRETURN return_get = SQLGetEnvAttr(this->env, SQL_ATTR_OUTPUT_NTS, &output_nts, 0, 0); + + EXPECT_TRUE(return_get == SQL_SUCCESS); + + EXPECT_EQ(output_nts, SQL_TRUE); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetEnvAttrGetLength) { + // Test is disabled because call to SQLGetEnvAttr is handled by the driver manager on + // Windows. This test case can be potentially used on macOS/Linux + GTEST_SKIP(); + + this->connect(); + + SQLINTEGER length; + + SQLRETURN return_get = + SQLGetEnvAttr(this->env, SQL_ATTR_ODBC_VERSION, nullptr, 0, &length); + + EXPECT_TRUE(return_get == SQL_SUCCESS); + + EXPECT_EQ(length, sizeof(SQLINTEGER)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetEnvAttrNullValuePointer) { + // Test is disabled because call to SQLGetEnvAttr is handled by the driver manager on + // Windows. This test case can be potentially used on macOS/Linux + GTEST_SKIP(); + this->connect(); + + SQLRETURN return_get = + SQLGetEnvAttr(this->env, SQL_ATTR_ODBC_VERSION, nullptr, 0, nullptr); + + EXPECT_TRUE(return_get == SQL_ERROR); + + this->disconnect(); +} + +TEST(SQLSetEnvAttr, TestSQLSetEnvAttrOutputNTSValid) { + // ODBC Environment + SQLHENV env; + + // Allocate an environment handle + SQLRETURN return_env = SQLAllocEnv(&env); + + EXPECT_TRUE(return_env == SQL_SUCCESS); + + // Attempt to set to output nts to supported version + SQLRETURN return_set = + SQLSetEnvAttr(env, SQL_ATTR_OUTPUT_NTS, reinterpret_cast(SQL_TRUE), 0); + + EXPECT_TRUE(return_set == SQL_SUCCESS); +} + +TEST(SQLSetEnvAttr, TestSQLSetEnvAttrOutputNTSInvalid) { + // ODBC Environment + SQLHENV env; + + // Allocate an environment handle + SQLRETURN return_env = SQLAllocEnv(&env); + + EXPECT_TRUE(return_env == SQL_SUCCESS); + + // Attempt to set to output nts to unsupported false + SQLRETURN return_set = + SQLSetEnvAttr(env, SQL_ATTR_OUTPUT_NTS, reinterpret_cast(SQL_FALSE), 0); + + EXPECT_TRUE(return_set == SQL_ERROR); +} + +TEST(SQLSetEnvAttr, TestSQLSetEnvAttrNullValuePointer) { + // ODBC Environment + SQLHENV env; + + // Allocate an environment handle + SQLRETURN return_env = SQLAllocEnv(&env); + + EXPECT_TRUE(return_env == SQL_SUCCESS); + + // Attempt to set using bad data pointer + SQLRETURN return_set = SQLSetEnvAttr(env, SQL_ATTR_ODBC_VERSION, nullptr, 0); + + EXPECT_TRUE(return_set == SQL_ERROR); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLDriverConnect) { + // ODBC Environment + SQLHENV env; + SQLHDBC conn; + + // Allocate an environment handle + SQLRETURN ret = SQLAllocEnv(&env); + + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLSetEnvAttr(env, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3, 0); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Allocate a connection using alloc handle + ret = SQLAllocHandle(SQL_HANDLE_DBC, env, &conn); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Connect string + std::string connect_str = this->getConnectionString(); + ASSERT_OK_AND_ASSIGN(std::wstring wconnect_str, + arrow::util::UTF8ToWideString(connect_str)); + std::vector connect_str0(wconnect_str.begin(), wconnect_str.end()); + + SQLWCHAR outstr[ODBC_BUFFER_SIZE] = L""; + SQLSMALLINT outstrlen; + + // Connecting to ODBC server. + ret = SQLDriverConnect(conn, NULL, &connect_str0[0], + static_cast(connect_str0.size()), outstr, + ODBC_BUFFER_SIZE, &outstrlen, SQL_DRIVER_NOPROMPT); + + if (ret != SQL_SUCCESS) { + std::cerr << GetOdbcErrorMessage(SQL_HANDLE_DBC, conn) << std::endl; + } + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Check that outstr has same content as connect_str + std::string out_connection_string = ODBC::SqlWcharToString(outstr, outstrlen); + Connection::ConnPropertyMap out_properties; + Connection::ConnPropertyMap in_properties; + ODBC::ODBCConnection::getPropertiesFromConnString(out_connection_string, + out_properties); + ODBC::ODBCConnection::getPropertiesFromConnString(connect_str, in_properties); + EXPECT_TRUE(compareConnPropertyMap(out_properties, in_properties)); + + // Disconnect from ODBC + ret = SQLDisconnect(conn); + + if (ret != SQL_SUCCESS) { + std::cerr << GetOdbcErrorMessage(SQL_HANDLE_DBC, conn) << std::endl; + } + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Free connection handle + ret = SQLFreeHandle(SQL_HANDLE_DBC, conn); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Free environment handle + ret = SQLFreeHandle(SQL_HANDLE_ENV, env); + + EXPECT_EQ(ret, SQL_SUCCESS); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLDriverConnectDsn) { + // ODBC Environment + SQLHENV env; + SQLHDBC conn; + + // Allocate an environment handle + SQLRETURN ret = SQLAllocEnv(&env); + + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLSetEnvAttr(env, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3, 0); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Allocate a connection using alloc handle + ret = SQLAllocHandle(SQL_HANDLE_DBC, env, &conn); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Connect string + std::string connect_str = this->getConnectionString(); + + // Write connection string content into a DSN, + // must succeed before continuing + ASSERT_TRUE(writeDSN(connect_str)); + + std::string dsn(TEST_DSN); + ASSERT_OK_AND_ASSIGN(std::wstring wdsn, arrow::util::UTF8ToWideString(dsn)); + + // Update connection string to use DSN to connect + connect_str = std::string("DSN=") + std::string(TEST_DSN) + + std::string(";driver={Apache Arrow Flight SQL ODBC Driver};"); + ASSERT_OK_AND_ASSIGN(std::wstring wconnect_str, + arrow::util::UTF8ToWideString(connect_str)); + std::vector connect_str0(wconnect_str.begin(), wconnect_str.end()); + + SQLWCHAR outstr[ODBC_BUFFER_SIZE] = L""; + SQLSMALLINT outstrlen; + + // Connecting to ODBC server. + ret = SQLDriverConnect(conn, NULL, &connect_str0[0], + static_cast(connect_str0.size()), outstr, + ODBC_BUFFER_SIZE, &outstrlen, SQL_DRIVER_NOPROMPT); + + if (ret != SQL_SUCCESS) { + std::cerr << GetOdbcErrorMessage(SQL_HANDLE_DBC, conn) << std::endl; + } + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Remove DSN + EXPECT_TRUE(UnregisterDsn(wdsn)); + + // Disconnect from ODBC + ret = SQLDisconnect(conn); + + if (ret != SQL_SUCCESS) { + std::cerr << GetOdbcErrorMessage(SQL_HANDLE_DBC, conn) << std::endl; + } + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Free connection handle + ret = SQLFreeHandle(SQL_HANDLE_DBC, conn); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Free environment handle + ret = SQLFreeHandle(SQL_HANDLE_ENV, env); + + EXPECT_EQ(ret, SQL_SUCCESS); +} + +TEST_F(FlightSQLODBCRemoteTestBase, TestSQLDriverConnectInvalidUid) { + // ODBC Environment + SQLHENV env; + SQLHDBC conn; + + // Allocate an environment handle + SQLRETURN ret = SQLAllocEnv(&env); + + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLSetEnvAttr(env, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3, 0); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Allocate a connection using alloc handle + ret = SQLAllocHandle(SQL_HANDLE_DBC, env, &conn); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Invalid connect string + std::string connect_str = getInvalidConnectionString(); + + ASSERT_OK_AND_ASSIGN(std::wstring wconnect_str, + arrow::util::UTF8ToWideString(connect_str)); + std::vector connect_str0(wconnect_str.begin(), wconnect_str.end()); + + SQLWCHAR outstr[ODBC_BUFFER_SIZE]; + SQLSMALLINT outstrlen; + + // Connecting to ODBC server. + ret = SQLDriverConnect(conn, NULL, &connect_str0[0], + static_cast(connect_str0.size()), outstr, + ODBC_BUFFER_SIZE, &outstrlen, SQL_DRIVER_NOPROMPT); + + EXPECT_TRUE(ret == SQL_ERROR); + + VerifyOdbcErrorState(SQL_HANDLE_DBC, conn, error_state_28000); + + std::string out_connection_string = ODBC::SqlWcharToString(outstr, outstrlen); + EXPECT_TRUE(out_connection_string.empty()); + + // Free connection handle + ret = SQLFreeHandle(SQL_HANDLE_DBC, conn); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Free environment handle + ret = SQLFreeHandle(SQL_HANDLE_ENV, env); + + EXPECT_EQ(ret, SQL_SUCCESS); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLConnect) { + // ODBC Environment + SQLHENV env; + SQLHDBC conn; + + // Allocate an environment handle + SQLRETURN ret = SQLAllocEnv(&env); + + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLSetEnvAttr(env, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3, 0); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Allocate a connection using alloc handle + ret = SQLAllocHandle(SQL_HANDLE_DBC, env, &conn); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Connect string + std::string connect_str = this->getConnectionString(); + + // Write connection string content into a DSN, + // must succeed before continuing + std::string uid(""), pwd(""); + ASSERT_TRUE(writeDSN(connect_str)); + + std::string dsn(TEST_DSN); + ASSERT_OK_AND_ASSIGN(std::wstring wdsn, arrow::util::UTF8ToWideString(dsn)); + ASSERT_OK_AND_ASSIGN(std::wstring wuid, arrow::util::UTF8ToWideString(uid)); + ASSERT_OK_AND_ASSIGN(std::wstring wpwd, arrow::util::UTF8ToWideString(pwd)); + std::vector dsn0(wdsn.begin(), wdsn.end()); + std::vector uid0(wuid.begin(), wuid.end()); + std::vector pwd0(wpwd.begin(), wpwd.end()); + + // Connecting to ODBC server. Empty uid and pwd should be ignored. + ret = SQLConnect(conn, dsn0.data(), static_cast(dsn0.size()), uid0.data(), + static_cast(uid0.size()), pwd0.data(), + static_cast(pwd0.size())); + + if (ret != SQL_SUCCESS) { + std::cerr << GetOdbcErrorMessage(SQL_HANDLE_DBC, conn) << std::endl; + } + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Remove DSN + EXPECT_TRUE(UnregisterDsn(wdsn)); + + // Disconnect from ODBC + ret = SQLDisconnect(conn); + + if (ret != SQL_SUCCESS) { + std::cerr << GetOdbcErrorMessage(SQL_HANDLE_DBC, conn) << std::endl; + } + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Free connection handle + ret = SQLFreeHandle(SQL_HANDLE_DBC, conn); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Free environment handle + ret = SQLFreeHandle(SQL_HANDLE_ENV, env); + + EXPECT_EQ(ret, SQL_SUCCESS); +} + +TEST_F(FlightSQLODBCRemoteTestBase, TestSQLConnectInputUidPwd) { + // ODBC Environment + SQLHENV env; + SQLHDBC conn; + + // Allocate an environment handle + SQLRETURN ret = SQLAllocEnv(&env); + + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLSetEnvAttr(env, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3, 0); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Allocate a connection using alloc handle + ret = SQLAllocHandle(SQL_HANDLE_DBC, env, &conn); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Connect string + std::string connect_str = getConnectionString(); + + // Retrieve valid uid and pwd, assumes TEST_CONNECT_STR contains uid and pwd + Connection::ConnPropertyMap properties; + ODBC::ODBCConnection::getPropertiesFromConnString(connect_str, properties); + std::string uid_key("uid"); + std::string pwd_key("pwd"); + std::string uid = properties[uid_key]; + std::string pwd = properties[pwd_key]; + + // Write connection string content without uid and pwd into a DSN, + // must succeed before continuing + properties.erase(uid_key); + properties.erase(pwd_key); + ASSERT_TRUE(writeDSN(properties)); + + std::string dsn(TEST_DSN); + ASSERT_OK_AND_ASSIGN(std::wstring wdsn, arrow::util::UTF8ToWideString(dsn)); + ASSERT_OK_AND_ASSIGN(std::wstring wuid, arrow::util::UTF8ToWideString(uid)); + ASSERT_OK_AND_ASSIGN(std::wstring wpwd, arrow::util::UTF8ToWideString(pwd)); + std::vector dsn0(wdsn.begin(), wdsn.end()); + std::vector uid0(wuid.begin(), wuid.end()); + std::vector pwd0(wpwd.begin(), wpwd.end()); + + // Connecting to ODBC server. + ret = SQLConnect(conn, dsn0.data(), static_cast(dsn0.size()), uid0.data(), + static_cast(uid0.size()), pwd0.data(), + static_cast(pwd0.size())); + + if (ret != SQL_SUCCESS) { + std::cerr << GetOdbcErrorMessage(SQL_HANDLE_DBC, conn) << std::endl; + } + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Remove DSN + EXPECT_TRUE(UnregisterDsn(wdsn)); + + // Disconnect from ODBC + ret = SQLDisconnect(conn); + + if (ret != SQL_SUCCESS) { + std::cerr << GetOdbcErrorMessage(SQL_HANDLE_DBC, conn) << std::endl; + } + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Free connection handle + ret = SQLFreeHandle(SQL_HANDLE_DBC, conn); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Free environment handle + ret = SQLFreeHandle(SQL_HANDLE_ENV, env); + + EXPECT_EQ(ret, SQL_SUCCESS); +} + +TEST_F(FlightSQLODBCRemoteTestBase, TestSQLConnectInvalidUid) { + // ODBC Environment + SQLHENV env; + SQLHDBC conn; + + // Allocate an environment handle + SQLRETURN ret = SQLAllocEnv(&env); + + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLSetEnvAttr(env, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3, 0); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Allocate a connection using alloc handle + ret = SQLAllocHandle(SQL_HANDLE_DBC, env, &conn); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Connect string + std::string connect_str = getConnectionString(); + + // Retrieve valid uid and pwd, assumes TEST_CONNECT_STR contains uid and pwd + Connection::ConnPropertyMap properties; + ODBC::ODBCConnection::getPropertiesFromConnString(connect_str, properties); + std::string uid = properties[std::string("uid")]; + std::string pwd = properties[std::string("pwd")]; + + // Append invalid uid to connection string + connect_str += std::string("uid=non_existent_id;"); + + // Write connection string content into a DSN, + // must succeed before continuing + ASSERT_TRUE(writeDSN(connect_str)); + + std::string dsn(TEST_DSN); + ASSERT_OK_AND_ASSIGN(std::wstring wdsn, arrow::util::UTF8ToWideString(dsn)); + ASSERT_OK_AND_ASSIGN(std::wstring wuid, arrow::util::UTF8ToWideString(uid)); + ASSERT_OK_AND_ASSIGN(std::wstring wpwd, arrow::util::UTF8ToWideString(pwd)); + std::vector dsn0(wdsn.begin(), wdsn.end()); + std::vector uid0(wuid.begin(), wuid.end()); + std::vector pwd0(wpwd.begin(), wpwd.end()); + + // Connecting to ODBC server. + ret = SQLConnect(conn, dsn0.data(), static_cast(dsn0.size()), uid0.data(), + static_cast(uid0.size()), pwd0.data(), + static_cast(pwd0.size())); + + // UID specified in DSN will take precedence, + // so connection still fails despite passing valid uid in SQLConnect call + EXPECT_TRUE(ret == SQL_ERROR); + + VerifyOdbcErrorState(SQL_HANDLE_DBC, conn, error_state_28000); + + // Remove DSN + EXPECT_TRUE(UnregisterDsn(wdsn)); + + // Free connection handle + ret = SQLFreeHandle(SQL_HANDLE_DBC, conn); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Free environment handle + ret = SQLFreeHandle(SQL_HANDLE_ENV, env); + + EXPECT_EQ(ret, SQL_SUCCESS); +} + +TEST_F(FlightSQLODBCRemoteTestBase, TestSQLConnectDSNPrecedence) { + // ODBC Environment + SQLHENV env; + SQLHDBC conn; + + // Allocate an environment handle + SQLRETURN ret = SQLAllocEnv(&env); + + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLSetEnvAttr(env, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3, 0); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Allocate a connection using alloc handle + ret = SQLAllocHandle(SQL_HANDLE_DBC, env, &conn); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Connect string + std::string connect_str = getConnectionString(); + + // Write connection string content into a DSN, + // must succeed before continuing + + // Pass incorrect uid and password to SQLConnect, they will be ignored. + // Assumes TEST_CONNECT_STR contains uid and pwd + std::string uid("non_existent_id"), pwd("non_existent_password"); + ASSERT_TRUE(writeDSN(connect_str)); + + std::string dsn(TEST_DSN); + ASSERT_OK_AND_ASSIGN(std::wstring wdsn, arrow::util::UTF8ToWideString(dsn)); + ASSERT_OK_AND_ASSIGN(std::wstring wuid, arrow::util::UTF8ToWideString(uid)); + ASSERT_OK_AND_ASSIGN(std::wstring wpwd, arrow::util::UTF8ToWideString(pwd)); + std::vector dsn0(wdsn.begin(), wdsn.end()); + std::vector uid0(wuid.begin(), wuid.end()); + std::vector pwd0(wpwd.begin(), wpwd.end()); + + // Connecting to ODBC server. + ret = SQLConnect(conn, dsn0.data(), static_cast(dsn0.size()), uid0.data(), + static_cast(uid0.size()), pwd0.data(), + static_cast(pwd0.size())); + + if (ret != SQL_SUCCESS) { + std::cerr << GetOdbcErrorMessage(SQL_HANDLE_DBC, conn) << std::endl; + } + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Remove DSN + EXPECT_TRUE(UnregisterDsn(wdsn)); + + // Disconnect from ODBC + ret = SQLDisconnect(conn); + + if (ret != SQL_SUCCESS) { + std::cerr << GetOdbcErrorMessage(SQL_HANDLE_DBC, conn) << std::endl; + } + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Free connection handle + ret = SQLFreeHandle(SQL_HANDLE_DBC, conn); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Free environment handle + ret = SQLFreeHandle(SQL_HANDLE_ENV, env); + + EXPECT_EQ(ret, SQL_SUCCESS); +} + +TEST(SQLDisconnect, TestSQLDisconnectWithoutConnection) { + // ODBC Environment + SQLHENV env; + SQLHDBC conn; + + // Allocate an environment handle + SQLRETURN ret = SQLAllocEnv(&env); + + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLSetEnvAttr(env, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3, 0); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Allocate a connection using alloc handle + ret = SQLAllocHandle(SQL_HANDLE_DBC, env, &conn); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Attempt to disconnect without a connection, expect to fail + ret = SQLDisconnect(conn); + + EXPECT_TRUE(ret == SQL_ERROR); + + // Expect ODBC driver manager to return error state + VerifyOdbcErrorState(SQL_HANDLE_DBC, conn, error_state_08003); + + // Free connection handle + ret = SQLFreeHandle(SQL_HANDLE_DBC, conn); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Free environment handle + ret = SQLFreeHandle(SQL_HANDLE_ENV, env); + + EXPECT_EQ(ret, SQL_SUCCESS); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestConnect) { + // Verifies connect and disconnect works on its own + this->connect(); + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLAllocFreeStmt) { + this->connect(); + SQLHSTMT statement; + + // Allocate a statement using alloc statement + SQLRETURN ret = SQLAllocStmt(this->conn, &statement); + + EXPECT_EQ(ret, SQL_SUCCESS); + + SQLWCHAR sql_buffer[ODBC_BUFFER_SIZE] = L"SELECT 1"; + ret = SQLExecDirect(statement, sql_buffer, SQL_NTS); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Close statement handle + ret = SQLFreeStmt(statement, SQL_CLOSE); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Free statement handle + ret = SQLFreeStmt(statement, SQL_DROP); + + EXPECT_EQ(ret, SQL_SUCCESS); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestCloseConnectionWithOpenStatement) { + // ODBC Environment + SQLHENV env; + SQLHDBC conn; + SQLHSTMT statement; + + // Allocate an environment handle + SQLRETURN ret = SQLAllocEnv(&env); + + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLSetEnvAttr(env, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3, 0); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Allocate a connection using alloc handle + ret = SQLAllocHandle(SQL_HANDLE_DBC, env, &conn); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Connect string + std::string connect_str = this->getConnectionString(); + ASSERT_OK_AND_ASSIGN(std::wstring wconnect_str, + arrow::util::UTF8ToWideString(connect_str)); + std::vector connect_str0(wconnect_str.begin(), wconnect_str.end()); + + SQLWCHAR outstr[ODBC_BUFFER_SIZE] = L""; + SQLSMALLINT outstrlen; + + // Connecting to ODBC server. + ret = SQLDriverConnect(conn, NULL, &connect_str0[0], + static_cast(connect_str0.size()), outstr, + ODBC_BUFFER_SIZE, &outstrlen, SQL_DRIVER_NOPROMPT); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Allocate a statement using alloc statement + ret = SQLAllocStmt(conn, &statement); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Disconnect from ODBC without closing the statement first + ret = SQLDisconnect(conn); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Free connection handle + ret = SQLFreeHandle(SQL_HANDLE_DBC, conn); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Free environment handle + ret = SQLFreeHandle(SQL_HANDLE_ENV, env); + + EXPECT_EQ(ret, SQL_SUCCESS); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLAllocFreeDesc) { + this->connect(); + SQLHDESC descriptor; + + // Allocate a descriptor using alloc handle + SQLRETURN ret = SQLAllocHandle(SQL_HANDLE_DESC, this->conn, &descriptor); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Free descriptor handle + ret = SQLFreeHandle(SQL_HANDLE_DESC, descriptor); + + EXPECT_EQ(ret, SQL_SUCCESS); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLSetStmtAttrDescriptor) { + this->connect(); + + SQLHDESC apd_descriptor, ard_descriptor; + + // Allocate an APD descriptor using alloc handle + SQLRETURN ret = SQLAllocHandle(SQL_HANDLE_DESC, this->conn, &apd_descriptor); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Allocate an ARD descriptor using alloc handle + ret = SQLAllocHandle(SQL_HANDLE_DESC, this->conn, &ard_descriptor); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Save implicitly allocated internal APD and ARD descriptor pointers + SQLPOINTER internal_apd, internal_ard = nullptr; + + ret = SQLGetStmtAttr(this->stmt, SQL_ATTR_APP_PARAM_DESC, &internal_apd, + sizeof(internal_apd), 0); + + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLGetStmtAttr(this->stmt, SQL_ATTR_APP_ROW_DESC, &internal_ard, + sizeof(internal_ard), 0); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Set APD descriptor to explicitly allocated handle + ret = SQLSetStmtAttr(this->stmt, SQL_ATTR_APP_PARAM_DESC, + reinterpret_cast(apd_descriptor), 0); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Set ARD descriptor to explicitly allocated handle + ret = SQLSetStmtAttr(this->stmt, SQL_ATTR_APP_ROW_DESC, + reinterpret_cast(ard_descriptor), 0); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Verify APD and ARD descriptors are set to explicitly allocated pointers + SQLPOINTER value = nullptr; + + ret = SQLGetStmtAttr(this->stmt, SQL_ATTR_APP_PARAM_DESC, &value, sizeof(value), 0); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_EQ(value, apd_descriptor); + + ret = SQLGetStmtAttr(this->stmt, SQL_ATTR_APP_ROW_DESC, &value, sizeof(value), 0); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_EQ(value, ard_descriptor); + + // Free explicitly allocated APD and ARD descriptor handles + ret = SQLFreeHandle(SQL_HANDLE_DESC, apd_descriptor); + + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLFreeHandle(SQL_HANDLE_DESC, ard_descriptor); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Verify APD and ARD descriptors has been reverted to implicit descriptors + value = nullptr; + + ret = SQLGetStmtAttr(this->stmt, SQL_ATTR_APP_PARAM_DESC, &value, sizeof(value), 0); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_EQ(value, internal_apd); + + ret = SQLGetStmtAttr(this->stmt, SQL_ATTR_APP_ROW_DESC, &value, sizeof(value), 0); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_EQ(value, internal_ard); + + this->disconnect(); +} + +} // namespace arrow::flight::sql::odbc + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/cpp/src/arrow/flight/sql/odbc/tests/errors_test.cc b/cpp/src/arrow/flight/sql/odbc/tests/errors_test.cc new file mode 100644 index 00000000000..276c16a113f --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/tests/errors_test.cc @@ -0,0 +1,731 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +#include "arrow/flight/sql/odbc/tests/odbc_test_suite.h" + +#ifdef _WIN32 +# include +#endif + +#include +#include +#include + +#include "gtest/gtest.h" + +namespace arrow::flight::sql::odbc { + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetDiagFieldWForConnectFailure) { + // ODBC Environment + SQLHENV env; + SQLHDBC conn; + + // Allocate an environment handle + SQLRETURN ret = SQLAllocEnv(&env); + + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLSetEnvAttr(env, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3, 0); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Allocate a connection using alloc handle + ret = SQLAllocHandle(SQL_HANDLE_DBC, env, &conn); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Invalid connect string + std::string connect_str = this->getInvalidConnectionString(); + + ASSERT_OK_AND_ASSIGN(std::wstring wconnect_str, + arrow::util::UTF8ToWideString(connect_str)); + std::vector connect_str0(wconnect_str.begin(), wconnect_str.end()); + + SQLWCHAR outstr[ODBC_BUFFER_SIZE]; + SQLSMALLINT outstrlen; + + // Connecting to ODBC server. + ret = SQLDriverConnect(conn, NULL, &connect_str0[0], + static_cast(connect_str0.size()), outstr, + ODBC_BUFFER_SIZE, &outstrlen, SQL_DRIVER_NOPROMPT); + + EXPECT_EQ(ret, SQL_ERROR); + + // Retrieve all supported header level and record level data + SQLSMALLINT HEADER_LEVEL = 0; + SQLSMALLINT RECORD_1 = 1; + + // SQL_DIAG_NUMBER + SQLINTEGER diag_number; + SQLSMALLINT diag_number_length; + + ret = SQLGetDiagField(SQL_HANDLE_DBC, conn, HEADER_LEVEL, SQL_DIAG_NUMBER, &diag_number, + sizeof(SQLINTEGER), &diag_number_length); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_EQ(diag_number, 1); + + // SQL_DIAG_SERVER_NAME + SQLWCHAR server_name[ODBC_BUFFER_SIZE]; + SQLSMALLINT server_name_length; + + ret = SQLGetDiagField(SQL_HANDLE_DBC, conn, RECORD_1, SQL_DIAG_SERVER_NAME, server_name, + ODBC_BUFFER_SIZE, &server_name_length); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // SQL_DIAG_MESSAGE_TEXT + SQLWCHAR message_text[ODBC_BUFFER_SIZE]; + SQLSMALLINT message_text_length; + + ret = SQLGetDiagField(SQL_HANDLE_DBC, conn, RECORD_1, SQL_DIAG_MESSAGE_TEXT, + message_text, ODBC_BUFFER_SIZE, &message_text_length); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_GT(message_text_length, 100); + + // SQL_DIAG_NATIVE + SQLINTEGER diag_native; + SQLSMALLINT diag_native_length; + + ret = SQLGetDiagField(SQL_HANDLE_DBC, conn, RECORD_1, SQL_DIAG_NATIVE, &diag_native, + sizeof(diag_native), &diag_native_length); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_EQ(diag_native, 200); + + // SQL_DIAG_SQLSTATE + const SQLSMALLINT sql_state_size = 6; + SQLWCHAR sql_state[sql_state_size]; + SQLSMALLINT sql_state_length; + ret = SQLGetDiagField(SQL_HANDLE_DBC, conn, RECORD_1, SQL_DIAG_SQLSTATE, sql_state, + sql_state_size * driver::odbcabstraction::GetSqlWCharSize(), + &sql_state_length); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_EQ(std::wstring(sql_state), std::wstring(L"28000")); + + // Free connection handle + ret = SQLFreeHandle(SQL_HANDLE_DBC, conn); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Free environment handle + ret = SQLFreeHandle(SQL_HANDLE_ENV, env); + + EXPECT_EQ(ret, SQL_SUCCESS); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetDiagFieldWForConnectFailureNTS) { + // Test is disabled because driver manager on Windows does not pass through SQL_NTS + // This test case can be potentially used on macOS/Linux + GTEST_SKIP(); + // ODBC Environment + SQLHENV env; + SQLHDBC conn; + + // Allocate an environment handle + SQLRETURN ret = SQLAllocEnv(&env); + + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLSetEnvAttr(env, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3, 0); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Allocate a connection using alloc handle + ret = SQLAllocHandle(SQL_HANDLE_DBC, env, &conn); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Invalid connect string + std::string connect_str = this->getInvalidConnectionString(); + + ASSERT_OK_AND_ASSIGN(std::wstring wconnect_str, + arrow::util::UTF8ToWideString(connect_str)); + std::vector connect_str0(wconnect_str.begin(), wconnect_str.end()); + + SQLWCHAR outstr[ODBC_BUFFER_SIZE]; + SQLSMALLINT outstrlen; + + // Connecting to ODBC server. + ret = SQLDriverConnect(conn, NULL, &connect_str0[0], + static_cast(connect_str0.size()), outstr, + ODBC_BUFFER_SIZE, &outstrlen, SQL_DRIVER_NOPROMPT); + + EXPECT_EQ(ret, SQL_ERROR); + + // Retrieve all supported header level and record level data + SQLSMALLINT RECORD_1 = 1; + + // SQL_DIAG_MESSAGE_TEXT SQL_NTS + SQLWCHAR message_text[ODBC_BUFFER_SIZE]; + SQLSMALLINT message_text_length; + + message_text[ODBC_BUFFER_SIZE - 1] = '\0'; + + ret = SQLGetDiagField(SQL_HANDLE_DBC, conn, RECORD_1, SQL_DIAG_MESSAGE_TEXT, + message_text, SQL_NTS, &message_text_length); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_GT(message_text_length, 100); + + // Free connection handle + ret = SQLFreeHandle(SQL_HANDLE_DBC, conn); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Free environment handle + ret = SQLFreeHandle(SQL_HANDLE_ENV, env); + + EXPECT_EQ(ret, SQL_SUCCESS); +} + +TYPED_TEST(FlightSQLODBCTestBase, + TestSQLGetDiagFieldWForDescriptorFailureFromDriverManager) { + this->connect(); + SQLHDESC descriptor; + + // Allocate a descriptor using alloc handle + SQLRETURN ret = SQLAllocHandle(SQL_HANDLE_DESC, this->conn, &descriptor); + + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLGetDescField(descriptor, 1, SQL_DESC_DATETIME_INTERVAL_CODE, 0, 0, 0); + + EXPECT_EQ(ret, SQL_ERROR); + + // Retrieve all supported header level and record level data + SQLSMALLINT HEADER_LEVEL = 0; + SQLSMALLINT RECORD_1 = 1; + + // SQL_DIAG_NUMBER + SQLINTEGER diag_number; + SQLSMALLINT diag_number_length; + + ret = SQLGetDiagField(SQL_HANDLE_DESC, descriptor, HEADER_LEVEL, SQL_DIAG_NUMBER, + &diag_number, sizeof(SQLINTEGER), &diag_number_length); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_EQ(diag_number, 1); + + // SQL_DIAG_SERVER_NAME + SQLWCHAR server_name[ODBC_BUFFER_SIZE]; + SQLSMALLINT server_name_length; + + ret = SQLGetDiagField(SQL_HANDLE_DESC, descriptor, RECORD_1, SQL_DIAG_SERVER_NAME, + server_name, ODBC_BUFFER_SIZE, &server_name_length); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // SQL_DIAG_MESSAGE_TEXT + SQLWCHAR message_text[ODBC_BUFFER_SIZE]; + SQLSMALLINT message_text_length; + + ret = SQLGetDiagField(SQL_HANDLE_DESC, descriptor, RECORD_1, SQL_DIAG_MESSAGE_TEXT, + message_text, ODBC_BUFFER_SIZE, &message_text_length); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_GT(message_text_length, 100); + + // SQL_DIAG_NATIVE + SQLINTEGER diag_native; + SQLSMALLINT diag_native_length; + + ret = SQLGetDiagField(SQL_HANDLE_DESC, descriptor, RECORD_1, SQL_DIAG_NATIVE, + &diag_native, sizeof(diag_native), &diag_native_length); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_EQ(diag_native, 0); + + // SQL_DIAG_SQLSTATE + const SQLSMALLINT sql_state_size = 6; + SQLWCHAR sql_state[sql_state_size]; + SQLSMALLINT sql_state_length; + ret = SQLGetDiagField( + SQL_HANDLE_DESC, descriptor, RECORD_1, SQL_DIAG_SQLSTATE, sql_state, + sql_state_size * driver::odbcabstraction::GetSqlWCharSize(), &sql_state_length); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_EQ(std::wstring(sql_state), std::wstring(L"IM001")); + + // Free descriptor handle + ret = SQLFreeHandle(SQL_HANDLE_DESC, descriptor); + + EXPECT_EQ(ret, SQL_SUCCESS); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, + TestSQLGetDiagRecForDescriptorFailureFromDriverManager) { + this->connect(); + SQLHDESC descriptor; + + // Allocate a descriptor using alloc handle + SQLRETURN ret = SQLAllocHandle(SQL_HANDLE_DESC, this->conn, &descriptor); + + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLGetDescField(descriptor, 1, SQL_DESC_DATETIME_INTERVAL_CODE, 0, 0, 0); + + EXPECT_EQ(ret, SQL_ERROR); + + SQLWCHAR sql_state[6]; + SQLINTEGER native_error; + SQLWCHAR message[ODBC_BUFFER_SIZE]; + SQLSMALLINT message_length; + + ret = SQLGetDiagRec(SQL_HANDLE_DESC, descriptor, 1, sql_state, &native_error, message, + ODBC_BUFFER_SIZE, &message_length); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_GT(message_length, 60); + + EXPECT_EQ(native_error, 0); + + // API not implemented error from driver manager + EXPECT_EQ(std::wstring(sql_state), std::wstring(L"IM001")); + + EXPECT_TRUE(!std::wstring(message).empty()); + + // Free descriptor handle + ret = SQLFreeHandle(SQL_HANDLE_DESC, descriptor); + + EXPECT_EQ(ret, SQL_SUCCESS); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetDiagRecForConnectFailure) { + // ODBC Environment + SQLHENV env; + SQLHDBC conn; + + // Allocate an environment handle + SQLRETURN ret = SQLAllocEnv(&env); + + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLSetEnvAttr(env, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3, 0); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Allocate a connection using alloc handle + ret = SQLAllocHandle(SQL_HANDLE_DBC, env, &conn); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Invalid connect string + std::string connect_str = this->getInvalidConnectionString(); + + ASSERT_OK_AND_ASSIGN(std::wstring wconnect_str, + arrow::util::UTF8ToWideString(connect_str)); + std::vector connect_str0(wconnect_str.begin(), wconnect_str.end()); + + SQLWCHAR outstr[ODBC_BUFFER_SIZE]; + SQLSMALLINT outstrlen; + + // Connecting to ODBC server. + ret = SQLDriverConnect(conn, NULL, &connect_str0[0], + static_cast(connect_str0.size()), outstr, + ODBC_BUFFER_SIZE, &outstrlen, SQL_DRIVER_NOPROMPT); + + EXPECT_EQ(ret, SQL_ERROR); + + SQLWCHAR sql_state[6]; + SQLINTEGER native_error; + SQLWCHAR message[ODBC_BUFFER_SIZE]; + SQLSMALLINT message_length; + + ret = SQLGetDiagRec(SQL_HANDLE_DBC, conn, 1, sql_state, &native_error, message, + ODBC_BUFFER_SIZE, &message_length); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_GT(message_length, 120); + + EXPECT_EQ(native_error, 200); + + EXPECT_EQ(std::wstring(sql_state), std::wstring(L"28000")); + + EXPECT_TRUE(!std::wstring(message).empty()); + + // Free connection handle + ret = SQLFreeHandle(SQL_HANDLE_DBC, conn); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Free environment handle + ret = SQLFreeHandle(SQL_HANDLE_ENV, env); + + EXPECT_EQ(ret, SQL_SUCCESS); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetDiagRecInputData) { + // SQLGetDiagRec does not post diagnostic records for itself. + this->connect(); + + SQLWCHAR sql_state[6]; + SQLINTEGER native_error; + SQLWCHAR message[ODBC_BUFFER_SIZE]; + SQLSMALLINT message_length; + + // Pass invalid record number + SQLRETURN ret = SQLGetDiagRec(SQL_HANDLE_DBC, this->conn, 0, sql_state, &native_error, + message, ODBC_BUFFER_SIZE, &message_length); + + EXPECT_EQ(ret, SQL_ERROR); + + // Pass valid record number with null inputs + ret = SQLGetDiagRec(SQL_HANDLE_DBC, this->conn, 1, 0, 0, 0, 0, 0); + + EXPECT_EQ(ret, SQL_NO_DATA); + + // Invalid handle + ret = SQLGetDiagRec(0, 0, 0, 0, 0, 0, 0, 0); + + EXPECT_EQ(ret, SQL_INVALID_HANDLE); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLErrorInputData) { + // Test ODBC 2.0 API SQLError. Driver manager maps SQLError to SQLGetDiagRec. + // SQLError does not post diagnostic records for itself. + this->connect(); + + // Pass valid handles with null inputs + SQLRETURN ret = SQLError(this->env, 0, 0, 0, 0, 0, 0, 0); + + EXPECT_EQ(ret, SQL_NO_DATA); + + ret = SQLError(0, this->conn, 0, 0, 0, 0, 0, 0); + + EXPECT_EQ(ret, SQL_NO_DATA); + + ret = SQLError(0, 0, this->stmt, 0, 0, 0, 0, 0); + + EXPECT_EQ(ret, SQL_NO_DATA); + + // Invalid handle + ret = SQLError(0, 0, 0, 0, 0, 0, 0, 0); + + EXPECT_EQ(ret, SQL_INVALID_HANDLE); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLErrorEnvErrorFromDriverManager) { + // Test ODBC 2.0 API SQLError. + // Known Windows Driver Manager (DM) behavior: + // When application passes buffer length greater than SQL_MAX_MESSAGE_LENGTH (512), + // DM passes 512 as buffer length to SQLError. + this->connect(); + + // Attempt to set environment attribute after connection handle allocation + SQLRETURN ret = SQLSetEnvAttr(this->env, SQL_ATTR_ODBC_VERSION, + reinterpret_cast(SQL_OV_ODBC2), 0); + + EXPECT_EQ(ret, SQL_ERROR); + + SQLWCHAR sql_state[6] = {0}; + SQLINTEGER native_error = 0; + SQLWCHAR message[SQL_MAX_MESSAGE_LENGTH] = {0}; + SQLSMALLINT message_length = 0; + ret = SQLError(this->env, 0, 0, sql_state, &native_error, message, + SQL_MAX_MESSAGE_LENGTH, &message_length); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_GT(message_length, 50); + + EXPECT_EQ(native_error, 0); + + // Function sequence error state from driver manager + EXPECT_EQ(std::wstring(sql_state), std::wstring(L"HY010")); + + EXPECT_TRUE(!std::wstring(message).empty()); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLErrorConnError) { + // Test ODBC 2.0 API SQLError. + // Known Windows Driver Manager (DM) behavior: + // When application passes buffer length greater than SQL_MAX_MESSAGE_LENGTH (512), + // DM passes 512 as buffer length to SQLError. + this->connect(); + + // Attempt to set unsupported attribute + SQLRETURN ret = SQLGetConnectAttr(this->conn, SQL_ATTR_TXN_ISOLATION, 0, 0, 0); + + EXPECT_EQ(ret, SQL_ERROR); + + SQLWCHAR sql_state[6] = {0}; + SQLINTEGER native_error = 0; + SQLWCHAR message[SQL_MAX_MESSAGE_LENGTH] = {0}; + SQLSMALLINT message_length = 0; + ret = SQLError(0, this->conn, 0, sql_state, &native_error, message, + SQL_MAX_MESSAGE_LENGTH, &message_length); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_GT(message_length, 60); + + EXPECT_EQ(native_error, 100); + + // optional feature not supported error state + EXPECT_EQ(std::wstring(sql_state), std::wstring(L"HYC00")); + + EXPECT_TRUE(!std::wstring(message).empty()); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLErrorStmtError) { + // Test ODBC 2.0 API SQLError. + // Known Windows Driver Manager (DM) behavior: + // When application passes buffer length greater than SQL_MAX_MESSAGE_LENGTH (512), + // DM passes 512 as buffer length to SQLError. + this->connect(); + + std::wstring wsql = L"1"; + std::vector sql0(wsql.begin(), wsql.end()); + + SQLRETURN ret = + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size())); + + EXPECT_EQ(ret, SQL_ERROR); + + SQLWCHAR sql_state[6] = {0}; + SQLINTEGER native_error = 0; + SQLWCHAR message[SQL_MAX_MESSAGE_LENGTH] = {0}; + SQLSMALLINT message_length = 0; + ret = SQLError(0, 0, this->stmt, sql_state, &native_error, message, + SQL_MAX_MESSAGE_LENGTH, &message_length); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_GT(message_length, 70); + + EXPECT_EQ(native_error, 100); + + EXPECT_EQ(std::wstring(sql_state), std::wstring(L"HY000")); + + EXPECT_TRUE(!std::wstring(message).empty()); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLErrorStmtWarning) { + // Test ODBC 2.0 API SQLError. + this->connect(); + + std::wstring wsql = L"SELECT 'VERY LONG STRING here' AS string_col;"; + std::vector sql0(wsql.begin(), wsql.end()); + + SQLRETURN ret = + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size())); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + const int len = 17; + SQLCHAR char_val[len]; + SQLLEN buf_len = sizeof(SQLCHAR) * len; + SQLLEN ind; + + ret = SQLGetData(this->stmt, 1, SQL_C_CHAR, &char_val, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS_WITH_INFO); + + SQLWCHAR sql_state[6] = {0}; + SQLINTEGER native_error = 0; + SQLWCHAR message[SQL_MAX_MESSAGE_LENGTH] = {0}; + SQLSMALLINT message_length = 0; + ret = SQLError(0, 0, this->stmt, sql_state, &native_error, message, + SQL_MAX_MESSAGE_LENGTH, &message_length); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_GT(message_length, 50); + + EXPECT_EQ(native_error, 1000100); + + // Verify string truncation warning is reported + EXPECT_EQ(std::wstring(sql_state), std::wstring(L"01004")); + + EXPECT_TRUE(!std::wstring(message).empty()); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLErrorEnvErrorODBCVer2FromDriverManager) { + // Test ODBC 2.0 API SQLError with ODBC ver 2. + // Known Windows Driver Manager (DM) behavior: + // When application passes buffer length greater than SQL_MAX_MESSAGE_LENGTH (512), + // DM passes 512 as buffer length to SQLError. + this->connect(SQL_OV_ODBC2); + + // Attempt to set environment attribute after connection handle allocation + SQLRETURN ret = SQLSetEnvAttr(this->env, SQL_ATTR_ODBC_VERSION, + reinterpret_cast(SQL_OV_ODBC2), 0); + + EXPECT_EQ(ret, SQL_ERROR); + + SQLWCHAR sql_state[6] = {0}; + SQLINTEGER native_error = 0; + SQLWCHAR message[SQL_MAX_MESSAGE_LENGTH] = {0}; + SQLSMALLINT message_length = 0; + ret = SQLError(this->env, 0, 0, sql_state, &native_error, message, + SQL_MAX_MESSAGE_LENGTH, &message_length); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_GT(message_length, 50); + + EXPECT_EQ(native_error, 0); + + // Function sequence error state from driver manager + EXPECT_EQ(std::wstring(sql_state), std::wstring(L"S1010")); + + EXPECT_TRUE(!std::wstring(message).empty()); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLErrorConnErrorODBCVer2) { + // Test ODBC 2.0 API SQLError with ODBC ver 2. + // Known Windows Driver Manager (DM) behavior: + // When application passes buffer length greater than SQL_MAX_MESSAGE_LENGTH (512), + // DM passes 512 as buffer length to SQLError. + this->connect(SQL_OV_ODBC2); + + // Attempt to set unsupported attribute + SQLRETURN ret = SQLGetConnectAttr(this->conn, SQL_ATTR_TXN_ISOLATION, 0, 0, 0); + + EXPECT_EQ(ret, SQL_ERROR); + + SQLWCHAR sql_state[6] = {0}; + SQLINTEGER native_error = 0; + SQLWCHAR message[SQL_MAX_MESSAGE_LENGTH] = {0}; + SQLSMALLINT message_length = 0; + ret = SQLError(0, this->conn, 0, sql_state, &native_error, message, + SQL_MAX_MESSAGE_LENGTH, &message_length); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_GT(message_length, 60); + + EXPECT_EQ(native_error, 100); + + // optional feature not supported error state. Driver Manager maps state to S1C00 + EXPECT_EQ(std::wstring(sql_state), std::wstring(L"S1C00")); + + EXPECT_TRUE(!std::wstring(message).empty()); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLErrorStmtErrorODBCVer2) { + // Test ODBC 2.0 API SQLError with ODBC ver 2. + // Known Windows Driver Manager (DM) behavior: + // When application passes buffer length greater than SQL_MAX_MESSAGE_LENGTH (512), + // DM passes 512 as buffer length to SQLError. + this->connect(SQL_OV_ODBC2); + + std::wstring wsql = L"1"; + std::vector sql0(wsql.begin(), wsql.end()); + + SQLRETURN ret = + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size())); + + EXPECT_EQ(ret, SQL_ERROR); + + SQLWCHAR sql_state[6] = {0}; + SQLINTEGER native_error = 0; + SQLWCHAR message[SQL_MAX_MESSAGE_LENGTH] = {0}; + SQLSMALLINT message_length = 0; + ret = SQLError(0, 0, this->stmt, sql_state, &native_error, message, + SQL_MAX_MESSAGE_LENGTH, &message_length); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_GT(message_length, 70); + + EXPECT_EQ(native_error, 100); + + // Driver Manager maps error state to S1000 + EXPECT_EQ(std::wstring(sql_state), std::wstring(L"S1000")); + + EXPECT_TRUE(!std::wstring(message).empty()); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLErrorStmtWarningODBCVer2) { + // Test ODBC 2.0 API SQLError. + this->connect(SQL_OV_ODBC2); + + std::wstring wsql = L"SELECT 'VERY LONG STRING here' AS string_col;"; + std::vector sql0(wsql.begin(), wsql.end()); + + SQLRETURN ret = + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size())); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + const int len = 17; + SQLCHAR char_val[len]; + SQLLEN buf_len = sizeof(SQLCHAR) * len; + SQLLEN ind; + + ret = SQLGetData(this->stmt, 1, SQL_C_CHAR, &char_val, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS_WITH_INFO); + + SQLWCHAR sql_state[6] = {0}; + SQLINTEGER native_error = 0; + SQLWCHAR message[SQL_MAX_MESSAGE_LENGTH] = {0}; + SQLSMALLINT message_length = 0; + ret = SQLError(0, 0, this->stmt, sql_state, &native_error, message, + SQL_MAX_MESSAGE_LENGTH, &message_length); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_GT(message_length, 50); + + EXPECT_EQ(native_error, 1000100); + + // Verify string truncation warning is reported + EXPECT_EQ(std::wstring(sql_state), std::wstring(L"01004")); + + EXPECT_TRUE(!std::wstring(message).empty()); +} + +} // namespace arrow::flight::sql::odbc diff --git a/cpp/src/arrow/flight/sql/odbc/tests/get_functions_test.cc b/cpp/src/arrow/flight/sql/odbc/tests/get_functions_test.cc new file mode 100644 index 00000000000..3eb54e49195 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/tests/get_functions_test.cc @@ -0,0 +1,240 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +#include "arrow/flight/sql/odbc/tests/odbc_test_suite.h" + +#ifdef _WIN32 +# include +#endif + +#include +#include +#include + +#include "gtest/gtest.h" + +namespace arrow::flight::sql::odbc { + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetFunctionsAllFunctions) { + // Verify driver manager return values for SQLGetFunctions + this->connect(); + + SQLUSMALLINT api_exists[SQL_API_ODBC3_ALL_FUNCTIONS_SIZE]; + const std::vector supported_functions = { + SQL_API_SQLALLOCHANDLE, SQL_API_SQLBINDCOL, SQL_API_SQLGETDIAGFIELD, + SQL_API_SQLCANCEL, SQL_API_SQLCLOSECURSOR, SQL_API_SQLGETDIAGREC, + SQL_API_SQLCOLATTRIBUTE, SQL_API_SQLGETENVATTR, SQL_API_SQLCONNECT, + SQL_API_SQLGETINFO, SQL_API_SQLGETSTMTATTR, SQL_API_SQLDESCRIBECOL, + SQL_API_SQLGETTYPEINFO, SQL_API_SQLDISCONNECT, SQL_API_SQLNUMRESULTCOLS, + SQL_API_SQLPREPARE, SQL_API_SQLEXECDIRECT, SQL_API_SQLEXECUTE, SQL_API_SQLROWCOUNT, + SQL_API_SQLFETCH, SQL_API_SQLSETCONNECTATTR, SQL_API_SQLFETCHSCROLL, + SQL_API_SQLFREEHANDLE, SQL_API_SQLFREESTMT, SQL_API_SQLGETCONNECTATTR, + SQL_API_SQLSETENVATTR, SQL_API_SQLSETSTMTATTR, SQL_API_SQLGETDATA, + SQL_API_SQLCOLUMNS, SQL_API_SQLTABLES, SQL_API_SQLNATIVESQL, + SQL_API_SQLDRIVERCONNECT, SQL_API_SQLMORERESULTS, SQL_API_SQLPRIMARYKEYS, + SQL_API_SQLFOREIGNKEYS, + + // ODBC 2.0 APIs + SQL_API_SQLSETSTMTOPTION, SQL_API_SQLGETSTMTOPTION, SQL_API_SQLSETCONNECTOPTION, + SQL_API_SQLGETCONNECTOPTION, SQL_API_SQLALLOCCONNECT, SQL_API_SQLALLOCENV, + SQL_API_SQLALLOCSTMT, SQL_API_SQLFREEENV, SQL_API_SQLFREECONNECT, + + // Driver Manager APIs + SQL_API_SQLGETFUNCTIONS, SQL_API_SQLDRIVERS, SQL_API_SQLDATASOURCES}; + const std::vector unsupported_functions = { + SQL_API_SQLPUTDATA, SQL_API_SQLGETDESCFIELD, SQL_API_SQLGETDESCREC, + SQL_API_SQLCOPYDESC, SQL_API_SQLPARAMDATA, SQL_API_SQLENDTRAN, + SQL_API_SQLSETCURSORNAME, SQL_API_SQLSETDESCFIELD, SQL_API_SQLSETDESCREC, + SQL_API_SQLGETCURSORNAME, SQL_API_SQLSTATISTICS, SQL_API_SQLSPECIALCOLUMNS, + SQL_API_SQLBINDPARAMETER, SQL_API_SQLBROWSECONNECT, SQL_API_SQLNUMPARAMS, + SQL_API_SQLBULKOPERATIONS, SQL_API_SQLCOLUMNPRIVILEGES, SQL_API_SQLPROCEDURECOLUMNS, + SQL_API_SQLDESCRIBEPARAM, SQL_API_SQLPROCEDURES, SQL_API_SQLSETPOS, + SQL_API_SQLTABLEPRIVILEGES}; + SQLRETURN ret = SQLGetFunctions(this->conn, SQL_API_ODBC3_ALL_FUNCTIONS, api_exists); + + EXPECT_EQ(ret, SQL_SUCCESS); + + for (int api : supported_functions) { + EXPECT_EQ(SQL_FUNC_EXISTS(api_exists, api), SQL_TRUE); + } + + for (int api : unsupported_functions) { + EXPECT_EQ(SQL_FUNC_EXISTS(api_exists, api), SQL_FALSE); + } + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetFunctionsAllFunctionsODBCVer2) { + // Verify driver manager return values for SQLGetFunctions + this->connect(SQL_OV_ODBC2); + + // ODBC 2.0 SQLGetFunctions returns 100 elements according to spec + SQLUSMALLINT api_exists[100]; + const std::vector supported_functions = { + SQL_API_SQLCONNECT, SQL_API_SQLGETINFO, SQL_API_SQLDESCRIBECOL, + SQL_API_SQLGETTYPEINFO, SQL_API_SQLDISCONNECT, SQL_API_SQLNUMRESULTCOLS, + SQL_API_SQLPREPARE, SQL_API_SQLEXECDIRECT, SQL_API_SQLEXECUTE, SQL_API_SQLROWCOUNT, + SQL_API_SQLFETCH, SQL_API_SQLFREESTMT, SQL_API_SQLGETDATA, SQL_API_SQLCOLUMNS, + SQL_API_SQLTABLES, SQL_API_SQLNATIVESQL, SQL_API_SQLDRIVERCONNECT, + SQL_API_SQLMORERESULTS, SQL_API_SQLSETSTMTOPTION, SQL_API_SQLGETSTMTOPTION, + SQL_API_SQLSETCONNECTOPTION, SQL_API_SQLGETCONNECTOPTION, SQL_API_SQLALLOCCONNECT, + SQL_API_SQLALLOCENV, SQL_API_SQLALLOCSTMT, SQL_API_SQLFREEENV, + SQL_API_SQLFREECONNECT, SQL_API_SQLPRIMARYKEYS, SQL_API_SQLFOREIGNKEYS, + + // Driver Manager APIs + SQL_API_SQLGETFUNCTIONS, SQL_API_SQLDRIVERS, SQL_API_SQLDATASOURCES}; + const std::vector unsupported_functions = { + SQL_API_SQLPUTDATA, SQL_API_SQLPARAMDATA, SQL_API_SQLSETCURSORNAME, + SQL_API_SQLGETCURSORNAME, SQL_API_SQLSTATISTICS, SQL_API_SQLSPECIALCOLUMNS, + SQL_API_SQLBINDPARAMETER, SQL_API_SQLBROWSECONNECT, SQL_API_SQLNUMPARAMS, + SQL_API_SQLBULKOPERATIONS, SQL_API_SQLCOLUMNPRIVILEGES, SQL_API_SQLPROCEDURECOLUMNS, + SQL_API_SQLDESCRIBEPARAM, SQL_API_SQLPROCEDURES, SQL_API_SQLSETPOS, + SQL_API_SQLTABLEPRIVILEGES}; + SQLRETURN ret = SQLGetFunctions(this->conn, SQL_API_ALL_FUNCTIONS, api_exists); + + EXPECT_EQ(ret, SQL_SUCCESS); + + for (int api : supported_functions) { + EXPECT_EQ(api_exists[api], SQL_TRUE); + } + + for (int api : unsupported_functions) { + EXPECT_EQ(api_exists[api], SQL_FALSE); + } + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetFunctionsSupportedSingleAPI) { + this->connect(); + + const std::vector supported_functions = { + SQL_API_SQLALLOCHANDLE, SQL_API_SQLBINDCOL, SQL_API_SQLGETDIAGFIELD, + SQL_API_SQLCANCEL, SQL_API_SQLCLOSECURSOR, SQL_API_SQLGETDIAGREC, + SQL_API_SQLCOLATTRIBUTE, SQL_API_SQLGETENVATTR, SQL_API_SQLCONNECT, + SQL_API_SQLGETINFO, SQL_API_SQLGETSTMTATTR, SQL_API_SQLDESCRIBECOL, + SQL_API_SQLGETTYPEINFO, SQL_API_SQLDISCONNECT, SQL_API_SQLNUMRESULTCOLS, + SQL_API_SQLPREPARE, SQL_API_SQLEXECDIRECT, SQL_API_SQLEXECUTE, SQL_API_SQLROWCOUNT, + SQL_API_SQLFETCH, SQL_API_SQLSETCONNECTATTR, SQL_API_SQLFETCHSCROLL, + SQL_API_SQLFREEHANDLE, SQL_API_SQLFREESTMT, SQL_API_SQLGETCONNECTATTR, + SQL_API_SQLSETENVATTR, SQL_API_SQLSETSTMTATTR, SQL_API_SQLGETDATA, + SQL_API_SQLCOLUMNS, SQL_API_SQLTABLES, SQL_API_SQLNATIVESQL, + SQL_API_SQLDRIVERCONNECT, SQL_API_SQLMORERESULTS, SQL_API_SQLPRIMARYKEYS, + SQL_API_SQLFOREIGNKEYS, + + // ODBC 2.0 APIs + SQL_API_SQLSETSTMTOPTION, SQL_API_SQLGETSTMTOPTION, SQL_API_SQLSETCONNECTOPTION, + SQL_API_SQLGETCONNECTOPTION, SQL_API_SQLALLOCCONNECT, SQL_API_SQLALLOCENV, + SQL_API_SQLALLOCSTMT, SQL_API_SQLFREEENV, SQL_API_SQLFREECONNECT, + + // Driver Manager APIs + SQL_API_SQLGETFUNCTIONS, SQL_API_SQLDRIVERS, SQL_API_SQLDATASOURCES}; + SQLUSMALLINT api_exists; + for (SQLUSMALLINT api : supported_functions) { + SQLRETURN ret = SQLGetFunctions(this->conn, api, &api_exists); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_EQ(api_exists, SQL_TRUE); + + api_exists = -1; + } + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetFunctionsUnsupportedSingleAPI) { + this->connect(); + + const std::vector unsupported_functions = { + SQL_API_SQLPUTDATA, SQL_API_SQLGETDESCFIELD, SQL_API_SQLGETDESCREC, + SQL_API_SQLCOPYDESC, SQL_API_SQLPARAMDATA, SQL_API_SQLENDTRAN, + SQL_API_SQLSETCURSORNAME, SQL_API_SQLSETDESCFIELD, SQL_API_SQLSETDESCREC, + SQL_API_SQLGETCURSORNAME, SQL_API_SQLSTATISTICS, SQL_API_SQLSPECIALCOLUMNS, + SQL_API_SQLBINDPARAMETER, SQL_API_SQLBROWSECONNECT, SQL_API_SQLNUMPARAMS, + SQL_API_SQLBULKOPERATIONS, SQL_API_SQLCOLUMNPRIVILEGES, SQL_API_SQLPROCEDURECOLUMNS, + SQL_API_SQLDESCRIBEPARAM, SQL_API_SQLPROCEDURES, SQL_API_SQLSETPOS, + SQL_API_SQLTABLEPRIVILEGES}; + SQLUSMALLINT api_exists; + for (SQLUSMALLINT api : unsupported_functions) { + SQLRETURN ret = SQLGetFunctions(this->conn, api, &api_exists); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_EQ(api_exists, SQL_FALSE); + + api_exists = -1; + } + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetFunctionsSupportedSingleAPIODBCVer2) { + this->connect(SQL_OV_ODBC2); + + const std::vector supported_functions = { + SQL_API_SQLCONNECT, SQL_API_SQLGETINFO, SQL_API_SQLDESCRIBECOL, + SQL_API_SQLGETTYPEINFO, SQL_API_SQLDISCONNECT, SQL_API_SQLNUMRESULTCOLS, + SQL_API_SQLPREPARE, SQL_API_SQLEXECDIRECT, SQL_API_SQLEXECUTE, SQL_API_SQLROWCOUNT, + SQL_API_SQLFETCH, SQL_API_SQLFREESTMT, SQL_API_SQLGETDATA, SQL_API_SQLCOLUMNS, + SQL_API_SQLTABLES, SQL_API_SQLNATIVESQL, SQL_API_SQLDRIVERCONNECT, + SQL_API_SQLMORERESULTS, SQL_API_SQLSETSTMTOPTION, SQL_API_SQLGETSTMTOPTION, + SQL_API_SQLSETCONNECTOPTION, SQL_API_SQLGETCONNECTOPTION, SQL_API_SQLALLOCCONNECT, + SQL_API_SQLALLOCENV, SQL_API_SQLALLOCSTMT, SQL_API_SQLFREEENV, + SQL_API_SQLFREECONNECT, SQL_API_SQLPRIMARYKEYS, SQL_API_SQLFOREIGNKEYS, + + // Driver Manager APIs + SQL_API_SQLGETFUNCTIONS, SQL_API_SQLDRIVERS, SQL_API_SQLDATASOURCES}; + SQLUSMALLINT api_exists; + for (SQLUSMALLINT api : supported_functions) { + SQLRETURN ret = SQLGetFunctions(this->conn, api, &api_exists); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_EQ(api_exists, SQL_TRUE); + + api_exists = -1; + } + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetFunctionsUnsupportedSingleAPIODBCVer2) { + this->connect(SQL_OV_ODBC2); + + const std::vector unsupported_functions = { + SQL_API_SQLPUTDATA, SQL_API_SQLPARAMDATA, SQL_API_SQLSETCURSORNAME, + SQL_API_SQLGETCURSORNAME, SQL_API_SQLSTATISTICS, SQL_API_SQLSPECIALCOLUMNS, + SQL_API_SQLBINDPARAMETER, SQL_API_SQLBROWSECONNECT, SQL_API_SQLNUMPARAMS, + SQL_API_SQLBULKOPERATIONS, SQL_API_SQLCOLUMNPRIVILEGES, SQL_API_SQLPROCEDURECOLUMNS, + SQL_API_SQLDESCRIBEPARAM, SQL_API_SQLPROCEDURES, SQL_API_SQLSETPOS, + SQL_API_SQLTABLEPRIVILEGES}; + SQLUSMALLINT api_exists; + for (SQLUSMALLINT api : unsupported_functions) { + SQLRETURN ret = SQLGetFunctions(this->conn, api, &api_exists); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_EQ(api_exists, SQL_FALSE); + + api_exists = -1; + } + + this->disconnect(); +} + +} // namespace arrow::flight::sql::odbc diff --git a/cpp/src/arrow/flight/sql/odbc/tests/odbc_test_suite.cc b/cpp/src/arrow/flight/sql/odbc/tests/odbc_test_suite.cc new file mode 100644 index 00000000000..bb92bee0713 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/tests/odbc_test_suite.cc @@ -0,0 +1,489 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// For DSN registration. flight_sql_connection.h needs to included first due to conflicts +// with windows.h +#include "arrow/flight/sql/odbc/flight_sql/flight_sql_connection.h" + +#include "arrow/flight/sql/odbc/tests/odbc_test_suite.h" + +// For DSN registration +#include "arrow/flight/sql/odbc/flight_sql/include/flight_sql/config/configuration.h" +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/odbc_connection.h" + +namespace arrow::flight::sql::odbc { + +void FlightSQLODBCRemoteTestBase::allocEnvConnHandles(SQLINTEGER odbc_ver) { + // Allocate an environment handle + SQLRETURN ret = SQLAllocEnv(&env); + + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLSetEnvAttr(env, SQL_ATTR_ODBC_VERSION, + reinterpret_cast(static_cast(odbc_ver)), 0); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Allocate a connection using alloc handle + ret = SQLAllocHandle(SQL_HANDLE_DBC, env, &conn); + + EXPECT_EQ(ret, SQL_SUCCESS); +} + +void FlightSQLODBCRemoteTestBase::connect(SQLINTEGER odbc_ver) { + allocEnvConnHandles(odbc_ver); + std::string connect_str = getConnectionString(); + connectWithString(connect_str); +} + +void FlightSQLODBCRemoteTestBase::connectWithString(std::string connect_str) { + // Connect string + std::vector connect_str0(connect_str.begin(), connect_str.end()); + + SQLWCHAR outstr[ODBC_BUFFER_SIZE]; + SQLSMALLINT outstrlen; + + // Connecting to ODBC server. + SQLRETURN ret = SQLDriverConnect(conn, NULL, &connect_str0[0], + static_cast(connect_str0.size()), outstr, + ODBC_BUFFER_SIZE, &outstrlen, SQL_DRIVER_NOPROMPT); + + if (ret != SQL_SUCCESS) { + std::cerr << GetOdbcErrorMessage(SQL_HANDLE_DBC, conn) << std::endl; + } + + // Assert connection is successful before we continue + ASSERT_TRUE(ret == SQL_SUCCESS); + + // Allocate a statement using alloc handle + ret = SQLAllocHandle(SQL_HANDLE_STMT, conn, &stmt); + + ASSERT_TRUE(ret == SQL_SUCCESS); +} + +void FlightSQLODBCRemoteTestBase::disconnect() { + // Close statement + SQLRETURN ret = SQLFreeHandle(SQL_HANDLE_STMT, stmt); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Disconnect from ODBC + ret = SQLDisconnect(conn); + + if (ret != SQL_SUCCESS) { + std::cerr << GetOdbcErrorMessage(SQL_HANDLE_DBC, conn) << std::endl; + } + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Free connection handle + ret = SQLFreeHandle(SQL_HANDLE_DBC, conn); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Free environment handle + ret = SQLFreeHandle(SQL_HANDLE_ENV, env); + + EXPECT_EQ(ret, SQL_SUCCESS); +} + +std::string FlightSQLODBCRemoteTestBase::getConnectionString() { + std::string connect_str = arrow::internal::GetEnvVar(TEST_CONNECT_STR).ValueOrDie(); + return connect_str; +} + +std::string FlightSQLODBCRemoteTestBase::getInvalidConnectionString() { + std::string connect_str = getConnectionString(); + // Append invalid uid to connection string + connect_str += std::string("uid=non_existent_id;"); + return connect_str; +} + +std::wstring FlightSQLODBCRemoteTestBase::getQueryAllDataTypes() { + std::wstring wsql = + LR"( SELECT + -- Numeric types + -128 as stiny_int_min, 127 as stiny_int_max, + 0 as utiny_int_min, 255 as utiny_int_max, + + -32768 as ssmall_int_min, 32767 as ssmall_int_max, + 0 as usmall_int_min, 65535 as usmall_int_max, + + CAST(-2147483648 AS INTEGER) AS sinteger_min, + CAST(2147483647 AS INTEGER) AS sinteger_max, + CAST(0 AS BIGINT) AS uinteger_min, + CAST(4294967295 AS BIGINT) AS uinteger_max, + + CAST(-9223372036854775808 AS BIGINT) AS sbigint_min, + CAST(9223372036854775807 AS BIGINT) AS sbigint_max, + CAST(0 AS BIGINT) AS ubigint_min, + --Use string to represent unsigned big int due to lack of support from + --remote test server + '18446744073709551615' AS ubigint_max, + + CAST(-999999999 AS DECIMAL(38, 0)) AS decimal_negative, + CAST(999999999 AS DECIMAL(38, 0)) AS decimal_positive, + + CAST(-3.40282347E38 AS FLOAT) AS float_min, CAST(3.40282347E38 AS FLOAT) AS float_max, + + CAST(-1.7976931348623157E308 AS DOUBLE) AS double_min, + CAST(1.7976931348623157E308 AS DOUBLE) AS double_max, + + --Boolean + CAST(false AS BOOLEAN) AS bit_false, + CAST(true AS BOOLEAN) AS bit_true, + + --Character types + 'Z' AS c_char, '你' AS c_wchar, + + '你好' AS c_wvarchar, + + 'XYZ' AS c_varchar, + + --Date / timestamp + CAST(DATE '1400-01-01' AS DATE) AS date_min, + CAST(DATE '9999-12-31' AS DATE) AS date_max, + + CAST(TIMESTAMP '1400-01-01 00:00:00' AS TIMESTAMP) AS timestamp_min, + CAST(TIMESTAMP '9999-12-31 23:59:59' AS TIMESTAMP) AS timestamp_max; + )"; + return wsql; +} + +void FlightSQLODBCRemoteTestBase::SetUp() { + if (arrow::internal::GetEnvVar(TEST_CONNECT_STR).ValueOr("").empty()) { + GTEST_SKIP() << "Skipping FlightSQLODBCRemoteTestBase test: TEST_CONNECT_STR not set"; + } +} + +std::string FindTokenInCallHeaders(const CallHeaders& incoming_headers) { + // Lambda function to compare characters without case sensitivity. + auto char_compare = [](const char& char1, const char& char2) { + return (::toupper(char1) == ::toupper(char2)); + }; + + std::string bearer_token(""); + auto authHeader = incoming_headers.find(kAuthHeader); + if (authHeader != incoming_headers.end()) { + const std::string auth_val(authHeader->second); + if (auth_val.size() > kBearerPrefix.length()) { + if (std::equal(auth_val.begin(), auth_val.begin() + kBearerPrefix.length(), + kBearerPrefix.begin(), char_compare)) { + bearer_token = auth_val.substr(kBearerPrefix.length()); + } + } + } + return bearer_token; +} + +void MockServerMiddleware::SendingHeaders(AddCallHeaders* outgoing_headers) { + std::string bearer_token = FindTokenInCallHeaders(incoming_headers_); + *isValid_ = (bearer_token == std::string(test_token)); +} + +Status MockServerMiddlewareFactory::StartCall( + const CallInfo& info, const ServerCallContext& context, + std::shared_ptr* middleware) { + std::string bearer_token = FindTokenInCallHeaders(context.incoming_headers()); + if (bearer_token == std::string(test_token)) { + *middleware = + std::make_shared(context.incoming_headers(), &isValid_); + } else { + return MakeFlightError(FlightStatusCode::Unauthenticated, + "Invalid token for mock server"); + } + + return Status::OK(); +} + +std::string FlightSQLODBCMockTestBase::getConnectionString() { + std::string connect_str( + "driver={Apache Arrow Flight SQL ODBC Driver};HOST=localhost;port=" + + std::to_string(port) + ";token=" + std::string(test_token) + + ";useEncryption=false;"); + return connect_str; +} + +std::string FlightSQLODBCMockTestBase::getInvalidConnectionString() { + std::string connect_str = getConnectionString(); + // Append invalid token to connection string + connect_str += std::string("token=invalid_token;"); + return connect_str; +} + +std::wstring FlightSQLODBCMockTestBase::getQueryAllDataTypes() { + std::wstring wsql = + LR"( SELECT + -- Numeric types + -128 AS stiny_int_min, 127 AS stiny_int_max, + 0 AS utiny_int_min, 255 AS utiny_int_max, + + -32768 AS ssmall_int_min, 32767 AS ssmall_int_max, + 0 AS usmall_int_min, 65535 AS usmall_int_max, + + CAST(-2147483648 AS INTEGER) AS sinteger_min, + CAST(2147483647 AS INTEGER) AS sinteger_max, + CAST(0 AS INTEGER) AS uinteger_min, + CAST(4294967295 AS INTEGER) AS uinteger_max, + + CAST(-9223372036854775808 AS INTEGER) AS sbigint_min, + CAST(9223372036854775807 AS INTEGER) AS sbigint_max, + CAST(0 AS INTEGER) AS ubigint_min, + -- stored as TEXT as SQLite doesn't support unsigned big int + '18446744073709551615' AS ubigint_max, + + CAST('-999999999' AS NUMERIC) AS decimal_negative, + CAST('999999999' AS NUMERIC) AS decimal_positive, + + CAST(-3.40282347E38 AS REAL) AS float_min, + CAST(3.40282347E38 AS REAL) AS float_max, + + CAST(-1.7976931348623157E308 AS REAL) AS double_min, + CAST(1.7976931348623157E308 AS REAL) AS double_max, + + -- Boolean + 0 AS bit_false, + 1 AS bit_true, + + -- Character types + 'Z' AS c_char, + '你' AS c_wchar, + '你好' AS c_wvarchar, + 'XYZ' AS c_varchar, + + DATE('1400-01-01') AS date_min, + DATE('9999-12-31') AS date_max, + + DATETIME('1400-01-01 00:00:00') AS timestamp_min, + DATETIME('9999-12-31 23:59:59') AS timestamp_max; + )"; + return wsql; +} + +void FlightSQLODBCMockTestBase::CreateTestTables() { + ASSERT_OK(server->ExecuteSql(R"( + CREATE TABLE TestTable ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + keyName varchar(100), + value int); + + INSERT INTO TestTable (keyName, value) VALUES ('One', 1); + INSERT INTO TestTable (keyName, value) VALUES ('Two', 0); + INSERT INTO TestTable (keyName, value) VALUES ('Three', -1); + )")); +} + +void FlightSQLODBCMockTestBase::CreateTableAllDataType() { + // Limitation on mock SQLite server: + // Only int64, float64, binary, and utf8 Arrow Types are supported by + // SQLiteFlightSqlServer::Impl::DoGetTables + ASSERT_OK(server->ExecuteSql(R"( + CREATE TABLE AllTypesTable( + bigint_col INTEGER PRIMARY KEY AUTOINCREMENT, + char_col varchar(100), + varbinary_col BLOB, + double_col REAL); + + INSERT INTO AllTypesTable ( + char_col, + varbinary_col, + double_col) VALUES ( + '1st Row', + X'31737420726F77', + 3.14159 + ); + )")); +} + +void FlightSQLODBCMockTestBase::CreateUnicodeTable() { + std::string unicodeSql = arrow::util::WideStringToUTF8( + LR"( + CREATE TABLE 数据( + 资料 varchar(100)); + + INSERT INTO 数据 (资料) VALUES ('第一行'); + INSERT INTO 数据 (资料) VALUES ('二行'); + INSERT INTO 数据 (资料) VALUES ('3rd Row'); + )") + .ValueOr(""); + ASSERT_OK(server->ExecuteSql(unicodeSql)); +} + +void FlightSQLODBCMockTestBase::SetUp() { + ASSERT_OK_AND_ASSIGN(auto location, Location::ForGrpcTcp("0.0.0.0", 0)); + arrow::flight::FlightServerOptions options(location); + options.auth_handler = std::make_unique(); + options.middleware.push_back( + {"bearer-auth-server", std::make_shared()}); + ASSERT_OK_AND_ASSIGN(server, + arrow::flight::sql::example::SQLiteFlightSqlServer::Create()); + ASSERT_OK(server->Init(options)); + + port = server->port(); + ASSERT_OK_AND_ASSIGN(location, Location::ForGrpcTcp("localhost", port)); + ASSERT_OK_AND_ASSIGN(auto client, arrow::flight::FlightClient::Connect(location)); +} + +void FlightSQLODBCMockTestBase::TearDown() { ASSERT_OK(server->Shutdown()); } + +bool compareConnPropertyMap(Connection::ConnPropertyMap map1, + Connection::ConnPropertyMap map2) { + if (map1.size() != map2.size()) return false; + + for (const auto& [key, value] : map1) { + if (value != map2[key]) return false; + } + + return true; +} + +void VerifyOdbcErrorState(SQLSMALLINT handle_type, SQLHANDLE handle, + std::string_view expected_state) { + using ODBC::SqlWcharToString; + + SQLWCHAR sql_state[7] = {}; + SQLINTEGER native_code; + + SQLWCHAR message[ODBC_BUFFER_SIZE] = {}; + SQLSMALLINT reallen = 0; + + // On Windows, reallen is in bytes. On Linux, reallen is in chars. + // So, not using reallen + SQLGetDiagRec(handle_type, handle, 1, sql_state, &native_code, message, + ODBC_BUFFER_SIZE, &reallen); + + std::string res = SqlWcharToString(sql_state); + + EXPECT_EQ(res, expected_state); +} + +std::string GetOdbcErrorMessage(SQLSMALLINT handle_type, SQLHANDLE handle) { + using ODBC::SqlWcharToString; + + SQLWCHAR sql_state[7] = {}; + SQLINTEGER native_code; + + SQLWCHAR message[ODBC_BUFFER_SIZE] = {}; + SQLSMALLINT reallen = 0; + + // On Windows, reallen is in bytes. On Linux, reallen is in chars. + // So, not using reallen + SQLGetDiagRec(handle_type, handle, 1, sql_state, &native_code, message, + ODBC_BUFFER_SIZE, &reallen); + + std::string res = SqlWcharToString(sql_state); + + if (res.empty() || !message[0]) { + res = "Cannot find ODBC error message"; + } else { + res.append(": ").append(SqlWcharToString(message)); + } + + return res; +} + +bool writeDSN(std::string connection_str) { + Connection::ConnPropertyMap properties; + + ODBC::ODBCConnection::getPropertiesFromConnString(connection_str, properties); + return writeDSN(properties); +} + +bool writeDSN(Connection::ConnPropertyMap properties) { + using driver::flight_sql::FlightSqlConnection; + using driver::flight_sql::config::Configuration; + using driver::odbcabstraction::Connection; + using ODBC::ODBCConnection; + + Configuration config; + config.Set(FlightSqlConnection::DSN, std::string(TEST_DSN)); + + for (const auto& [key, value] : properties) { + config.Set(key, value); + } + + std::string driver = config.Get(FlightSqlConnection::DRIVER); + std::wstring wDriver = arrow::util::UTF8ToWideString(driver).ValueOr(L""); + return RegisterDsn(config, wDriver.c_str()); +} + +std::wstring ConvertToWString(const std::vector& strVal, SQLSMALLINT strLen) { + std::wstring attrStr; + if (strLen == 0) { + attrStr = std::wstring(&strVal[0]); + } else { + EXPECT_GT(strLen, 0); + EXPECT_LE(strLen, static_cast(ODBC_BUFFER_SIZE)); + attrStr = + std::wstring(strVal.begin(), strVal.begin() + strLen / ODBC::GetSqlWCharSize()); + } + return attrStr; +} + +void CheckStringColumnW(SQLHSTMT stmt, int colId, const std::wstring& expected) { + SQLWCHAR buf[1024]; + SQLLEN bufLen = sizeof(buf) * ODBC::GetSqlWCharSize(); + + SQLRETURN ret = SQLGetData(stmt, colId, SQL_C_WCHAR, buf, bufLen, &bufLen); + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_GT(bufLen, 0); + + // returned bufLen is in bytes so convert to length in characters + size_t charCount = static_cast(bufLen) / ODBC::GetSqlWCharSize(); + std::wstring returned(buf, buf + charCount); + + EXPECT_EQ(returned, expected); +} + +void CheckNullColumnW(SQLHSTMT stmt, int colId) { + SQLWCHAR buf[1024]; + SQLLEN bufLen = sizeof(buf); + + SQLRETURN ret = SQLGetData(stmt, colId, SQL_C_WCHAR, buf, bufLen, &bufLen); + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_EQ(bufLen, SQL_NULL_DATA); +} + +void CheckIntColumn(SQLHSTMT stmt, int colId, const SQLINTEGER& expected) { + SQLINTEGER buf; + SQLLEN bufLen = sizeof(buf); + + SQLRETURN ret = SQLGetData(stmt, colId, SQL_C_LONG, &buf, sizeof(buf), &bufLen); + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_EQ(buf, expected); +} + +void CheckSmallIntColumn(SQLHSTMT stmt, int colId, const SQLSMALLINT& expected) { + SQLSMALLINT buf; + SQLLEN bufLen = sizeof(buf); + + SQLRETURN ret = SQLGetData(stmt, colId, SQL_C_SSHORT, &buf, sizeof(buf), &bufLen); + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_EQ(buf, expected); +} + +void ValidateFetch(SQLHSTMT stmt, SQLRETURN expectedReturn) { + SQLRETURN ret = SQLFetch(stmt); + + EXPECT_EQ(ret, expectedReturn); +} + +} // namespace arrow::flight::sql::odbc diff --git a/cpp/src/arrow/flight/sql/odbc/tests/odbc_test_suite.h b/cpp/src/arrow/flight/sql/odbc/tests/odbc_test_suite.h new file mode 100644 index 00000000000..dcd342a62c6 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/tests/odbc_test_suite.h @@ -0,0 +1,237 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/testing/gtest_util.h" +#include "arrow/util/io_util.h" +#include "arrow/util/utf8.h" + +#include "arrow/flight/server_middleware.h" +#include "arrow/flight/sql/client.h" +#include "arrow/flight/sql/example/sqlite_server.h" +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/encoding_utils.h" + +#ifdef _WIN32 +# include +#endif + +#include +#include +#include + +#include + +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/odbc_connection.h" + +// For DSN registration +#include "arrow/flight/sql/odbc/flight_sql/system_dsn.h" + +#define TEST_CONNECT_STR "ARROW_FLIGHT_SQL_ODBC_CONN" +#define TEST_DSN "Apache Arrow Flight SQL Test DSN" + +namespace arrow::flight::sql::odbc { +using driver::odbcabstraction::Connection; + +class FlightSQLODBCRemoteTestBase : public ::testing::Test { + public: + /// \brief Allocate environment and connection handles + void allocEnvConnHandles(SQLINTEGER odbc_ver = SQL_OV_ODBC3); + /// \brief Connect to Arrow Flight SQL server using connection string defined in + /// environment variable "ARROW_FLIGHT_SQL_ODBC_CONN", allocate statement handle. + /// Connects using ODBC Ver 3 by default + void connect(SQLINTEGER odbc_ver = SQL_OV_ODBC3); + /// \brief Connect to Arrow Flight SQL server using connection string + void connectWithString(std::string connection_str); + /// \brief Disconnect from server + void disconnect(); + /// \brief Get connection string from environment variable "ARROW_FLIGHT_SQL_ODBC_CONN" + std::string virtual getConnectionString(); + /// \brief Get invalid connection string based on connection string defined in + /// environment variable "ARROW_FLIGHT_SQL_ODBC_CONN" + std::string virtual getInvalidConnectionString(); + /// \brief Return a SQL query that selects all data types + std::wstring virtual getQueryAllDataTypes(); + + /** ODBC Environment. */ + SQLHENV env = 0; + + /** ODBC Connect. */ + SQLHDBC conn = 0; + + /** ODBC Statement. */ + SQLHSTMT stmt = 0; + + protected: + void SetUp() override; +}; + +static constexpr std::string_view kAuthHeader = "authorization"; +static constexpr std::string_view kBearerPrefix = "Bearer "; +static constexpr std::string_view test_token = "t0k3n"; + +std::string FindTokenInCallHeaders(const CallHeaders& incoming_headers); + +// A server middleware for validating incoming bearer header authentication. +class MockServerMiddleware : public ServerMiddleware { + public: + explicit MockServerMiddleware(const CallHeaders& incoming_headers, bool* isValid) + : isValid_(isValid) { + incoming_headers_ = incoming_headers; + } + + void SendingHeaders(AddCallHeaders* outgoing_headers) override; + + void CallCompleted(const Status& status) override {} + + std::string name() const override { return "MockServerMiddleware"; } + + private: + CallHeaders incoming_headers_; + bool* isValid_; +}; + +// Factory for base64 header authentication testing. +class MockServerMiddlewareFactory : public ServerMiddlewareFactory { + public: + MockServerMiddlewareFactory() : isValid_(false) {} + + Status StartCall(const CallInfo& info, const ServerCallContext& context, + std::shared_ptr* middleware) override; + + private: + bool isValid_; +}; + +class FlightSQLODBCMockTestBase : public FlightSQLODBCRemoteTestBase { + // Sets up a mock server for each test case + public: + /// \brief Get connection string for mock server + std::string getConnectionString() override; + /// \brief Get invalid connection string for mock server + std::string getInvalidConnectionString() override; + /// \brief Return a SQL query that selects all data types + std::wstring getQueryAllDataTypes() override; + + /// \brief Run a SQL query to create default table for table test cases + void CreateTestTables(); + + /// \brief run a SQL query to create a table with all data types + void CreateTableAllDataType(); + /// \brief run a SQL query to create a table with unicode name + void CreateUnicodeTable(); + + int port; + + protected: + void SetUp() override; + + void TearDown() override; + + private: + std::shared_ptr server; +}; + +template +class FlightSQLODBCTestBase : public T { + public: + using List = std::list; +}; + +using TestTypes = + ::testing::Types; +TYPED_TEST_SUITE(FlightSQLODBCTestBase, TestTypes); + +/** ODBC read buffer size. */ +enum { ODBC_BUFFER_SIZE = 1024 }; + +/// Compare ConnPropertyMap, key value is case-insensitive +bool compareConnPropertyMap(Connection::ConnPropertyMap map1, + Connection::ConnPropertyMap map2); + +/// Get error message from ODBC driver using SQLGetDiagRec +std::string GetOdbcErrorMessage(SQLSMALLINT handle_type, SQLHANDLE handle); + +static constexpr std::string_view error_state_01004 = "01004"; +static constexpr std::string_view error_state_01S07 = "01S07"; +static constexpr std::string_view error_state_01S02 = "01S02"; +static constexpr std::string_view error_state_07009 = "07009"; +static constexpr std::string_view error_state_08003 = "08003"; +static constexpr std::string_view error_state_22002 = "22002"; +static constexpr std::string_view error_state_24000 = "24000"; +static constexpr std::string_view error_state_28000 = "28000"; +static constexpr std::string_view error_state_HY000 = "HY000"; +static constexpr std::string_view error_state_HY004 = "HY004"; +static constexpr std::string_view error_state_HY009 = "HY009"; +static constexpr std::string_view error_state_HY010 = "HY010"; +static constexpr std::string_view error_state_HY017 = "HY017"; +static constexpr std::string_view error_state_HY024 = "HY024"; +static constexpr std::string_view error_state_HY090 = "HY090"; +static constexpr std::string_view error_state_HY091 = "HY091"; +static constexpr std::string_view error_state_HY092 = "HY092"; +static constexpr std::string_view error_state_HY106 = "HY106"; +static constexpr std::string_view error_state_HY114 = "HY114"; +static constexpr std::string_view error_state_HY118 = "HY118"; +static constexpr std::string_view error_state_HYC00 = "HYC00"; +static constexpr std::string_view error_state_S1004 = "S1004"; + +/// Verify ODBC Error State +void VerifyOdbcErrorState(SQLSMALLINT handle_type, SQLHANDLE handle, + std::string_view expected_state); + +/// \brief Write connection string into DSN +/// \param[in] connection_str the connection string. +/// \return true on success +bool writeDSN(std::string connection_str); + +/// \brief Write properties map into DSN +/// \param[in] properties map. +/// \return true on success +bool writeDSN(Connection::ConnPropertyMap properties); + +/// \brief Check wide char vector and convert into wstring +/// \param[in] strVal Vector of SQLWCHAR. +/// \param[in] strLen length of string, in bytes. +/// \return wstring +std::wstring ConvertToWString(const std::vector& strVal, SQLSMALLINT strLen); + +/// \brief Check wide string column. +/// \param[in] stmt Statement. +/// \param[in] colId Column ID to check. +/// \param[in] expected Expected value. +void CheckStringColumnW(SQLHSTMT stmt, int colId, const std::wstring& expected); + +/// \brief Check wide string column value is null. +/// \param[in] stmt Statement. +/// \param[in] colId Column ID to check. +void CheckNullColumnW(SQLHSTMT stmt, int colId); + +/// \brief Check int column. +/// \param[in] stmt Statement. +/// \param[in] colId Column ID to check. +/// \param[in] expected Expected value. +void CheckIntColumn(SQLHSTMT stmt, int colId, const SQLINTEGER& expected); + +/// \brief Check smallint column. +/// \param[in] stmt Statement. +/// \param[in] colId Column ID to check. +/// \param[in] expected Expected value. +void CheckSmallIntColumn(SQLHSTMT stmt, int colId, const SQLSMALLINT& expected); + +/// \brief Check sql return against expected. +/// \param[in] stmt Statement. +/// \param[in] expected Expected return. +void ValidateFetch(SQLHSTMT stmt, SQLRETURN expected); +} // namespace arrow::flight::sql::odbc diff --git a/cpp/src/arrow/flight/sql/odbc/tests/statement_attr_test.cc b/cpp/src/arrow/flight/sql/odbc/tests/statement_attr_test.cc new file mode 100644 index 00000000000..aff661dec00 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/tests/statement_attr_test.cc @@ -0,0 +1,893 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +#include "arrow/flight/sql/odbc/tests/odbc_test_suite.h" + +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/odbc_statement.h" +#include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/spi/statement.h" + +#ifdef _WIN32 +# include +#endif + +#include +#include +#include + +#include "gtest/gtest.h" + +namespace arrow::flight::sql::odbc { + +// Helper Functions + +// Validate SQLULEN return value +void validateGetStmtAttr(SQLHSTMT statement, SQLINTEGER attribute, + SQLULEN expected_value) { + SQLULEN value = 0; + SQLINTEGER stringLength = 0; + + SQLRETURN ret = + SQLGetStmtAttr(statement, attribute, &value, sizeof(value), &stringLength); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_EQ(value, expected_value); +} + +// Validate SQLLEN return value +void validateGetStmtAttr(SQLHSTMT statement, SQLINTEGER attribute, + SQLLEN expected_value) { + SQLLEN value = 0; + SQLINTEGER stringLength = 0; + + SQLRETURN ret = + SQLGetStmtAttr(statement, attribute, &value, sizeof(value), &stringLength); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_EQ(value, expected_value); +} + +// Validate SQLPOINTER return value +void validateGetStmtAttr(SQLHSTMT statement, SQLINTEGER attribute, + SQLPOINTER expected_value) { + SQLPOINTER value = nullptr; + SQLINTEGER stringLength = 0; + + SQLRETURN ret = + SQLGetStmtAttr(statement, attribute, &value, sizeof(value), &stringLength); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_EQ(value, expected_value); +} + +// Validate unsigned length SQLULEN return value is greater than +void validateGetStmtAttrGreaterThan(SQLHSTMT statement, SQLINTEGER attribute, + SQLULEN compared_value) { + SQLULEN value = 0; + SQLINTEGER stringLengthPtr; + + SQLRETURN ret = SQLGetStmtAttr(statement, attribute, &value, 0, &stringLengthPtr); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_GT(value, compared_value); +} + +// Validate error return value and code +void validateGetStmtAttrErrorCode(SQLHSTMT statement, SQLINTEGER attribute, + std::string_view error_code) { + SQLULEN value = 0; + SQLINTEGER stringLengthPtr; + + SQLRETURN ret = SQLGetStmtAttr(statement, attribute, &value, 0, &stringLengthPtr); + + EXPECT_EQ(ret, SQL_ERROR); + + VerifyOdbcErrorState(SQL_HANDLE_STMT, statement, error_code); +} + +// Validate return value for call to SQLSetStmtAttr with SQLULEN +void validateSetStmtAttr(SQLHSTMT statement, SQLINTEGER attribute, SQLULEN new_value) { + SQLINTEGER stringLengthPtr = sizeof(SQLULEN); + + SQLRETURN ret = SQLSetStmtAttr( + statement, attribute, reinterpret_cast(new_value), stringLengthPtr); + + EXPECT_EQ(ret, SQL_SUCCESS); +} + +// Validate return value for call to SQLSetStmtAttr with SQLLEN +void validateSetStmtAttr(SQLHSTMT statement, SQLINTEGER attribute, SQLLEN new_value) { + SQLINTEGER stringLengthPtr = sizeof(SQLLEN); + + SQLRETURN ret = SQLSetStmtAttr( + statement, attribute, reinterpret_cast(new_value), stringLengthPtr); + + EXPECT_EQ(ret, SQL_SUCCESS); +} + +// Validate return value for call to SQLSetStmtAttr with SQLPOINTER +void validateSetStmtAttr(SQLHSTMT statement, SQLINTEGER attribute, SQLPOINTER value) { + SQLRETURN ret = SQLSetStmtAttr(statement, attribute, value, 0); + + EXPECT_EQ(ret, SQL_SUCCESS); +} + +// Validate error return value and code +void validateSetStmtAttrErrorCode(SQLHSTMT statement, SQLINTEGER attribute, + SQLULEN new_value, std::string_view error_code) { + SQLINTEGER stringLengthPtr = sizeof(SQLULEN); + + SQLRETURN ret = SQLSetStmtAttr( + statement, attribute, reinterpret_cast(new_value), stringLengthPtr); + + EXPECT_EQ(ret, SQL_ERROR); + + VerifyOdbcErrorState(SQL_HANDLE_STMT, statement, error_code); +} + +// Test Cases + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetStmtAttrAppParamDesc) { + this->connect(); + + validateGetStmtAttrGreaterThan(this->stmt, SQL_ATTR_APP_PARAM_DESC, + static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetStmtAttrAppRowDesc) { + this->connect(); + + validateGetStmtAttrGreaterThan(this->stmt, SQL_ATTR_APP_ROW_DESC, + static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetStmtAttrAsyncEnable) { + this->connect(); + + validateGetStmtAttr(this->stmt, SQL_ATTR_ASYNC_ENABLE, + static_cast(SQL_ASYNC_ENABLE_OFF)); + + this->disconnect(); +} + +#ifdef SQL_ATTR_ASYNC_STMT_EVENT +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetStmtAttrAsyncStmtEventUnsupported) { + this->connect(); + + // Optional feature not implemented + validateGetStmtAttrErrorCode(this->stmt, SQL_ATTR_ASYNC_STMT_EVENT, error_state_HYC00); + + this->disconnect(); +} +#endif + +#ifdef SQL_ATTR_ASYNC_STMT_PCALLBACK +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetStmtAttrAsyncStmtPCCallbackUnsupported) { + this->connect(); + + // Optional feature not implemented + validateGetStmtAttrErrorCode(this->stmt, SQL_ATTR_ASYNC_STMT_PCALLBACK, + error_state_HYC00); + + this->disconnect(); +} +#endif + +#ifdef SQL_ATTR_ASYNC_STMT_PCONTEXT +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetStmtAttrAsyncStmtPCContextUnsupported) { + this->connect(); + + // Optional feature not implemented + validateGetStmtAttrErrorCode(this->stmt, SQL_ATTR_ASYNC_STMT_PCONTEXT, + error_state_HYC00); + + this->disconnect(); +} +#endif + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetStmtAttrConcurrency) { + this->connect(); + + validateGetStmtAttr(this->stmt, SQL_ATTR_CONCURRENCY, + static_cast(SQL_CONCUR_READ_ONLY)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetStmtAttrCursorScrollable) { + this->connect(); + + validateGetStmtAttr(this->stmt, SQL_ATTR_CURSOR_SCROLLABLE, + static_cast(SQL_NONSCROLLABLE)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetStmtAttrCursorSensitivity) { + this->connect(); + + validateGetStmtAttr(this->stmt, SQL_ATTR_CURSOR_SENSITIVITY, + static_cast(SQL_UNSPECIFIED)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetStmtAttrCursorType) { + this->connect(); + + validateGetStmtAttr(this->stmt, SQL_ATTR_CURSOR_TYPE, + static_cast(SQL_CURSOR_FORWARD_ONLY)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetStmtAttrEnableAutoIPD) { + this->connect(); + + validateGetStmtAttr(this->stmt, SQL_ATTR_ENABLE_AUTO_IPD, + static_cast(SQL_FALSE)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetStmtAttrFetchBookmarkPointer) { + this->connect(); + + validateGetStmtAttr(this->stmt, SQL_ATTR_FETCH_BOOKMARK_PTR, static_cast(NULL)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetStmtAttrIMPParamDesc) { + this->connect(); + + validateGetStmtAttrGreaterThan(this->stmt, SQL_ATTR_IMP_PARAM_DESC, + static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetStmtAttrIMPRowDesc) { + this->connect(); + + validateGetStmtAttrGreaterThan(this->stmt, SQL_ATTR_IMP_ROW_DESC, + static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetStmtAttrKeysetSize) { + this->connect(); + + validateGetStmtAttr(this->stmt, SQL_ATTR_KEYSET_SIZE, static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetStmtAttrMaxLength) { + this->connect(); + + validateGetStmtAttr(this->stmt, SQL_ATTR_MAX_LENGTH, static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetStmtAttrMaxRows) { + this->connect(); + + validateGetStmtAttr(this->stmt, SQL_ATTR_MAX_ROWS, static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetStmtAttrMetadataID) { + this->connect(); + + validateGetStmtAttr(this->stmt, SQL_ATTR_METADATA_ID, static_cast(SQL_FALSE)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetStmtAttrNoscan) { + this->connect(); + + validateGetStmtAttr(this->stmt, SQL_ATTR_NOSCAN, static_cast(SQL_NOSCAN_OFF)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetStmtAttrParamBindOffsetPtr) { + this->connect(); + + validateGetStmtAttr(this->stmt, SQL_ATTR_PARAM_BIND_OFFSET_PTR, + static_cast(nullptr)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetStmtAttrParamBindType) { + this->connect(); + + validateGetStmtAttr(this->stmt, SQL_ATTR_PARAM_BIND_TYPE, + static_cast(SQL_PARAM_BIND_BY_COLUMN)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetStmtAttrParamOperationPtr) { + this->connect(); + + validateGetStmtAttr(this->stmt, SQL_ATTR_PARAM_OPERATION_PTR, + static_cast(nullptr)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetStmtAttrParamStatusPtr) { + this->connect(); + + validateGetStmtAttr(this->stmt, SQL_ATTR_PARAM_STATUS_PTR, + static_cast(nullptr)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetStmtAttrParamsProcessedPtr) { + this->connect(); + + validateGetStmtAttr(this->stmt, SQL_ATTR_PARAMS_PROCESSED_PTR, + static_cast(nullptr)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetStmtAttrParamsetSize) { + this->connect(); + + validateGetStmtAttr(this->stmt, SQL_ATTR_PARAMSET_SIZE, static_cast(1)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetStmtAttrQueryTimeout) { + this->connect(); + + validateGetStmtAttr(this->stmt, SQL_ATTR_QUERY_TIMEOUT, static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetStmtAttrRetrieveData) { + this->connect(); + + validateGetStmtAttr(this->stmt, SQL_ATTR_RETRIEVE_DATA, + static_cast(SQL_RD_ON)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetStmtAttrRowArraySize) { + this->connect(); + + validateGetStmtAttr(this->stmt, SQL_ATTR_ROW_ARRAY_SIZE, static_cast(1)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetStmtAttrRowBindOffsetPtr) { + this->connect(); + + validateGetStmtAttr(this->stmt, SQL_ATTR_ROW_BIND_OFFSET_PTR, + static_cast(nullptr)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetStmtAttrRowBindType) { + this->connect(); + + validateGetStmtAttr(this->stmt, SQL_ATTR_ROW_BIND_TYPE, static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetStmtAttrRowNumber) { + this->connect(); + + std::wstring wsql = L"SELECT 1;"; + std::vector sql0(wsql.begin(), wsql.end()); + + SQLRETURN ret = + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size())); + + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLFetch(this->stmt); + + EXPECT_EQ(ret, SQL_SUCCESS); + + validateGetStmtAttr(this->stmt, SQL_ATTR_ROW_NUMBER, static_cast(1)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetStmtAttrRowOperationPtr) { + this->connect(); + + validateGetStmtAttr(this->stmt, SQL_ATTR_ROW_OPERATION_PTR, + static_cast(nullptr)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetStmtAttrRowStatusPtr) { + this->connect(); + + validateGetStmtAttr(this->stmt, SQL_ATTR_ROW_STATUS_PTR, + static_cast(nullptr)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetStmtAttrRowsFetchedPtr) { + this->connect(); + + validateGetStmtAttr(this->stmt, SQL_ATTR_ROWS_FETCHED_PTR, + static_cast(nullptr)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetStmtAttrSimulateCursor) { + this->connect(); + + validateGetStmtAttr(this->stmt, SQL_ATTR_SIMULATE_CURSOR, + static_cast(SQL_SC_UNIQUE)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetStmtAttrUseBookmarks) { + this->connect(); + + validateGetStmtAttr(this->stmt, SQL_ATTR_USE_BOOKMARKS, + static_cast(SQL_UB_OFF)); + + this->disconnect(); +} + +// This is a pre ODBC 3 attribute +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetStmtAttrRowsetSize) { + this->connect(); + + validateGetStmtAttr(this->stmt, SQL_ROWSET_SIZE, static_cast(1)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLSetStmtAttrAppParamDesc) { + SQLULEN app_param_desc = 0; + SQLINTEGER stringLengthPtr; + this->connect(); + + SQLRETURN ret = SQLGetStmtAttr(this->stmt, SQL_ATTR_APP_PARAM_DESC, &app_param_desc, 0, + &stringLengthPtr); + + EXPECT_EQ(ret, SQL_SUCCESS); + + validateSetStmtAttr(this->stmt, SQL_ATTR_APP_PARAM_DESC, static_cast(0)); + + validateSetStmtAttr(this->stmt, SQL_ATTR_APP_PARAM_DESC, + static_cast(app_param_desc)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLSetStmtAttrAppRowDesc) { + SQLULEN app_row_desc = 0; + SQLINTEGER stringLengthPtr; + this->connect(); + + SQLRETURN ret = SQLGetStmtAttr(this->stmt, SQL_ATTR_APP_ROW_DESC, &app_row_desc, 0, + &stringLengthPtr); + + EXPECT_EQ(ret, SQL_SUCCESS); + + validateSetStmtAttr(this->stmt, SQL_ATTR_APP_ROW_DESC, static_cast(0)); + + validateSetStmtAttr(this->stmt, SQL_ATTR_APP_ROW_DESC, + static_cast(app_row_desc)); + + this->disconnect(); +} + +#ifdef SQL_ATTR_ASYNC_ENABLE +TYPED_TEST(FlightSQLODBCTestBase, TestSQLSetStmtAttrAsyncEnableUnsupported) { + this->connect(); + + // Optional feature not implemented + validateSetStmtAttrErrorCode(this->stmt, SQL_ATTR_ASYNC_ENABLE, SQL_ASYNC_ENABLE_OFF, + error_state_HYC00); + + this->disconnect(); +} +#endif + +#ifdef SQL_ATTR_ASYNC_STMT_EVENT +TYPED_TEST(FlightSQLODBCTestBase, TestSQLSetStmtAttrAsyncStmtEventUnsupported) { + this->connect(); + + // Driver does not support asynchronous notification + validateSetStmtAttrErrorCode(this->stmt, SQL_ATTR_ASYNC_STMT_EVENT, 0, + error_state_HY118); + + this->disconnect(); +} +#endif + +#ifdef SQL_ATTR_ASYNC_STMT_PCALLBACK +TYPED_TEST(FlightSQLODBCTestBase, TestSQLSetStmtAttrAsyncStmtPCCallbackUnsupported) { + this->connect(); + + validateSetStmtAttrErrorCode(this->stmt, SQL_ATTR_ASYNC_STMT_PCALLBACK, 0, + error_state_HYC00); + + this->disconnect(); +} +#endif + +#ifdef SQL_ATTR_ASYNC_STMT_PCONTEXT +TYPED_TEST(FlightSQLODBCTestBase, TestSQLSetStmtAttrAsyncStmtPCContextUnsupported) { + this->connect(); + + // Optional feature not implemented + validateSetStmtAttrErrorCode(this->stmt, SQL_ATTR_ASYNC_STMT_PCONTEXT, 0, + error_state_HYC00); + + this->disconnect(); +} +#endif + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLSetStmtAttrConcurrency) { + this->connect(); + + validateSetStmtAttr(this->stmt, SQL_ATTR_CONCURRENCY, + static_cast(SQL_CONCUR_READ_ONLY)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLSetStmtAttrCursorScrollable) { + this->connect(); + + validateSetStmtAttr(this->stmt, SQL_ATTR_CURSOR_SCROLLABLE, + static_cast(SQL_NONSCROLLABLE)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLSetStmtAttrCursorSensitivity) { + this->connect(); + + validateSetStmtAttr(this->stmt, SQL_ATTR_CURSOR_SENSITIVITY, + static_cast(SQL_UNSPECIFIED)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLSetStmtAttrCursorType) { + this->connect(); + + validateSetStmtAttr(this->stmt, SQL_ATTR_CURSOR_TYPE, + static_cast(SQL_CURSOR_FORWARD_ONLY)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLSetStmtAttrEnableAutoIPD) { + this->connect(); + + validateSetStmtAttr(this->stmt, SQL_ATTR_ENABLE_AUTO_IPD, + static_cast(SQL_FALSE)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLSetStmtAttrFetchBookmarkPointer) { + this->connect(); + + validateSetStmtAttr(this->stmt, SQL_ATTR_FETCH_BOOKMARK_PTR, static_cast(NULL)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLSetStmtAttrIMPParamDesc) { + this->connect(); + + // Invalid use of an automatically allocated descriptor handle + validateSetStmtAttrErrorCode(this->stmt, SQL_ATTR_IMP_PARAM_DESC, + static_cast(0), error_state_HY017); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLSetStmtAttrIMPRowDesc) { + this->connect(); + + // Invalid use of an automatically allocated descriptor handle + validateSetStmtAttrErrorCode(this->stmt, SQL_ATTR_IMP_ROW_DESC, static_cast(0), + error_state_HY017); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLSetStmtAttrKeysetSizeUnsupported) { + this->connect(); + + validateSetStmtAttr(this->stmt, SQL_ATTR_KEYSET_SIZE, static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLSetStmtAttrMaxLength) { + this->connect(); + + validateSetStmtAttr(this->stmt, SQL_ATTR_MAX_LENGTH, static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLSetStmtAttrMaxRows) { + this->connect(); + + // Cannot set read-only attribute + validateSetStmtAttrErrorCode(this->stmt, SQL_ATTR_MAX_ROWS, static_cast(0), + error_state_HY092); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLSetStmtAttrMetadataID) { + this->connect(); + + validateSetStmtAttr(this->stmt, SQL_ATTR_METADATA_ID, static_cast(SQL_FALSE)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLSetStmtAttrNoscan) { + this->connect(); + + validateSetStmtAttr(this->stmt, SQL_ATTR_NOSCAN, static_cast(SQL_NOSCAN_OFF)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLSetStmtAttrParamBindOffsetPtr) { + this->connect(); + + SQLULEN offset = 1000; + + validateSetStmtAttr(this->stmt, SQL_ATTR_PARAM_BIND_OFFSET_PTR, + static_cast(&offset)); + + validateGetStmtAttr(this->stmt, SQL_ATTR_PARAM_BIND_OFFSET_PTR, + static_cast(&offset)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLSetStmtAttrParamBindType) { + this->connect(); + + validateSetStmtAttr(this->stmt, SQL_ATTR_PARAM_BIND_TYPE, + static_cast(SQL_PARAM_BIND_BY_COLUMN)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLSetStmtAttrParamOperationPtr) { + this->connect(); + + constexpr SQLULEN param_set_size = 4; + SQLUSMALLINT param_operations[param_set_size] = {SQL_PARAM_PROCEED, SQL_PARAM_IGNORE, + SQL_PARAM_PROCEED, SQL_PARAM_IGNORE}; + + validateSetStmtAttr(this->stmt, SQL_ATTR_PARAM_OPERATION_PTR, + static_cast(param_operations)); + + validateGetStmtAttr(this->stmt, SQL_ATTR_PARAM_OPERATION_PTR, + static_cast(param_operations)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLSetStmtAttrParamStatusPtr) { + this->connect(); + + // Driver does not support parameters, so just check array can be saved/retrieved + constexpr SQLULEN param_status_size = 4; + SQLUSMALLINT param_status[param_status_size] = {SQL_PARAM_PROCEED, SQL_PARAM_IGNORE, + SQL_PARAM_PROCEED, SQL_PARAM_IGNORE}; + + validateSetStmtAttr(this->stmt, SQL_ATTR_PARAM_STATUS_PTR, + static_cast(param_status)); + + validateGetStmtAttr(this->stmt, SQL_ATTR_PARAM_STATUS_PTR, + static_cast(param_status)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLSetStmtAttrParamsProcessedPtr) { + this->connect(); + + SQLULEN processed_count = 0; + + validateSetStmtAttr(this->stmt, SQL_ATTR_PARAMS_PROCESSED_PTR, + static_cast(&processed_count)); + + validateGetStmtAttr(this->stmt, SQL_ATTR_PARAMS_PROCESSED_PTR, + static_cast(&processed_count)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLSetStmtAttrParamsetSize) { + this->connect(); + + validateSetStmtAttr(this->stmt, SQL_ATTR_PARAMSET_SIZE, static_cast(1)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLSetStmtAttrQueryTimeout) { + this->connect(); + + validateSetStmtAttr(this->stmt, SQL_ATTR_QUERY_TIMEOUT, static_cast(1)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLSetStmtAttrRetrieveData) { + this->connect(); + + validateSetStmtAttr(this->stmt, SQL_ATTR_RETRIEVE_DATA, + static_cast(SQL_RD_ON)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLSetStmtAttrRowArraySize) { + this->connect(); + + validateSetStmtAttr(this->stmt, SQL_ATTR_ROW_ARRAY_SIZE, static_cast(1)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLSetStmtAttrRowBindOffsetPtr) { + this->connect(); + + SQLULEN offset = 1000; + + validateSetStmtAttr(this->stmt, SQL_ATTR_ROW_BIND_OFFSET_PTR, + static_cast(&offset)); + + validateGetStmtAttr(this->stmt, SQL_ATTR_ROW_BIND_OFFSET_PTR, + static_cast(&offset)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLSetStmtAttrRowBindType) { + this->connect(); + + validateSetStmtAttr(this->stmt, SQL_ATTR_ROW_BIND_TYPE, static_cast(0)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLSetStmtAttrRowNumber) { + this->connect(); + + // Cannot set read-only attribute + validateSetStmtAttrErrorCode(this->stmt, SQL_ATTR_ROW_NUMBER, static_cast(0), + error_state_HY092); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLSetStmtAttrRowOperationPtr) { + this->connect(); + + constexpr SQLULEN param_set_size = 4; + SQLUSMALLINT row_operations[param_set_size] = {SQL_ROW_PROCEED, SQL_ROW_IGNORE, + SQL_ROW_PROCEED, SQL_ROW_IGNORE}; + + validateSetStmtAttr(this->stmt, SQL_ATTR_ROW_OPERATION_PTR, + static_cast(row_operations)); + + validateGetStmtAttr(this->stmt, SQL_ATTR_ROW_OPERATION_PTR, + static_cast(row_operations)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLSetStmtAttrRowStatusPtr) { + this->connect(); + + constexpr SQLULEN row_status_size = 4; + SQLUSMALLINT values[4] = {0, 0, 0, 0}; + + validateSetStmtAttr(this->stmt, SQL_ATTR_ROW_STATUS_PTR, + static_cast(values)); + + validateGetStmtAttr(this->stmt, SQL_ATTR_ROW_STATUS_PTR, + static_cast(values)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLSetStmtAttrRowsFetchedPtr) { + this->connect(); + + SQLULEN rows_fetched = 1; + + validateSetStmtAttr(this->stmt, SQL_ATTR_ROWS_FETCHED_PTR, + static_cast(&rows_fetched)); + + validateGetStmtAttr(this->stmt, SQL_ATTR_ROWS_FETCHED_PTR, + static_cast(&rows_fetched)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLSetStmtAttrSimulateCursor) { + this->connect(); + + validateSetStmtAttr(this->stmt, SQL_ATTR_SIMULATE_CURSOR, + static_cast(SQL_SC_UNIQUE)); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLSetStmtAttrUseBookmarks) { + this->connect(); + + validateSetStmtAttr(this->stmt, SQL_ATTR_USE_BOOKMARKS, + static_cast(SQL_UB_OFF)); + + this->disconnect(); +} + +// This is a pre ODBC 3 attribute +TYPED_TEST(FlightSQLODBCTestBase, TestSQLSetStmtAttrRowsetSize) { + this->connect(); + + validateSetStmtAttr(this->stmt, SQL_ROWSET_SIZE, static_cast(1)); + + this->disconnect(); +} + +} // namespace arrow::flight::sql::odbc diff --git a/cpp/src/arrow/flight/sql/odbc/tests/statement_test.cc b/cpp/src/arrow/flight/sql/odbc/tests/statement_test.cc new file mode 100644 index 00000000000..0d255101db3 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/tests/statement_test.cc @@ -0,0 +1,2603 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +#include "arrow/flight/sql/odbc/tests/odbc_test_suite.h" + +#ifdef _WIN32 +# include +#endif + +#include +#include +#include + +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace arrow::flight::sql::odbc { +TYPED_TEST(FlightSQLODBCTestBase, TestSQLExecDirectSimpleQuery) { + this->connect(); + + std::wstring wsql = L"SELECT 1;"; + std::vector sql0(wsql.begin(), wsql.end()); + + SQLRETURN ret = + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size())); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + SQLINTEGER val; + + ret = SQLGetData(this->stmt, 1, SQL_C_LONG, &val, 0, 0); + + EXPECT_EQ(ret, SQL_SUCCESS); + // Verify 1 is returned + EXPECT_EQ(val, 1); + + ret = SQLFetch(this->stmt); + + EXPECT_EQ(ret, SQL_NO_DATA); + + ret = SQLGetData(this->stmt, 1, SQL_C_LONG, &val, 0, 0); + + EXPECT_EQ(ret, SQL_ERROR); + // Invalid cursor state + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, error_state_24000); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLExecDirectInvalidQuery) { + this->connect(); + + std::wstring wsql = L"SELECT;"; + std::vector sql0(wsql.begin(), wsql.end()); + + SQLRETURN ret = + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size())); + + EXPECT_EQ(ret, SQL_ERROR); + // ODBC provides generic error code HY000 to all statement errors + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, error_state_HY000); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLExecuteSimpleQuery) { + this->connect(); + + std::wstring wsql = L"SELECT 1;"; + std::vector sql0(wsql.begin(), wsql.end()); + + SQLRETURN ret = SQLPrepare(this->stmt, &sql0[0], static_cast(sql0.size())); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLExecute(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Fetch data + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + SQLINTEGER val; + + ret = SQLGetData(this->stmt, 1, SQL_C_LONG, &val, 0, 0); + + EXPECT_EQ(ret, SQL_SUCCESS); + // Verify 1 is returned + EXPECT_EQ(val, 1); + + ret = SQLFetch(this->stmt); + + EXPECT_EQ(ret, SQL_NO_DATA); + + ret = SQLGetData(this->stmt, 1, SQL_C_LONG, &val, 0, 0); + + EXPECT_EQ(ret, SQL_ERROR); + // Invalid cursor state + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, error_state_24000); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLPrepareInvalidQuery) { + this->connect(); + + std::wstring wsql = L"SELECT;"; + std::vector sql0(wsql.begin(), wsql.end()); + + SQLRETURN ret = SQLPrepare(this->stmt, &sql0[0], static_cast(sql0.size())); + + EXPECT_EQ(ret, SQL_ERROR); + // ODBC provides generic error code HY000 to all statement errors + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, error_state_HY000); + + ret = SQLExecute(this->stmt); + // Verify function sequence error state is returned + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, error_state_HY010); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLExecDirectDataQuery) { + this->connect(); + + std::wstring wsql = this->getQueryAllDataTypes(); + std::vector sql0(wsql.begin(), wsql.end()); + + SQLRETURN ret = + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size())); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Numeric Types + + // Signed Tiny Int + int8_t stiny_int_val; + SQLLEN buf_len = sizeof(stiny_int_val); + SQLLEN ind; + + ret = SQLGetData(this->stmt, 1, SQL_C_STINYINT, &stiny_int_val, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(stiny_int_val, std::numeric_limits::min()); + + ret = SQLGetData(this->stmt, 2, SQL_C_STINYINT, &stiny_int_val, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(stiny_int_val, std::numeric_limits::max()); + + // Unsigned Tiny Int + uint8_t utiny_int_val; + buf_len = sizeof(utiny_int_val); + + ret = SQLGetData(this->stmt, 3, SQL_C_UTINYINT, &utiny_int_val, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(utiny_int_val, std::numeric_limits::min()); + + ret = SQLGetData(this->stmt, 4, SQL_C_UTINYINT, &utiny_int_val, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(utiny_int_val, std::numeric_limits::max()); + + // Signed Small Int + int16_t ssmall_int_val; + buf_len = sizeof(ssmall_int_val); + + ret = SQLGetData(this->stmt, 5, SQL_C_SSHORT, &ssmall_int_val, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(ssmall_int_val, std::numeric_limits::min()); + + ret = SQLGetData(this->stmt, 6, SQL_C_SSHORT, &ssmall_int_val, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(ssmall_int_val, std::numeric_limits::max()); + + // Unsigned Small Int + uint16_t usmall_int_val; + buf_len = sizeof(usmall_int_val); + + ret = SQLGetData(this->stmt, 7, SQL_C_USHORT, &usmall_int_val, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(usmall_int_val, std::numeric_limits::min()); + + ret = SQLGetData(this->stmt, 8, SQL_C_USHORT, &usmall_int_val, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(usmall_int_val, std::numeric_limits::max()); + + // Signed Integer + SQLINTEGER slong_val; + buf_len = sizeof(slong_val); + + ret = SQLGetData(this->stmt, 9, SQL_C_SLONG, &slong_val, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(slong_val, std::numeric_limits::min()); + + ret = SQLGetData(this->stmt, 10, SQL_C_SLONG, &slong_val, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(slong_val, std::numeric_limits::max()); + + // Unsigned Integer + SQLUINTEGER ulong_val; + buf_len = sizeof(ulong_val); + + ret = SQLGetData(this->stmt, 11, SQL_C_ULONG, &ulong_val, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(ulong_val, std::numeric_limits::min()); + + ret = SQLGetData(this->stmt, 12, SQL_C_ULONG, &ulong_val, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(ulong_val, std::numeric_limits::max()); + + // Signed Big Int + SQLBIGINT sbig_int_val; + buf_len = sizeof(sbig_int_val); + + ret = SQLGetData(this->stmt, 13, SQL_C_SBIGINT, &sbig_int_val, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(sbig_int_val, std::numeric_limits::min()); + + ret = SQLGetData(this->stmt, 14, SQL_C_SBIGINT, &sbig_int_val, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(sbig_int_val, std::numeric_limits::max()); + + // Unsigned Big Int + SQLUBIGINT ubig_int_val; + buf_len = sizeof(ubig_int_val); + + ret = SQLGetData(this->stmt, 15, SQL_C_UBIGINT, &ubig_int_val, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(ubig_int_val, std::numeric_limits::min()); + + ret = SQLGetData(this->stmt, 16, SQL_C_UBIGINT, &ubig_int_val, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(ubig_int_val, std::numeric_limits::max()); + + // Decimal + SQL_NUMERIC_STRUCT decimal_val; + memset(&decimal_val, 0, sizeof(decimal_val)); + buf_len = sizeof(SQL_NUMERIC_STRUCT); + + ret = SQLGetData(this->stmt, 17, SQL_C_NUMERIC, &decimal_val, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + // Check for negative decimal_val value + EXPECT_EQ(decimal_val.sign, 0); + EXPECT_EQ(decimal_val.scale, 0); + EXPECT_EQ(decimal_val.precision, 38); + EXPECT_THAT(decimal_val.val, ::testing::ElementsAre(0xFF, 0xC9, 0x9A, 0x3B, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0)); + + memset(&decimal_val, 0, sizeof(decimal_val)); + ret = SQLGetData(this->stmt, 18, SQL_C_NUMERIC, &decimal_val, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + // Check for positive decimal_val value + EXPECT_EQ(decimal_val.sign, 1); + EXPECT_EQ(decimal_val.scale, 0); + EXPECT_EQ(decimal_val.precision, 38); + EXPECT_THAT(decimal_val.val, ::testing::ElementsAre(0xFF, 0xC9, 0x9A, 0x3B, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0)); + + // Float + float float_val; + buf_len = sizeof(float_val); + + ret = SQLGetData(this->stmt, 19, SQL_C_FLOAT, &float_val, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + // Get minimum negative float value + EXPECT_EQ(float_val, -std::numeric_limits::max()); + + ret = SQLGetData(this->stmt, 20, SQL_C_FLOAT, &float_val, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(float_val, std::numeric_limits::max()); + + // Double + SQLDOUBLE double_val; + buf_len = sizeof(double_val); + + ret = SQLGetData(this->stmt, 21, SQL_C_DOUBLE, &double_val, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + // Get minimum negative double value + EXPECT_EQ(double_val, -std::numeric_limits::max()); + + ret = SQLGetData(this->stmt, 22, SQL_C_DOUBLE, &double_val, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(double_val, std::numeric_limits::max()); + + // Bit + bool bit_val; + buf_len = sizeof(bit_val); + + ret = SQLGetData(this->stmt, 23, SQL_C_BIT, &bit_val, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(bit_val, false); + + ret = SQLGetData(this->stmt, 24, SQL_C_BIT, &bit_val, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(bit_val, true); + + // Characters + + // Char + SQLCHAR char_val[2]; + buf_len = sizeof(SQLCHAR) * 2; + + ret = SQLGetData(this->stmt, 25, SQL_C_CHAR, &char_val, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(char_val[0], 'Z'); + + // WChar + SQLWCHAR wchar_val[2]; + constexpr size_t wchar_size = driver::odbcabstraction::GetSqlWCharSize(); + buf_len = wchar_size * 2; + + ret = SQLGetData(this->stmt, 26, SQL_C_WCHAR, &wchar_val, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(wchar_val[0], L'你'); + + // WVarchar + SQLWCHAR wvarchar_val[3]; + buf_len = wchar_size * 3; + + ret = SQLGetData(this->stmt, 27, SQL_C_WCHAR, &wvarchar_val, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(wvarchar_val[0], L'你'); + EXPECT_EQ(wvarchar_val[1], L'好'); + + // varchar + SQLCHAR varchar_val[4]; + buf_len = sizeof(SQLCHAR) * 4; + + ret = SQLGetData(this->stmt, 28, SQL_C_CHAR, &varchar_val, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(varchar_val[0], 'X'); + EXPECT_EQ(varchar_val[1], 'Y'); + EXPECT_EQ(varchar_val[2], 'Z'); + + // Date and Timestamp + + // Date + SQL_DATE_STRUCT date_var{}; + buf_len = sizeof(date_var); + + ret = SQLGetData(this->stmt, 29, SQL_C_TYPE_DATE, &date_var, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + // Check min values for date. Min valid year is 1400. + EXPECT_EQ(date_var.day, 1); + EXPECT_EQ(date_var.month, 1); + EXPECT_EQ(date_var.year, 1400); + + ret = SQLGetData(this->stmt, 30, SQL_C_TYPE_DATE, &date_var, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + // Check max values for date. Max valid year is 9999. + EXPECT_EQ(date_var.day, 31); + EXPECT_EQ(date_var.month, 12); + EXPECT_EQ(date_var.year, 9999); + + // Timestamp + SQL_TIMESTAMP_STRUCT timestamp_var{}; + buf_len = sizeof(timestamp_var); + + ret = SQLGetData(this->stmt, 31, SQL_C_TYPE_TIMESTAMP, ×tamp_var, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + // Check min values for date. Min valid year is 1400. + EXPECT_EQ(timestamp_var.day, 1); + EXPECT_EQ(timestamp_var.month, 1); + EXPECT_EQ(timestamp_var.year, 1400); + EXPECT_EQ(timestamp_var.hour, 0); + EXPECT_EQ(timestamp_var.minute, 0); + EXPECT_EQ(timestamp_var.second, 0); + EXPECT_EQ(timestamp_var.fraction, 0); + + ret = SQLGetData(this->stmt, 32, SQL_C_TYPE_TIMESTAMP, ×tamp_var, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + // Check max values for date. Max valid year is 9999. + EXPECT_EQ(timestamp_var.day, 31); + EXPECT_EQ(timestamp_var.month, 12); + EXPECT_EQ(timestamp_var.year, 9999); + EXPECT_EQ(timestamp_var.hour, 23); + EXPECT_EQ(timestamp_var.minute, 59); + EXPECT_EQ(timestamp_var.second, 59); + EXPECT_EQ(timestamp_var.fraction, 0); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCRemoteTestBase, TestSQLExecDirectTimeQuery) { + // Mock server test is skipped due to limitation on the mock server. + // Time type from mock server does not include the fraction + this->connect(); + + std::wstring wsql = + LR"( + SELECT CAST(TIME '00:00:00' AS TIME) AS time_min, + CAST(TIME '23:59:59' AS TIME) AS time_max; + )"; + std::vector sql0(wsql.begin(), wsql.end()); + + SQLRETURN ret = + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size())); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + SQL_TIME_STRUCT time_var{}; + SQLLEN buf_len = sizeof(time_var); + SQLLEN ind; + + ret = SQLGetData(this->stmt, 1, SQL_C_TYPE_TIME, &time_var, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + // Check min values for time. + EXPECT_EQ(time_var.hour, 0); + EXPECT_EQ(time_var.minute, 0); + EXPECT_EQ(time_var.second, 0); + + ret = SQLGetData(this->stmt, 2, SQL_C_TYPE_TIME, &time_var, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + // Check max values for time. + EXPECT_EQ(time_var.hour, 23); + EXPECT_EQ(time_var.minute, 59); + EXPECT_EQ(time_var.second, 59); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLExecDirectVarbinaryQuery) { + // Have binary test on mock test base as remote test servers tend to have different + // formats for binary data + this->connect(); + + std::wstring wsql = L"SELECT X'ABCDEF' AS c_varbinary;"; + std::vector sql0(wsql.begin(), wsql.end()); + + SQLRETURN ret = + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size())); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + // varbinary + std::vector varbinary_val(3); + SQLLEN buf_len = varbinary_val.size(); + SQLLEN ind; + ret = SQLGetData(this->stmt, 1, SQL_C_BINARY, &varbinary_val[0], buf_len, &ind); + EXPECT_EQ(varbinary_val[0], '\xAB'); + EXPECT_EQ(varbinary_val[1], '\xCD'); + EXPECT_EQ(varbinary_val[2], '\xEF'); + + this->disconnect(); +} + +// Tests with SQL_C_DEFAULT as the target type + +TEST_F(FlightSQLODBCRemoteTestBase, TestSQLExecDirectDataQueryDefaultType) { + // Test with default types. Only testing target types supported by server. + this->connect(); + + std::wstring wsql = this->getQueryAllDataTypes(); + std::vector sql0(wsql.begin(), wsql.end()); + + SQLRETURN ret = + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size())); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Numeric Types + // Signed Integer + SQLINTEGER slong_val; + SQLLEN buf_len = sizeof(slong_val); + SQLLEN ind; + + ret = SQLGetData(this->stmt, 9, SQL_C_DEFAULT, &slong_val, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(slong_val, std::numeric_limits::min()); + + ret = SQLGetData(this->stmt, 10, SQL_C_DEFAULT, &slong_val, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(slong_val, std::numeric_limits::max()); + + // Signed Big Int + SQLBIGINT sbig_int_val; + buf_len = sizeof(sbig_int_val); + + ret = SQLGetData(this->stmt, 13, SQL_C_DEFAULT, &sbig_int_val, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(sbig_int_val, std::numeric_limits::min()); + + ret = SQLGetData(this->stmt, 14, SQL_C_DEFAULT, &sbig_int_val, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(sbig_int_val, std::numeric_limits::max()); + + // Decimal + SQL_NUMERIC_STRUCT decimal_val; + memset(&decimal_val, 0, sizeof(decimal_val)); + buf_len = sizeof(SQL_NUMERIC_STRUCT); + + ret = SQLGetData(this->stmt, 17, SQL_C_DEFAULT, &decimal_val, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + // Check for negative decimal_val value + EXPECT_EQ(decimal_val.sign, 0); + EXPECT_EQ(decimal_val.scale, 0); + EXPECT_EQ(decimal_val.precision, 38); + EXPECT_THAT(decimal_val.val, ::testing::ElementsAre(0xFF, 0xC9, 0x9A, 0x3B, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0)); + + memset(&decimal_val, 0, sizeof(decimal_val)); + ret = SQLGetData(this->stmt, 18, SQL_C_DEFAULT, &decimal_val, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + // Check for positive decimal_val value + EXPECT_EQ(decimal_val.sign, 1); + EXPECT_EQ(decimal_val.scale, 0); + EXPECT_EQ(decimal_val.precision, 38); + EXPECT_THAT(decimal_val.val, ::testing::ElementsAre(0xFF, 0xC9, 0x9A, 0x3B, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0)); + + // Float + float float_val; + buf_len = sizeof(float_val); + + ret = SQLGetData(this->stmt, 19, SQL_C_DEFAULT, &float_val, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + // Get minimum negative float value + EXPECT_EQ(float_val, -std::numeric_limits::max()); + + ret = SQLGetData(this->stmt, 20, SQL_C_DEFAULT, &float_val, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(float_val, std::numeric_limits::max()); + + // Double + SQLDOUBLE double_val; + buf_len = sizeof(double_val); + + ret = SQLGetData(this->stmt, 21, SQL_C_DEFAULT, &double_val, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + // Get minimum negative double value + EXPECT_EQ(double_val, -std::numeric_limits::max()); + + ret = SQLGetData(this->stmt, 22, SQL_C_DEFAULT, &double_val, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(double_val, std::numeric_limits::max()); + + // Bit + bool bit_val; + buf_len = sizeof(bit_val); + + ret = SQLGetData(this->stmt, 23, SQL_C_DEFAULT, &bit_val, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(bit_val, false); + + ret = SQLGetData(this->stmt, 24, SQL_C_DEFAULT, &bit_val, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(bit_val, true); + + // Characters + + // Char will be fetched as wchar by default + SQLWCHAR wchar_val[2]; + constexpr size_t wchar_size = driver::odbcabstraction::GetSqlWCharSize(); + buf_len = wchar_size * 2; + + ret = SQLGetData(this->stmt, 25, SQL_C_DEFAULT, &wchar_val, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(wchar_val[0], L'Z'); + + // WChar + SQLWCHAR wchar_val2[2]; + buf_len = wchar_size * 2; + ret = SQLGetData(this->stmt, 26, SQL_C_DEFAULT, &wchar_val2, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(wchar_val2[0], L'你'); + + // WVarchar + SQLWCHAR wvarchar_val[3]; + buf_len = wchar_size * 3; + + ret = SQLGetData(this->stmt, 27, SQL_C_DEFAULT, &wvarchar_val, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(wvarchar_val[0], L'你'); + EXPECT_EQ(wvarchar_val[1], L'好'); + + // Varchar will be fetched as WVarchar by default + SQLWCHAR wvarchar_val2[4]; + buf_len = wchar_size * 4; + + ret = SQLGetData(this->stmt, 28, SQL_C_DEFAULT, &wvarchar_val2, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(wvarchar_val2[0], L'X'); + EXPECT_EQ(wvarchar_val2[1], L'Y'); + EXPECT_EQ(wvarchar_val2[2], L'Z'); + + // Date and Timestamp + + // Date + SQL_DATE_STRUCT date_var{}; + buf_len = sizeof(date_var); + + ret = SQLGetData(this->stmt, 29, SQL_C_DEFAULT, &date_var, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + // Check min values for date. Min valid year is 1400. + EXPECT_EQ(date_var.day, 1); + EXPECT_EQ(date_var.month, 1); + EXPECT_EQ(date_var.year, 1400); + + ret = SQLGetData(this->stmt, 30, SQL_C_DEFAULT, &date_var, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + // Check max values for date. Max valid year is 9999. + EXPECT_EQ(date_var.day, 31); + EXPECT_EQ(date_var.month, 12); + EXPECT_EQ(date_var.year, 9999); + + // Timestamp + SQL_TIMESTAMP_STRUCT timestamp_var{}; + buf_len = sizeof(timestamp_var); + + ret = SQLGetData(this->stmt, 31, SQL_C_DEFAULT, ×tamp_var, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + // Check min values for date. Min valid year is 1400. + EXPECT_EQ(timestamp_var.day, 1); + EXPECT_EQ(timestamp_var.month, 1); + EXPECT_EQ(timestamp_var.year, 1400); + EXPECT_EQ(timestamp_var.hour, 0); + EXPECT_EQ(timestamp_var.minute, 0); + EXPECT_EQ(timestamp_var.second, 0); + EXPECT_EQ(timestamp_var.fraction, 0); + + ret = SQLGetData(this->stmt, 32, SQL_C_DEFAULT, ×tamp_var, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + // Check max values for date. Max valid year is 9999. + EXPECT_EQ(timestamp_var.day, 31); + EXPECT_EQ(timestamp_var.month, 12); + EXPECT_EQ(timestamp_var.year, 9999); + EXPECT_EQ(timestamp_var.hour, 23); + EXPECT_EQ(timestamp_var.minute, 59); + EXPECT_EQ(timestamp_var.second, 59); + EXPECT_EQ(timestamp_var.fraction, 0); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCRemoteTestBase, TestSQLExecDirectTimeQueryDefaultType) { + // Mock server test is skipped due to limitation on the mock server. + // Time type from mock server does not include the fraction + this->connect(); + + std::wstring wsql = + LR"( + SELECT CAST(TIME '00:00:00' AS TIME) AS time_min, + CAST(TIME '23:59:59' AS TIME) AS time_max; + )"; + std::vector sql0(wsql.begin(), wsql.end()); + + SQLRETURN ret = + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size())); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + SQL_TIME_STRUCT time_var{}; + SQLLEN buf_len = sizeof(time_var); + SQLLEN ind; + + ret = SQLGetData(this->stmt, 1, SQL_C_DEFAULT, &time_var, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + // Check min values for time. + EXPECT_EQ(time_var.hour, 0); + EXPECT_EQ(time_var.minute, 0); + EXPECT_EQ(time_var.second, 0); + + ret = SQLGetData(this->stmt, 2, SQL_C_DEFAULT, &time_var, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + // Check max values for time. + EXPECT_EQ(time_var.hour, 23); + EXPECT_EQ(time_var.minute, 59); + EXPECT_EQ(time_var.second, 59); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCRemoteTestBase, TestSQLExecDirectVarbinaryQueryDefaultType) { + // Limitation on mock test server prevents SQL_C_DEFAULT from working properly. + // Mock server has type `DENSE_UNION` for varbinary. + // Note that not all remote servers support "from_hex" function + this->connect(); + + std::wstring wsql = L"SELECT from_hex('ABCDEF') AS c_varbinary;"; + std::vector sql0(wsql.begin(), wsql.end()); + + SQLRETURN ret = + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size())); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + // varbinary + std::vector varbinary_val(3); + SQLLEN buf_len = varbinary_val.size(); + SQLLEN ind; + ret = SQLGetData(this->stmt, 1, SQL_C_DEFAULT, &varbinary_val[0], buf_len, &ind); + EXPECT_EQ(varbinary_val[0], '\xAB'); + EXPECT_EQ(varbinary_val[1], '\xCD'); + EXPECT_EQ(varbinary_val[2], '\xEF'); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLExecDirectGuidQueryUnsupported) { + this->connect(); + + // Query GUID as string as SQLite does not support GUID + std::wstring wsql = L"SELECT 'C77313CF-4E08-47CE-B6DF-94DD2FCF3541' AS guid;"; + std::vector sql0(wsql.begin(), wsql.end()); + + SQLRETURN ret = + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size())); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + SQLGUID guid_var; + SQLLEN buf_len = sizeof(guid_var); + SQLLEN ind; + + ret = SQLGetData(this->stmt, 1, SQL_C_GUID, &guid_var, buf_len, &ind); + + EXPECT_EQ(ret, SQL_ERROR); + // GUID is not supported by ODBC + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, error_state_HY000); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLExecDirectRowFetching) { + this->connect(); + + std::wstring wsql = + LR"( + SELECT 1 AS small_table + UNION ALL + SELECT 2 + UNION ALL + SELECT 3; + )"; + std::vector sql0(wsql.begin(), wsql.end()); + + SQLRETURN ret = + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size())); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Fetch row 1 + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + SQLINTEGER val; + SQLLEN buf_len = sizeof(val); + SQLLEN ind; + + ret = SQLGetData(this->stmt, 1, SQL_C_LONG, &val, buf_len, &ind); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Verify 1 is returned + EXPECT_EQ(val, 1); + + // Fetch row 2 + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLGetData(this->stmt, 1, SQL_C_LONG, &val, buf_len, &ind); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Verify 2 is returned + EXPECT_EQ(val, 2); + + // Fetch row 3 + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLGetData(this->stmt, 1, SQL_C_LONG, &val, buf_len, &ind); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Verify 3 is returned + EXPECT_EQ(val, 3); + + // Verify result set has no more data beyond row 3 + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_NO_DATA); + + ret = SQLGetData(this->stmt, 1, SQL_C_LONG, &val, 0, &ind); + EXPECT_EQ(ret, SQL_ERROR); + + // Invalid cursor state + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, error_state_24000); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLFetchScrollRowFetching) { + this->connect(); + + SQLLEN rows_fetched; + SQLRETURN ret = SQLSetStmtAttr(this->stmt, SQL_ATTR_ROWS_FETCHED_PTR, &rows_fetched, 0); + + std::wstring wsql = + LR"( + SELECT 1 AS small_table + UNION ALL + SELECT 2 + UNION ALL + SELECT 3; + )"; + std::vector sql0(wsql.begin(), wsql.end()); + + ret = SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size())); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Fetch row 1 + ret = SQLFetchScroll(this->stmt, SQL_FETCH_NEXT, 0); + EXPECT_EQ(ret, SQL_SUCCESS); + + SQLINTEGER val; + SQLLEN buf_len = sizeof(val); + SQLLEN ind; + + ret = SQLGetData(this->stmt, 1, SQL_C_LONG, &val, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + // Verify 1 is returned + EXPECT_EQ(val, 1); + // Verify 1 row is fetched + EXPECT_EQ(rows_fetched, 1); + + // Fetch row 2 + ret = SQLFetchScroll(this->stmt, SQL_FETCH_NEXT, 0); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLGetData(this->stmt, 1, SQL_C_LONG, &val, buf_len, &ind); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Verify 2 is returned + EXPECT_EQ(val, 2); + // Verify 1 row is fetched in the last SQLFetchScroll call + EXPECT_EQ(rows_fetched, 1); + + // Fetch row 3 + ret = SQLFetchScroll(this->stmt, SQL_FETCH_NEXT, 0); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLGetData(this->stmt, 1, SQL_C_LONG, &val, buf_len, &ind); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Verify 3 is returned + EXPECT_EQ(val, 3); + // Verify 1 row is fetched in the last SQLFetchScroll call + EXPECT_EQ(rows_fetched, 1); + + // Verify result set has no more data beyond row 3 + ret = SQLFetchScroll(this->stmt, SQL_FETCH_NEXT, 0); + EXPECT_EQ(ret, SQL_NO_DATA); + + ret = SQLGetData(this->stmt, 1, SQL_C_LONG, &val, 0, &ind); + + EXPECT_EQ(ret, SQL_ERROR); + // Invalid cursor state + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, error_state_24000); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLFetchScrollUnsupportedOrientation) { + // SQL_FETCH_PRIOR is the only supported fetch orientation. + this->connect(); + + std::wstring wsql = L"SELECT 1;"; + std::vector sql0(wsql.begin(), wsql.end()); + + SQLRETURN ret = + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size())); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLFetchScroll(this->stmt, SQL_FETCH_PRIOR, 0); + EXPECT_EQ(ret, SQL_ERROR); + + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, error_state_HYC00); + + SQLLEN fetch_offset = 1; + + ret = SQLFetchScroll(this->stmt, SQL_FETCH_RELATIVE, fetch_offset); + EXPECT_EQ(ret, SQL_ERROR); + + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, error_state_HYC00); + + ret = SQLFetchScroll(this->stmt, SQL_FETCH_ABSOLUTE, fetch_offset); + EXPECT_EQ(ret, SQL_ERROR); + + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, error_state_HYC00); + + ret = SQLFetchScroll(this->stmt, SQL_FETCH_FIRST, 0); + EXPECT_EQ(ret, SQL_ERROR); + + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, error_state_HYC00); + + ret = SQLFetchScroll(this->stmt, SQL_FETCH_LAST, 0); + EXPECT_EQ(ret, SQL_ERROR); + + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, error_state_HYC00); + + ret = SQLFetchScroll(this->stmt, SQL_FETCH_BOOKMARK, fetch_offset); + EXPECT_EQ(ret, SQL_ERROR); + + // DM returns state HY106 for SQL_FETCH_BOOKMARK + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, error_state_HY106); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLExecDirectVarcharTruncation) { + this->connect(); + + std::wstring wsql = L"SELECT 'VERY LONG STRING here' AS string_col;"; + std::vector sql0(wsql.begin(), wsql.end()); + + SQLRETURN ret = + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size())); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + const int len = 17; + SQLCHAR char_val[len]; + SQLLEN buf_len = sizeof(SQLCHAR) * len; + SQLLEN ind; + + ret = SQLGetData(this->stmt, 1, SQL_C_CHAR, &char_val, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS_WITH_INFO); + // Verify string truncation is reported + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, error_state_01004); + + EXPECT_EQ(ODBC::SqlStringToString(char_val), std::string("VERY LONG STRING")); + EXPECT_EQ(ind, 21); + + // Fetch same column 2nd time + const int len2 = 2; + SQLCHAR char_val2[len2]; + buf_len = sizeof(SQLCHAR) * len2; + + ret = SQLGetData(this->stmt, 1, SQL_C_CHAR, &char_val2, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS_WITH_INFO); + // Verify string truncation is reported + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, error_state_01004); + + EXPECT_EQ(ODBC::SqlStringToString(char_val2), std::string(" ")); + EXPECT_EQ(ind, 5); + + // Fetch same column 3rd time + const int len3 = 5; + SQLCHAR char_val3[len3]; + buf_len = sizeof(SQLCHAR) * len3; + + ret = SQLGetData(this->stmt, 1, SQL_C_CHAR, &char_val3, buf_len, &ind); + + // Verify that there is no more truncation reports. The full string has been fetched. + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_EQ(ODBC::SqlStringToString(char_val3), std::string("here")); + EXPECT_EQ(ind, 4); + + // Attempt to fetch data 4th time + SQLCHAR char_val4[len]; + ret = SQLGetData(this->stmt, 1, SQL_C_CHAR, &char_val4, 0, &ind); + // Verify SQL_NO_DATA is returned + EXPECT_EQ(ret, SQL_NO_DATA); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLExecDirectWVarcharTruncation) { + this->connect(); + + std::wstring wsql = L"SELECT 'VERY LONG Unicode STRING 句子 here' AS wstring_col;"; + std::vector sql0(wsql.begin(), wsql.end()); + + SQLRETURN ret = + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size())); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + const int len = 28; + SQLWCHAR wchar_val[len]; + constexpr size_t wchar_size = driver::odbcabstraction::GetSqlWCharSize(); + SQLLEN buf_len = wchar_size * len; + SQLLEN ind; + + ret = SQLGetData(this->stmt, 1, SQL_C_WCHAR, &wchar_val, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS_WITH_INFO); + // Verify string truncation is reported + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, error_state_01004); + + EXPECT_EQ(std::wstring(wchar_val), std::wstring(L"VERY LONG Unicode STRING 句子")); + EXPECT_EQ(ind, 32 * wchar_size); + + // Fetch same column 2nd time + const int len2 = 2; + SQLWCHAR wchar_val2[len2]; + buf_len = wchar_size * len2; + + ret = SQLGetData(this->stmt, 1, SQL_C_WCHAR, &wchar_val2, buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS_WITH_INFO); + // Verify string truncation is reported + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, error_state_01004); + + EXPECT_EQ(std::wstring(wchar_val2), std::wstring(L" ")); + EXPECT_EQ(ind, 5 * wchar_size); + + // Fetch same column 3rd time + const int len3 = 5; + SQLWCHAR wchar_val3[len3]; + buf_len = wchar_size * len3; + + ret = SQLGetData(this->stmt, 1, SQL_C_WCHAR, &wchar_val3, buf_len, &ind); + + // Verify that there is no more truncation reports. The full string has been fetched. + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_EQ(std::wstring(wchar_val3), std::wstring(L"here")); + EXPECT_EQ(ind, 4 * wchar_size); + + // Attempt to fetch data 4th time + SQLWCHAR wchar_val4[len]; + ret = SQLGetData(this->stmt, 1, SQL_C_WCHAR, &wchar_val4, 0, &ind); + // Verify SQL_NO_DATA is returned + EXPECT_EQ(ret, SQL_NO_DATA); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLExecDirectVarbinaryTruncation) { + // Have binary test on mock test base as remote test servers tend to have different + // formats for binary data + this->connect(); + + std::wstring wsql = L"SELECT X'ABCDEFAB' AS c_varbinary;"; + std::vector sql0(wsql.begin(), wsql.end()); + + SQLRETURN ret = + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size())); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + // varbinary + std::vector varbinary_val(3); + SQLLEN buf_len = varbinary_val.size(); + SQLLEN ind; + ret = SQLGetData(this->stmt, 1, SQL_C_BINARY, &varbinary_val[0], buf_len, &ind); + // Verify binary truncation is reported + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, error_state_01004); + EXPECT_EQ(varbinary_val[0], '\xAB'); + EXPECT_EQ(varbinary_val[1], '\xCD'); + EXPECT_EQ(varbinary_val[2], '\xEF'); + EXPECT_EQ(ind, 4); + + // Fetch same column 2nd time + std::vector varbinary_val2(1); + buf_len = varbinary_val2.size(); + + ret = SQLGetData(this->stmt, 1, SQL_C_BINARY, &varbinary_val2[0], buf_len, &ind); + + // Verify that there is no more truncation reports. The full binary has been fetched. + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_EQ(varbinary_val[0], '\xAB'); + EXPECT_EQ(ind, 1); + + // Attempt to fetch data 3rd time + std::vector varbinary_val3(1); + buf_len = varbinary_val3.size(); + ret = SQLGetData(this->stmt, 1, SQL_C_BINARY, &varbinary_val3[0], buf_len, &ind); + // Verify SQL_NO_DATA is returned + EXPECT_EQ(ret, SQL_NO_DATA); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLExecDirectFloatTruncation) { + // Test is disabled until float truncation is supported. + // GH-46985: return warning message instead of error on float truncation case + GTEST_SKIP(); + this->connect(); + + std::wstring wsql; + if constexpr (std::is_same_v) { + wsql = std::wstring(L"SELECT CAST(1.234 AS REAL) AS float_val"); + } else if constexpr (std::is_same_v) { + wsql = std::wstring(L"SELECT CAST(1.234 AS FLOAT) AS float_val"); + } + std::vector sql0(wsql.begin(), wsql.end()); + + SQLRETURN ret = + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size())); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + int16_t ssmall_int_val; + + ret = SQLGetData(this->stmt, 1, SQL_C_SSHORT, &ssmall_int_val, 0, 0); + EXPECT_EQ(ret, SQL_SUCCESS_WITH_INFO); + // Verify float truncation is reported + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, error_state_01S07); + + EXPECT_EQ(ssmall_int_val, 1); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCRemoteTestBase, TestSQLExecDirectNullQuery) { + // Limitation on mock test server prevents null from working properly, so use remote + // server instead. Mock server has type `DENSE_UNION` for null column data. + this->connect(); + + std::wstring wsql = L"SELECT null as null_col;"; + std::vector sql0(wsql.begin(), wsql.end()); + + SQLRETURN ret = + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size())); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + SQLINTEGER val; + SQLLEN ind; + + ret = SQLGetData(this->stmt, 1, SQL_C_LONG, &val, 0, &ind); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Verify SQL_NULL_DATA is returned for indicator + EXPECT_EQ(ind, SQL_NULL_DATA); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLExecDirectTruncationQueryNullIndicator) { + // Driver should not error out when indicator is null if the cell is non-null + // Have binary test on mock test base as remote test servers tend to have different + // formats for binary data + this->connect(); + + std::wstring wsql = + LR"( + SELECT 1, + 'VERY LONG STRING here' AS string_col, + 'VERY LONG Unicode STRING 句子 here' AS wstring_col, + X'ABCDEFAB' AS c_varbinary; + )"; + std::vector sql0(wsql.begin(), wsql.end()); + + SQLRETURN ret = + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size())); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + SQLINTEGER val; + ret = SQLGetData(this->stmt, 1, SQL_C_LONG, &val, 0, 0); + + EXPECT_EQ(ret, SQL_SUCCESS); + // Verify 1 is returned for non-truncation case. + EXPECT_EQ(val, 1); + + // Char + const int len = 17; + SQLCHAR char_val[len]; + SQLLEN buf_len = sizeof(SQLCHAR) * len; + + ret = SQLGetData(this->stmt, 2, SQL_C_CHAR, &char_val, buf_len, 0); + + EXPECT_EQ(ret, SQL_SUCCESS_WITH_INFO); + // Verify string truncation is reported + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, error_state_01004); + + // WChar + const int len2 = 28; + SQLWCHAR wchar_val[len2]; + constexpr size_t wchar_size = driver::odbcabstraction::GetSqlWCharSize(); + buf_len = wchar_size * len2; + + ret = SQLGetData(this->stmt, 3, SQL_C_WCHAR, &wchar_val, buf_len, 0); + + EXPECT_EQ(ret, SQL_SUCCESS_WITH_INFO); + // Verify string truncation is reported + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, error_state_01004); + + // varbinary + std::vector varbinary_val(3); + buf_len = varbinary_val.size(); + ret = SQLGetData(this->stmt, 4, SQL_C_BINARY, &varbinary_val[0], buf_len, 0); + // Verify binary truncation is reported + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, error_state_01004); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCRemoteTestBase, TestSQLExecDirectNullQueryNullIndicator) { + // Limitation on mock test server prevents null from working properly, so use remote + // server instead. Mock server has type `DENSE_UNION` for null column data. + this->connect(); + + std::wstring wsql = L"SELECT null as null_col;"; + std::vector sql0(wsql.begin(), wsql.end()); + + SQLRETURN ret = + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size())); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + SQLINTEGER val; + + ret = SQLGetData(this->stmt, 1, SQL_C_LONG, &val, 0, 0); + + EXPECT_EQ(ret, SQL_ERROR); + // Verify invalid null indicator is reported, as it is required + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, error_state_22002); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLExecDirectIgnoreInvalidBufLen) { + // Verify the driver ignores invalid buffer length for fixed data types + this->connect(); + + std::wstring wsql = this->getQueryAllDataTypes(); + std::vector sql0(wsql.begin(), wsql.end()); + + SQLRETURN ret = + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size())); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Numeric Types + + // Signed Tiny Int + int8_t stiny_int_val; + SQLLEN invalid_buf_len = -1; + SQLLEN ind; + + ret = SQLGetData(this->stmt, 1, SQL_C_STINYINT, &stiny_int_val, invalid_buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(stiny_int_val, std::numeric_limits::min()); + + ret = SQLGetData(this->stmt, 2, SQL_C_STINYINT, &stiny_int_val, invalid_buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(stiny_int_val, std::numeric_limits::max()); + + // Unsigned Tiny Int + uint8_t utiny_int_val; + + ret = SQLGetData(this->stmt, 3, SQL_C_UTINYINT, &utiny_int_val, invalid_buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(utiny_int_val, std::numeric_limits::min()); + + ret = SQLGetData(this->stmt, 4, SQL_C_UTINYINT, &utiny_int_val, invalid_buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(utiny_int_val, std::numeric_limits::max()); + + // Signed Small Int + int16_t ssmall_int_val; + + ret = SQLGetData(this->stmt, 5, SQL_C_SSHORT, &ssmall_int_val, invalid_buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(ssmall_int_val, std::numeric_limits::min()); + + ret = SQLGetData(this->stmt, 6, SQL_C_SSHORT, &ssmall_int_val, invalid_buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(ssmall_int_val, std::numeric_limits::max()); + + // Unsigned Small Int + uint16_t usmall_int_val; + + ret = SQLGetData(this->stmt, 7, SQL_C_USHORT, &usmall_int_val, invalid_buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(usmall_int_val, std::numeric_limits::min()); + + ret = SQLGetData(this->stmt, 8, SQL_C_USHORT, &usmall_int_val, invalid_buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(usmall_int_val, std::numeric_limits::max()); + + // Signed Integer + SQLINTEGER slong_val; + + ret = SQLGetData(this->stmt, 9, SQL_C_SLONG, &slong_val, invalid_buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(slong_val, std::numeric_limits::min()); + + ret = SQLGetData(this->stmt, 10, SQL_C_SLONG, &slong_val, invalid_buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(slong_val, std::numeric_limits::max()); + + // Unsigned Integer + SQLUINTEGER ulong_val; + + ret = SQLGetData(this->stmt, 11, SQL_C_ULONG, &ulong_val, invalid_buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(ulong_val, std::numeric_limits::min()); + + ret = SQLGetData(this->stmt, 12, SQL_C_ULONG, &ulong_val, invalid_buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(ulong_val, std::numeric_limits::max()); + + // Signed Big Int + SQLBIGINT sbig_int_val; + + ret = SQLGetData(this->stmt, 13, SQL_C_SBIGINT, &sbig_int_val, invalid_buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(sbig_int_val, std::numeric_limits::min()); + + ret = SQLGetData(this->stmt, 14, SQL_C_SBIGINT, &sbig_int_val, invalid_buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(sbig_int_val, std::numeric_limits::max()); + + // Unsigned Big Int + SQLUBIGINT ubig_int_val; + + ret = SQLGetData(this->stmt, 15, SQL_C_UBIGINT, &ubig_int_val, invalid_buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(ubig_int_val, std::numeric_limits::min()); + + ret = SQLGetData(this->stmt, 16, SQL_C_UBIGINT, &ubig_int_val, invalid_buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(ubig_int_val, std::numeric_limits::max()); + + // Decimal + SQL_NUMERIC_STRUCT decimal_val; + memset(&decimal_val, 0, sizeof(decimal_val)); + + ret = SQLGetData(this->stmt, 17, SQL_C_NUMERIC, &decimal_val, invalid_buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + // Check for negative decimal_val value + EXPECT_EQ(decimal_val.sign, 0); + EXPECT_EQ(decimal_val.scale, 0); + EXPECT_EQ(decimal_val.precision, 38); + EXPECT_THAT(decimal_val.val, ::testing::ElementsAre(0xFF, 0xC9, 0x9A, 0x3B, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0)); + + memset(&decimal_val, 0, sizeof(decimal_val)); + ret = SQLGetData(this->stmt, 18, SQL_C_NUMERIC, &decimal_val, invalid_buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + // Check for positive decimal_val value + EXPECT_EQ(decimal_val.sign, 1); + EXPECT_EQ(decimal_val.scale, 0); + EXPECT_EQ(decimal_val.precision, 38); + EXPECT_THAT(decimal_val.val, ::testing::ElementsAre(0xFF, 0xC9, 0x9A, 0x3B, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0)); + + // Float + float float_val; + + ret = SQLGetData(this->stmt, 19, SQL_C_FLOAT, &float_val, invalid_buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + // Get minimum negative float value + EXPECT_EQ(float_val, -std::numeric_limits::max()); + + ret = SQLGetData(this->stmt, 20, SQL_C_FLOAT, &float_val, invalid_buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(float_val, std::numeric_limits::max()); + + // Double + SQLDOUBLE double_val; + + ret = SQLGetData(this->stmt, 21, SQL_C_DOUBLE, &double_val, invalid_buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + // Get minimum negative double value + EXPECT_EQ(double_val, -std::numeric_limits::max()); + + ret = SQLGetData(this->stmt, 22, SQL_C_DOUBLE, &double_val, invalid_buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(double_val, std::numeric_limits::max()); + + // Bit + bool bit_val; + + ret = SQLGetData(this->stmt, 23, SQL_C_BIT, &bit_val, invalid_buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(bit_val, false); + + ret = SQLGetData(this->stmt, 24, SQL_C_BIT, &bit_val, invalid_buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(bit_val, true); + + // Date and Timestamp + + // Date + SQL_DATE_STRUCT date_var{}; + + ret = SQLGetData(this->stmt, 29, SQL_C_TYPE_DATE, &date_var, invalid_buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + // Check min values for date. Min valid year is 1400. + EXPECT_EQ(date_var.day, 1); + EXPECT_EQ(date_var.month, 1); + EXPECT_EQ(date_var.year, 1400); + + ret = SQLGetData(this->stmt, 30, SQL_C_TYPE_DATE, &date_var, invalid_buf_len, &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + // Check max values for date. Max valid year is 9999. + EXPECT_EQ(date_var.day, 31); + EXPECT_EQ(date_var.month, 12); + EXPECT_EQ(date_var.year, 9999); + + // Timestamp + SQL_TIMESTAMP_STRUCT timestamp_var{}; + + ret = SQLGetData(this->stmt, 31, SQL_C_TYPE_TIMESTAMP, ×tamp_var, invalid_buf_len, + &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + // Check min values for date. Min valid year is 1400. + EXPECT_EQ(timestamp_var.day, 1); + EXPECT_EQ(timestamp_var.month, 1); + EXPECT_EQ(timestamp_var.year, 1400); + EXPECT_EQ(timestamp_var.hour, 0); + EXPECT_EQ(timestamp_var.minute, 0); + EXPECT_EQ(timestamp_var.second, 0); + EXPECT_EQ(timestamp_var.fraction, 0); + + ret = SQLGetData(this->stmt, 32, SQL_C_TYPE_TIMESTAMP, ×tamp_var, invalid_buf_len, + &ind); + + EXPECT_EQ(ret, SQL_SUCCESS); + // Check max values for date. Max valid year is 9999. + EXPECT_EQ(timestamp_var.day, 31); + EXPECT_EQ(timestamp_var.month, 12); + EXPECT_EQ(timestamp_var.year, 9999); + EXPECT_EQ(timestamp_var.hour, 23); + EXPECT_EQ(timestamp_var.minute, 59); + EXPECT_EQ(timestamp_var.second, 59); + EXPECT_EQ(timestamp_var.fraction, 0); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLBindColDataQuery) { + this->connect(); + + // Numeric Types + + // Signed Tiny Int + int8_t stiny_int_val_min; + int8_t stiny_int_val_max; + SQLLEN buf_len = 0; + SQLLEN ind; + + SQLRETURN ret = + SQLBindCol(this->stmt, 1, SQL_C_STINYINT, &stiny_int_val_min, buf_len, &ind); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLBindCol(this->stmt, 2, SQL_C_STINYINT, &stiny_int_val_max, buf_len, &ind); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Unsigned Tiny Int + uint8_t utiny_int_val_min; + uint8_t utiny_int_val_max; + + ret = SQLBindCol(this->stmt, 3, SQL_C_UTINYINT, &utiny_int_val_min, buf_len, &ind); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLBindCol(this->stmt, 4, SQL_C_UTINYINT, &utiny_int_val_max, buf_len, &ind); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Signed Small Int + int16_t ssmall_int_val_min; + int16_t ssmall_int_val_max; + + ret = SQLBindCol(this->stmt, 5, SQL_C_SSHORT, &ssmall_int_val_min, buf_len, &ind); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLBindCol(this->stmt, 6, SQL_C_SSHORT, &ssmall_int_val_max, buf_len, &ind); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Unsigned Small Int + uint16_t usmall_int_val_min; + uint16_t usmall_int_val_max; + + ret = SQLBindCol(this->stmt, 7, SQL_C_USHORT, &usmall_int_val_min, buf_len, &ind); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLBindCol(this->stmt, 8, SQL_C_USHORT, &usmall_int_val_max, buf_len, &ind); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Signed Integer + SQLINTEGER slong_val_min; + SQLINTEGER slong_val_max; + + ret = SQLBindCol(this->stmt, 9, SQL_C_SLONG, &slong_val_min, buf_len, &ind); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLBindCol(this->stmt, 10, SQL_C_SLONG, &slong_val_max, buf_len, &ind); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Unsigned Integer + SQLUINTEGER ulong_val_min; + SQLUINTEGER ulong_val_max; + + ret = SQLBindCol(this->stmt, 11, SQL_C_ULONG, &ulong_val_min, buf_len, &ind); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLBindCol(this->stmt, 12, SQL_C_ULONG, &ulong_val_max, buf_len, &ind); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Signed Big Int + SQLBIGINT sbig_int_val_min; + SQLBIGINT sbig_int_val_max; + + ret = SQLBindCol(this->stmt, 13, SQL_C_SBIGINT, &sbig_int_val_min, buf_len, &ind); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLBindCol(this->stmt, 14, SQL_C_SBIGINT, &sbig_int_val_max, buf_len, &ind); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Unsigned Big Int + SQLUBIGINT ubig_int_val_min; + SQLUBIGINT ubig_int_val_max; + + ret = SQLBindCol(this->stmt, 15, SQL_C_UBIGINT, &ubig_int_val_min, buf_len, &ind); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLBindCol(this->stmt, 16, SQL_C_UBIGINT, &ubig_int_val_max, buf_len, &ind); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Decimal + SQL_NUMERIC_STRUCT decimal_val_neg; + SQL_NUMERIC_STRUCT decimal_val_pos; + memset(&decimal_val_neg, 0, sizeof(decimal_val_neg)); + memset(&decimal_val_pos, 0, sizeof(decimal_val_pos)); + + ret = SQLBindCol(this->stmt, 17, SQL_C_NUMERIC, &decimal_val_neg, buf_len, &ind); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLBindCol(this->stmt, 18, SQL_C_NUMERIC, &decimal_val_pos, buf_len, &ind); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Float + float float_val_min; + float float_val_max; + + ret = SQLBindCol(this->stmt, 19, SQL_C_FLOAT, &float_val_min, buf_len, &ind); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLBindCol(this->stmt, 20, SQL_C_FLOAT, &float_val_max, buf_len, &ind); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Double + SQLDOUBLE double_val_min; + SQLDOUBLE double_val_max; + + ret = SQLBindCol(this->stmt, 21, SQL_C_DOUBLE, &double_val_min, buf_len, &ind); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLBindCol(this->stmt, 22, SQL_C_DOUBLE, &double_val_max, buf_len, &ind); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Bit + bool bit_val_false; + bool bit_val_true; + + ret = SQLBindCol(this->stmt, 23, SQL_C_BIT, &bit_val_false, buf_len, &ind); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLBindCol(this->stmt, 24, SQL_C_BIT, &bit_val_true, buf_len, &ind); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Characters + SQLCHAR char_val[2]; + buf_len = sizeof(SQLCHAR) * 2; + + ret = SQLBindCol(this->stmt, 25, SQL_C_CHAR, &char_val, buf_len, &ind); + EXPECT_EQ(ret, SQL_SUCCESS); + + SQLWCHAR wchar_val[2]; + constexpr size_t wchar_size = driver::odbcabstraction::GetSqlWCharSize(); + buf_len = wchar_size * 2; + + ret = SQLBindCol(this->stmt, 26, SQL_C_WCHAR, &wchar_val, buf_len, &ind); + EXPECT_EQ(ret, SQL_SUCCESS); + + SQLWCHAR wvarchar_val[3]; + buf_len = wchar_size * 3; + + ret = SQLBindCol(this->stmt, 27, SQL_C_WCHAR, &wvarchar_val, buf_len, &ind); + EXPECT_EQ(ret, SQL_SUCCESS); + + SQLCHAR varchar_val[4]; + buf_len = sizeof(SQLCHAR) * 4; + + ret = SQLBindCol(this->stmt, 28, SQL_C_CHAR, &varchar_val, buf_len, &ind); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Date and Timestamp + SQL_DATE_STRUCT date_val_min{}, date_val_max{}; + buf_len = 0; + + ret = SQLBindCol(this->stmt, 29, SQL_C_TYPE_DATE, &date_val_min, buf_len, &ind); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLBindCol(this->stmt, 30, SQL_C_TYPE_DATE, &date_val_max, buf_len, &ind); + EXPECT_EQ(ret, SQL_SUCCESS); + + SQL_TIMESTAMP_STRUCT timestamp_val_min{}, timestamp_val_max{}; + + ret = + SQLBindCol(this->stmt, 31, SQL_C_TYPE_TIMESTAMP, ×tamp_val_min, buf_len, &ind); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = + SQLBindCol(this->stmt, 32, SQL_C_TYPE_TIMESTAMP, ×tamp_val_max, buf_len, &ind); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Execute query and fetch data once since there is only 1 row. + std::wstring wsql = this->getQueryAllDataTypes(); + std::vector sql0(wsql.begin(), wsql.end()); + + ret = SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size())); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Data verification + + // Signed Tiny Int + EXPECT_EQ(stiny_int_val_min, std::numeric_limits::min()); + EXPECT_EQ(stiny_int_val_max, std::numeric_limits::max()); + + // Unsigned Tiny Int + EXPECT_EQ(utiny_int_val_min, std::numeric_limits::min()); + EXPECT_EQ(utiny_int_val_max, std::numeric_limits::max()); + + // Signed Small Int + EXPECT_EQ(ssmall_int_val_min, std::numeric_limits::min()); + EXPECT_EQ(ssmall_int_val_max, std::numeric_limits::max()); + + // Unsigned Small Int + EXPECT_EQ(usmall_int_val_min, std::numeric_limits::min()); + EXPECT_EQ(usmall_int_val_max, std::numeric_limits::max()); + + // Signed Long + EXPECT_EQ(slong_val_min, std::numeric_limits::min()); + EXPECT_EQ(slong_val_max, std::numeric_limits::max()); + + // Unsigned Long + EXPECT_EQ(ulong_val_min, std::numeric_limits::min()); + EXPECT_EQ(ulong_val_max, std::numeric_limits::max()); + + // Signed Big Int + EXPECT_EQ(sbig_int_val_min, std::numeric_limits::min()); + EXPECT_EQ(sbig_int_val_max, std::numeric_limits::max()); + + // Unsigned Big Int + EXPECT_EQ(ubig_int_val_min, std::numeric_limits::min()); + EXPECT_EQ(ubig_int_val_max, std::numeric_limits::max()); + + // Decimal + EXPECT_EQ(decimal_val_neg.sign, 0); + EXPECT_EQ(decimal_val_neg.scale, 0); + EXPECT_EQ(decimal_val_neg.precision, 38); + EXPECT_THAT(decimal_val_neg.val, ::testing::ElementsAre(0xFF, 0xC9, 0x9A, 0x3B, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0)); + + EXPECT_EQ(decimal_val_pos.sign, 1); + EXPECT_EQ(decimal_val_pos.scale, 0); + EXPECT_EQ(decimal_val_pos.precision, 38); + EXPECT_THAT(decimal_val_pos.val, ::testing::ElementsAre(0xFF, 0xC9, 0x9A, 0x3B, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0)); + + // Float + EXPECT_EQ(float_val_min, -std::numeric_limits::max()); + EXPECT_EQ(float_val_max, std::numeric_limits::max()); + + // Double + EXPECT_EQ(double_val_min, -std::numeric_limits::max()); + EXPECT_EQ(double_val_max, std::numeric_limits::max()); + + // Bit + EXPECT_EQ(bit_val_false, false); + EXPECT_EQ(bit_val_true, true); + + // Characters + EXPECT_EQ(char_val[0], 'Z'); + EXPECT_EQ(wchar_val[0], L'你'); + EXPECT_EQ(wvarchar_val[0], L'你'); + EXPECT_EQ(wvarchar_val[1], L'好'); + + EXPECT_EQ(varchar_val[0], 'X'); + EXPECT_EQ(varchar_val[1], 'Y'); + EXPECT_EQ(varchar_val[2], 'Z'); + + // Date + EXPECT_EQ(date_val_min.day, 1); + EXPECT_EQ(date_val_min.month, 1); + EXPECT_EQ(date_val_min.year, 1400); + + EXPECT_EQ(date_val_max.day, 31); + EXPECT_EQ(date_val_max.month, 12); + EXPECT_EQ(date_val_max.year, 9999); + + // Timestamp + EXPECT_EQ(timestamp_val_min.day, 1); + EXPECT_EQ(timestamp_val_min.month, 1); + EXPECT_EQ(timestamp_val_min.year, 1400); + EXPECT_EQ(timestamp_val_min.hour, 0); + EXPECT_EQ(timestamp_val_min.minute, 0); + EXPECT_EQ(timestamp_val_min.second, 0); + EXPECT_EQ(timestamp_val_min.fraction, 0); + + EXPECT_EQ(timestamp_val_max.day, 31); + EXPECT_EQ(timestamp_val_max.month, 12); + EXPECT_EQ(timestamp_val_max.year, 9999); + EXPECT_EQ(timestamp_val_max.hour, 23); + EXPECT_EQ(timestamp_val_max.minute, 59); + EXPECT_EQ(timestamp_val_max.second, 59); + EXPECT_EQ(timestamp_val_max.fraction, 0); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCRemoteTestBase, TestSQLBindColTimeQuery) { + // Mock server test is skipped due to limitation on the mock server. + // Time type from mock server does not include the fraction + this->connect(); + + SQL_TIME_STRUCT time_var_min{}; + SQL_TIME_STRUCT time_var_max{}; + SQLLEN buf_len = sizeof(time_var_min); + SQLLEN ind; + + SQLRETURN ret = + SQLBindCol(this->stmt, 1, SQL_C_TYPE_TIME, &time_var_min, buf_len, &ind); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLBindCol(this->stmt, 2, SQL_C_TYPE_TIME, &time_var_max, buf_len, &ind); + EXPECT_EQ(ret, SQL_SUCCESS); + + std::wstring wsql = + LR"( + SELECT CAST(TIME '00:00:00' AS TIME) AS time_min, + CAST(TIME '23:59:59' AS TIME) AS time_max; + )"; + std::vector sql0(wsql.begin(), wsql.end()); + + ret = SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size())); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Check min values for time. + EXPECT_EQ(time_var_min.hour, 0); + EXPECT_EQ(time_var_min.minute, 0); + EXPECT_EQ(time_var_min.second, 0); + + // Check max values for time. + EXPECT_EQ(time_var_max.hour, 23); + EXPECT_EQ(time_var_max.minute, 59); + EXPECT_EQ(time_var_max.second, 59); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLBindColVarbinaryQuery) { + // Have binary test on mock test base as remote test servers tend to have different + // formats for binary data + this->connect(); + + // varbinary + std::vector varbinary_val(3); + SQLLEN buf_len = varbinary_val.size(); + SQLLEN ind; + SQLRETURN ret = + SQLBindCol(this->stmt, 1, SQL_C_BINARY, &varbinary_val[0], buf_len, &ind); + + std::wstring wsql = L"SELECT X'ABCDEF' AS c_varbinary;"; + std::vector sql0(wsql.begin(), wsql.end()); + + ret = SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size())); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Check varbinary values + EXPECT_EQ(varbinary_val[0], '\xAB'); + EXPECT_EQ(varbinary_val[1], '\xCD'); + EXPECT_EQ(varbinary_val[2], '\xEF'); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCRemoteTestBase, TestSQLBindColNullQuery) { + // Limitation on mock test server prevents null from working properly, so use remote + // server instead. Mock server has type `DENSE_UNION` for null column data. + this->connect(); + + SQLINTEGER val; + SQLLEN ind; + + SQLRETURN ret = SQLBindCol(this->stmt, 1, SQL_C_LONG, &val, 0, &ind); + EXPECT_EQ(ret, SQL_SUCCESS); + + std::wstring wsql = L"SELECT null as null_col;"; + std::vector sql0(wsql.begin(), wsql.end()); + + ret = SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size())); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Verify SQL_NULL_DATA is returned for indicator + EXPECT_EQ(ind, SQL_NULL_DATA); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCRemoteTestBase, TestSQLBindColNullQueryNullIndicator) { + // Limitation on mock test server prevents null from working properly, so use remote + // server instead. Mock server has type `DENSE_UNION` for null column data. + this->connect(); + + SQLINTEGER val; + + SQLRETURN ret = SQLBindCol(this->stmt, 1, SQL_C_LONG, &val, 0, 0); + + std::wstring wsql = L"SELECT null as null_col;"; + std::vector sql0(wsql.begin(), wsql.end()); + + ret = SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size())); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLFetch(this->stmt); + + EXPECT_EQ(ret, SQL_ERROR); + // Verify invalid null indicator is reported, as it is required + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, error_state_22002); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLBindColRowFetching) { + this->connect(); + + SQLINTEGER val; + SQLLEN buf_len = sizeof(val); + SQLLEN ind; + + // Same variable will be used for column 1, the value of `val` + // should be updated after every SQLFetch call. + SQLRETURN ret = SQLBindCol(this->stmt, 1, SQL_C_LONG, &val, buf_len, &ind); + + std::wstring wsql = + LR"( + SELECT 1 AS small_table + UNION ALL + SELECT 2 + UNION ALL + SELECT 3; + )"; + std::vector sql0(wsql.begin(), wsql.end()); + + ret = SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size())); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Fetch row 1 + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Verify 1 is returned + EXPECT_EQ(val, 1); + + // Fetch row 2 + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Verify 2 is returned + EXPECT_EQ(val, 2); + + // Fetch row 3 + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Verify 3 is returned + EXPECT_EQ(val, 3); + + // Verify result set has no more data beyond row 3 + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_NO_DATA); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLBindColRowArraySize) { + // Set SQL_ATTR_ROW_ARRAY_SIZE to fetch 3 rows at once + this->connect(); + + constexpr SQLULEN rows = 3; + SQLINTEGER val[rows]; + SQLLEN buf_len = sizeof(val); + SQLLEN ind[rows]; + + // Same variable will be used for column 1, the value of `val` + // should be updated after every SQLFetch call. + SQLRETURN ret = SQLBindCol(this->stmt, 1, SQL_C_LONG, val, buf_len, ind); + + SQLLEN rows_fetched; + ret = SQLSetStmtAttr(this->stmt, SQL_ATTR_ROWS_FETCHED_PTR, &rows_fetched, 0); + + std::wstring wsql = + LR"( + SELECT 1 AS small_table + UNION ALL + SELECT 2 + UNION ALL + SELECT 3; + )"; + std::vector sql0(wsql.begin(), wsql.end()); + + ret = SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size())); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLSetStmtAttr(this->stmt, SQL_ATTR_ROW_ARRAY_SIZE, + reinterpret_cast(rows), 0); + + // Fetch 3 rows at once + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Verify 3 rows are fetched + EXPECT_EQ(rows_fetched, 3); + + // Verify 1 is returned + EXPECT_EQ(val[0], 1); + // Verify 2 is returned + EXPECT_EQ(val[1], 2); + // Verify 3 is returned + EXPECT_EQ(val[2], 3); + + // Verify result set has no more data beyond row 3 + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_NO_DATA); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLBindColIndicatorOnly) { + // GH-47021: implement driver to return indicator value when data pointer is null + GTEST_SKIP(); + // Verify driver supports null data pointer with valid indicator pointer + this->connect(); + + // Numeric Types + + // Signed Tiny Int + SQLLEN stiny_int_ind; + + SQLRETURN ret = SQLBindCol(this->stmt, 1, SQL_C_STINYINT, 0, 0, &stiny_int_ind); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Characters + SQLLEN buf_len = sizeof(SQLCHAR) * 2; + SQLLEN char_val_ind; + + ret = SQLBindCol(this->stmt, 25, SQL_C_CHAR, 0, buf_len, &char_val_ind); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Execute query and fetch data once since there is only 1 row. + std::wstring wsql = this->getQueryAllDataTypes(); + std::vector sql0(wsql.begin(), wsql.end()); + + ret = SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size())); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Verify values for indicator pointer + // Signed Tiny Int + EXPECT_EQ(stiny_int_ind, 1); + + // Char array + EXPECT_EQ(char_val_ind, 1); + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLBindColIndicatorOnlySQLUnbind) { + // Verify driver supports valid indicator pointer after unbinding all columns + this->connect(); + + // Numeric Types + + // Signed Tiny Int + int8_t stiny_int_val; + SQLLEN stiny_int_ind; + + SQLRETURN ret = + SQLBindCol(this->stmt, 1, SQL_C_STINYINT, &stiny_int_val, 0, &stiny_int_ind); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Characters + SQLCHAR char_val[2]; + SQLLEN buf_len = sizeof(SQLCHAR) * 2; + SQLLEN char_val_ind; + + ret = SQLBindCol(this->stmt, 25, SQL_C_CHAR, &char_val, buf_len, &char_val_ind); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Driver should still be able to execute queries after unbinding columns + ret = SQLFreeStmt(this->stmt, SQL_UNBIND); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Execute query and fetch data once since there is only 1 row. + std::wstring wsql = this->getQueryAllDataTypes(); + std::vector sql0(wsql.begin(), wsql.end()); + + ret = SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size())); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + // GH-47021: implement driver to return indicator value when data pointer is null and + // uncomment the checks Verify values for indicator pointer Signed Tiny Int + // EXPECT_EQ(stiny_int_ind, 1); + + // Char array + // EXPECT_EQ(char_val_ind, 1); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLExtendedFetchRowFetching) { + // Set SQL_ROWSET_SIZE to fetch 3 rows at once + this->connect(); + + constexpr SQLULEN rows = 3; + SQLINTEGER val[rows]; + SQLLEN buf_len = sizeof(val); + SQLLEN ind[rows]; + + // Same variable will be used for column 1, the value of `val` + // should be updated after every SQLFetch call. + SQLRETURN ret = SQLBindCol(this->stmt, 1, SQL_C_LONG, val, buf_len, ind); + + ret = + SQLSetStmtAttr(this->stmt, SQL_ROWSET_SIZE, reinterpret_cast(rows), 0); + + std::wstring wsql = + LR"( + SELECT 1 AS small_table + UNION ALL + SELECT 2 + UNION ALL + SELECT 3; + )"; + std::vector sql0(wsql.begin(), wsql.end()); + + ret = SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size())); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Fetch row 1-3. + SQLULEN row_count; + SQLUSMALLINT row_status[rows]; + + ret = SQLExtendedFetch(this->stmt, SQL_FETCH_NEXT, 0, &row_count, row_status); + EXPECT_EQ(ret, SQL_SUCCESS); + EXPECT_EQ(row_count, 3); + + for (int i = 0; i < rows; i++) { + EXPECT_EQ(row_status[i], SQL_SUCCESS); + } + + // Verify 1 is returned for row 1 + EXPECT_EQ(val[0], 1); + // Verify 2 is returned for row 2 + EXPECT_EQ(val[1], 2); + // Verify 3 is returned for row 3 + EXPECT_EQ(val[2], 3); + + // Verify result set has no more data beyond row 3 + SQLULEN row_count2; + SQLUSMALLINT row_status2[rows]; + ret = SQLExtendedFetch(this->stmt, SQL_FETCH_NEXT, 0, &row_count2, row_status2); + EXPECT_EQ(ret, SQL_NO_DATA); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCRemoteTestBase, TestSQLExtendedFetchQueryNullIndicator) { + // GH-47110: SQLExtendedFetch should return SQL_SUCCESS_WITH_INFO for 22002 + // Limitation on mock test server prevents null from working properly, so use remote + // server instead. Mock server has type `DENSE_UNION` for null column data. + GTEST_SKIP(); + this->connect(); + + SQLINTEGER val; + + SQLRETURN ret = SQLBindCol(this->stmt, 1, SQL_C_LONG, &val, 0, 0); + + std::wstring wsql = L"SELECT null as null_col;"; + std::vector sql0(wsql.begin(), wsql.end()); + + ret = SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size())); + EXPECT_EQ(ret, SQL_SUCCESS); + + SQLULEN row_count1; + SQLUSMALLINT row_status1[1]; + + // SQLExtendedFetch should return SQL_SUCCESS_WITH_INFO for 22002 state + ret = SQLExtendedFetch(this->stmt, SQL_FETCH_NEXT, 0, &row_count1, row_status1); + EXPECT_EQ(ret, SQL_SUCCESS_WITH_INFO); + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, error_state_22002); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLMoreResultsNoData) { + // Verify SQLMoreResults is stubbed to return SQL_NO_DATA + this->connect(); + + std::wstring wsql = L"SELECT 1;"; + std::vector sql0(wsql.begin(), wsql.end()); + + SQLRETURN ret = + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size())); + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLMoreResults(this->stmt); + + EXPECT_EQ(ret, SQL_NO_DATA); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLMoreResultsInvalidFunctionSequence) { + this->connect(); + + SQLRETURN ret = SQLMoreResults(this->stmt); + + // Verify function sequence error state is reported when SQLMoreResults is called + // without executing any queries + EXPECT_EQ(ret, SQL_ERROR); + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, error_state_HY010); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLNativeSqlReturnsInputString) { + this->connect(); + + SQLWCHAR buf[1024]; + constexpr SQLINTEGER bufCharLen = sizeof(buf) / ODBC::GetSqlWCharSize(); + SQLWCHAR inputStr[] = L"SELECT * FROM mytable WHERE id == 1"; + SQLINTEGER inputCharLen = static_cast(wcslen(inputStr)); + SQLINTEGER outputCharLen = 0; + std::wstring expectedString = std::wstring(inputStr); + + SQLRETURN ret = + SQLNativeSql(this->conn, inputStr, inputCharLen, buf, bufCharLen, &outputCharLen); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_EQ(outputCharLen, inputCharLen); + + // returned length is in characters + std::wstring returnedString(buf, buf + outputCharLen); + + EXPECT_EQ(returnedString, expectedString); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLNativeSqlReturnsNTSInputString) { + this->connect(); + + SQLWCHAR buf[1024]; + constexpr SQLINTEGER bufCharLen = sizeof(buf) / ODBC::GetSqlWCharSize(); + SQLWCHAR inputStr[] = L"SELECT * FROM mytable WHERE id == 1"; + SQLINTEGER inputCharLen = static_cast(wcslen(inputStr)); + SQLINTEGER outputCharLen = 0; + std::wstring expectedString = std::wstring(inputStr); + + SQLRETURN ret = + SQLNativeSql(this->conn, inputStr, SQL_NTS, buf, bufCharLen, &outputCharLen); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_EQ(outputCharLen, inputCharLen); + + // returned length is in characters + std::wstring returnedString(buf, buf + outputCharLen); + + EXPECT_EQ(returnedString, expectedString); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLNativeSqlReturnsInputStringLength) { + this->connect(); + + SQLWCHAR inputStr[] = L"SELECT * FROM mytable WHERE id == 1"; + SQLINTEGER inputCharLen = static_cast(wcslen(inputStr)); + SQLINTEGER outputCharLen = 0; + std::wstring expectedString = std::wstring(inputStr); + + SQLRETURN ret = + SQLNativeSql(this->conn, inputStr, inputCharLen, nullptr, 0, &outputCharLen); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_EQ(outputCharLen, inputCharLen); + + ret = SQLNativeSql(this->conn, inputStr, SQL_NTS, nullptr, 0, &outputCharLen); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_EQ(outputCharLen, inputCharLen); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLNativeSqlReturnsTruncatedString) { + this->connect(); + + const SQLINTEGER smallBufSizeInChar = 11; + SQLWCHAR smallBuf[smallBufSizeInChar]; + constexpr SQLINTEGER smallBufCharLen = sizeof(smallBuf) / ODBC::GetSqlWCharSize(); + SQLWCHAR inputStr[] = L"SELECT * FROM mytable WHERE id == 1"; + SQLINTEGER inputCharLen = static_cast(wcslen(inputStr)); + SQLINTEGER outputCharLen = 0; + + // Create expected return string based on buf size + SQLWCHAR expectedStringBuf[smallBufSizeInChar]; + wcsncpy(expectedStringBuf, inputStr, 10); + expectedStringBuf[10] = L'\0'; + std::wstring expectedString(expectedStringBuf, expectedStringBuf + smallBufSizeInChar); + + SQLRETURN ret = SQLNativeSql(this->conn, inputStr, inputCharLen, smallBuf, + smallBufCharLen, &outputCharLen); + + EXPECT_EQ(ret, SQL_SUCCESS_WITH_INFO); + VerifyOdbcErrorState(SQL_HANDLE_DBC, this->conn, error_state_01004); + + // Returned text length represents full string char length regardless of truncation + EXPECT_EQ(outputCharLen, inputCharLen); + + std::wstring returnedString(smallBuf, smallBuf + smallBufCharLen); + + EXPECT_EQ(returnedString, expectedString); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLNativeSqlReturnsErrorOnBadInputs) { + this->connect(); + + SQLWCHAR buf[1024]; + constexpr SQLINTEGER bufCharLen = sizeof(buf) / ODBC::GetSqlWCharSize(); + SQLWCHAR inputStr[] = L"SELECT * FROM mytable WHERE id == 1"; + SQLINTEGER inputCharLen = static_cast(wcslen(inputStr)); + SQLINTEGER outputCharLen = 0; + + SQLRETURN ret = + SQLNativeSql(this->conn, nullptr, inputCharLen, buf, bufCharLen, &outputCharLen); + + EXPECT_EQ(ret, SQL_ERROR); + VerifyOdbcErrorState(SQL_HANDLE_DBC, this->conn, error_state_HY009); + + ret = SQLNativeSql(this->conn, nullptr, SQL_NTS, buf, bufCharLen, &outputCharLen); + + EXPECT_EQ(ret, SQL_ERROR); + VerifyOdbcErrorState(SQL_HANDLE_DBC, this->conn, error_state_HY009); + + ret = SQLNativeSql(this->conn, inputStr, -100, buf, bufCharLen, &outputCharLen); + + EXPECT_EQ(ret, SQL_ERROR); + VerifyOdbcErrorState(SQL_HANDLE_DBC, this->conn, error_state_HY090); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, SQLNumResultColsReturnsColumnsOnSelect) { + this->connect(); + + SQLSMALLINT columnCount = 0; + SQLSMALLINT expectedValue = 3; + SQLWCHAR sqlQuery[] = L"SELECT 1 AS col1, 'One' AS col2, 3 AS col3"; + SQLINTEGER queryLength = static_cast(wcslen(sqlQuery)); + + SQLRETURN ret = SQLExecDirect(this->stmt, sqlQuery, queryLength); + + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLFetch(this->stmt); + + EXPECT_EQ(ret, SQL_SUCCESS); + + CheckIntColumn(this->stmt, 1, 1); + CheckStringColumnW(this->stmt, 2, L"One"); + CheckIntColumn(this->stmt, 3, 3); + + ret = SQLNumResultCols(this->stmt, &columnCount); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_EQ(columnCount, expectedValue); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, SQLNumResultColsReturnsSuccessOnNullptr) { + this->connect(); + + SQLWCHAR sqlQuery[] = L"SELECT 1 AS col1, 'One' AS col2, 3 AS col3"; + SQLINTEGER queryLength = static_cast(wcslen(sqlQuery)); + + SQLRETURN ret = SQLExecDirect(this->stmt, sqlQuery, queryLength); + + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLFetch(this->stmt); + + EXPECT_EQ(ret, SQL_SUCCESS); + + CheckIntColumn(this->stmt, 1, 1); + CheckStringColumnW(this->stmt, 2, L"One"); + CheckIntColumn(this->stmt, 3, 3); + + ret = SQLNumResultCols(this->stmt, nullptr); + + EXPECT_EQ(ret, SQL_SUCCESS); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, SQLNumResultColsFunctionSequenceErrorOnNoQuery) { + this->connect(); + + SQLSMALLINT columnCount = 0; + SQLSMALLINT expectedValue = 0; + + SQLRETURN ret = SQLNumResultCols(this->stmt, &columnCount); + + EXPECT_EQ(ret, SQL_ERROR); + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, error_state_HY010); + + EXPECT_EQ(columnCount, expectedValue); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, SQLRowCountReturnsNegativeOneOnSelect) { + this->connect(); + + SQLLEN rowCount = 0; + SQLLEN expectedValue = -1; + SQLWCHAR sqlQuery[] = L"SELECT 1 AS col1, 'One' AS col2, 3 AS col3"; + SQLINTEGER queryLength = static_cast(wcslen(sqlQuery)); + + SQLRETURN ret = SQLExecDirect(this->stmt, sqlQuery, queryLength); + + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLFetch(this->stmt); + + EXPECT_EQ(ret, SQL_SUCCESS); + + CheckIntColumn(this->stmt, 1, 1); + CheckStringColumnW(this->stmt, 2, L"One"); + CheckIntColumn(this->stmt, 3, 3); + + ret = SQLRowCount(this->stmt, &rowCount); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_EQ(rowCount, expectedValue); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, SQLRowCountReturnsSuccessOnNullptr) { + this->connect(); + + SQLWCHAR sqlQuery[] = L"SELECT 1 AS col1, 'One' AS col2, 3 AS col3"; + SQLINTEGER queryLength = static_cast(wcslen(sqlQuery)); + + SQLRETURN ret = SQLExecDirect(this->stmt, sqlQuery, queryLength); + + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLFetch(this->stmt); + + EXPECT_EQ(ret, SQL_SUCCESS); + + CheckIntColumn(this->stmt, 1, 1); + CheckStringColumnW(this->stmt, 2, L"One"); + CheckIntColumn(this->stmt, 3, 3); + + ret = SQLRowCount(this->stmt, 0); + + EXPECT_EQ(ret, SQL_SUCCESS); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, SQLRowCountFunctionSequenceErrorOnNoQuery) { + this->connect(); + + SQLLEN rowCount = 0; + SQLLEN expectedValue = 0; + + SQLRETURN ret = SQLRowCount(this->stmt, &rowCount); + + EXPECT_EQ(ret, SQL_ERROR); + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, error_state_HY010); + + EXPECT_EQ(rowCount, expectedValue); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLFreeStmtSQLClose) { + this->connect(); + + std::wstring wsql = L"SELECT 1;"; + std::vector sql0(wsql.begin(), wsql.end()); + + SQLRETURN ret = + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size())); + + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLFreeStmt(this->stmt, SQL_CLOSE); + + EXPECT_EQ(ret, SQL_SUCCESS); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLCloseCursor) { + this->connect(); + + std::wstring wsql = L"SELECT 1;"; + std::vector sql0(wsql.begin(), wsql.end()); + + SQLRETURN ret = + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size())); + + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLCloseCursor(this->stmt); + + EXPECT_EQ(ret, SQL_SUCCESS); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLFreeStmtSQLCloseWithoutCursor) { + // SQLFreeStmt(SQL_CLOSE) does not throw error with invalid cursor + this->connect(); + + SQLRETURN ret = SQLFreeStmt(this->stmt, SQL_CLOSE); + + EXPECT_EQ(ret, SQL_SUCCESS); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLCloseCursorWithoutCursor) { + this->connect(); + + SQLRETURN ret = SQLCloseCursor(this->stmt); + + EXPECT_EQ(ret, SQL_ERROR); + + // Verify invalid cursor error state is returned + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, error_state_24000); + + this->disconnect(); +} + +} // namespace arrow::flight::sql::odbc diff --git a/cpp/src/arrow/flight/sql/odbc/tests/tables_test.cc b/cpp/src/arrow/flight/sql/odbc/tests/tables_test.cc new file mode 100644 index 00000000000..a6cbd38f881 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/tests/tables_test.cc @@ -0,0 +1,679 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +#include "arrow/flight/sql/odbc/tests/odbc_test_suite.h" + +#ifdef _WIN32 +# include +#endif + +#include +#include +#include + +#include "gtest/gtest.h" + +namespace arrow::flight::sql::odbc { + +// Helper Functions + +std::wstring GetStringColumnW(SQLHSTMT stmt, int colId) { + SQLWCHAR buf[1024]; + SQLLEN lenIndicator = 0; + + SQLRETURN ret = SQLGetData(stmt, colId, SQL_C_WCHAR, buf, sizeof(buf), &lenIndicator); + + EXPECT_EQ(ret, SQL_SUCCESS); + + if (lenIndicator == SQL_NULL_DATA) { + return L""; + } + + // indicator is in bytes, so convert to character count + size_t charCount = static_cast(lenIndicator) / ODBC::GetSqlWCharSize(); + return std::wstring(buf, buf + charCount); +} + +// Test Cases + +TYPED_TEST(FlightSQLODBCTestBase, SQLTablesTestInputData) { + this->connect(); + + SQLWCHAR catalogName[] = L""; + SQLWCHAR schemaName[] = L""; + SQLWCHAR tableName[] = L""; + SQLWCHAR tableType[] = L""; + + // All values populated + SQLRETURN ret = SQLTables(this->stmt, catalogName, sizeof(catalogName), schemaName, + sizeof(schemaName), tableName, sizeof(tableName), tableType, + sizeof(tableType)); + + EXPECT_EQ(ret, SQL_SUCCESS); + + ValidateFetch(this->stmt, SQL_NO_DATA); + + // Sizes are nulls + ret = SQLTables(this->stmt, catalogName, 0, schemaName, 0, tableName, 0, tableType, 0); + + EXPECT_EQ(ret, SQL_SUCCESS); + + ValidateFetch(this->stmt, SQL_NO_DATA); + + // Values are nulls + ret = SQLTables(this->stmt, 0, sizeof(catalogName), 0, sizeof(schemaName), 0, + sizeof(tableName), 0, sizeof(tableType)); + + EXPECT_EQ(ret, SQL_SUCCESS); + + ValidateFetch(this->stmt, SQL_SUCCESS); + // Close statement cursor to avoid leaving in an invalid state + SQLFreeStmt(this->stmt, SQL_CLOSE); + + // All values and sizes are nulls + ret = SQLTables(this->stmt, 0, 0, 0, 0, 0, 0, 0, 0); + + EXPECT_EQ(ret, SQL_SUCCESS); + + ValidateFetch(this->stmt, SQL_SUCCESS); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, SQLTablesTestGetMetadataForAllCatalogs) { + this->connect(); + + SQLWCHAR empty[] = L""; + SQLWCHAR SQL_ALL_CATALOGS_W[] = L"%"; + std::wstring expectedCatalogName = std::wstring(L"main"); + + // Get Catalog metadata + SQLRETURN ret = SQLTables(this->stmt, SQL_ALL_CATALOGS_W, SQL_NTS, empty, SQL_NTS, + empty, SQL_NTS, empty, SQL_NTS); + + EXPECT_EQ(ret, SQL_SUCCESS); + + ValidateFetch(this->stmt, SQL_SUCCESS); + + CheckStringColumnW(this->stmt, 1, expectedCatalogName); + CheckNullColumnW(this->stmt, 2); + CheckNullColumnW(this->stmt, 3); + CheckNullColumnW(this->stmt, 4); + CheckNullColumnW(this->stmt, 5); + + ValidateFetch(this->stmt, SQL_NO_DATA); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, SQLTablesTestGetMetadataForNamedCatalog) { + this->connect(); + this->CreateTestTables(); + + SQLWCHAR catalogName[] = L"main"; + SQLWCHAR* tableNames[] = {(SQLWCHAR*)L"TestTable", (SQLWCHAR*)L"foreignTable", + (SQLWCHAR*)L"intTable", (SQLWCHAR*)L"sqlite_sequence"}; + std::wstring expectedCatalogName = std::wstring(catalogName); + std::wstring expectedTableType = std::wstring(L"table"); + + // Get named Catalog metadata - Mock server returns the system table sqlite_sequence as + // type "table" + SQLRETURN ret = SQLTables(this->stmt, catalogName, SQL_NTS, nullptr, SQL_NTS, nullptr, + SQL_NTS, nullptr, SQL_NTS); + + EXPECT_EQ(ret, SQL_SUCCESS); + + for (size_t i = 0; i < sizeof(tableNames) / sizeof(*tableNames); ++i) { + ValidateFetch(this->stmt, SQL_SUCCESS); + + CheckStringColumnW(this->stmt, 1, expectedCatalogName); + // Mock server does not support table schema + CheckNullColumnW(this->stmt, 2); + CheckStringColumnW(this->stmt, 3, tableNames[i]); + CheckStringColumnW(this->stmt, 4, expectedTableType); + CheckNullColumnW(this->stmt, 5); + } + + ValidateFetch(this->stmt, SQL_NO_DATA); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, SQLTablesTestGetSchemaHasNoData) { + this->connect(); + + SQLWCHAR SQL_ALL_SCHEMAS_W[] = L"%"; + + // Validate that no schema data is available for Mock server + SQLRETURN ret = SQLTables(this->stmt, nullptr, SQL_NTS, SQL_ALL_SCHEMAS_W, SQL_NTS, + nullptr, SQL_NTS, nullptr, SQL_NTS); + + EXPECT_EQ(ret, SQL_SUCCESS); + + ValidateFetch(this->stmt, SQL_NO_DATA); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCRemoteTestBase, SQLTablesTestGetMetadataForAllSchemas) { + this->connect(); + + SQLWCHAR empty[] = L""; + SQLWCHAR SQL_ALL_SCHEMAS_W[] = L"%"; + std::set actualSchemas; + std::set expectedSchemas = {L"$scratch", L"INFORMATION_SCHEMA", L"sys", + L"sys.cache"}; + + // Return is unordered and contains user specific schemas, so collect schema names for + // comparison with a known list + SQLRETURN ret = SQLTables(this->stmt, empty, SQL_NTS, SQL_ALL_SCHEMAS_W, SQL_NTS, empty, + SQL_NTS, empty, SQL_NTS); + + ASSERT_EQ(ret, SQL_SUCCESS); + + while (true) { + ret = SQLFetch(this->stmt); + if (ret == SQL_NO_DATA) break; + ASSERT_EQ(ret, SQL_SUCCESS); + + CheckNullColumnW(this->stmt, 1); + std::wstring schema = GetStringColumnW(this->stmt, 2); + CheckNullColumnW(this->stmt, 3); + CheckNullColumnW(this->stmt, 4); + CheckNullColumnW(this->stmt, 5); + + // Skip user-specific schemas like "@UserName" + if (!schema.empty() && schema[0] != L'@') { + actualSchemas.insert(schema); + } + } + + EXPECT_EQ(actualSchemas, expectedSchemas); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCRemoteTestBase, SQLTablesTestFilterByAllSchema) { + // Requires creation of user table named ODBCTest using schema $scratch in remote server + this->connect(); + + SQLWCHAR SQL_ALL_SCHEMAS_W[] = L"%"; + SQLWCHAR* schemaNames[] = {(SQLWCHAR*)L"INFORMATION_SCHEMA", + (SQLWCHAR*)L"INFORMATION_SCHEMA", + (SQLWCHAR*)L"INFORMATION_SCHEMA", + (SQLWCHAR*)L"INFORMATION_SCHEMA", + (SQLWCHAR*)L"INFORMATION_SCHEMA", + (SQLWCHAR*)L"sys", + (SQLWCHAR*)L"sys", + (SQLWCHAR*)L"sys", + (SQLWCHAR*)L"sys", + (SQLWCHAR*)L"sys", + (SQLWCHAR*)L"sys", + (SQLWCHAR*)L"sys", + (SQLWCHAR*)L"sys", + (SQLWCHAR*)L"sys", + (SQLWCHAR*)L"sys", + (SQLWCHAR*)L"sys", + (SQLWCHAR*)L"sys", + (SQLWCHAR*)L"sys", + (SQLWCHAR*)L"sys", + (SQLWCHAR*)L"sys", + (SQLWCHAR*)L"sys", + (SQLWCHAR*)L"sys", + (SQLWCHAR*)L"sys", + (SQLWCHAR*)L"sys", + (SQLWCHAR*)L"sys", + (SQLWCHAR*)L"sys", + (SQLWCHAR*)L"sys", + (SQLWCHAR*)L"sys.cache", + (SQLWCHAR*)L"sys.cache", + (SQLWCHAR*)L"sys.cache", + (SQLWCHAR*)L"sys.cache", + (SQLWCHAR*)L"$scratch"}; + std::wstring expectedSystemTableType = std::wstring(L"SYSTEM_TABLE"); + std::wstring expectedUserTableType = std::wstring(L"TABLE"); + + SQLRETURN ret = SQLTables(this->stmt, nullptr, SQL_NTS, SQL_ALL_SCHEMAS_W, SQL_NTS, + nullptr, SQL_NTS, nullptr, SQL_NTS); + + EXPECT_EQ(ret, SQL_SUCCESS); + + for (size_t i = 0; i < sizeof(schemaNames) / sizeof(*schemaNames); ++i) { + ValidateFetch(this->stmt, SQL_SUCCESS); + + const std::wstring& expectedTableType = + (std::wstring(schemaNames[i]).rfind(L"sys", 0) == 0 || + std::wstring(schemaNames[i]) == L"INFORMATION_SCHEMA") + ? expectedSystemTableType + : expectedUserTableType; + + CheckNullColumnW(this->stmt, 1); + CheckStringColumnW(this->stmt, 2, schemaNames[i]); + // Ignore table name + CheckStringColumnW(this->stmt, 4, expectedTableType); + CheckNullColumnW(this->stmt, 5); + } + + ValidateFetch(this->stmt, SQL_NO_DATA); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCRemoteTestBase, SQLTablesGetMetadataForNamedSchema) { + // Requires creation of user table named ODBCTest using schema $scratch in remote server + this->connect(); + + SQLWCHAR schemaName[] = L"$scratch"; + std::wstring expectedSchemaName = std::wstring(schemaName); + std::wstring expectedTableName = std::wstring(L"ODBCTest"); + std::wstring expectedTableType = std::wstring(L"TABLE"); + + SQLRETURN ret = SQLTables(this->stmt, nullptr, SQL_NTS, schemaName, SQL_NTS, nullptr, + SQL_NTS, nullptr, SQL_NTS); + + EXPECT_EQ(ret, SQL_SUCCESS); + + ValidateFetch(this->stmt, SQL_SUCCESS); + + CheckNullColumnW(this->stmt, 1); + CheckStringColumnW(this->stmt, 2, expectedSchemaName); + // Ignore table name + CheckStringColumnW(this->stmt, 4, expectedTableType); + CheckNullColumnW(this->stmt, 5); + + ValidateFetch(this->stmt, SQL_NO_DATA); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, SQLTablesTestGetMetadataForAllTables) { + this->connect(); + this->CreateTestTables(); + + SQLWCHAR SQL_ALL_TABLES_W[] = L"%"; + SQLWCHAR* tableNames[] = {(SQLWCHAR*)L"TestTable", (SQLWCHAR*)L"foreignTable", + (SQLWCHAR*)L"intTable", (SQLWCHAR*)L"sqlite_sequence"}; + std::wstring expectedCatalogName = std::wstring(L"main"); + std::wstring expectedTableType = std::wstring(L"table"); + + // Get all Table metadata - Mock server returns the system table sqlite_sequence as type + // "table" + SQLRETURN ret = SQLTables(this->stmt, nullptr, SQL_NTS, nullptr, SQL_NTS, + SQL_ALL_TABLES_W, SQL_NTS, nullptr, SQL_NTS); + + EXPECT_EQ(ret, SQL_SUCCESS); + + for (size_t i = 0; i < sizeof(tableNames) / sizeof(*tableNames); ++i) { + ValidateFetch(this->stmt, SQL_SUCCESS); + + CheckStringColumnW(this->stmt, 1, expectedCatalogName); + // Mock server does not support table schema + CheckNullColumnW(this->stmt, 2); + CheckStringColumnW(this->stmt, 3, tableNames[i]); + CheckStringColumnW(this->stmt, 4, expectedTableType); + CheckNullColumnW(this->stmt, 5); + } + + ValidateFetch(this->stmt, SQL_NO_DATA); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, SQLTablesTestGetMetadataForTableName) { + this->connect(); + this->CreateTestTables(); + + SQLWCHAR* tableNames[] = {(SQLWCHAR*)L"TestTable", (SQLWCHAR*)L"foreignTable", + (SQLWCHAR*)L"intTable", (SQLWCHAR*)L"sqlite_sequence"}; + std::wstring expectedCatalogName = std::wstring(L"main"); + std::wstring expectedTableType = std::wstring(L"table"); + + for (size_t i = 0; i < sizeof(tableNames) / sizeof(*tableNames); ++i) { + // Get specific Table metadata + SQLRETURN ret = SQLTables(this->stmt, nullptr, SQL_NTS, nullptr, SQL_NTS, + tableNames[i], SQL_NTS, nullptr, SQL_NTS); + + EXPECT_EQ(ret, SQL_SUCCESS); + + ValidateFetch(this->stmt, SQL_SUCCESS); + + CheckStringColumnW(this->stmt, 1, expectedCatalogName); + // Mock server does not support table schema + CheckNullColumnW(this->stmt, 2); + CheckStringColumnW(this->stmt, 3, tableNames[i]); + CheckStringColumnW(this->stmt, 4, expectedTableType); + CheckNullColumnW(this->stmt, 5); + + ValidateFetch(this->stmt, SQL_NO_DATA); + } + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, SQLTablesTestGetMetadataForUnicodeTableByTableName) { + this->connect(); + this->CreateUnicodeTable(); + + SQLWCHAR unicodeTableName[] = L"数据"; + std::wstring expectedCatalogName = std::wstring(L"main"); + std::wstring expectedTableName = std::wstring(unicodeTableName); + std::wstring expectedTableType = std::wstring(L"table"); + + // Get specific Table metadata + SQLRETURN ret = SQLTables(this->stmt, nullptr, SQL_NTS, nullptr, SQL_NTS, + unicodeTableName, SQL_NTS, nullptr, SQL_NTS); + + EXPECT_EQ(ret, SQL_SUCCESS); + + ValidateFetch(this->stmt, SQL_SUCCESS); + + CheckStringColumnW(this->stmt, 1, expectedCatalogName); + // Mock server does not support table schema + CheckNullColumnW(this->stmt, 2); + CheckStringColumnW(this->stmt, 3, expectedTableName); + CheckStringColumnW(this->stmt, 4, expectedTableType); + CheckNullColumnW(this->stmt, 5); + + ValidateFetch(this->stmt, SQL_NO_DATA); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, SQLTablesTestGetMetadataForInvalidTableNameNoData) { + this->connect(); + this->CreateTestTables(); + + SQLWCHAR invalidTableName[] = L"NonExistantTableName"; + + // Try to get metadata for a non-existant table name + SQLRETURN ret = SQLTables(this->stmt, nullptr, SQL_NTS, nullptr, SQL_NTS, + invalidTableName, SQL_NTS, nullptr, SQL_NTS); + + EXPECT_EQ(ret, SQL_SUCCESS); + + ValidateFetch(this->stmt, SQL_NO_DATA); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, SQLTablesGetMetadataForTableType) { + // Mock server only supports table type "table" in lowercase + this->connect(); + this->CreateTestTables(); + + SQLWCHAR tableTypeTableLowercase[] = L"table"; + SQLWCHAR tableTypeTableUppercase[] = L"TABLE"; + SQLWCHAR tableTypeView[] = L"VIEW"; + SQLWCHAR tableTypeTableView[] = L"TABLE,VIEW"; + SQLWCHAR* tableNames[] = {(SQLWCHAR*)L"TestTable", (SQLWCHAR*)L"foreignTable", + (SQLWCHAR*)L"intTable", (SQLWCHAR*)L"sqlite_sequence"}; + std::wstring expectedCatalogName = std::wstring(L"main"); + std::wstring expectedTableName = std::wstring(L"TestTable"); + std::wstring expectedTableType = std::wstring(tableTypeTableLowercase); + SQLRETURN ret = SQL_SUCCESS; + + ret = SQLTables(this->stmt, nullptr, SQL_NTS, nullptr, SQL_NTS, nullptr, SQL_NTS, + tableTypeTableUppercase, SQL_NTS); + + EXPECT_EQ(ret, SQL_SUCCESS); + + ValidateFetch(this->stmt, SQL_NO_DATA); + + ret = SQLTables(this->stmt, nullptr, SQL_NTS, nullptr, SQL_NTS, nullptr, SQL_NTS, + tableTypeView, SQL_NTS); + + EXPECT_EQ(ret, SQL_SUCCESS); + + ValidateFetch(this->stmt, SQL_NO_DATA); + + ret = SQLTables(this->stmt, nullptr, SQL_NTS, nullptr, SQL_NTS, nullptr, SQL_NTS, + tableTypeTableView, SQL_NTS); + + EXPECT_EQ(ret, SQL_SUCCESS); + + ValidateFetch(this->stmt, SQL_NO_DATA); + + // Returns user table as well as system tables, even though only type table requested + ret = SQLTables(this->stmt, nullptr, SQL_NTS, nullptr, SQL_NTS, nullptr, SQL_NTS, + tableTypeTableLowercase, SQL_NTS); + + EXPECT_EQ(ret, SQL_SUCCESS); + + for (size_t i = 0; i < sizeof(tableNames) / sizeof(*tableNames); ++i) { + ValidateFetch(this->stmt, SQL_SUCCESS); + + CheckStringColumnW(this->stmt, 1, expectedCatalogName); + // Mock server does not support table schema + CheckNullColumnW(this->stmt, 2); + CheckStringColumnW(this->stmt, 3, tableNames[i]); + CheckStringColumnW(this->stmt, 4, expectedTableType); + CheckNullColumnW(this->stmt, 5); + } + + ValidateFetch(this->stmt, SQL_NO_DATA); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCRemoteTestBase, SQLTablesGetMetadataForTableTypeTable) { + // Requires creation of user table named ODBCTest using schema $scratch in remote server + this->connect(); + + SQLWCHAR* typeList[] = {(SQLWCHAR*)L"TABLE", (SQLWCHAR*)L"TABLE,VIEW"}; + std::wstring expectedSchemaName = std::wstring(L"$scratch"); + std::wstring expectedTableName = std::wstring(L"ODBCTest"); + std::wstring expectedTableType = std::wstring(L"TABLE"); + SQLRETURN ret = SQL_SUCCESS; + + for (size_t i = 0; i < sizeof(typeList) / sizeof(*typeList); ++i) { + ret = SQLTables(this->stmt, nullptr, SQL_NTS, nullptr, SQL_NTS, nullptr, SQL_NTS, + typeList[i], SQL_NTS); + + EXPECT_EQ(ret, SQL_SUCCESS); + + ValidateFetch(this->stmt, SQL_SUCCESS); + + CheckNullColumnW(this->stmt, 1); + CheckStringColumnW(this->stmt, 2, expectedSchemaName); + CheckStringColumnW(this->stmt, 3, expectedTableName); + CheckStringColumnW(this->stmt, 4, expectedTableType); + CheckNullColumnW(this->stmt, 5); + + ValidateFetch(this->stmt, SQL_NO_DATA); + } + + this->disconnect(); +} + +TEST_F(FlightSQLODBCRemoteTestBase, SQLTablesGetMetadataForTableTypeViewHasNoData) { + this->connect(); + + SQLWCHAR empty[] = L""; + SQLWCHAR typeView[] = L"VIEW"; + + SQLRETURN ret = SQLTables(this->stmt, nullptr, SQL_NTS, nullptr, SQL_NTS, empty, + SQL_NTS, typeView, SQL_NTS); + + EXPECT_EQ(ret, SQL_SUCCESS); + + ValidateFetch(this->stmt, SQL_NO_DATA); + + ret = SQLTables(this->stmt, nullptr, SQL_NTS, nullptr, SQL_NTS, nullptr, SQL_NTS, + typeView, SQL_NTS); + + EXPECT_EQ(ret, SQL_SUCCESS); + + ValidateFetch(this->stmt, SQL_NO_DATA); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, SQLTablesGetSupportedTableTypes) { + this->connect(); + + SQLWCHAR empty[] = L""; + SQLWCHAR SQL_ALL_TABLE_TYPES_W[] = L"%"; + std::wstring expectedTableType = std::wstring(L"table"); + + // Mock server returns lower case for supported type of "table" + SQLRETURN ret = SQLTables(this->stmt, empty, SQL_NTS, empty, SQL_NTS, empty, SQL_NTS, + SQL_ALL_TABLE_TYPES_W, SQL_NTS); + + EXPECT_EQ(ret, SQL_SUCCESS); + + ValidateFetch(this->stmt, SQL_SUCCESS); + + CheckNullColumnW(this->stmt, 1); + CheckNullColumnW(this->stmt, 2); + CheckNullColumnW(this->stmt, 3); + CheckStringColumnW(this->stmt, 4, expectedTableType); + CheckNullColumnW(this->stmt, 5); + + ValidateFetch(this->stmt, SQL_NO_DATA); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCRemoteTestBase, SQLTablesGetSupportedTableTypes) { + this->connect(); + + SQLWCHAR empty[] = L""; + SQLWCHAR SQL_ALL_TABLE_TYPES_W[] = L"%"; + SQLWCHAR* typeLists[] = {(SQLWCHAR*)L"TABLE", (SQLWCHAR*)L"SYSTEM_TABLE", + (SQLWCHAR*)L"VIEW"}; + + SQLRETURN ret = SQLTables(this->stmt, empty, SQL_NTS, empty, SQL_NTS, empty, SQL_NTS, + SQL_ALL_TABLE_TYPES_W, SQL_NTS); + + EXPECT_EQ(ret, SQL_SUCCESS); + + for (size_t i = 0; i < sizeof(typeLists) / sizeof(*typeLists); ++i) { + ValidateFetch(this->stmt, SQL_SUCCESS); + + CheckNullColumnW(this->stmt, 1); + CheckNullColumnW(this->stmt, 2); + CheckNullColumnW(this->stmt, 3); + CheckStringColumnW(this->stmt, 4, typeLists[i]); + CheckNullColumnW(this->stmt, 5); + } + + ValidateFetch(this->stmt, SQL_NO_DATA); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, SQLTablesGetMetadataBySQLDescribeCol) { + this->connect(); + + SQLWCHAR columnName[1024]; + constexpr SQLINTEGER bufCharLen = sizeof(columnName) / ODBC::GetSqlWCharSize(); + SQLSMALLINT nameLength = 0; + SQLSMALLINT columnDataType = 0; + SQLULEN columnSize = 0; + SQLSMALLINT decimalDigits = 0; + SQLSMALLINT nullable = 0; + size_t columnIndex = 0; + + SQLWCHAR* columnNames[] = {(SQLWCHAR*)L"TABLE_CAT", (SQLWCHAR*)L"TABLE_SCHEM", + (SQLWCHAR*)L"TABLE_NAME", (SQLWCHAR*)L"TABLE_TYPE", + (SQLWCHAR*)L"REMARKS"}; + SQLSMALLINT columnDataTypes[] = {SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, + SQL_WVARCHAR}; + SQLULEN columnSizes[] = {1024, 1024, 1024, 1024, 1024}; + + SQLRETURN ret = SQLTables(this->stmt, nullptr, SQL_NTS, nullptr, SQL_NTS, nullptr, + SQL_NTS, nullptr, SQL_NTS); + + EXPECT_EQ(ret, SQL_SUCCESS); + + for (size_t i = 0; i < sizeof(columnNames) / sizeof(*columnNames); ++i) { + columnIndex = i + 1; + + ret = SQLDescribeCol(this->stmt, columnIndex, columnName, bufCharLen, &nameLength, + &columnDataType, &columnSize, &decimalDigits, &nullable); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_EQ(nameLength, wcslen(columnNames[i])); + + std::wstring returned(columnName, columnName + nameLength); + EXPECT_EQ(returned, columnNames[i]); + EXPECT_EQ(columnDataType, columnDataTypes[i]); + EXPECT_EQ(columnSize, columnSizes[i]); + EXPECT_EQ(decimalDigits, 0); + EXPECT_EQ(nullable, SQL_NULLABLE); + + nameLength = 0; + columnDataType = 0; + columnSize = 0; + decimalDigits = 0; + nullable = 0; + } + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, SQLTablesGetMetadataBySQLDescribeColODBC2) { + this->connect(SQL_OV_ODBC2); + + SQLWCHAR columnName[1024]; + constexpr SQLINTEGER bufCharLen = sizeof(columnName) / ODBC::GetSqlWCharSize(); + SQLSMALLINT nameLength = 0; + SQLSMALLINT columnDataType = 0; + SQLULEN columnSize = 0; + SQLSMALLINT decimalDigits = 0; + SQLSMALLINT nullable = 0; + size_t columnIndex = 0; + + SQLWCHAR* columnNames[] = {(SQLWCHAR*)L"TABLE_QUALIFIER", (SQLWCHAR*)L"TABLE_OWNER", + (SQLWCHAR*)L"TABLE_NAME", (SQLWCHAR*)L"TABLE_TYPE", + (SQLWCHAR*)L"REMARKS"}; + SQLSMALLINT columnDataTypes[] = {SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, + SQL_WVARCHAR}; + SQLULEN columnSizes[] = {1024, 1024, 1024, 1024, 1024}; + + SQLRETURN ret = SQLTables(this->stmt, nullptr, SQL_NTS, nullptr, SQL_NTS, nullptr, + SQL_NTS, nullptr, SQL_NTS); + + EXPECT_EQ(ret, SQL_SUCCESS); + + for (size_t i = 0; i < sizeof(columnNames) / sizeof(*columnNames); ++i) { + columnIndex = i + 1; + + ret = SQLDescribeCol(this->stmt, columnIndex, columnName, bufCharLen, &nameLength, + &columnDataType, &columnSize, &decimalDigits, &nullable); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_EQ(nameLength, wcslen(columnNames[i])); + + std::wstring returned(columnName, columnName + nameLength); + EXPECT_EQ(returned, columnNames[i]); + EXPECT_EQ(columnDataType, columnDataTypes[i]); + EXPECT_EQ(columnSize, columnSizes[i]); + EXPECT_EQ(decimalDigits, 0); + EXPECT_EQ(nullable, SQL_NULLABLE); + + nameLength = 0; + columnDataType = 0; + columnSize = 0; + decimalDigits = 0; + nullable = 0; + } + + this->disconnect(); +} +} // namespace arrow::flight::sql::odbc diff --git a/cpp/src/arrow/flight/sql/odbc/tests/type_info_test.cc b/cpp/src/arrow/flight/sql/odbc/tests/type_info_test.cc new file mode 100644 index 00000000000..29d737c38d6 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/tests/type_info_test.cc @@ -0,0 +1,2100 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +#include "arrow/flight/sql/odbc/tests/odbc_test_suite.h" + +#ifdef _WIN32 +# include +#endif + +#include +#include +#include + +#include "gtest/gtest.h" + +namespace arrow::flight::sql::odbc { + +using std::optional; + +void checkSQLDescribeCol(SQLHSTMT stmt, const SQLUSMALLINT columnIndex, + const std::wstring& expectedName, + const SQLSMALLINT& expectedDataType, + const SQLULEN& expectedColumnSize, + const SQLSMALLINT& expectedDecimalDigits, + const SQLSMALLINT& expectedNullable) { + SQLWCHAR columnName[1024]; + constexpr SQLINTEGER bufCharLen = sizeof(columnName) / ODBC::GetSqlWCharSize(); + SQLSMALLINT nameLength = 0; + SQLSMALLINT columnDataType = 0; + SQLULEN columnSize = 0; + SQLSMALLINT decimalDigits = 0; + SQLSMALLINT nullable = 0; + + SQLRETURN ret = SQLDescribeCol(stmt, columnIndex, columnName, bufCharLen, &nameLength, + &columnDataType, &columnSize, &decimalDigits, &nullable); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_EQ(nameLength, expectedName.size()); + + std::wstring returned(columnName, columnName + nameLength); + EXPECT_EQ(returned, expectedName); + EXPECT_EQ(columnDataType, expectedDataType); + EXPECT_EQ(columnSize, expectedColumnSize); + EXPECT_EQ(decimalDigits, expectedDecimalDigits); + EXPECT_EQ(nullable, expectedNullable); +} + +void checkSQLDescribeColODBC2(SQLHSTMT stmt) { + SQLWCHAR* columnNames[] = {(SQLWCHAR*)L"TYPE_NAME", + (SQLWCHAR*)L"DATA_TYPE", + (SQLWCHAR*)L"PRECISION", + (SQLWCHAR*)L"LITERAL_PREFIX", + (SQLWCHAR*)L"LITERAL_SUFFIX", + (SQLWCHAR*)L"CREATE_PARAMS", + (SQLWCHAR*)L"NULLABLE", + (SQLWCHAR*)L"CASE_SENSITIVE", + (SQLWCHAR*)L"SEARCHABLE", + (SQLWCHAR*)L"UNSIGNED_ATTRIBUTE", + (SQLWCHAR*)L"MONEY", + (SQLWCHAR*)L"AUTO_INCREMENT", + (SQLWCHAR*)L"LOCAL_TYPE_NAME", + (SQLWCHAR*)L"MINIMUM_SCALE", + (SQLWCHAR*)L"MAXIMUM_SCALE", + (SQLWCHAR*)L"SQL_DATA_TYPE", + (SQLWCHAR*)L"SQL_DATETIME_SUB", + (SQLWCHAR*)L"NUM_PREC_RADIX", + (SQLWCHAR*)L"INTERVAL_PRECISION"}; + SQLSMALLINT columnDataTypes[] = {SQL_WVARCHAR, SQL_SMALLINT, SQL_INTEGER, SQL_WVARCHAR, + SQL_WVARCHAR, SQL_WVARCHAR, SQL_SMALLINT, SQL_SMALLINT, + SQL_SMALLINT, SQL_SMALLINT, SQL_SMALLINT, SQL_SMALLINT, + SQL_WVARCHAR, SQL_SMALLINT, SQL_SMALLINT, SQL_SMALLINT, + SQL_SMALLINT, SQL_INTEGER, SQL_SMALLINT}; + SQLULEN columnSizes[] = {1024, 2, 4, 1024, 1024, 1024, 2, 2, 2, 2, + 2, 2, 1024, 2, 2, 2, 2, 4, 2}; + SQLSMALLINT columnDecimalDigits[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + SQLSMALLINT columnNullable[] = {SQL_NO_NULLS, SQL_NO_NULLS, SQL_NULLABLE, SQL_NULLABLE, + SQL_NULLABLE, SQL_NULLABLE, SQL_NO_NULLS, SQL_NO_NULLS, + SQL_NO_NULLS, SQL_NULLABLE, SQL_NO_NULLS, SQL_NULLABLE, + SQL_NULLABLE, SQL_NULLABLE, SQL_NULLABLE, SQL_NO_NULLS, + SQL_NULLABLE, SQL_NULLABLE, SQL_NULLABLE}; + + for (size_t i = 0; i < sizeof(columnNames) / sizeof(*columnNames); ++i) { + SQLUSMALLINT columnIndex = i + 1; + checkSQLDescribeCol(stmt, columnIndex, columnNames[i], columnDataTypes[i], + columnSizes[i], columnDecimalDigits[i], columnNullable[i]); + } +} + +void checkSQLDescribeColODBC3(SQLHSTMT stmt) { + SQLWCHAR* columnNames[] = { + (SQLWCHAR*)L"TYPE_NAME", (SQLWCHAR*)L"DATA_TYPE", + (SQLWCHAR*)L"COLUMN_SIZE", (SQLWCHAR*)L"LITERAL_PREFIX", + (SQLWCHAR*)L"LITERAL_SUFFIX", (SQLWCHAR*)L"CREATE_PARAMS", + (SQLWCHAR*)L"NULLABLE", (SQLWCHAR*)L"CASE_SENSITIVE", + (SQLWCHAR*)L"SEARCHABLE", (SQLWCHAR*)L"UNSIGNED_ATTRIBUTE", + (SQLWCHAR*)L"FIXED_PREC_SCALE", (SQLWCHAR*)L"AUTO_UNIQUE_VALUE", + (SQLWCHAR*)L"LOCAL_TYPE_NAME", (SQLWCHAR*)L"MINIMUM_SCALE", + (SQLWCHAR*)L"MAXIMUM_SCALE", (SQLWCHAR*)L"SQL_DATA_TYPE", + (SQLWCHAR*)L"SQL_DATETIME_SUB", (SQLWCHAR*)L"NUM_PREC_RADIX", + (SQLWCHAR*)L"INTERVAL_PRECISION"}; + SQLSMALLINT columnDataTypes[] = {SQL_WVARCHAR, SQL_SMALLINT, SQL_INTEGER, SQL_WVARCHAR, + SQL_WVARCHAR, SQL_WVARCHAR, SQL_SMALLINT, SQL_SMALLINT, + SQL_SMALLINT, SQL_SMALLINT, SQL_SMALLINT, SQL_SMALLINT, + SQL_WVARCHAR, SQL_SMALLINT, SQL_SMALLINT, SQL_SMALLINT, + SQL_SMALLINT, SQL_INTEGER, SQL_SMALLINT}; + SQLULEN columnSizes[] = {1024, 2, 4, 1024, 1024, 1024, 2, 2, 2, 2, + 2, 2, 1024, 2, 2, 2, 2, 4, 2}; + SQLSMALLINT columnDecimalDigits[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + SQLSMALLINT columnNullable[] = {SQL_NO_NULLS, SQL_NO_NULLS, SQL_NULLABLE, SQL_NULLABLE, + SQL_NULLABLE, SQL_NULLABLE, SQL_NO_NULLS, SQL_NO_NULLS, + SQL_NO_NULLS, SQL_NULLABLE, SQL_NO_NULLS, SQL_NULLABLE, + SQL_NULLABLE, SQL_NULLABLE, SQL_NULLABLE, SQL_NO_NULLS, + SQL_NULLABLE, SQL_NULLABLE, SQL_NULLABLE}; + + for (size_t i = 0; i < sizeof(columnNames) / sizeof(*columnNames); ++i) { + SQLUSMALLINT columnIndex = i + 1; + checkSQLDescribeCol(stmt, columnIndex, columnNames[i], columnDataTypes[i], + columnSizes[i], columnDecimalDigits[i], columnNullable[i]); + } +} + +void checkSQLGetTypeInfo( + SQLHSTMT stmt, const std::wstring& expectedTypeName, + const SQLSMALLINT& expectedDataType, const SQLINTEGER& expectedColumnSize, + const optional& expectedLiteralPrefix, + const optional& expectedLiteralSuffix, + const optional& expectedCreateParams, + const SQLSMALLINT& expectedNullable, const SQLSMALLINT& expectedCaseSensitive, + const SQLSMALLINT& expectedSearchable, const SQLSMALLINT& expectedUnsignedAttr, + const SQLSMALLINT& expectedFixedPrecScale, const SQLSMALLINT& expectedAutoUniqueValue, + const std::wstring& expectedLocalTypeName, const SQLSMALLINT& expectedMinScale, + const SQLSMALLINT& expectedMaxScale, const SQLSMALLINT& expectedSqlDataType, + const SQLSMALLINT& expectedSqlDatetimeSub, const SQLINTEGER& expectedNumPrecRadix, + const SQLINTEGER& expectedIntervalPrec) { + CheckStringColumnW(stmt, 1, expectedTypeName); // type name + CheckSmallIntColumn(stmt, 2, expectedDataType); // data type + CheckIntColumn(stmt, 3, expectedColumnSize); // column size + + if (expectedLiteralPrefix) { // literal prefix + CheckStringColumnW(stmt, 4, *expectedLiteralPrefix); + } else { + CheckNullColumnW(stmt, 4); + } + + if (expectedLiteralSuffix) { // literal suffix + CheckStringColumnW(stmt, 5, *expectedLiteralSuffix); + } else { + CheckNullColumnW(stmt, 5); + } + + if (expectedCreateParams) { // create params + CheckStringColumnW(stmt, 6, *expectedCreateParams); + } else { + CheckNullColumnW(stmt, 6); + } + + CheckSmallIntColumn(stmt, 7, expectedNullable); // nullable + CheckSmallIntColumn(stmt, 8, expectedCaseSensitive); // case sensitive + CheckSmallIntColumn(stmt, 9, expectedSearchable); // searchable + CheckSmallIntColumn(stmt, 10, expectedUnsignedAttr); // unsigned attr + CheckSmallIntColumn(stmt, 11, expectedFixedPrecScale); // fixed prec scale + CheckSmallIntColumn(stmt, 12, expectedAutoUniqueValue); // auto unique value + CheckStringColumnW(stmt, 13, expectedLocalTypeName); // local type name + CheckSmallIntColumn(stmt, 14, expectedMinScale); // min scale + CheckSmallIntColumn(stmt, 15, expectedMaxScale); // max scale + CheckSmallIntColumn(stmt, 16, expectedSqlDataType); // sql data type + CheckSmallIntColumn(stmt, 17, expectedSqlDatetimeSub); // sql datetime sub + CheckIntColumn(stmt, 18, expectedNumPrecRadix); // num prec radix + CheckIntColumn(stmt, 19, expectedIntervalPrec); // interval prec +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetTypeInfoAllTypes) { + this->connect(); + + SQLRETURN ret = SQLGetTypeInfo(this->stmt, SQL_ALL_TYPES); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Check bit data type + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkSQLGetTypeInfo(this->stmt, + std::wstring(L"bit"), // expectedTypeName + SQL_BIT, // expectedDataType + 1, // expectedColumnSize + std::nullopt, // expectedLiteralPrefix + std::nullopt, // expectedLiteralSuffix + std::nullopt, // expectedCreateParams + SQL_NULLABLE, // expectedNullable + SQL_FALSE, // expectedCaseSensitive + SQL_SEARCHABLE, // expectedSearchable + NULL, // expectedUnsignedAttr + SQL_FALSE, // expectedFixedPrecScale + NULL, // expectedAutoUniqueValue + std::wstring(L"bit"), // expectedLocalTypeName + NULL, // expectedMinScale + NULL, // expectedMaxScale + SQL_BIT, // expectedSqlDataType + NULL, // expectedSqlDatetimeSub + NULL, // expectedNumPrecRadix + NULL); // expectedIntervalPrec + + checkSQLDescribeColODBC3(this->stmt); + + // Check tinyint data type + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkSQLGetTypeInfo(this->stmt, + std::wstring(L"tinyint"), // expectedTypeName + SQL_TINYINT, // expectedDataType + 3, // expectedColumnSize + std::nullopt, // expectedLiteralPrefix + std::nullopt, // expectedLiteralSuffix + std::nullopt, // expectedCreateParams + SQL_NULLABLE, // expectedNullable + SQL_FALSE, // expectedCaseSensitive + SQL_SEARCHABLE, // expectedSearchable + SQL_FALSE, // expectedUnsignedAttr + SQL_FALSE, // expectedFixedPrecScale + NULL, // expectedAutoUniqueValue + std::wstring(L"tinyint"), // expectedLocalTypeName + NULL, // expectedMinScale + NULL, // expectedMaxScale + SQL_TINYINT, // expectedSqlDataType + NULL, // expectedSqlDatetimeSub + NULL, // expectedNumPrecRadix + NULL); // expectedIntervalPrec + + checkSQLDescribeColODBC3(this->stmt); + + // Check bigint data type + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkSQLGetTypeInfo(this->stmt, + std::wstring(L"bigint"), // expectedTypeName + SQL_BIGINT, // expectedDataType + 19, // expectedColumnSize + std::nullopt, // expectedLiteralPrefix + std::nullopt, // expectedLiteralSuffix + std::nullopt, // expectedCreateParams + SQL_NULLABLE, // expectedNullable + SQL_FALSE, // expectedCaseSensitive + SQL_SEARCHABLE, // expectedSearchable + SQL_FALSE, // expectedUnsignedAttr + SQL_FALSE, // expectedFixedPrecScale + NULL, // expectedAutoUniqueValue + std::wstring(L"bigint"), // expectedLocalTypeName + NULL, // expectedMinScale + NULL, // expectedMaxScale + SQL_BIGINT, // expectedSqlDataType + NULL, // expectedSqlDatetimeSub + NULL, // expectedNumPrecRadix + NULL); // expectedIntervalPrec + + checkSQLDescribeColODBC3(this->stmt); + + // Check longvarbinary data type + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkSQLGetTypeInfo(this->stmt, + std::wstring(L"longvarbinary"), // expectedTypeName + SQL_LONGVARBINARY, // expectedDataType + 65536, // expectedColumnSize + std::nullopt, // expectedLiteralPrefix + std::nullopt, // expectedLiteralSuffix + std::nullopt, // expectedCreateParams + SQL_NULLABLE, // expectedNullable + SQL_FALSE, // expectedCaseSensitive + SQL_SEARCHABLE, // expectedSearchable + NULL, // expectedUnsignedAttr + SQL_FALSE, // expectedFixedPrecScale + NULL, // expectedAutoUniqueValue + std::wstring(L"longvarbinary"), // expectedLocalTypeName + NULL, // expectedMinScale + NULL, // expectedMaxScale + SQL_LONGVARBINARY, // expectedSqlDataType + NULL, // expectedSqlDatetimeSub + NULL, // expectedNumPrecRadix + NULL); // expectedIntervalPrec + + checkSQLDescribeColODBC3(this->stmt); + + // Check varbinary data type + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkSQLGetTypeInfo(this->stmt, + std::wstring(L"varbinary"), // expectedTypeName + SQL_VARBINARY, // expectedDataType + 255, // expectedColumnSize + std::nullopt, // expectedLiteralPrefix + std::nullopt, // expectedLiteralSuffix + std::nullopt, // expectedCreateParams + SQL_NULLABLE, // expectedNullable + SQL_FALSE, // expectedCaseSensitive + SQL_SEARCHABLE, // expectedSearchable + NULL, // expectedUnsignedAttr + SQL_FALSE, // expectedFixedPrecScale + NULL, // expectedAutoUniqueValue + std::wstring(L"varbinary"), // expectedLocalTypeName + NULL, // expectedMinScale + NULL, // expectedMaxScale + SQL_VARBINARY, // expectedSqlDataType + NULL, // expectedSqlDatetimeSub + NULL, // expectedNumPrecRadix + NULL); // expectedIntervalPrec + + checkSQLDescribeColODBC3(this->stmt); + + // Check text data type + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Driver returns SQL_WLONGVARCHAR since unicode is enabled + checkSQLGetTypeInfo(this->stmt, + std::wstring(L"text"), // expectedTypeName + SQL_WLONGVARCHAR, // expectedDataType + 65536, // expectedColumnSize + std::wstring(L"'"), // expectedLiteralPrefix + std::wstring(L"'"), // expectedLiteralSuffix + std::wstring(L"length"), // expectedCreateParams + SQL_NULLABLE, // expectedNullable + SQL_FALSE, // expectedCaseSensitive + SQL_SEARCHABLE, // expectedSearchable + NULL, // expectedUnsignedAttr + SQL_FALSE, // expectedFixedPrecScale + NULL, // expectedAutoUniqueValue + std::wstring(L"text"), // expectedLocalTypeName + NULL, // expectedMinScale + NULL, // expectedMaxScale + SQL_WLONGVARCHAR, // expectedSqlDataType + NULL, // expectedSqlDatetimeSub + NULL, // expectedNumPrecRadix + NULL); // expectedIntervalPrec + + checkSQLDescribeColODBC3(this->stmt); + + // Check longvarchar data type + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkSQLGetTypeInfo(this->stmt, + std::wstring(L"longvarchar"), // expectedTypeName + SQL_WLONGVARCHAR, // expectedDataType + 65536, // expectedColumnSize + std::wstring(L"'"), // expectedLiteralPrefix + std::wstring(L"'"), // expectedLiteralSuffix + std::wstring(L"length"), // expectedCreateParams + SQL_NULLABLE, // expectedNullable + SQL_FALSE, // expectedCaseSensitive + SQL_SEARCHABLE, // expectedSearchable + NULL, // expectedUnsignedAttr + SQL_FALSE, // expectedFixedPrecScale + NULL, // expectedAutoUniqueValue + std::wstring(L"longvarchar"), // expectedLocalTypeName + NULL, // expectedMinScale + NULL, // expectedMaxScale + SQL_WLONGVARCHAR, // expectedSqlDataType + NULL, // expectedSqlDatetimeSub + NULL, // expectedNumPrecRadix + NULL); // expectedIntervalPrec + + checkSQLDescribeColODBC3(this->stmt); + + // Check char data type + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Driver returns SQL_WCHAR since unicode is enabled + checkSQLGetTypeInfo(this->stmt, + std::wstring(L"char"), // expectedTypeName + SQL_WCHAR, // expectedDataType + 255, // expectedColumnSize + std::wstring(L"'"), // expectedLiteralPrefix + std::wstring(L"'"), // expectedLiteralSuffix + std::wstring(L"length"), // expectedCreateParams + SQL_NULLABLE, // expectedNullable + SQL_FALSE, // expectedCaseSensitive + SQL_SEARCHABLE, // expectedSearchable + NULL, // expectedUnsignedAttr + SQL_FALSE, // expectedFixedPrecScale + NULL, // expectedAutoUniqueValue + std::wstring(L"char"), // expectedLocalTypeName + NULL, // expectedMinScale + NULL, // expectedMaxScale + SQL_WCHAR, // expectedSqlDataType + NULL, // expectedSqlDatetimeSub + NULL, // expectedNumPrecRadix + NULL); // expectedIntervalPrec + + checkSQLDescribeColODBC3(this->stmt); + + // Check integer data type + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkSQLGetTypeInfo(this->stmt, + std::wstring(L"integer"), // expectedTypeName + SQL_INTEGER, // expectedDataType + 9, // expectedColumnSize + std::nullopt, // expectedLiteralPrefix + std::nullopt, // expectedLiteralSuffix + std::nullopt, // expectedCreateParams + SQL_NULLABLE, // expectedNullable + SQL_FALSE, // expectedCaseSensitive + SQL_SEARCHABLE, // expectedSearchable + SQL_FALSE, // expectedUnsignedAttr + SQL_FALSE, // expectedFixedPrecScale + NULL, // expectedAutoUniqueValue + std::wstring(L"integer"), // expectedLocalTypeName + NULL, // expectedMinScale + NULL, // expectedMaxScale + SQL_INTEGER, // expectedSqlDataType + NULL, // expectedSqlDatetimeSub + NULL, // expectedNumPrecRadix + NULL); // expectedIntervalPrec + + checkSQLDescribeColODBC3(this->stmt); + + // Check smallint data type + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkSQLGetTypeInfo(this->stmt, + std::wstring(L"smallint"), // expectedTypeName + SQL_SMALLINT, // expectedDataType + 5, // expectedColumnSize + std::nullopt, // expectedLiteralPrefix + std::nullopt, // expectedLiteralSuffix + std::nullopt, // expectedCreateParams + SQL_NULLABLE, // expectedNullable + SQL_FALSE, // expectedCaseSensitive + SQL_SEARCHABLE, // expectedSearchable + SQL_FALSE, // expectedUnsignedAttr + SQL_FALSE, // expectedFixedPrecScale + NULL, // expectedAutoUniqueValue + std::wstring(L"smallint"), // expectedLocalTypeName + NULL, // expectedMinScale + NULL, // expectedMaxScale + SQL_SMALLINT, // expectedSqlDataType + NULL, // expectedSqlDatetimeSub + NULL, // expectedNumPrecRadix + NULL); // expectedIntervalPrec + + checkSQLDescribeColODBC3(this->stmt); + + // Check float data type + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkSQLGetTypeInfo(this->stmt, + std::wstring(L"float"), // expectedTypeName + SQL_FLOAT, // expectedDataType + 7, // expectedColumnSize + std::nullopt, // expectedLiteralPrefix + std::nullopt, // expectedLiteralSuffix + std::nullopt, // expectedCreateParams + SQL_NULLABLE, // expectedNullable + SQL_FALSE, // expectedCaseSensitive + SQL_SEARCHABLE, // expectedSearchable + SQL_FALSE, // expectedUnsignedAttr + SQL_FALSE, // expectedFixedPrecScale + NULL, // expectedAutoUniqueValue + std::wstring(L"float"), // expectedLocalTypeName + NULL, // expectedMinScale + NULL, // expectedMaxScale + SQL_FLOAT, // expectedSqlDataType + NULL, // expectedSqlDatetimeSub + NULL, // expectedNumPrecRadix + NULL); // expectedIntervalPrec + + checkSQLDescribeColODBC3(this->stmt); + + // Check double data type + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkSQLGetTypeInfo(this->stmt, + std::wstring(L"double"), // expectedTypeName + SQL_DOUBLE, // expectedDataType + 15, // expectedColumnSize + std::nullopt, // expectedLiteralPrefix + std::nullopt, // expectedLiteralSuffix + std::nullopt, // expectedCreateParams + SQL_NULLABLE, // expectedNullable + SQL_FALSE, // expectedCaseSensitive + SQL_SEARCHABLE, // expectedSearchable + SQL_FALSE, // expectedUnsignedAttr + SQL_FALSE, // expectedFixedPrecScale + NULL, // expectedAutoUniqueValue + std::wstring(L"double"), // expectedLocalTypeName + NULL, // expectedMinScale + NULL, // expectedMaxScale + SQL_DOUBLE, // expectedSqlDataType + NULL, // expectedSqlDatetimeSub + NULL, // expectedNumPrecRadix + NULL); // expectedIntervalPrec + + checkSQLDescribeColODBC3(this->stmt); + + // Check numeric data type + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Mock server treats numeric data type as a double type + checkSQLGetTypeInfo(this->stmt, + std::wstring(L"numeric"), // expectedTypeName + SQL_DOUBLE, // expectedDataType + 15, // expectedColumnSize + std::nullopt, // expectedLiteralPrefix + std::nullopt, // expectedLiteralSuffix + std::nullopt, // expectedCreateParams + SQL_NULLABLE, // expectedNullable + SQL_FALSE, // expectedCaseSensitive + SQL_SEARCHABLE, // expectedSearchable + SQL_FALSE, // expectedUnsignedAttr + SQL_FALSE, // expectedFixedPrecScale + NULL, // expectedAutoUniqueValue + std::wstring(L"numeric"), // expectedLocalTypeName + NULL, // expectedMinScale + NULL, // expectedMaxScale + SQL_DOUBLE, // expectedSqlDataType + NULL, // expectedSqlDatetimeSub + NULL, // expectedNumPrecRadix + NULL); // expectedIntervalPrec + + checkSQLDescribeColODBC3(this->stmt); + + // Check varchar data type + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Driver returns SQL_WVARCHAR since unicode is enabled + checkSQLGetTypeInfo(this->stmt, + std::wstring(L"varchar"), // expectedTypeName + SQL_WVARCHAR, // expectedDataType + 255, // expectedColumnSize + std::wstring(L"'"), // expectedLiteralPrefix + std::wstring(L"'"), // expectedLiteralSuffix + std::wstring(L"length"), // expectedCreateParams + SQL_NULLABLE, // expectedNullable + SQL_FALSE, // expectedCaseSensitive + SQL_SEARCHABLE, // expectedSearchable + SQL_FALSE, // expectedUnsignedAttr + SQL_FALSE, // expectedFixedPrecScale + NULL, // expectedAutoUniqueValue + std::wstring(L"varchar"), // expectedLocalTypeName + NULL, // expectedMinScale + NULL, // expectedMaxScale + SQL_WVARCHAR, // expectedSqlDataType + NULL, // expectedSqlDatetimeSub + NULL, // expectedNumPrecRadix + NULL); // expectedIntervalPrec + + checkSQLDescribeColODBC3(this->stmt); + + // Check date data type + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkSQLGetTypeInfo(this->stmt, + std::wstring(L"date"), // expectedTypeName + SQL_TYPE_DATE, // expectedDataType + 10, // expectedColumnSize + std::wstring(L"'"), // expectedLiteralPrefix + std::wstring(L"'"), // expectedLiteralSuffix + std::nullopt, // expectedCreateParams + SQL_NULLABLE, // expectedNullable + SQL_FALSE, // expectedCaseSensitive + SQL_SEARCHABLE, // expectedSearchable + SQL_FALSE, // expectedUnsignedAttr + SQL_FALSE, // expectedFixedPrecScale + NULL, // expectedAutoUniqueValue + std::wstring(L"date"), // expectedLocalTypeName + NULL, // expectedMinScale + NULL, // expectedMaxScale + SQL_DATETIME, // expectedSqlDataType + SQL_CODE_DATE, // expectedSqlDatetimeSub + NULL, // expectedNumPrecRadix + NULL); // expectedIntervalPrec + + checkSQLDescribeColODBC3(this->stmt); + + // Check time data type + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkSQLGetTypeInfo(this->stmt, + std::wstring(L"time"), // expectedTypeName + SQL_TYPE_TIME, // expectedDataType + 8, // expectedColumnSize + std::wstring(L"'"), // expectedLiteralPrefix + std::wstring(L"'"), // expectedLiteralSuffix + std::nullopt, // expectedCreateParams + SQL_NULLABLE, // expectedNullable + SQL_FALSE, // expectedCaseSensitive + SQL_SEARCHABLE, // expectedSearchable + SQL_FALSE, // expectedUnsignedAttr + SQL_FALSE, // expectedFixedPrecScale + NULL, // expectedAutoUniqueValue + std::wstring(L"time"), // expectedLocalTypeName + NULL, // expectedMinScale + NULL, // expectedMaxScale + SQL_DATETIME, // expectedSqlDataType + SQL_CODE_TIME, // expectedSqlDatetimeSub + NULL, // expectedNumPrecRadix + NULL); // expectedIntervalPrec + + checkSQLDescribeColODBC3(this->stmt); + + // Check timestamp data type + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkSQLGetTypeInfo(this->stmt, + std::wstring(L"timestamp"), // expectedTypeName + SQL_TYPE_TIMESTAMP, // expectedDataType + 32, // expectedColumnSize + std::wstring(L"'"), // expectedLiteralPrefix + std::wstring(L"'"), // expectedLiteralSuffix + std::nullopt, // expectedCreateParams + SQL_NULLABLE, // expectedNullable + SQL_FALSE, // expectedCaseSensitive + SQL_SEARCHABLE, // expectedSearchable + SQL_FALSE, // expectedUnsignedAttr + SQL_FALSE, // expectedFixedPrecScale + NULL, // expectedAutoUniqueValue + std::wstring(L"timestamp"), // expectedLocalTypeName + NULL, // expectedMinScale + NULL, // expectedMaxScale + SQL_DATETIME, // expectedSqlDataType + SQL_CODE_TIMESTAMP, // expectedSqlDatetimeSub + NULL, // expectedNumPrecRadix + NULL); // expectedIntervalPrec + + checkSQLDescribeColODBC3(this->stmt); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetTypeInfoAllTypesODBCVer2) { + this->connect(SQL_OV_ODBC2); + + SQLRETURN ret = SQLGetTypeInfo(this->stmt, SQL_ALL_TYPES); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Check bit data type + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkSQLGetTypeInfo(this->stmt, + std::wstring(L"bit"), // expectedTypeName + SQL_BIT, // expectedDataType + 1, // expectedColumnSize + std::nullopt, // expectedLiteralPrefix + std::nullopt, // expectedLiteralSuffix + std::nullopt, // expectedCreateParams + SQL_NULLABLE, // expectedNullable + SQL_FALSE, // expectedCaseSensitive + SQL_SEARCHABLE, // expectedSearchable + NULL, // expectedUnsignedAttr + SQL_FALSE, // expectedFixedPrecScale + NULL, // expectedAutoUniqueValue + std::wstring(L"bit"), // expectedLocalTypeName + NULL, // expectedMinScale + NULL, // expectedMaxScale + SQL_BIT, // expectedSqlDataType + NULL, // expectedSqlDatetimeSub + NULL, // expectedNumPrecRadix + NULL); // expectedIntervalPrec + + checkSQLDescribeColODBC2(this->stmt); + + // Check tinyint data type + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkSQLGetTypeInfo(this->stmt, + std::wstring(L"tinyint"), // expectedTypeName + SQL_TINYINT, // expectedDataType + 3, // expectedColumnSize + std::nullopt, // expectedLiteralPrefix + std::nullopt, // expectedLiteralSuffix + std::nullopt, // expectedCreateParams + SQL_NULLABLE, // expectedNullable + SQL_FALSE, // expectedCaseSensitive + SQL_SEARCHABLE, // expectedSearchable + SQL_FALSE, // expectedUnsignedAttr + SQL_FALSE, // expectedFixedPrecScale + NULL, // expectedAutoUniqueValue + std::wstring(L"tinyint"), // expectedLocalTypeName + NULL, // expectedMinScale + NULL, // expectedMaxScale + SQL_TINYINT, // expectedSqlDataType + NULL, // expectedSqlDatetimeSub + NULL, // expectedNumPrecRadix + NULL); // expectedIntervalPrec + + checkSQLDescribeColODBC2(this->stmt); + + // Check bigint data type + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkSQLGetTypeInfo(this->stmt, + std::wstring(L"bigint"), // expectedTypeName + SQL_BIGINT, // expectedDataType + 19, // expectedColumnSize + std::nullopt, // expectedLiteralPrefix + std::nullopt, // expectedLiteralSuffix + std::nullopt, // expectedCreateParams + SQL_NULLABLE, // expectedNullable + SQL_FALSE, // expectedCaseSensitive + SQL_SEARCHABLE, // expectedSearchable + SQL_FALSE, // expectedUnsignedAttr + SQL_FALSE, // expectedFixedPrecScale + NULL, // expectedAutoUniqueValue + std::wstring(L"bigint"), // expectedLocalTypeName + NULL, // expectedMinScale + NULL, // expectedMaxScale + SQL_BIGINT, // expectedSqlDataType + NULL, // expectedSqlDatetimeSub + NULL, // expectedNumPrecRadix + NULL); // expectedIntervalPrec + + checkSQLDescribeColODBC2(this->stmt); + + // Check longvarbinary data type + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkSQLGetTypeInfo(this->stmt, + std::wstring(L"longvarbinary"), // expectedTypeName + SQL_LONGVARBINARY, // expectedDataType + 65536, // expectedColumnSize + std::nullopt, // expectedLiteralPrefix + std::nullopt, // expectedLiteralSuffix + std::nullopt, // expectedCreateParams + SQL_NULLABLE, // expectedNullable + SQL_FALSE, // expectedCaseSensitive + SQL_SEARCHABLE, // expectedSearchable + NULL, // expectedUnsignedAttr + SQL_FALSE, // expectedFixedPrecScale + NULL, // expectedAutoUniqueValue + std::wstring(L"longvarbinary"), // expectedLocalTypeName + NULL, // expectedMinScale + NULL, // expectedMaxScale + SQL_LONGVARBINARY, // expectedSqlDataType + NULL, // expectedSqlDatetimeSub + NULL, // expectedNumPrecRadix + NULL); // expectedIntervalPrec + + checkSQLDescribeColODBC2(this->stmt); + + // Check varbinary data type + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkSQLGetTypeInfo(this->stmt, + std::wstring(L"varbinary"), // expectedTypeName + SQL_VARBINARY, // expectedDataType + 255, // expectedColumnSize + std::nullopt, // expectedLiteralPrefix + std::nullopt, // expectedLiteralSuffix + std::nullopt, // expectedCreateParams + SQL_NULLABLE, // expectedNullable + SQL_FALSE, // expectedCaseSensitive + SQL_SEARCHABLE, // expectedSearchable + NULL, // expectedUnsignedAttr + SQL_FALSE, // expectedFixedPrecScale + NULL, // expectedAutoUniqueValue + std::wstring(L"varbinary"), // expectedLocalTypeName + NULL, // expectedMinScale + NULL, // expectedMaxScale + SQL_VARBINARY, // expectedSqlDataType + NULL, // expectedSqlDatetimeSub + NULL, // expectedNumPrecRadix + NULL); // expectedIntervalPrec + + checkSQLDescribeColODBC2(this->stmt); + + // Check text data type + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Driver returns SQL_WLONGVARCHAR since unicode is enabled + checkSQLGetTypeInfo(this->stmt, + std::wstring(L"text"), // expectedTypeName + SQL_WLONGVARCHAR, // expectedDataType + 65536, // expectedColumnSize + std::wstring(L"'"), // expectedLiteralPrefix + std::wstring(L"'"), // expectedLiteralSuffix + std::wstring(L"length"), // expectedCreateParams + SQL_NULLABLE, // expectedNullable + SQL_FALSE, // expectedCaseSensitive + SQL_SEARCHABLE, // expectedSearchable + NULL, // expectedUnsignedAttr + SQL_FALSE, // expectedFixedPrecScale + NULL, // expectedAutoUniqueValue + std::wstring(L"text"), // expectedLocalTypeName + NULL, // expectedMinScale + NULL, // expectedMaxScale + SQL_WLONGVARCHAR, // expectedSqlDataType + NULL, // expectedSqlDatetimeSub + NULL, // expectedNumPrecRadix + NULL); // expectedIntervalPrec + + checkSQLDescribeColODBC2(this->stmt); + + // Check longvarchar data type + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkSQLGetTypeInfo(this->stmt, + std::wstring(L"longvarchar"), // expectedTypeName + SQL_WLONGVARCHAR, // expectedDataType + 65536, // expectedColumnSize + std::wstring(L"'"), // expectedLiteralPrefix + std::wstring(L"'"), // expectedLiteralSuffix + std::wstring(L"length"), // expectedCreateParams + SQL_NULLABLE, // expectedNullable + SQL_FALSE, // expectedCaseSensitive + SQL_SEARCHABLE, // expectedSearchable + NULL, // expectedUnsignedAttr + SQL_FALSE, // expectedFixedPrecScale + NULL, // expectedAutoUniqueValue + std::wstring(L"longvarchar"), // expectedLocalTypeName + NULL, // expectedMinScale + NULL, // expectedMaxScale + SQL_WLONGVARCHAR, // expectedSqlDataType + NULL, // expectedSqlDatetimeSub + NULL, // expectedNumPrecRadix + NULL); // expectedIntervalPrec + + checkSQLDescribeColODBC2(this->stmt); + + // Check char data type + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Driver returns SQL_WCHAR since unicode is enabled + checkSQLGetTypeInfo(this->stmt, + std::wstring(L"char"), // expectedTypeName + SQL_WCHAR, // expectedDataType + 255, // expectedColumnSize + std::wstring(L"'"), // expectedLiteralPrefix + std::wstring(L"'"), // expectedLiteralSuffix + std::wstring(L"length"), // expectedCreateParams + SQL_NULLABLE, // expectedNullable + SQL_FALSE, // expectedCaseSensitive + SQL_SEARCHABLE, // expectedSearchable + NULL, // expectedUnsignedAttr + SQL_FALSE, // expectedFixedPrecScale + NULL, // expectedAutoUniqueValue + std::wstring(L"char"), // expectedLocalTypeName + NULL, // expectedMinScale + NULL, // expectedMaxScale + SQL_WCHAR, // expectedSqlDataType + NULL, // expectedSqlDatetimeSub + NULL, // expectedNumPrecRadix + NULL); // expectedIntervalPrec + + checkSQLDescribeColODBC2(this->stmt); + + // Check integer data type + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkSQLGetTypeInfo(this->stmt, + std::wstring(L"integer"), // expectedTypeName + SQL_INTEGER, // expectedDataType + 9, // expectedColumnSize + std::nullopt, // expectedLiteralPrefix + std::nullopt, // expectedLiteralSuffix + std::nullopt, // expectedCreateParams + SQL_NULLABLE, // expectedNullable + SQL_FALSE, // expectedCaseSensitive + SQL_SEARCHABLE, // expectedSearchable + SQL_FALSE, // expectedUnsignedAttr + SQL_FALSE, // expectedFixedPrecScale + NULL, // expectedAutoUniqueValue + std::wstring(L"integer"), // expectedLocalTypeName + NULL, // expectedMinScale + NULL, // expectedMaxScale + SQL_INTEGER, // expectedSqlDataType + NULL, // expectedSqlDatetimeSub + NULL, // expectedNumPrecRadix + NULL); // expectedIntervalPrec + + checkSQLDescribeColODBC2(this->stmt); + + // Check smallint data type + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkSQLGetTypeInfo(this->stmt, + std::wstring(L"smallint"), // expectedTypeName + SQL_SMALLINT, // expectedDataType + 5, // expectedColumnSize + std::nullopt, // expectedLiteralPrefix + std::nullopt, // expectedLiteralSuffix + std::nullopt, // expectedCreateParams + SQL_NULLABLE, // expectedNullable + SQL_FALSE, // expectedCaseSensitive + SQL_SEARCHABLE, // expectedSearchable + SQL_FALSE, // expectedUnsignedAttr + SQL_FALSE, // expectedFixedPrecScale + NULL, // expectedAutoUniqueValue + std::wstring(L"smallint"), // expectedLocalTypeName + NULL, // expectedMinScale + NULL, // expectedMaxScale + SQL_SMALLINT, // expectedSqlDataType + NULL, // expectedSqlDatetimeSub + NULL, // expectedNumPrecRadix + NULL); // expectedIntervalPrec + + checkSQLDescribeColODBC2(this->stmt); + + // Check float data type + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkSQLGetTypeInfo(this->stmt, + std::wstring(L"float"), // expectedTypeName + SQL_FLOAT, // expectedDataType + 7, // expectedColumnSize + std::nullopt, // expectedLiteralPrefix + std::nullopt, // expectedLiteralSuffix + std::nullopt, // expectedCreateParams + SQL_NULLABLE, // expectedNullable + SQL_FALSE, // expectedCaseSensitive + SQL_SEARCHABLE, // expectedSearchable + SQL_FALSE, // expectedUnsignedAttr + SQL_FALSE, // expectedFixedPrecScale + NULL, // expectedAutoUniqueValue + std::wstring(L"float"), // expectedLocalTypeName + NULL, // expectedMinScale + NULL, // expectedMaxScale + SQL_FLOAT, // expectedSqlDataType + NULL, // expectedSqlDatetimeSub + NULL, // expectedNumPrecRadix + NULL); // expectedIntervalPrec + + checkSQLDescribeColODBC2(this->stmt); + + // Check double data type + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkSQLGetTypeInfo(this->stmt, + std::wstring(L"double"), // expectedTypeName + SQL_DOUBLE, // expectedDataType + 15, // expectedColumnSize + std::nullopt, // expectedLiteralPrefix + std::nullopt, // expectedLiteralSuffix + std::nullopt, // expectedCreateParams + SQL_NULLABLE, // expectedNullable + SQL_FALSE, // expectedCaseSensitive + SQL_SEARCHABLE, // expectedSearchable + SQL_FALSE, // expectedUnsignedAttr + SQL_FALSE, // expectedFixedPrecScale + NULL, // expectedAutoUniqueValue + std::wstring(L"double"), // expectedLocalTypeName + NULL, // expectedMinScale + NULL, // expectedMaxScale + SQL_DOUBLE, // expectedSqlDataType + NULL, // expectedSqlDatetimeSub + NULL, // expectedNumPrecRadix + NULL); // expectedIntervalPrec + + checkSQLDescribeColODBC2(this->stmt); + + // Check numeric data type + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Mock server treats numeric data type as a double type + checkSQLGetTypeInfo(this->stmt, + std::wstring(L"numeric"), // expectedTypeName + SQL_DOUBLE, // expectedDataType + 15, // expectedColumnSize + std::nullopt, // expectedLiteralPrefix + std::nullopt, // expectedLiteralSuffix + std::nullopt, // expectedCreateParams + SQL_NULLABLE, // expectedNullable + SQL_FALSE, // expectedCaseSensitive + SQL_SEARCHABLE, // expectedSearchable + SQL_FALSE, // expectedUnsignedAttr + SQL_FALSE, // expectedFixedPrecScale + NULL, // expectedAutoUniqueValue + std::wstring(L"numeric"), // expectedLocalTypeName + NULL, // expectedMinScale + NULL, // expectedMaxScale + SQL_DOUBLE, // expectedSqlDataType + NULL, // expectedSqlDatetimeSub + NULL, // expectedNumPrecRadix + NULL); // expectedIntervalPrec + + checkSQLDescribeColODBC2(this->stmt); + + // Check varchar data type + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Driver returns SQL_WVARCHAR since unicode is enabled + checkSQLGetTypeInfo(this->stmt, + std::wstring(L"varchar"), // expectedTypeName + SQL_WVARCHAR, // expectedDataType + 255, // expectedColumnSize + std::wstring(L"'"), // expectedLiteralPrefix + std::wstring(L"'"), // expectedLiteralSuffix + std::wstring(L"length"), // expectedCreateParams + SQL_NULLABLE, // expectedNullable + SQL_FALSE, // expectedCaseSensitive + SQL_SEARCHABLE, // expectedSearchable + SQL_FALSE, // expectedUnsignedAttr + SQL_FALSE, // expectedFixedPrecScale + NULL, // expectedAutoUniqueValue + std::wstring(L"varchar"), // expectedLocalTypeName + NULL, // expectedMinScale + NULL, // expectedMaxScale + SQL_WVARCHAR, // expectedSqlDataType + NULL, // expectedSqlDatetimeSub + NULL, // expectedNumPrecRadix + NULL); // expectedIntervalPrec + + checkSQLDescribeColODBC2(this->stmt); + + // Check date data type + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkSQLGetTypeInfo(this->stmt, + std::wstring(L"date"), // expectedTypeName + SQL_DATE, // expectedDataType + 10, // expectedColumnSize + std::wstring(L"'"), // expectedLiteralPrefix + std::wstring(L"'"), // expectedLiteralSuffix + std::nullopt, // expectedCreateParams + SQL_NULLABLE, // expectedNullable + SQL_FALSE, // expectedCaseSensitive + SQL_SEARCHABLE, // expectedSearchable + SQL_FALSE, // expectedUnsignedAttr + SQL_FALSE, // expectedFixedPrecScale + NULL, // expectedAutoUniqueValue + std::wstring(L"date"), // expectedLocalTypeName + NULL, // expectedMinScale + NULL, // expectedMaxScale + SQL_DATETIME, // expectedSqlDataType + NULL, // expectedSqlDatetimeSub, driver returns NULL for Ver2 + NULL, // expectedNumPrecRadix + NULL); // expectedIntervalPrec + + checkSQLDescribeColODBC2(this->stmt); + + // Check time data type + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkSQLGetTypeInfo(this->stmt, + std::wstring(L"time"), // expectedTypeName + SQL_TIME, // expectedDataType + 8, // expectedColumnSize + std::wstring(L"'"), // expectedLiteralPrefix + std::wstring(L"'"), // expectedLiteralSuffix + std::nullopt, // expectedCreateParams + SQL_NULLABLE, // expectedNullable + SQL_FALSE, // expectedCaseSensitive + SQL_SEARCHABLE, // expectedSearchable + SQL_FALSE, // expectedUnsignedAttr + SQL_FALSE, // expectedFixedPrecScale + NULL, // expectedAutoUniqueValue + std::wstring(L"time"), // expectedLocalTypeName + NULL, // expectedMinScale + NULL, // expectedMaxScale + SQL_DATETIME, // expectedSqlDataType + NULL, // expectedSqlDatetimeSub, driver returns NULL for Ver2 + NULL, // expectedNumPrecRadix + NULL); // expectedIntervalPrec + + checkSQLDescribeColODBC2(this->stmt); + + // Check timestamp data type + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkSQLGetTypeInfo(this->stmt, + std::wstring(L"timestamp"), // expectedTypeName + SQL_TIMESTAMP, // expectedDataType + 32, // expectedColumnSize + std::wstring(L"'"), // expectedLiteralPrefix + std::wstring(L"'"), // expectedLiteralSuffix + std::nullopt, // expectedCreateParams + SQL_NULLABLE, // expectedNullable + SQL_FALSE, // expectedCaseSensitive + SQL_SEARCHABLE, // expectedSearchable + SQL_FALSE, // expectedUnsignedAttr + SQL_FALSE, // expectedFixedPrecScale + NULL, // expectedAutoUniqueValue + std::wstring(L"timestamp"), // expectedLocalTypeName + NULL, // expectedMinScale + NULL, // expectedMaxScale + SQL_DATETIME, // expectedSqlDataType + NULL, // expectedSqlDatetimeSub, driver returns NULL for Ver2 + NULL, // expectedNumPrecRadix + NULL); // expectedIntervalPrec + + checkSQLDescribeColODBC2(this->stmt); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetTypeInfoBit) { + this->connect(); + + SQLRETURN ret = SQLGetTypeInfo(this->stmt, SQL_BIT); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Check bit data type + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkSQLGetTypeInfo(this->stmt, + std::wstring(L"bit"), // expectedTypeName + SQL_BIT, // expectedDataType + 1, // expectedColumnSize + std::nullopt, // expectedLiteralPrefix + std::nullopt, // expectedLiteralSuffix + std::nullopt, // expectedCreateParams + SQL_NULLABLE, // expectedNullable + SQL_FALSE, // expectedCaseSensitive + SQL_SEARCHABLE, // expectedSearchable + NULL, // expectedUnsignedAttr + SQL_FALSE, // expectedFixedPrecScale + NULL, // expectedAutoUniqueValue + std::wstring(L"bit"), // expectedLocalTypeName + NULL, // expectedMinScale + NULL, // expectedMaxScale + SQL_BIT, // expectedSqlDataType + NULL, // expectedSqlDatetimeSub + NULL, // expectedNumPrecRadix + NULL); // expectedIntervalPrec + + checkSQLDescribeColODBC3(this->stmt); + + // No more data + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_NO_DATA); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetTypeInfoTinyInt) { + this->connect(); + + SQLRETURN ret = SQLGetTypeInfo(this->stmt, SQL_TINYINT); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Check tinyint data type + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkSQLGetTypeInfo(this->stmt, + std::wstring(L"tinyint"), // expectedTypeName + SQL_TINYINT, // expectedDataType + 3, // expectedColumnSize + std::nullopt, // expectedLiteralPrefix + std::nullopt, // expectedLiteralSuffix + std::nullopt, // expectedCreateParams + SQL_NULLABLE, // expectedNullable + SQL_FALSE, // expectedCaseSensitive + SQL_SEARCHABLE, // expectedSearchable + SQL_FALSE, // expectedUnsignedAttr + SQL_FALSE, // expectedFixedPrecScale + NULL, // expectedAutoUniqueValue + std::wstring(L"tinyint"), // expectedLocalTypeName + NULL, // expectedMinScale + NULL, // expectedMaxScale + SQL_TINYINT, // expectedSqlDataType + NULL, // expectedSqlDatetimeSub + NULL, // expectedNumPrecRadix + NULL); // expectedIntervalPrec + + checkSQLDescribeColODBC3(this->stmt); + + // No more data + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_NO_DATA); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetTypeInfoBigInt) { + this->connect(); + + SQLRETURN ret = SQLGetTypeInfo(this->stmt, SQL_BIGINT); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Check bigint data type + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkSQLGetTypeInfo(this->stmt, + std::wstring(L"bigint"), // expectedTypeName + SQL_BIGINT, // expectedDataType + 19, // expectedColumnSize + std::nullopt, // expectedLiteralPrefix + std::nullopt, // expectedLiteralSuffix + std::nullopt, // expectedCreateParams + SQL_NULLABLE, // expectedNullable + SQL_FALSE, // expectedCaseSensitive + SQL_SEARCHABLE, // expectedSearchable + SQL_FALSE, // expectedUnsignedAttr + SQL_FALSE, // expectedFixedPrecScale + NULL, // expectedAutoUniqueValue + std::wstring(L"bigint"), // expectedLocalTypeName + NULL, // expectedMinScale + NULL, // expectedMaxScale + SQL_BIGINT, // expectedSqlDataType + NULL, // expectedSqlDatetimeSub + NULL, // expectedNumPrecRadix + NULL); // expectedIntervalPrec + + checkSQLDescribeColODBC3(this->stmt); + + // No more data + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_NO_DATA); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetTypeInfoLongVarbinary) { + this->connect(); + + SQLRETURN ret = SQLGetTypeInfo(this->stmt, SQL_LONGVARBINARY); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Check longvarbinary data type + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkSQLGetTypeInfo(this->stmt, + std::wstring(L"longvarbinary"), // expectedTypeName + SQL_LONGVARBINARY, // expectedDataType + 65536, // expectedColumnSize + std::nullopt, // expectedLiteralPrefix + std::nullopt, // expectedLiteralSuffix + std::nullopt, // expectedCreateParams + SQL_NULLABLE, // expectedNullable + SQL_FALSE, // expectedCaseSensitive + SQL_SEARCHABLE, // expectedSearchable + NULL, // expectedUnsignedAttr + SQL_FALSE, // expectedFixedPrecScale + NULL, // expectedAutoUniqueValue + std::wstring(L"longvarbinary"), // expectedLocalTypeName + NULL, // expectedMinScale + NULL, // expectedMaxScale + SQL_LONGVARBINARY, // expectedSqlDataType + NULL, // expectedSqlDatetimeSub + NULL, // expectedNumPrecRadix + NULL); // expectedIntervalPrec + + checkSQLDescribeColODBC3(this->stmt); + + // No more data + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_NO_DATA); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetTypeInfoVarbinary) { + this->connect(); + + SQLRETURN ret = SQLGetTypeInfo(this->stmt, SQL_VARBINARY); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Check varbinary data type + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkSQLGetTypeInfo(this->stmt, + std::wstring(L"varbinary"), // expectedTypeName + SQL_VARBINARY, // expectedDataType + 255, // expectedColumnSize + std::nullopt, // expectedLiteralPrefix + std::nullopt, // expectedLiteralSuffix + std::nullopt, // expectedCreateParams + SQL_NULLABLE, // expectedNullable + SQL_FALSE, // expectedCaseSensitive + SQL_SEARCHABLE, // expectedSearchable + NULL, // expectedUnsignedAttr + SQL_FALSE, // expectedFixedPrecScale + NULL, // expectedAutoUniqueValue + std::wstring(L"varbinary"), // expectedLocalTypeName + NULL, // expectedMinScale + NULL, // expectedMaxScale + SQL_VARBINARY, // expectedSqlDataType + NULL, // expectedSqlDatetimeSub + NULL, // expectedNumPrecRadix + NULL); // expectedIntervalPrec + + // No more data + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_NO_DATA); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetTypeInfoLongVarchar) { + this->connect(); + + SQLRETURN ret = SQLGetTypeInfo(this->stmt, SQL_WLONGVARCHAR); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Check text data type + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Driver returns SQL_WLONGVARCHAR since unicode is enabled + checkSQLGetTypeInfo(this->stmt, + std::wstring(L"text"), // expectedTypeName + SQL_WLONGVARCHAR, // expectedDataType + 65536, // expectedColumnSize + std::wstring(L"'"), // expectedLiteralPrefix + std::wstring(L"'"), // expectedLiteralSuffix + std::wstring(L"length"), // expectedCreateParams + SQL_NULLABLE, // expectedNullable + SQL_FALSE, // expectedCaseSensitive + SQL_SEARCHABLE, // expectedSearchable + NULL, // expectedUnsignedAttr + SQL_FALSE, // expectedFixedPrecScale + NULL, // expectedAutoUniqueValue + std::wstring(L"text"), // expectedLocalTypeName + NULL, // expectedMinScale + NULL, // expectedMaxScale + SQL_WLONGVARCHAR, // expectedSqlDataType + NULL, // expectedSqlDatetimeSub + NULL, // expectedNumPrecRadix + NULL); // expectedIntervalPrec + + checkSQLDescribeColODBC3(this->stmt); + + // Check longvarchar data type + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkSQLGetTypeInfo(this->stmt, + std::wstring(L"longvarchar"), // expectedTypeName + SQL_WLONGVARCHAR, // expectedDataType + 65536, // expectedColumnSize + std::wstring(L"'"), // expectedLiteralPrefix + std::wstring(L"'"), // expectedLiteralSuffix + std::wstring(L"length"), // expectedCreateParams + SQL_NULLABLE, // expectedNullable + SQL_FALSE, // expectedCaseSensitive + SQL_SEARCHABLE, // expectedSearchable + NULL, // expectedUnsignedAttr + SQL_FALSE, // expectedFixedPrecScale + NULL, // expectedAutoUniqueValue + std::wstring(L"longvarchar"), // expectedLocalTypeName + NULL, // expectedMinScale + NULL, // expectedMaxScale + SQL_WLONGVARCHAR, // expectedSqlDataType + NULL, // expectedSqlDatetimeSub + NULL, // expectedNumPrecRadix + NULL); // expectedIntervalPrec + + checkSQLDescribeColODBC3(this->stmt); + + // No more data + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_NO_DATA); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetTypeInfoChar) { + this->connect(); + + SQLRETURN ret = SQLGetTypeInfo(this->stmt, SQL_WCHAR); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Check char data type + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Driver returns SQL_WCHAR since unicode is enabled + checkSQLGetTypeInfo(this->stmt, + std::wstring(L"char"), // expectedTypeName + SQL_WCHAR, // expectedDataType + 255, // expectedColumnSize + std::wstring(L"'"), // expectedLiteralPrefix + std::wstring(L"'"), // expectedLiteralSuffix + std::wstring(L"length"), // expectedCreateParams + SQL_NULLABLE, // expectedNullable + SQL_FALSE, // expectedCaseSensitive + SQL_SEARCHABLE, // expectedSearchable + NULL, // expectedUnsignedAttr + SQL_FALSE, // expectedFixedPrecScale + NULL, // expectedAutoUniqueValue + std::wstring(L"char"), // expectedLocalTypeName + NULL, // expectedMinScale + NULL, // expectedMaxScale + SQL_WCHAR, // expectedSqlDataType + NULL, // expectedSqlDatetimeSub + NULL, // expectedNumPrecRadix + NULL); // expectedIntervalPrec + + checkSQLDescribeColODBC3(this->stmt); + + // No more data + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_NO_DATA); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetTypeInfoInteger) { + this->connect(); + + SQLRETURN ret = SQLGetTypeInfo(this->stmt, SQL_INTEGER); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Check integer data type + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkSQLGetTypeInfo(this->stmt, + std::wstring(L"integer"), // expectedTypeName + SQL_INTEGER, // expectedDataType + 9, // expectedColumnSize + std::nullopt, // expectedLiteralPrefix + std::nullopt, // expectedLiteralSuffix + std::nullopt, // expectedCreateParams + SQL_NULLABLE, // expectedNullable + SQL_FALSE, // expectedCaseSensitive + SQL_SEARCHABLE, // expectedSearchable + SQL_FALSE, // expectedUnsignedAttr + SQL_FALSE, // expectedFixedPrecScale + NULL, // expectedAutoUniqueValue + std::wstring(L"integer"), // expectedLocalTypeName + NULL, // expectedMinScale + NULL, // expectedMaxScale + SQL_INTEGER, // expectedSqlDataType + NULL, // expectedSqlDatetimeSub + NULL, // expectedNumPrecRadix + NULL); // expectedIntervalPrec + + checkSQLDescribeColODBC3(this->stmt); + + // No more data + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_NO_DATA); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetTypeInfoSmallInt) { + this->connect(); + + SQLRETURN ret = SQLGetTypeInfo(this->stmt, SQL_SMALLINT); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Check smallint data type + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkSQLGetTypeInfo(this->stmt, + std::wstring(L"smallint"), // expectedTypeName + SQL_SMALLINT, // expectedDataType + 5, // expectedColumnSize + std::nullopt, // expectedLiteralPrefix + std::nullopt, // expectedLiteralSuffix + std::nullopt, // expectedCreateParams + SQL_NULLABLE, // expectedNullable + SQL_FALSE, // expectedCaseSensitive + SQL_SEARCHABLE, // expectedSearchable + SQL_FALSE, // expectedUnsignedAttr + SQL_FALSE, // expectedFixedPrecScale + NULL, // expectedAutoUniqueValue + std::wstring(L"smallint"), // expectedLocalTypeName + NULL, // expectedMinScale + NULL, // expectedMaxScale + SQL_SMALLINT, // expectedSqlDataType + NULL, // expectedSqlDatetimeSub + NULL, // expectedNumPrecRadix + NULL); // expectedIntervalPrec + + checkSQLDescribeColODBC3(this->stmt); + + // No more data + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_NO_DATA); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetTypeInfoFloat) { + this->connect(); + + SQLRETURN ret = SQLGetTypeInfo(this->stmt, SQL_FLOAT); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Check float data type + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkSQLGetTypeInfo(this->stmt, + std::wstring(L"float"), // expectedTypeName + SQL_FLOAT, // expectedDataType + 7, // expectedColumnSize + std::nullopt, // expectedLiteralPrefix + std::nullopt, // expectedLiteralSuffix + std::nullopt, // expectedCreateParams + SQL_NULLABLE, // expectedNullable + SQL_FALSE, // expectedCaseSensitive + SQL_SEARCHABLE, // expectedSearchable + SQL_FALSE, // expectedUnsignedAttr + SQL_FALSE, // expectedFixedPrecScale + NULL, // expectedAutoUniqueValue + std::wstring(L"float"), // expectedLocalTypeName + NULL, // expectedMinScale + NULL, // expectedMaxScale + SQL_FLOAT, // expectedSqlDataType + NULL, // expectedSqlDatetimeSub + NULL, // expectedNumPrecRadix + NULL); // expectedIntervalPrec + + checkSQLDescribeColODBC3(this->stmt); + + // No more data + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_NO_DATA); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetTypeInfoDouble) { + this->connect(); + + SQLRETURN ret = SQLGetTypeInfo(this->stmt, SQL_DOUBLE); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Check double data type + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkSQLGetTypeInfo(this->stmt, + std::wstring(L"double"), // expectedTypeName + SQL_DOUBLE, // expectedDataType + 15, // expectedColumnSize + std::nullopt, // expectedLiteralPrefix + std::nullopt, // expectedLiteralSuffix + std::nullopt, // expectedCreateParams + SQL_NULLABLE, // expectedNullable + SQL_FALSE, // expectedCaseSensitive + SQL_SEARCHABLE, // expectedSearchable + SQL_FALSE, // expectedUnsignedAttr + SQL_FALSE, // expectedFixedPrecScale + NULL, // expectedAutoUniqueValue + std::wstring(L"double"), // expectedLocalTypeName + NULL, // expectedMinScale + NULL, // expectedMaxScale + SQL_DOUBLE, // expectedSqlDataType + NULL, // expectedSqlDatetimeSub + NULL, // expectedNumPrecRadix + NULL); // expectedIntervalPrec + + checkSQLDescribeColODBC3(this->stmt); + + // Check numeric data type + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Mock server treats numeric data type as a double type + checkSQLGetTypeInfo(this->stmt, + std::wstring(L"numeric"), // expectedTypeName + SQL_DOUBLE, // expectedDataType + 15, // expectedColumnSize + std::nullopt, // expectedLiteralPrefix + std::nullopt, // expectedLiteralSuffix + std::nullopt, // expectedCreateParams + SQL_NULLABLE, // expectedNullable + SQL_FALSE, // expectedCaseSensitive + SQL_SEARCHABLE, // expectedSearchable + SQL_FALSE, // expectedUnsignedAttr + SQL_FALSE, // expectedFixedPrecScale + NULL, // expectedAutoUniqueValue + std::wstring(L"numeric"), // expectedLocalTypeName + NULL, // expectedMinScale + NULL, // expectedMaxScale + SQL_DOUBLE, // expectedSqlDataType + NULL, // expectedSqlDatetimeSub + NULL, // expectedNumPrecRadix + NULL); // expectedIntervalPrec + + checkSQLDescribeColODBC3(this->stmt); + + // No more data + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_NO_DATA); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetTypeInfoVarchar) { + this->connect(); + + SQLRETURN ret = SQLGetTypeInfo(this->stmt, SQL_WVARCHAR); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Check varchar data type + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Driver returns SQL_WVARCHAR since unicode is enabled + checkSQLGetTypeInfo(this->stmt, + std::wstring(L"varchar"), // expectedTypeName + SQL_WVARCHAR, // expectedDataType + 255, // expectedColumnSize + std::wstring(L"'"), // expectedLiteralPrefix + std::wstring(L"'"), // expectedLiteralSuffix + std::wstring(L"length"), // expectedCreateParams + SQL_NULLABLE, // expectedNullable + SQL_FALSE, // expectedCaseSensitive + SQL_SEARCHABLE, // expectedSearchable + SQL_FALSE, // expectedUnsignedAttr + SQL_FALSE, // expectedFixedPrecScale + NULL, // expectedAutoUniqueValue + std::wstring(L"varchar"), // expectedLocalTypeName + NULL, // expectedMinScale + NULL, // expectedMaxScale + SQL_WVARCHAR, // expectedSqlDataType + NULL, // expectedSqlDatetimeSub + NULL, // expectedNumPrecRadix + NULL); // expectedIntervalPrec + + checkSQLDescribeColODBC3(this->stmt); + + // No more data + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_NO_DATA); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetTypeInfoSQLTypeDate) { + this->connect(); + + SQLRETURN ret = SQLGetTypeInfo(this->stmt, SQL_TYPE_DATE); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Check date data type + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkSQLGetTypeInfo(this->stmt, + std::wstring(L"date"), // expectedTypeName + SQL_TYPE_DATE, // expectedDataType + 10, // expectedColumnSize + std::wstring(L"'"), // expectedLiteralPrefix + std::wstring(L"'"), // expectedLiteralSuffix + std::nullopt, // expectedCreateParams + SQL_NULLABLE, // expectedNullable + SQL_FALSE, // expectedCaseSensitive + SQL_SEARCHABLE, // expectedSearchable + SQL_FALSE, // expectedUnsignedAttr + SQL_FALSE, // expectedFixedPrecScale + NULL, // expectedAutoUniqueValue + std::wstring(L"date"), // expectedLocalTypeName + NULL, // expectedMinScale + NULL, // expectedMaxScale + SQL_DATETIME, // expectedSqlDataType + SQL_CODE_DATE, // expectedSqlDatetimeSub + NULL, // expectedNumPrecRadix + NULL); // expectedIntervalPrec + + checkSQLDescribeColODBC3(this->stmt); + + // No more data + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_NO_DATA); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetTypeInfoSQLDate) { + this->connect(); + + // Pass ODBC Ver 2 data type + SQLRETURN ret = SQLGetTypeInfo(this->stmt, SQL_DATE); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Check date data type + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkSQLGetTypeInfo(this->stmt, + std::wstring(L"date"), // expectedTypeName + SQL_TYPE_DATE, // expectedDataType + 10, // expectedColumnSize + std::wstring(L"'"), // expectedLiteralPrefix + std::wstring(L"'"), // expectedLiteralSuffix + std::nullopt, // expectedCreateParams + SQL_NULLABLE, // expectedNullable + SQL_FALSE, // expectedCaseSensitive + SQL_SEARCHABLE, // expectedSearchable + SQL_FALSE, // expectedUnsignedAttr + SQL_FALSE, // expectedFixedPrecScale + NULL, // expectedAutoUniqueValue + std::wstring(L"date"), // expectedLocalTypeName + NULL, // expectedMinScale + NULL, // expectedMaxScale + SQL_DATETIME, // expectedSqlDataType + SQL_CODE_DATE, // expectedSqlDatetimeSub + NULL, // expectedNumPrecRadix + NULL); // expectedIntervalPrec + + checkSQLDescribeColODBC3(this->stmt); + + // No more data + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_NO_DATA); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetTypeInfoDateODBCVer2) { + this->connect(SQL_OV_ODBC2); + + SQLRETURN ret = SQLGetTypeInfo(this->stmt, SQL_DATE); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Check date data type + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkSQLGetTypeInfo(this->stmt, + std::wstring(L"date"), // expectedTypeName + SQL_DATE, // expectedDataType + 10, // expectedColumnSize + std::wstring(L"'"), // expectedLiteralPrefix + std::wstring(L"'"), // expectedLiteralSuffix + std::nullopt, // expectedCreateParams + SQL_NULLABLE, // expectedNullable + SQL_FALSE, // expectedCaseSensitive + SQL_SEARCHABLE, // expectedSearchable + SQL_FALSE, // expectedUnsignedAttr + SQL_FALSE, // expectedFixedPrecScale + NULL, // expectedAutoUniqueValue + std::wstring(L"date"), // expectedLocalTypeName + NULL, // expectedMinScale + NULL, // expectedMaxScale + SQL_DATETIME, // expectedSqlDataType + NULL, // expectedSqlDatetimeSub, driver returns NULL for Ver2 + NULL, // expectedNumPrecRadix + NULL); // expectedIntervalPrec + + checkSQLDescribeColODBC2(this->stmt); + + // No more data + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_NO_DATA); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetTypeInfoSQLTypeDateODBCVer2) { + this->connect(SQL_OV_ODBC2); + + // Pass ODBC Ver 3 data type + SQLRETURN ret = SQLGetTypeInfo(this->stmt, SQL_TYPE_DATE); + + EXPECT_EQ(ret, SQL_ERROR); + + // Driver manager returns SQL data type out of range error state + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, error_state_S1004); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetTypeInfoSQLTypeTime) { + this->connect(); + + SQLRETURN ret = SQLGetTypeInfo(this->stmt, SQL_TYPE_TIME); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Check time data type + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkSQLGetTypeInfo(this->stmt, + std::wstring(L"time"), // expectedTypeName + SQL_TYPE_TIME, // expectedDataType + 8, // expectedColumnSize + std::wstring(L"'"), // expectedLiteralPrefix + std::wstring(L"'"), // expectedLiteralSuffix + std::nullopt, // expectedCreateParams + SQL_NULLABLE, // expectedNullable + SQL_FALSE, // expectedCaseSensitive + SQL_SEARCHABLE, // expectedSearchable + SQL_FALSE, // expectedUnsignedAttr + SQL_FALSE, // expectedFixedPrecScale + NULL, // expectedAutoUniqueValue + std::wstring(L"time"), // expectedLocalTypeName + NULL, // expectedMinScale + NULL, // expectedMaxScale + SQL_DATETIME, // expectedSqlDataType + SQL_CODE_TIME, // expectedSqlDatetimeSub + NULL, // expectedNumPrecRadix + NULL); // expectedIntervalPrec + + checkSQLDescribeColODBC3(this->stmt); + + // No more data + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_NO_DATA); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetTypeInfoSQLTime) { + this->connect(); + + // Pass ODBC Ver 2 data type + SQLRETURN ret = SQLGetTypeInfo(this->stmt, SQL_TIME); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Check time data type + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkSQLGetTypeInfo(this->stmt, + std::wstring(L"time"), // expectedTypeName + SQL_TYPE_TIME, // expectedDataType + 8, // expectedColumnSize + std::wstring(L"'"), // expectedLiteralPrefix + std::wstring(L"'"), // expectedLiteralSuffix + std::nullopt, // expectedCreateParams + SQL_NULLABLE, // expectedNullable + SQL_FALSE, // expectedCaseSensitive + SQL_SEARCHABLE, // expectedSearchable + SQL_FALSE, // expectedUnsignedAttr + SQL_FALSE, // expectedFixedPrecScale + NULL, // expectedAutoUniqueValue + std::wstring(L"time"), // expectedLocalTypeName + NULL, // expectedMinScale + NULL, // expectedMaxScale + SQL_DATETIME, // expectedSqlDataType + SQL_CODE_TIME, // expectedSqlDatetimeSub + NULL, // expectedNumPrecRadix + NULL); // expectedIntervalPrec + + checkSQLDescribeColODBC3(this->stmt); + + // No more data + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_NO_DATA); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetTypeInfoTimeODBCVer2) { + this->connect(SQL_OV_ODBC2); + + SQLRETURN ret = SQLGetTypeInfo(this->stmt, SQL_TIME); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Check time data type + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkSQLGetTypeInfo(this->stmt, + std::wstring(L"time"), // expectedTypeName + SQL_TIME, // expectedDataType + 8, // expectedColumnSize + std::wstring(L"'"), // expectedLiteralPrefix + std::wstring(L"'"), // expectedLiteralSuffix + std::nullopt, // expectedCreateParams + SQL_NULLABLE, // expectedNullable + SQL_FALSE, // expectedCaseSensitive + SQL_SEARCHABLE, // expectedSearchable + SQL_FALSE, // expectedUnsignedAttr + SQL_FALSE, // expectedFixedPrecScale + NULL, // expectedAutoUniqueValue + std::wstring(L"time"), // expectedLocalTypeName + NULL, // expectedMinScale + NULL, // expectedMaxScale + SQL_DATETIME, // expectedSqlDataType + NULL, // expectedSqlDatetimeSub, driver returns NULL for Ver2 + NULL, // expectedNumPrecRadix + NULL); // expectedIntervalPrec + + checkSQLDescribeColODBC2(this->stmt); + + // No more data + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_NO_DATA); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetTypeInfoSQLTypeTimeODBCVer2) { + this->connect(SQL_OV_ODBC2); + + // Pass ODBC Ver 3 data type + SQLRETURN ret = SQLGetTypeInfo(this->stmt, SQL_TYPE_TIME); + + EXPECT_EQ(ret, SQL_ERROR); + + // Driver manager returns SQL data type out of range error state + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, error_state_S1004); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetTypeInfoSQLTypeTimestamp) { + this->connect(); + + SQLRETURN ret = SQLGetTypeInfo(this->stmt, SQL_TYPE_TIMESTAMP); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Check timestamp data type + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkSQLGetTypeInfo(this->stmt, + std::wstring(L"timestamp"), // expectedTypeName + SQL_TYPE_TIMESTAMP, // expectedDataType + 32, // expectedColumnSize + std::wstring(L"'"), // expectedLiteralPrefix + std::wstring(L"'"), // expectedLiteralSuffix + std::nullopt, // expectedCreateParams + SQL_NULLABLE, // expectedNullable + SQL_FALSE, // expectedCaseSensitive + SQL_SEARCHABLE, // expectedSearchable + SQL_FALSE, // expectedUnsignedAttr + SQL_FALSE, // expectedFixedPrecScale + NULL, // expectedAutoUniqueValue + std::wstring(L"timestamp"), // expectedLocalTypeName + NULL, // expectedMinScale + NULL, // expectedMaxScale + SQL_DATETIME, // expectedSqlDataType + SQL_CODE_TIMESTAMP, // expectedSqlDatetimeSub + NULL, // expectedNumPrecRadix + NULL); // expectedIntervalPrec + + checkSQLDescribeColODBC3(this->stmt); + + // No more data + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_NO_DATA); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetTypeInfoSQLTimestamp) { + this->connect(); + + // Pass ODBC Ver 2 data type + SQLRETURN ret = SQLGetTypeInfo(this->stmt, SQL_TIMESTAMP); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Check timestamp data type + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkSQLGetTypeInfo(this->stmt, + std::wstring(L"timestamp"), // expectedTypeName + SQL_TYPE_TIMESTAMP, // expectedDataType + 32, // expectedColumnSize + std::wstring(L"'"), // expectedLiteralPrefix + std::wstring(L"'"), // expectedLiteralSuffix + std::nullopt, // expectedCreateParams + SQL_NULLABLE, // expectedNullable + SQL_FALSE, // expectedCaseSensitive + SQL_SEARCHABLE, // expectedSearchable + SQL_FALSE, // expectedUnsignedAttr + SQL_FALSE, // expectedFixedPrecScale + NULL, // expectedAutoUniqueValue + std::wstring(L"timestamp"), // expectedLocalTypeName + NULL, // expectedMinScale + NULL, // expectedMaxScale + SQL_DATETIME, // expectedSqlDataType + SQL_CODE_TIMESTAMP, // expectedSqlDatetimeSub + NULL, // expectedNumPrecRadix + NULL); // expectedIntervalPrec + + checkSQLDescribeColODBC3(this->stmt); + + // No more data + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_NO_DATA); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetTypeInfoSQLTimestampODBCVer2) { + this->connect(SQL_OV_ODBC2); + + SQLRETURN ret = SQLGetTypeInfo(this->stmt, SQL_TIMESTAMP); + EXPECT_EQ(ret, SQL_SUCCESS); + + // Check timestamp data type + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_SUCCESS); + + checkSQLGetTypeInfo(this->stmt, + std::wstring(L"timestamp"), // expectedTypeName + SQL_TIMESTAMP, // expectedDataType + 32, // expectedColumnSize + std::wstring(L"'"), // expectedLiteralPrefix + std::wstring(L"'"), // expectedLiteralSuffix + std::nullopt, // expectedCreateParams + SQL_NULLABLE, // expectedNullable + SQL_FALSE, // expectedCaseSensitive + SQL_SEARCHABLE, // expectedSearchable + SQL_FALSE, // expectedUnsignedAttr + SQL_FALSE, // expectedFixedPrecScale + NULL, // expectedAutoUniqueValue + std::wstring(L"timestamp"), // expectedLocalTypeName + NULL, // expectedMinScale + NULL, // expectedMaxScale + SQL_DATETIME, // expectedSqlDataType + NULL, // expectedSqlDatetimeSub, driver returns NULL for Ver2 + NULL, // expectedNumPrecRadix + NULL); // expectedIntervalPrec + + checkSQLDescribeColODBC2(this->stmt); + + // No more data + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_NO_DATA); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetTypeInfoSQLTypeTimestampODBCVer2) { + this->connect(SQL_OV_ODBC2); + + // Pass ODBC Ver 3 data type + SQLRETURN ret = SQLGetTypeInfo(this->stmt, SQL_TYPE_TIMESTAMP); + + EXPECT_EQ(ret, SQL_ERROR); + + // Driver manager returns SQL data type out of range error state + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, error_state_S1004); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, TestSQLGetTypeInfoInvalidDataType) { + this->connect(); + + SQLSMALLINT invalidDataType = -114; + SQLRETURN ret = SQLGetTypeInfo(this->stmt, invalidDataType); + + EXPECT_EQ(ret, SQL_ERROR); + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, error_state_HY004); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetTypeInfoUnsupportedDataType) { + // Assumes mock and remote server don't support GUID data type + this->connect(); + + SQLRETURN ret = SQLGetTypeInfo(this->stmt, SQL_GUID); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Result set is empty with valid data type that is unsupported by the server + ret = SQLFetch(this->stmt); + EXPECT_EQ(ret, SQL_NO_DATA); + + this->disconnect(); +} + +} // namespace arrow::flight::sql::odbc diff --git a/cpp/src/arrow/flight/sql/odbc/visibility.h b/cpp/src/arrow/flight/sql/odbc/visibility.h new file mode 100644 index 00000000000..416dfecc864 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/visibility.h @@ -0,0 +1,48 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#if defined(_WIN32) || defined(__CYGWIN__) +# if defined(_MSC_VER) +# pragma warning(push) +# pragma warning(disable : 4251) +# else +# pragma GCC diagnostic ignored "-Wattributes" +# endif + +# ifdef ARROW_FLIGHT_SQL_ODBC_STATIC +# define ARROW_FLIGHT_SQL_ODBC_EXPORT +# elif defined(ARROW_FLIGHT_SQL_ODBC_EXPORTING) +# define ARROW_FLIGHT_SQL_ODBC_EXPORT __declspec(dllexport) +# else +# define ARROW_FLIGHT_SQL_ODBC_EXPORT __declspec(dllimport) +# endif + +# define ARROW_FLIGHT_SQL_ODBC_NO_EXPORT +#else // Not Windows +# ifndef ARROW_FLIGHT_SQL_ODBC_EXPORT +# define ARROW_FLIGHT_SQL_ODBC_EXPORT __attribute__((visibility("default"))) +# endif +# ifndef ARROW_FLIGHT_SQL_ODBC_NO_EXPORT +# define ARROW_FLIGHT_SQL_ODBC_NO_EXPORT __attribute__((visibility("hidden"))) +# endif +#endif // Non-Windows + +#if defined(_MSC_VER) +# pragma warning(pop) +#endif diff --git a/testing b/testing index fbf6b703dc9..d2a13712303 160000 --- a/testing +++ b/testing @@ -1 +1 @@ -Subproject commit fbf6b703dc93d17d75fa3664c5aa2c7873ebaf06 +Subproject commit d2a13712303498963395318a4eb42872e66aead7