diff --git a/.clang-format b/.clang-format index 525f5600..5846e195 100644 --- a/.clang-format +++ b/.clang-format @@ -28,7 +28,7 @@ SpacesInAngles: 'false' SpacesInContainerLiterals: 'false' SpacesInParentheses: 'false' SpacesInSquareBrackets: 'false' -Standard: c++17 +Standard: c++20 UseTab: Never SortIncludes: true ColumnLimit: 100 diff --git a/.drone.jsonnet b/.drone.jsonnet index 5bf25d51..e45eab33 100644 --- a/.drone.jsonnet +++ b/.drone.jsonnet @@ -10,7 +10,7 @@ local submodules = { local apt_get_quiet = 'apt-get -o=Dpkg::Use-Pty=0 -q'; -local libngtcp2_deps = ['libgnutls28-dev', 'libprotobuf-dev']; +local libngtcp2_deps = ['libgnutls28-dev', 'libprotobuf-dev', 'libngtcp2-dev', 'libngtcp2-crypto-gnutls-dev']; local default_deps_nocxx = [ 'nlohmann-json3-dev', @@ -375,12 +375,14 @@ local static_build(name, 'libsession-util-windows-x64-TAG.zip', deps=['g++-mingw-w64-x86-64-posix'], cmake_extra='-DCMAKE_CXX_FLAGS=-fdiagnostics-color=always -DCMAKE_TOOLCHAIN_FILE=../cmake/mingw-x86-64-toolchain.cmake'), + /* currently broken: static_build('Static Windows x86', docker_base + 'debian-win32-cross', 'libsession-util-windows-x86-TAG.zip', deps=['g++-mingw-w64-i686-posix'], allow_fail=true, cmake_extra='-DCMAKE_CXX_FLAGS=-fdiagnostics-color=always -DCMAKE_TOOLCHAIN_FILE=../cmake/mingw-i686-toolchain.cmake'), + */ debian_pipeline( 'Static Android', docker_base + 'android', diff --git a/.gitignore b/.gitignore index 1004c092..1fe18bd5 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,4 @@ /compile_commands.json /.cache/ /.vscode/ -*.DS_Store \ No newline at end of file +.DS_STORE diff --git a/.gitmodules b/.gitmodules index 22049b33..d772e270 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,6 +1,3 @@ -[submodule "external/oxen-encoding"] - path = external/oxen-encoding - url = https://github.com/session-foundation/oxen-encoding.git [submodule "external/libsodium-internal"] path = external/libsodium-internal url = https://github.com/session-foundation/libsodium-internal.git @@ -16,6 +13,9 @@ [submodule "external/nlohmann-json"] path = external/nlohmann-json url = https://github.com/nlohmann/json.git +[submodule "external/oxen-libquic"] + path = external/oxen-libquic + url = https://github.com/oxen-io/oxen-libquic.git [submodule "external/protobuf"] path = external/protobuf url = https://github.com/protocolbuffers/protobuf.git diff --git a/CMakeLists.txt b/CMakeLists.txt index 1026503a..b491704b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -3,7 +3,7 @@ cmake_minimum_required(VERSION 3.14...3.23) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) # Has to be set before `project()`, and ignored on non-macos: -set(CMAKE_OSX_DEPLOYMENT_TARGET 10.13 CACHE STRING "macOS deployment target (Apple clang only)") +set(CMAKE_OSX_DEPLOYMENT_TARGET 10.15 CACHE STRING "macOS deployment target (Apple clang only)") set(LANGS C CXX) find_program(CCACHE_PROGRAM ccache) @@ -16,7 +16,6 @@ if(CCACHE_PROGRAM) endforeach() endif() - project(libsession-util VERSION 1.2.0 DESCRIPTION "Session client utility library" @@ -41,7 +40,7 @@ else() endif() -set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_EXTENSIONS OFF) @@ -81,7 +80,7 @@ option(STATIC_LIBSTD "Statically link libstdc++/libgcc" ${default_static_libstd} option(USE_LTO "Use Link-Time Optimization" ${use_lto_default}) -# Provide this as an option for now because GMP and iOS are sometimes unhappy with each other. +# Provide this as an option for now because GMP and Desktop are sometimes unhappy with each other. option(ENABLE_ONIONREQ "Build with onion request functionality" ON) if(USE_LTO) @@ -118,6 +117,21 @@ set(CMAKE_POSITION_INDEPENDENT_CODE ON) add_subdirectory(external) + +if(ENABLE_ONIONREQ) + if(NOT TARGET nettle::nettle) + if(BUILD_STATIC_DEPS) + message(FATAL_ERROR "Internal error: nettle::nettle target (expected via libquic BUILD_STATIC_DEPS) not found") + else() + find_package(PkgConfig REQUIRED) + pkg_check_modules(NETTLE REQUIRED IMPORTED_TARGET nettle) + add_library(nettle INTERFACE) + target_link_libraries(nettle INTERFACE PkgConfig::NETTLE) + add_library(nettle::nettle ALIAS nettle) + endif() + endif() +endif() + add_subdirectory(src) add_subdirectory(proto) diff --git a/cmake/AddStaticBundleLib.cmake b/cmake/AddStaticBundleLib.cmake index 8c482c47..2e37c0b6 100644 --- a/cmake/AddStaticBundleLib.cmake +++ b/cmake/AddStaticBundleLib.cmake @@ -1,6 +1,11 @@ set(LIBSESSION_STATIC_BUNDLE_LIBS "" CACHE INTERNAL "list of libs to go into the static bundle lib") +function(_libsession_static_bundle_append tgt) + list(APPEND LIBSESSION_STATIC_BUNDLE_LIBS "${tgt}") + set(LIBSESSION_STATIC_BUNDLE_LIBS "${LIBSESSION_STATIC_BUNDLE_LIBS}" CACHE INTERNAL "") +endfunction() + # Call as: # # libsession_static_bundle(target [target2 ...]) @@ -8,7 +13,24 @@ set(LIBSESSION_STATIC_BUNDLE_LIBS "" CACHE INTERNAL "list of libs to go into the # to append the given target(s) to the list of libraries that will be combined to make the static # bundled libsession-util.a. function(libsession_static_bundle) - list(APPEND LIBSESSION_STATIC_BUNDLE_LIBS "${ARGN}") - list(REMOVE_DUPLICATES LIBSESSION_STATIC_BUNDLE_LIBS) - set(LIBSESSION_STATIC_BUNDLE_LIBS "${LIBSESSION_STATIC_BUNDLE_LIBS}" CACHE INTERNAL "") + foreach(tgt IN LISTS ARGN) + if(TARGET "${tgt}" AND NOT "${tgt}" IN_LIST LIBSESSION_STATIC_BUNDLE_LIBS) + get_target_property(tgt_type ${tgt} TYPE) + + if(tgt_type STREQUAL STATIC_LIBRARY) + message(STATUS "Adding ${tgt} to libsession-util bundled library list") + _libsession_static_bundle_append("${tgt}") + endif() + + if(tgt_type STREQUAL INTERFACE_LIBRARY) + get_target_property(tgt_link_deps ${tgt} INTERFACE_LINK_LIBRARIES) + else() + get_target_property(tgt_link_deps ${tgt} LINK_LIBRARIES) + endif() + + if(tgt_link_deps) + libsession_static_bundle(${tgt_link_deps}) + endif() + endif() + endforeach() endfunction() diff --git a/cmake/GenVersion.cmake b/cmake/GenVersion.cmake index 20c4a81a..3ac32ff5 100644 --- a/cmake/GenVersion.cmake +++ b/cmake/GenVersion.cmake @@ -41,7 +41,7 @@ else() OUTPUT_VARIABLE git_tag OUTPUT_STRIP_TRAILING_WHITESPACE) - if(git_tag) + if(git_tag AND git_tag MATCHES "^v[0-9]+\\.[0-9]+\\.[0-9]+$") message(STATUS "${git_commit} is tagged (${git_tag}); tagging version as 'release'") set(vfull "v${PROJECT_VERSION_MAJOR}.${PROJECT_VERSION_MINOR}.${PROJECT_VERSION_PATCH}") set(PROJECT_VERSION_TAG "release") @@ -49,6 +49,9 @@ else() if (NOT git_tag STREQUAL "${vfull}") message(FATAL_ERROR "This commit is tagged, but the tag (${git_tag}) does not match the project version (${vfull})!") endif() + elseif(git_tag) + message(WARNING "Did not recognize git tag (${git_tag}) for ${git_commit}; tagging with commit hash") + set(PROJECT_VERSION_TAG "${git_commit}") else() message(STATUS "Did not find a git tag for ${git_commit}; tagging version with the commit hash") set(PROJECT_VERSION_TAG "${git_commit}") diff --git a/cmake/StaticBuild.cmake b/cmake/StaticBuild.cmake index 6435e5f0..49863f8c 100644 --- a/cmake/StaticBuild.cmake +++ b/cmake/StaticBuild.cmake @@ -5,21 +5,6 @@ set(LOCAL_MIRROR "" CACHE STRING "local mirror path/URL for lib downloads") -set(GMP_VERSION 6.3.0 CACHE STRING "gmp version") -set(GMP_MIRROR ${LOCAL_MIRROR} https://gmplib.org/download/gmp - CACHE STRING "gmp mirror(s)") -set(GMP_SOURCE gmp-${GMP_VERSION}.tar.xz) -set(GMP_HASH SHA512=e85a0dab5195889948a3462189f0e0598d331d3457612e2d3350799dba2e244316d256f8161df5219538eb003e4b5343f989aaa00f96321559063ed8c8f29fd2 - CACHE STRING "gmp source hash") - -set(NETTLE_VERSION 3.9.1 CACHE STRING "nettle version") -set(NETTLE_MIRROR ${LOCAL_MIRROR} https://ftp.gnu.org/gnu/nettle - CACHE STRING "nettle mirror(s)") -set(NETTLE_SOURCE nettle-${NETTLE_VERSION}.tar.gz) -set(NETTLE_HASH SHA512=5939c4b43cf9ff6c6272245b85f123c81f8f4e37089fa4f39a00a570016d837f6e706a33226e4bbfc531b02a55b2756ff312461225ed88de338a73069e031ced - CACHE STRING "nettle source hash") - - include(ExternalProject) set(DEPS_DESTDIR ${CMAKE_BINARY_DIR}/static-deps) @@ -230,32 +215,6 @@ elseif(gmp_build_host STREQUAL "") set(gmp_build_host "--build=${CMAKE_LIBRARY_ARCHITECTURE}") endif() -if(ENABLE_ONIONREQ) - build_external(gmp - CONFIGURE_COMMAND ./configure ${gmp_build_host} --disable-shared --prefix=${DEPS_DESTDIR} --with-pic - "CC=${deps_cc}" "CXX=${deps_cxx}" "CFLAGS=${deps_CFLAGS}${apple_cflags_arch}" "CXXFLAGS=${deps_CXXFLAGS}${apple_cxxflags_arch}" - "LDFLAGS=${apple_ldflags_arch}" ${cross_rc} CC_FOR_BUILD=cc CPP_FOR_BUILD=cpp - ) - add_static_target(gmp gmp_external libgmp.a) - - build_external(nettle - CONFIGURE_COMMAND ./configure ${gmp_build_host} --disable-shared --prefix=${DEPS_DESTDIR} --libdir=${DEPS_DESTDIR}/lib - --with-pic --disable-openssl - "CC=${deps_cc}" "CXX=${deps_cxx}" - "CFLAGS=${deps_CFLAGS}${apple_cflags_arch}" "CXXFLAGS=${deps_CXXFLAGS}${apple_cxxflags_arch}" - "CPPFLAGS=-I${DEPS_DESTDIR}/include" - "LDFLAGS=-L${DEPS_DESTDIR}/lib${apple_ldflags_arch}" - - DEPENDS gmp_external - BUILD_BYPRODUCTS - ${DEPS_DESTDIR}/lib/libnettle.a - ${DEPS_DESTDIR}/lib/libhogweed.a - ${DEPS_DESTDIR}/include/nettle/version.h - ) - add_static_target(nettle nettle_external libnettle.a gmp) - add_static_target(hogweed nettle_external libhogweed.a nettle) -endif() - link_libraries(-static-libstdc++) if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") link_libraries(-static-libgcc) diff --git a/external/CMakeLists.txt b/external/CMakeLists.txt index 0327de9d..c9c8b160 100644 --- a/external/CMakeLists.txt +++ b/external/CMakeLists.txt @@ -26,7 +26,7 @@ if(SUBMODULE_CHECK) message(STATUS "Checking submodules") check_submodule(ios-cmake) check_submodule(libsodium-internal) - check_submodule(oxen-encoding) + check_submodule(oxen-libquic external/oxen-logging external/oxen-encoding) check_submodule(nlohmann-json) check_submodule(zstd) check_submodule(protobuf) @@ -37,7 +37,7 @@ if(NOT BUILD_STATIC_DEPS AND NOT FORCE_ALL_SUBMODULES) find_package(PkgConfig REQUIRED) endif() -macro(system_or_submodule BIGNAME smallname pkgconf subdir) +macro(libsession_system_or_submodule BIGNAME smallname pkgconf subdir) option(FORCE_${BIGNAME}_SUBMODULE "force using ${smallname} submodule" OFF) if(NOT BUILD_STATIC_DEPS AND NOT FORCE_${BIGNAME}_SUBMODULE AND NOT FORCE_ALL_SUBMODULES) pkg_check_modules(${BIGNAME} ${pkgconf} IMPORTED_TARGET GLOBAL) @@ -57,6 +57,9 @@ macro(system_or_submodule BIGNAME smallname pkgconf subdir) if(TARGET ${smallname} AND NOT TARGET ${smallname}::${smallname}) add_library(${smallname}::${smallname} ALIAS ${smallname}) endif() + if(BUILD_STATIC_DEPS AND STATIC_BUNDLE) + libsession_static_bundle(${smallname}::${smallname}) + endif() endmacro() @@ -65,11 +68,6 @@ set(cross_host "") set(cross_rc "") if(CMAKE_CROSSCOMPILING) if(APPLE_TARGET_TRIPLE) - if(PLATFORM MATCHES "OS64" OR PLATFORM MATCHES "SIMULATORARM64") - set(APPLE_TARGET_TRIPLE aarch64-apple-ios) - elseif(PLATFORM MATCHES "SIMULATOR64") - set(APPLE_TARGET_TRIPLE x86_64-apple-ios) - endif() set(cross_host "--host=${APPLE_TARGET_TRIPLE}") elseif(ANDROID) if(CMAKE_ANDROID_ARCH_ABI MATCHES x86_64) @@ -102,8 +100,31 @@ if(CMAKE_CROSSCOMPILING) endif() endif() +set(LIBQUIC_BUILD_TESTS OFF CACHE BOOL "") +if(ENABLE_ONIONREQ) + libsession_system_or_submodule(OXENQUIC quic liboxenquic>=1.1.0 oxen-libquic) +endif() + +if(NOT TARGET oxenc::oxenc) + # The oxenc target will already exist if we load libquic above via submodule + set(OXENC_BUILD_TESTS OFF CACHE BOOL "") + set(OXENC_BUILD_DOCS OFF CACHE BOOL "") + libsession_system_or_submodule(OXENC oxenc liboxenc>=1.1.0 oxen-libquic/external/oxen-encoding) +endif() -system_or_submodule(OXENC oxenc liboxenc>=1.0.10 oxen-encoding) +if(NOT TARGET oxen::logging) + add_subdirectory(oxen-libquic/external/oxen-logging) +endif() + +oxen_logging_add_source_dir("${PROJECT_SOURCE_DIR}") + +# Apple xcode 15 has a completely broken std::source_location; we can't fix it, but at least we can +# hack up the source locations to hide the path that it uses (which is the useless path to +# oxen/log.hpp where the info/critical/etc. bodies are). +if(APPLE AND CMAKE_CXX_COMPILER_ID STREQUAL AppleClang AND NOT CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 16) + message(WARNING "${CMAKE_CXX_COMPILER_ID} ${CMAKE_CXX_COMPILER_VERSION} is broken: filenames in logging statements will not display properly") + oxen_logging_add_source_dir("${CMAKE_CURRENT_SOURCE_DIR}/oxen-libquic/external/oxen-logging/include/oxen") +endif() if(CMAKE_C_COMPILER_LAUNCHER) @@ -143,7 +164,7 @@ set(protobuf_BUILD_SHARED_LIBS OFF CACHE BOOL "" FORCE) set(protobuf_ABSL_PROVIDER "module" CACHE STRING "" FORCE) set(protobuf_BUILD_PROTOC_BINARIES OFF CACHE BOOL "") set(protobuf_BUILD_PROTOBUF_BINARIES ON CACHE BOOL "" FORCE) -system_or_submodule(PROTOBUF_LITE protobuf_lite protobuf-lite>=3.21 protobuf) +libsession_system_or_submodule(PROTOBUF_LITE protobuf_lite protobuf-lite>=3.21 protobuf) if(TARGET PkgConfig::PROTOBUF_LITE AND NOT TARGET protobuf::libprotobuf-lite) add_library(protobuf::libprotobuf-lite ALIAS PkgConfig::PROTOBUF_LITE) endif() @@ -172,4 +193,4 @@ libsession_static_bundle(libzstd_static) set(JSON_BuildTests OFF CACHE INTERNAL "") set(JSON_Install ON CACHE INTERNAL "") # Required to export targets that we use -system_or_submodule(NLOHMANN nlohmann_json nlohmann_json>=3.7.0 nlohmann-json) +libsession_system_or_submodule(NLOHMANN nlohmann_json nlohmann_json>=3.7.0 nlohmann-json) diff --git a/external/oxen-encoding b/external/oxen-encoding deleted file mode 160000 index a7de6375..00000000 --- a/external/oxen-encoding +++ /dev/null @@ -1 +0,0 @@ -Subproject commit a7de63756dcc5c31cb899a4b810e6434b1a7c01c diff --git a/external/oxen-libquic b/external/oxen-libquic new file mode 160000 index 00000000..79d3f89b --- /dev/null +++ b/external/oxen-libquic @@ -0,0 +1 @@ +Subproject commit 79d3f89b9260880904c831bf8bebafc916ca7f7d diff --git a/include/session/config.hpp b/include/session/config.hpp index 1c6cddb4..eaf418e5 100644 --- a/include/session/config.hpp +++ b/include/session/config.hpp @@ -95,10 +95,10 @@ class ConfigMessage { std::optional> verified_signature_; // This will be set during construction from configs based on the merge result: - // -1 means we had to merge one or more configs together into a new merged config - // >= 0 indicates the index of the config we used if we did not merge (i.e. there was only one + // nullopt means we had to merge one or more configs together into a new merged config + // If set to a value then the value is the index of the config we used (i.e. there was only one // config, or there were multiple but one of them referenced all the others). - int unmerged_ = -1; + std::optional unmerged_; public: constexpr static int DEFAULT_DIFF_LAGS = 5; @@ -203,13 +203,13 @@ class ConfigMessage { /// After loading multiple config files this flag indicates whether or not we had to produce a /// new, merged configuration message (true) or did not need to merge (false). (For config /// messages that were not loaded from serialized data this is always true). - bool merged() const { return unmerged_ == -1; } + bool merged() const { return !unmerged_; } /// After loading multiple config files this field contains the index of the single config we /// used if we didn't need to merge (that is: there was only one config or one config that /// superceded all the others). If we had to merge (or this wasn't loaded from serialized - /// data), this will return -1. - int unmerged_index() const { return unmerged_; } + /// data), this will return std::nullopt. + const std::optional& unmerged_index() const { return unmerged_; } /// Read-only access to the optional verified signature if this message contained a valid, /// verified signature when it was parsed. Returns nullopt otherwise (e.g. not loaded from @@ -364,8 +364,6 @@ class MutableConfigMessage : public ConfigMessage { /// - `dict` -- a `bt_dict_consumer` positioned at or before the "~" key where the signature is /// expected. (If the bt_dict_consumer has already consumed the "~" key then this call will fail /// as if the signature was missing). -/// - `config_msg` -- the full config message; this must be a view of the same data in memory that -/// `dict` is parsing (i.e. it cannot be a copy). /// - `verifier` -- a callback to invoke to verify the signature of the message. If the callback is /// empty then the signature will be ignored (it is neither required nor verified). /// - `verified_signature` is a pointer to a std::optional array of signature data; if this is @@ -381,7 +379,6 @@ class MutableConfigMessage : public ConfigMessage { /// - throws on failure void verify_config_sig( oxenc::bt_dict_consumer dict, - ustring_view config_msg, const ConfigMessage::verify_callable& verifier, std::optional>* verified_signature = nullptr, bool trust_signature = false); diff --git a/include/session/config/base.h b/include/session/config/base.h index dbb9801f..0fa722a3 100644 --- a/include/session/config/base.h +++ b/include/session/config/base.h @@ -42,44 +42,6 @@ typedef struct config_object { /// - `conf` -- [in] Pointer to config_object object LIBSESSION_EXPORT void config_free(config_object* conf); -typedef enum config_log_level { - LOG_LEVEL_DEBUG = 0, - LOG_LEVEL_INFO, - LOG_LEVEL_WARNING, - LOG_LEVEL_ERROR -} config_log_level; - -/// API: base/config_set_logger -/// -/// Sets a logging function; takes the log function pointer and a context pointer (which can be NULL -/// if not needed). The given function pointer will be invoked with one of the above values, a -/// null-terminated c string containing the log message, and the void* context object given when -/// setting the logger (this is for caller-specific state data and won't be touched). -/// -/// The logging function must have signature: -/// -/// void log(config_log_level lvl, const char* msg, void* ctx); -/// -/// Can be called with callback set to NULL to clear an existing logger. -/// -/// The config object itself has no log level: the caller should filter by level as needed. -/// -/// Declaration: -/// ```cpp -/// VOID config_set_logger( -/// [in, out] config_object* conf, -/// [in] void(*)(config_log_level, const char*, void*) callback, -/// [in] void* ctx -/// ); -/// ``` -/// -/// Inputs: -/// - `conf` -- [in] Pointer to config_object object -/// - `callback` -- [in] Callback function -/// - `ctx` --- [in, optional] Pointer to an optional context. Set to NULL if unused -LIBSESSION_EXPORT void config_set_logger( - config_object* conf, void (*callback)(config_log_level, const char*, void*), void* ctx); - /// API: base/config_storage_namespace /// /// Returns the numeric namespace in which config messages of this type should be stored. @@ -242,7 +204,10 @@ LIBSESSION_EXPORT void config_confirm_pushed( /// - `conf` -- [in] Pointer to config_object object /// - `out` -- [out] Pointer to the output location /// - `outlen` -- [out] Length of output -LIBSESSION_EXPORT void config_dump(config_object* conf, unsigned char** out, size_t* outlen); +/// +/// Output: +/// - `bool` -- Returns true if the call succeeds, false if an error occurs. +LIBSESSION_EXPORT bool config_dump(config_object* conf, unsigned char** out, size_t* outlen); /// API: base/config_needs_dump /// @@ -288,6 +253,33 @@ LIBSESSION_EXPORT bool config_needs_dump(const config_object* conf); LIBSESSION_EXPORT config_string_list* config_current_hashes(const config_object* conf) LIBSESSION_WARN_UNUSED; +/// API: base/config_old_hashes +/// +/// Obtains the known old hashes. Note that this will be empty if there are no old hashes or +/// the config is in a dirty state (in which case these should be retrieved via the `push` +/// function). Calling this function, or the `push` function, will clear the stored old_hashes. +/// +/// The returned pointer belongs to the caller and must be freed via `free()` when done with it. +/// +/// Declaration: +/// ```cpp +/// CONFIG_STRING_LIST* config_old_hashes( +/// [in] const config_object* conf +/// ); +/// +/// ``` +/// +/// Inputs: +/// - `conf` -- [in] Pointer to config_object object +/// +/// Outputs: +/// - `config_string_list*` -- pointer to the list of hashes; the pointer belongs to the caller +LIBSESSION_EXPORT config_string_list* config_old_hashes(config_object* conf) +#ifdef __GNUC__ + __attribute__((warn_unused_result)) +#endif + ; + /// API: base/config_get_keys /// /// Obtains the current group decryption keys. @@ -322,7 +314,7 @@ LIBSESSION_EXPORT unsigned char* config_get_keys(const config_object* conf, size /// /// Declaration: /// ```cpp -/// VOID config_add_key( +/// BOOL config_add_key( /// [in, out] config_object* conf, /// [in] const unsigned char* key /// ); @@ -332,7 +324,10 @@ LIBSESSION_EXPORT unsigned char* config_get_keys(const config_object* conf, size /// Inputs: /// - `conf` -- [in, out] Pointer to config_object object /// - `key` -- [in] Pointer to the binary key object, must be 32 bytes -LIBSESSION_EXPORT void config_add_key(config_object* conf, const unsigned char* key); +/// +/// Output: +/// - `bool` -- Returns true if the call succeeds, false if an error occurs. +LIBSESSION_EXPORT bool config_add_key(config_object* conf, const unsigned char* key); /// API: base/config_add_key_low_prio /// @@ -341,7 +336,7 @@ LIBSESSION_EXPORT void config_add_key(config_object* conf, const unsigned char* /// /// Declaration: /// ```cpp -/// VOID config_add_key_low_prio( +/// BOOL config_add_key_low_prio( /// [in, out] config_object* conf, /// [in] const unsigned char* key /// ); @@ -351,7 +346,10 @@ LIBSESSION_EXPORT void config_add_key(config_object* conf, const unsigned char* /// Inputs: /// - `conf` -- [in, out] Pointer to config_object object /// - `key` -- [in] Pointer to the binary key object, must be 32 bytes -LIBSESSION_EXPORT void config_add_key_low_prio(config_object* conf, const unsigned char* key); +/// +/// Output: +/// - `bool` -- Returns true if the call succeeds, false if an error occurs. +LIBSESSION_EXPORT bool config_add_key_low_prio(config_object* conf, const unsigned char* key); /// API: base/config_clear_keys /// @@ -497,7 +495,10 @@ LIBSESSION_EXPORT const char* config_encryption_domain(const config_object* conf /// Inputs: /// - `secret` -- pointer to a 64-byte sodium-style Ed25519 "secret key" buffer (technically the /// seed+precomputed pubkey concatenated together) that sets both the secret key and public key. -LIBSESSION_EXPORT void config_set_sig_keys(config_object* conf, const unsigned char* secret); +/// +/// Output: +/// - `bool` -- Returns true if the call succeeds, false if an error occurs. +LIBSESSION_EXPORT bool config_set_sig_keys(config_object* conf, const unsigned char* secret); /// API: base/config_set_sig_pubkey /// @@ -507,7 +508,10 @@ LIBSESSION_EXPORT void config_set_sig_keys(config_object* conf, const unsigned c /// /// Inputs: /// - `pubkey` -- pointer to the 32-byte Ed25519 pubkey that must have signed incoming messages. -LIBSESSION_EXPORT void config_set_sig_pubkey(config_object* conf, const unsigned char* pubkey); +/// +/// Output: +/// - `bool` -- Returns true if the call succeeds, false if an error occurs. +LIBSESSION_EXPORT bool config_set_sig_pubkey(config_object* conf, const unsigned char* pubkey); /// API: base/config_get_sig_pubkey /// diff --git a/include/session/config/base.hpp b/include/session/config/base.hpp index 7601e3ca..d674201a 100644 --- a/include/session/config/base.hpp +++ b/include/session/config/base.hpp @@ -10,6 +10,7 @@ #include #include +#include "../logging.hpp" #include "../sodium_array.hpp" #include "base.h" #include "namespaces.hpp" @@ -33,9 +34,6 @@ template static constexpr bool is_dict_value = is_dict_subtype || is_one_of; -// Levels for the logging callback -enum class LogLevel { debug = 0, info, warning, error }; - /// Our current config state enum class ConfigState : int { /// Clean means the config is confirmed stored on the server and we haven't changed anything. @@ -165,7 +163,7 @@ class ConfigBase : public ConfigSig { std::string _curr_hash; // Contains obsolete known message hashes that are obsoleted by the most recent merge or push; - // these are returned (and cleared) when `push` is called. + // these are returned (and cleared) when `push` or `old_hashes` are called. std::unordered_set _old_hashes; protected: @@ -204,12 +202,6 @@ class ConfigBase : public ConfigSig { // deleted at the next push. void set_state(ConfigState s); - // Invokes the `logger` callback if set, does nothing if there is no logger. - void log(LogLevel lvl, std::string msg) { - if (logger) - logger(lvl, std::move(msg)); - } - // Returns a reference to the current MutableConfigMessage. If the current message is not // already dirty (i.e. Clean or Waiting) then calling this increments the seqno counter. MutableConfigMessage& dirty(); @@ -851,9 +843,6 @@ class ConfigBase : public ConfigSig { // Proxy class providing read and write access to the contained config data. const DictFieldRoot data{*this}; - // If set then we log things by calling this callback - std::function logger; - /// API: base/ConfigBase::storage_namespace /// /// Accesses the storage namespace where this config type is to be stored/loaded from. See @@ -1018,6 +1007,18 @@ class ConfigBase : public ConfigSig { /// - `std::vector` -- Returns current config hashes std::vector current_hashes() const; + /// API: base/ConfigBase::old_hashes + /// + /// The old config hash(es); this can be empty if there are no old hashes or if the config is in + /// a dirty state (in which case these should be retrieved via the `push` function). Calling + /// this function or the `push` function will clear the stored old_hashes. + /// + /// Inputs: None + /// + /// Outputs: + /// - `std::vector` -- Returns old config hashes + std::vector old_hashes(); + /// API: base/ConfigBase::needs_push /// /// Returns true if this object contains updated data that has not yet been confirmed stored on @@ -1290,24 +1291,47 @@ inline const internals& unbox(const config_object* conf) { return *static_cast*>(conf->internals); } -// Sets an error message in the internals.error string and updates the last_error pointer in the -// outer (C) config_object struct to point at it. -void set_error(config_object* conf, std::string e); - -// Same as above, but gets the error string out of an exception and passed through a return value. -// Intended to simplify catch-and-return-error such as: -// try { -// whatever(); -// } catch (const std::exception& e) { -// return set_error(conf, LIB_SESSION_ERR_OHNOES, e); -// } -inline int set_error(config_object* conf, int errcode, const std::exception& e) { - set_error(conf, e.what()); - return errcode; +template +void copy_c_str(char (&dest)[N], std::string_view src) { + if (src.size() >= N) + src.remove_suffix(src.size() - N - 1); + std::memcpy(dest, src.data(), src.size()); + dest[src.size()] = 0; } -// Copies a value contained in a string into a new malloced char buffer, returning the buffer and -// size via the two pointer arguments. -void copy_out(ustring_view data, unsigned char** out, size_t* outlen); +// Wraps a labmda and, if an exception is thrown, sets an error message in the internals.error +// string and updates the last_error pointer in the outer (C) config_object struct to point at it. +// +// No return value: accepts void and pointer returns; pointer returns will become nullptr on error +template +decltype(auto) wrap_exceptions(config_object* conf, Call&& f) { + using Ret = std::invoke_result_t; + + try { + conf->last_error = nullptr; + return std::invoke(std::forward(f)); + } catch (const std::exception& e) { + copy_c_str(conf->_error_buf, e.what()); + conf->last_error = conf->_error_buf; + } + if constexpr (std::is_pointer_v) + return static_cast(nullptr); + else + static_assert(std::is_void_v, "Don't know how to return an error value!"); +} + +// Same as above but accepts callbacks with value returns on errors: returns `f()` on success, +// `error_return` on exception +template +Ret wrap_exceptions(config_object* conf, Call&& f, Ret error_return) { + try { + conf->last_error = nullptr; + return std::invoke(std::forward(f)); + } catch (const std::exception& e) { + copy_c_str(conf->_error_buf, e.what()); + conf->last_error = conf->_error_buf; + } + return error_return; +} } // namespace session::config diff --git a/include/session/config/contacts.h b/include/session/config/contacts.h index ee863074..e2752153 100644 --- a/include/session/config/contacts.h +++ b/include/session/config/contacts.h @@ -152,8 +152,8 @@ LIBSESSION_EXPORT bool contacts_get_or_construct( /// - `contact` -- [in] Pointer containing the contact info data /// /// Output: -/// - `void` -- Returns Nothing -LIBSESSION_EXPORT void contacts_set(config_object* conf, const contacts_contact* contact); +/// - `bool` -- Returns true if the call succeeds, false if an error occurs. +LIBSESSION_EXPORT bool contacts_set(config_object* conf, const contacts_contact* contact); // NB: wrappers for set_name, set_nickname, etc. C++ methods are deliberately omitted as they would // save very little in actual calling code. The procedure for updating a single field without them diff --git a/include/session/config/convo_info_volatile.h b/include/session/config/convo_info_volatile.h index eacecdb9..952b6ff7 100644 --- a/include/session/config/convo_info_volatile.h +++ b/include/session/config/convo_info_volatile.h @@ -360,7 +360,10 @@ LIBSESSION_EXPORT bool convo_info_volatile_get_or_construct_legacy_group( /// Inputs: /// - `conf` -- [in] Pointer to the config object /// - `convo` -- [in] Pointer to conversation info structure -LIBSESSION_EXPORT void convo_info_volatile_set_1to1( +/// +/// Output: +/// - `bool` -- Returns true if the call succeeds, false if an error occurs. +LIBSESSION_EXPORT bool convo_info_volatile_set_1to1( config_object* conf, const convo_info_volatile_1to1* convo); /// API: convo_info_volatile/convo_info_volatile_set_community @@ -378,7 +381,10 @@ LIBSESSION_EXPORT void convo_info_volatile_set_1to1( /// Inputs: /// - `conf` -- [in] Pointer to the config object /// - `convo` -- [in] Pointer to community info structure -LIBSESSION_EXPORT void convo_info_volatile_set_community( +/// +/// Output: +/// - `bool` -- Returns true if the call succeeds, false if an error occurs. +LIBSESSION_EXPORT bool convo_info_volatile_set_community( config_object* conf, const convo_info_volatile_community* convo); /// API: convo_info_volatile/convo_info_volatile_set_group @@ -396,7 +402,10 @@ LIBSESSION_EXPORT void convo_info_volatile_set_community( /// Inputs: /// - `conf` -- [in] Pointer to the config object /// - `convo` -- [in] Pointer to group info structure -LIBSESSION_EXPORT void convo_info_volatile_set_group( +/// +/// Output: +/// - `bool` -- Returns true if the call succeeds, false if an error occurs. +LIBSESSION_EXPORT bool convo_info_volatile_set_group( config_object* conf, const convo_info_volatile_group* convo); /// API: convo_info_volatile/convo_info_volatile_set_legacy_group @@ -414,7 +423,10 @@ LIBSESSION_EXPORT void convo_info_volatile_set_group( /// Inputs: /// - `conf` -- [in] Pointer to the config object /// - `convo` -- [in] Pointer to legacy group info structure -LIBSESSION_EXPORT void convo_info_volatile_set_legacy_group( +/// +/// Output: +/// - `bool` -- Returns true if the call succeeds, false if an error occurs. +LIBSESSION_EXPORT bool convo_info_volatile_set_legacy_group( config_object* conf, const convo_info_volatile_legacy_group* convo); /// API: convo_info_volatile/convo_info_volatile_erase_1to1 diff --git a/include/session/config/groups/info.h b/include/session/config/groups/info.h index 04456e52..4e6a85af 100644 --- a/include/session/config/groups/info.h +++ b/include/session/config/groups/info.h @@ -6,7 +6,6 @@ extern "C" { #include "../base.h" #include "../profile_pic.h" -#include "../util.h" LIBSESSION_EXPORT extern const size_t GROUP_INFO_NAME_MAX_LENGTH; LIBSESSION_EXPORT extern const size_t GROUP_INFO_DESCRIPTION_MAX_LENGTH; diff --git a/include/session/config/groups/info.hpp b/include/session/config/groups/info.hpp index fbd44629..e55b9377 100644 --- a/include/session/config/groups/info.hpp +++ b/include/session/config/groups/info.hpp @@ -32,7 +32,7 @@ class Info : public ConfigBase { /// Limits for the name & description strings, in bytes. If longer, we truncate to these /// lengths: static constexpr size_t NAME_MAX_LENGTH = 100; // same as base_group_info::NAME_MAX_LENGTH - static constexpr size_t DESCRIPTION_MAX_LENGTH = 2000; + static constexpr size_t DESCRIPTION_MAX_LENGTH = 600; // No default constructor Info() = delete; diff --git a/include/session/config/groups/keys.hpp b/include/session/config/groups/keys.hpp index f637616c..fbe5e773 100644 --- a/include/session/config/groups/keys.hpp +++ b/include/session/config/groups/keys.hpp @@ -1,13 +1,11 @@ #pragma once #include -#include #include #include "../../config.hpp" #include "../base.hpp" #include "../namespaces.hpp" -#include "../profile_pic.hpp" #include "members.hpp" namespace session::config::groups { @@ -140,12 +138,11 @@ class Keys : public ConfigSig { public: /// The multiple of members keys we include in the message; we add junk entries to the key list - /// to reach a multiple of this. 75 is chosen because it's a decently large human-round number - /// that should still fit within 4kiB page size on the storage server (allowing for some extra - /// row field storage). - static constexpr int MESSAGE_KEY_MULTIPLE = 75; + /// to reach a multiple of this. 45 is chosen because it's a decently large human-round number + /// that should still fit within 2.5kiB size limitation for push notifications. + static constexpr int MESSAGE_KEY_MULTIPLE = 45; - // 75 because: + // 45 because: // 2 // for the 'de' delimiters of the outer dict // + 3 + 2 + 12 // for the `1:g` and `iNNNNNNNNNNe` generation keypair // + 3 + 3 + 24 // for the `1:n`, `24:`, and 24 byte nonce @@ -155,7 +152,9 @@ class Keys : public ConfigSig { // + 3 + 3 + 64; // for the `1:~` and `64:` and 64 byte signature // = 177 + 48N // - // and N=75 puts us a little bit under 4kiB (which is sqlite's default page size). + // and N=45 puts us a little bit under 2.5kiB (which is the limit we have on the push + // notification server because after base64 encoding it gets close to the 4kiB limit for push + // notification content). /// A key expires when it has been surpassed by another key for at least this amount of time. /// We default this to double the 30 days that we strictly need to avoid race conditions with diff --git a/include/session/config/groups/members.h b/include/session/config/groups/members.h index a02062bf..d502fbe2 100644 --- a/include/session/config/groups/members.h +++ b/include/session/config/groups/members.h @@ -128,7 +128,10 @@ LIBSESSION_EXPORT bool groups_members_get_or_construct( /// Inputs: /// - `conf` -- [in, out] Pointer to the config object /// - `member` -- [in] Pointer containing the member info data -LIBSESSION_EXPORT void groups_members_set(config_object* conf, const config_group_member* member); +/// +/// Output: +/// - `bool` -- Returns true if the call succeeds, false if an error occurs. +LIBSESSION_EXPORT bool groups_members_set(config_object* conf, const config_group_member* member); /// API: groups/groups_members_get_status /// diff --git a/include/session/config/profile_pic.hpp b/include/session/config/profile_pic.hpp index 59601a14..93259c95 100644 --- a/include/session/config/profile_pic.hpp +++ b/include/session/config/profile_pic.hpp @@ -1,9 +1,8 @@ #pragma once +#include #include -#include "session/types.hpp" - namespace session::config { // Profile pic info. diff --git a/include/session/config/protos.hpp b/include/session/config/protos.hpp index d9eec29b..f7d1371d 100644 --- a/include/session/config/protos.hpp +++ b/include/session/config/protos.hpp @@ -1,7 +1,8 @@ #pragma once +#include + #include "namespaces.hpp" -#include "session/util.hpp" namespace session::config::protos { diff --git a/include/session/config/user_groups.h b/include/session/config/user_groups.h index 4f065e10..e314e075 100644 --- a/include/session/config/user_groups.h +++ b/include/session/config/user_groups.h @@ -327,7 +327,10 @@ LIBSESSION_EXPORT void user_groups_set_community( /// Inputs: /// - `conf` -- [in] Pointer to config_object object /// - `group` -- [in] Pointer to a group info object -LIBSESSION_EXPORT void user_groups_set_group(config_object* conf, const ugroups_group_info* group); +/// +/// Output: +/// - `bool` -- Returns true if the call succeeds, false if an error occurs. +LIBSESSION_EXPORT bool user_groups_set_group(config_object* conf, const ugroups_group_info* group); /// API: user_groups/user_groups_set_legacy_group /// @@ -347,7 +350,10 @@ LIBSESSION_EXPORT void user_groups_set_group(config_object* conf, const ugroups_ /// Inputs: /// - `conf` -- [in] Pointer to config_object object /// - `group` -- [in] Pointer to a legacy group info object -LIBSESSION_EXPORT void user_groups_set_legacy_group( +/// +/// Output: +/// - `bool` -- Returns true if the call succeeds, false if an error occurs. +LIBSESSION_EXPORT bool user_groups_set_legacy_group( config_object* conf, const ugroups_legacy_group_info* group); /// API: user_groups/user_groups_set_free_legacy_group @@ -367,7 +373,10 @@ LIBSESSION_EXPORT void user_groups_set_legacy_group( /// Inputs: /// - `conf` -- [in] Pointer to config_object object /// - `group` -- [in] Pointer to a legacy group info object -LIBSESSION_EXPORT void user_groups_set_free_legacy_group( +/// +/// Output: +/// - `bool` -- Returns true if the call succeeds, false if an error occurs. +LIBSESSION_EXPORT bool user_groups_set_free_legacy_group( config_object* conf, ugroups_legacy_group_info* group); /// API: user_groups/user_groups_erase_community diff --git a/include/session/config/user_groups.hpp b/include/session/config/user_groups.hpp index 3d905c7d..077130a5 100644 --- a/include/session/config/user_groups.hpp +++ b/include/session/config/user_groups.hpp @@ -254,7 +254,7 @@ struct community_info : base_group_info, community { void load(const dict& info_dict); friend class UserGroups; - friend class comm_iterator_helper; + friend struct comm_iterator_helper; }; using any_group_info = std::variant; diff --git a/include/session/curve25519.hpp b/include/session/curve25519.hpp index 34a437ad..21e8de47 100644 --- a/include/session/curve25519.hpp +++ b/include/session/curve25519.hpp @@ -7,7 +7,7 @@ namespace session::curve25519 { /// Generates a random curve25519 key pair -std::pair, std::array> curve25519_key_pair(); +std::pair, std::array> curve25519_key_pair(); /// API: curve25519/to_curve25519_pubkey /// diff --git a/include/session/file.hpp b/include/session/file.hpp new file mode 100644 index 00000000..a2bf747d --- /dev/null +++ b/include/session/file.hpp @@ -0,0 +1,28 @@ +#pragma once +#include +#include +#include +#include + +// Utility functions for working with files + +namespace session { + +namespace fs = std::filesystem; + +/// Opens a file for writing of binary data, setting up the returned ofstream with exceptions +/// enabled for any failures. This also throws if the file cannot be opened. If the file already +/// exists it will be truncated. +std::ofstream open_for_writing(const fs::path& filename); + +/// Opens a file for reading of binary data, setting up the returned ifstream with exceptions +/// enabled for any failures. This also throws if the file cannot be opened. +std::ifstream open_for_reading(const fs::path& filename); + +/// Reads a (binary) file from disk into the string `contents`. +std::string read_whole_file(const fs::path& filename); + +/// Dumps (binary) string contents to disk. The file is overwritten if it already exists. +void write_whole_file(const fs::path& filename, std::string_view contents = ""); + +} // namespace session diff --git a/include/session/log_level.h b/include/session/log_level.h new file mode 100644 index 00000000..bc941424 --- /dev/null +++ b/include/session/log_level.h @@ -0,0 +1,20 @@ +#pragma once + +#ifdef __cplusplus +extern "C" { +#endif + +// Note: These values must match the values in spdlog::level::level_enum +typedef enum LOG_LEVEL { + LOG_LEVEL_TRACE = 0, + LOG_LEVEL_DEBUG = 1, + LOG_LEVEL_INFO = 2, + LOG_LEVEL_WARN = 3, + LOG_LEVEL_ERROR = 4, + LOG_LEVEL_CRITICAL = 5, + LOG_LEVEL_OFF = 6, +} LOG_LEVEL; + +#ifdef __cplusplus +} +#endif diff --git a/include/session/logging.h b/include/session/logging.h new file mode 100644 index 00000000..98811934 --- /dev/null +++ b/include/session/logging.h @@ -0,0 +1,73 @@ +#pragma once + +#ifdef __cplusplus +extern "C" { +#endif + +#include + +#include "export.h" +#include "log_level.h" + +/// API: session/session_add_logger_simple +/// +/// Registers a callback that is invoked when a message is logged. This callback is invoked with +/// just the log message. +/// +/// Inputs: +/// - `callback` -- [in] callback to be called when a new message should be logged. +LIBSESSION_EXPORT void session_add_logger_simple(void (*callback)(const char* msg, size_t msglen)); + +/// API: session/session_add_logger_full +/// +/// Registers a callback that is invoked when a message is logged. The callback is invoked with the +/// log message, the category name of the log message, and the level of the message. +/// +/// Inputs: +/// - `callback` -- [in] callback to be called when a new message should be logged. +LIBSESSION_EXPORT void session_add_logger_full(void (*callback)( + const char* msg, size_t msglen, const char* cat, size_t cat_len, LOG_LEVEL level)); + +/// API: session/session_logger_reset_level +/// +/// Resets the log level of all existing category loggers, and sets a new default for any created +/// after this call. If this has not been called, the default log level of category loggers is +/// info. +LIBSESSION_EXPORT void session_logger_reset_level(LOG_LEVEL level); + +/// API: session/session_logger_set_level_default +/// +/// Sets the log level of new category loggers initialized after this call, but does not change the +/// log level of already-initialized category loggers. +LIBSESSION_EXPORT void session_logger_set_level_default(LOG_LEVEL level); + +/// API: session/session_logger_get_level_default +/// +/// Gets the default log level of new loggers (since the last reset_level or set_level_default +/// call). +LIBSESSION_EXPORT LOG_LEVEL session_logger_get_level_default(); + +/// API: session/session_logger_set_level +/// +/// Set the log level of a specific logger category +LIBSESSION_EXPORT void session_logger_set_level(const char* cat_name, LOG_LEVEL level); + +/// API: session/session_logger_get_level +/// +/// Gets the log level of a specific logger category +LIBSESSION_EXPORT LOG_LEVEL session_logger_get_level(const char* cat_name); + +/// API: session/session_manual_log +/// +/// Logs the provided value via oxen::log, can be used to test that the loggers are working +/// correctly +LIBSESSION_EXPORT void session_manual_log(const char* msg); + +/// API: session/session_clear_loggers +/// +/// Clears all currently set loggers +LIBSESSION_EXPORT void session_clear_loggers(); + +#ifdef __cplusplus +} +#endif diff --git a/include/session/logging.hpp b/include/session/logging.hpp new file mode 100644 index 00000000..2b949120 --- /dev/null +++ b/include/session/logging.hpp @@ -0,0 +1,112 @@ +#pragma once + +#include +#include + +#include "log_level.h" + +// forward declaration +namespace spdlog::level { +enum level_enum : int; +} + +namespace session { + +// This is working roughly like an enum class, but with some useful conversions and comparisons +// defined. We allow implicit conversion to this from a spdlog level_enum, and explicit conversion +// *to* a level_enum, as well as comparison operators (so that, for example, LogLevel::warn >= +// LogLevel::info). +struct LogLevel { + int level; + + LogLevel(spdlog::level::level_enum lvl); + explicit constexpr LogLevel(int lvl) : level{lvl} {} + + // Returns the log level as an spdlog enum (which is also a oxen::log::Level). + spdlog::level::level_enum spdlog_level() const; + + std::string_view to_string() const; + + static const LogLevel trace; + static const LogLevel debug; + static const LogLevel info; + static const LogLevel warn; + static const LogLevel error; + static const LogLevel critical; + + auto operator<=>(const LogLevel& other) const { return level <=> other.level; } +}; + +inline const LogLevel LogLevel::trace{LOG_LEVEL_TRACE}; +inline const LogLevel LogLevel::debug{LOG_LEVEL_DEBUG}; +inline const LogLevel LogLevel::info{LOG_LEVEL_INFO}; +inline const LogLevel LogLevel::warn{LOG_LEVEL_WARN}; +inline const LogLevel LogLevel::error{LOG_LEVEL_ERROR}; +inline const LogLevel LogLevel::critical{LOG_LEVEL_CRITICAL}; + +/// API: add_logger +/// +/// Adds a logger callback for oxen-logging log messages (such as from the network object). +/// +/// Inputs: +/// - `callback` -- [in] callback to be called when a new message should be logged. This +/// callback must be callable as one of: +/// +/// callback(std::string_view msg) +/// callback(std::string_view msg, std::string_view log_cat, LogLevel level) +/// +void add_logger(std::function cb); +void add_logger( + std::function cb); + +/// API: session/logger_reset_level +/// +/// Resets the log level of all existing category loggers, and sets a new default for any created +/// after this call. If this has not been called, the default log level of category loggers is +/// info. +/// +/// This function is simply a wrapper around oxen::log::reset_level +void logger_reset_level(LogLevel level); + +/// API: session/logger_set_level_default +/// +/// Sets the log level of new category loggers initialized after this call, but does not change the +/// log level of already-initialized category loggers. +/// +/// This function is simply a wrapper around oxen::log::set_level_default +void logger_set_level_default(LogLevel level); + +/// API: session/logger_get_level_default +/// +/// Gets the default log level of new loggers (since the last reset_level or set_level_default +/// call). +/// +/// This function is simply a wrapper around oxen::log::get_level_default +LogLevel logger_get_level_default(); + +/// API: session/logger_set_level +/// +/// Set the log level of a specific logger category +/// +/// This function is simply a wrapper around oxen::log::set_level +void logger_set_level(std::string cat_name, LogLevel level); + +/// API: session/logger_get_level +/// +/// Gets the log level of a specific logger category +/// +/// This function is simply a wrapper around oxen::log::get_level +LogLevel logger_get_level(std::string cat_name); + +/// API: session/manual_log +/// +/// Logs the provided value via oxen::log, can be used to test that the loggers are working +/// correctly +void manual_log(std::string_view msg); + +/// API: session/clear_loggers +/// +/// Clears all currently set loggers +void clear_loggers(); + +} // namespace session diff --git a/include/session/network.h b/include/session/network.h new file mode 100644 index 00000000..3d9af0d7 --- /dev/null +++ b/include/session/network.h @@ -0,0 +1,351 @@ +#pragma once + +#ifdef __cplusplus +extern "C" { +#endif + +#include +#include + +#include "export.h" +#include "log_level.h" +#include "onionreq/builder.h" +#include "platform.h" + +typedef enum CONNECTION_STATUS { + CONNECTION_STATUS_UNKNOWN = 0, + CONNECTION_STATUS_CONNECTING = 1, + CONNECTION_STATUS_CONNECTED = 2, + CONNECTION_STATUS_DISCONNECTED = 3, +} CONNECTION_STATUS; + +typedef struct network_object { + // Internal opaque object pointer; calling code should leave this alone. + void* internals; +} network_object; + +typedef struct network_service_node { + uint8_t ip[4]; + uint16_t quic_port; + char ed25519_pubkey_hex[65]; // The 64-byte ed25519 pubkey in hex + null terminator. +} network_service_node; + +typedef struct network_server_destination { + const char* method; + const char* protocol; + const char* host; + const char* endpoint; + uint16_t port; + const char* x25519_pubkey; + const char** headers; + const char** header_values; + size_t headers_size; +} network_server_destination; + +typedef struct onion_request_path { + const network_service_node* nodes; + const size_t nodes_count; +} onion_request_path; + +/// API: network/network_init +/// +/// Constructs a new network object. +/// +/// When done with the object the `network_object` must be destroyed by passing the pointer to +/// network_free(). +/// +/// Inputs: +/// - `network` -- [out] Pointer to the network object +/// - `cache_path` -- [in] Path where the snode cache files should be stored. Should be +/// NULL-terminated. +/// - `use_testnet` -- [in] Flag indicating whether the network should connect to testnet or +/// mainnet. +/// - `single_path_mode` -- [in] Flag indicating whether the network should be in "single path mode" +/// (ie. use a single path for everything - this is useful for iOS App Extensions which perform a +/// single action and then close so we don't waste time building other paths). +/// - `pre_build_paths` -- [in] Flag indicating whether the network should pre-build it's paths. +/// - `error` -- [out] the pointer to a buffer in which we will write an error string if an error +/// occurs; error messages are discarded if this is given as NULL. If non-NULL this must be a +/// buffer of at least 256 bytes. +/// +/// Outputs: +/// - `bool` -- Returns true on success; returns false and write the exception message as a C-string +/// into `error` (if not NULL) on failure. +LIBSESSION_EXPORT bool network_init( + network_object** network, + const char* cache_path, + bool use_testnet, + bool single_path_mode, + bool pre_build_paths, + char* error) __attribute__((warn_unused_result)); + +/// API: network/network_free +/// +/// Frees a network object. +/// +/// Inputs: +/// - `network` -- [in] Pointer to network_object object +LIBSESSION_EXPORT void network_free(network_object* network); + +/// API: network/network_suspend +/// +/// Suspends the network preventing any further requests from creating new connections and paths. +/// This function also calls the `close_connections` function. +LIBSESSION_EXPORT void network_suspend(network_object* network); + +/// API: network/network_resume +/// +/// Resumes the network allowing new requests to creating new connections and paths. +LIBSESSION_EXPORT void network_resume(network_object* network); + +/// API: network/network_close_connections +/// +/// Closes any currently active connections. +LIBSESSION_EXPORT void network_close_connections(network_object* network); + +/// API: network/network_clear_cache +/// +/// Clears the cached from memory and from disk (if a cache path was provided during +/// initialization). +LIBSESSION_EXPORT void network_clear_cache(network_object* network); + +/// API: network/network_get_cache_size +/// +/// Retrieves the current size of the snode cache from memory (if a cache doesn't exist or +/// hasn't been loaded then this will return 0). +LIBSESSION_EXPORT size_t network_get_snode_cache_size(network_object* network); + +/// API: network/network_set_status_changed_callback +/// +/// Registers a callback to be called whenever the network connection status changes. +/// +/// Inputs: +/// - `network` -- [in] Pointer to the network object +/// - `callback` -- [in] callback to be called when the network connection status changes. +/// - `ctx` -- [in, optional] Pointer to an optional context. Set to NULL if unused. +LIBSESSION_EXPORT void network_set_status_changed_callback( + network_object* network, void (*callback)(CONNECTION_STATUS status, void* ctx), void* ctx); + +/// API: network/network_set_paths_changed_callback +/// +/// Registers a callback to be called whenever the onion request paths are updated. +/// +/// The pointer provided to the callback belongs to the caller and must be freed via `free()` when +/// done with it. +/// +/// Inputs: +/// - `network` -- [in] Pointer to the network object +/// - `callback` -- [in] callback to be called when the onion request paths are updated. +/// - `ctx` -- [in, optional] Pointer to an optional context. Set to NULL if unused. +LIBSESSION_EXPORT void network_set_paths_changed_callback( + network_object* network, + void (*callback)(onion_request_path* paths, size_t paths_len, void* ctx), + void* ctx); + +/// API: network/network_get_swarm +/// +/// Retrieves the swarm for the given pubkey. If there is already an entry in the cache for the +/// swarm then that will be returned, otherwise a network request will be made to retrieve the +/// swarm and save it to the cache. +/// +/// Inputs: +/// - `network` -- [in] Pointer to the network object +/// - 'swarm_pubkey_hex' - [in] x25519 pubkey for the swarm in hex (64 characters). +/// - 'callback' - [in] callback to be called with the retrieved swarm (in the case of an error +/// the callback will be called with an empty list). +/// - `ctx` -- [in, optional] Pointer to an optional context. Set to NULL if unused. +LIBSESSION_EXPORT void network_get_swarm( + network_object* network, + const char* swarm_pubkey_hex, + void (*callback)(network_service_node* nodes, size_t nodes_len, void*), + void* ctx); + +/// API: network/network_get_random_nodes +/// +/// Retrieves a number of random nodes from the snode pool. If the are no nodes in the pool a +/// new pool will be populated and the nodes will be retrieved from that. +/// +/// Inputs: +/// - `network` -- [in] Pointer to the network object +/// - 'count' - [in] the number of nodes to retrieve. +/// - 'callback' - [in] callback to be called with the retrieved nodes (in the case of an error +/// the callback will be called with an empty list). +/// - `ctx` -- [in, optional] Pointer to an optional context. Set to NULL if unused. +LIBSESSION_EXPORT void network_get_random_nodes( + network_object* network, + uint16_t count, + void (*callback)(network_service_node*, size_t, void*), + void* ctx); + +/// API: network/network_onion_response_callback_t +/// +/// Function pointer typedef for the callback function pointer given to +/// network_send_onion_request_to_snode_destination and +/// network_send_onion_request_to_server_destination. +/// +/// Fields: +/// - `success` -- true if the request was successful, false if it failed. +/// - `timeout` -- true if the request failed because of a timeout +/// - `status_code` -- the HTTP numeric status code of the request, e.g. 200 for OK +/// - `headers` -- the response headers, array of null-terminated C strings +/// - `header_values` -- the response header values, array of null-terminated C strings +/// - `headers_size` -- the number of `headers`/`header_values` +/// - `response` -- pointer to the beginning of the response body +/// - `response_size` -- length of the response body +/// - `ctx` -- the context pointer passed to the function that initiated the request. +typedef void (*network_onion_response_callback_t)( + bool success, + bool timeout, + int16_t status_code, + const char** headers, + const char** header_values, + size_t headers_size, + const char* response, + size_t response_size, + void* ctx); + +/// API: network/network_send_onion_request_to_snode_destination +/// +/// Sends a request via onion routing to the provided service node. +/// +/// Inputs: +/// - `network` -- [in] Pointer to the network object. +/// - `node` -- [in] address information about the service node the request should be sent to. +/// - `body` -- [in] data to send to the specified node. +/// - `body_size` -- [in] size of the `body`. +/// - `request_timeout_ms` -- [in] timeout in milliseconds to use for the request. This won't take +/// the path build into account so if the path build takes forever then this request will never +/// timeout. +/// - `request_and_path_build_timeout_ms` -- [in] timeout in milliseconds to use for the request and +/// path build (if required). This value takes presedence over `request_timeout_ms` if provided, +/// the request itself will be given a timeout of this value subtracting however long it took to +/// build the path. A value of `0` will be ignored and `request_timeout_ms` will be used instead. +/// - `callback` -- [in] callback to be called with the result of the request. +/// - `ctx` -- [in, optional] Pointer to an optional context to pass through to the callback. Set to +/// NULL if unused. +LIBSESSION_EXPORT void network_send_onion_request_to_snode_destination( + network_object* network, + const network_service_node node, + const unsigned char* body, + size_t body_size, + const char* swarm_pubkey_hex, + int64_t request_timeout_ms, + int64_t request_and_path_build_timeout_ms, + network_onion_response_callback_t callback, + void* ctx); + +/// API: network/network_send_onion_request_to_server_destination +/// +/// Sends a request via onion routing to the provided server. +/// +/// Inputs: +/// - `network` -- [in] Pointer to the network object. +/// - `server` -- [in] struct containing information about the server the request should be sent to. +/// - `body` -- [in] data to send to the specified endpoint. +/// - `body_size` -- [in] size of the `body`. +/// - `request_timeout_ms` -- [in] timeout in milliseconds to use for the request. This won't take +/// the path build into account so if the path build takes forever then this request will never +/// timeout. +/// - `request_and_path_build_timeout_ms` -- [in] timeout in milliseconds to use for the request and +/// path build (if required). This value takes presedence over `request_timeout_ms` if provided, +/// the request itself will be given a timeout of this value subtracting however long it took to +/// build the path. A value of `0` will be ignored and `request_timeout_ms` will be used instead. +/// - `callback` -- [in] callback to be called with the result of the request. +/// - `ctx` -- [in, optional] Pointer to an optional context to pass through to the callback. Set +/// to NULL if unused. +LIBSESSION_EXPORT void network_send_onion_request_to_server_destination( + network_object* network, + const network_server_destination server, + const unsigned char* body, + size_t body_size, + int64_t request_timeout_ms, + int64_t request_and_path_build_timeout_ms, + network_onion_response_callback_t callback, + void* ctx); + +/// API: network/network_upload_to_server +/// +/// Uploads a file to a server. +/// +/// Inputs: +/// - `network` -- [in] Pointer to the network object. +/// - `server` -- [in] struct containing information about the server the request should be sent to. +/// - `data` -- [in] data to upload to the file server. +/// - `data_len` -- [in] size of the `data`. +/// - `file_name` -- [in, optional] name of the file being uploaded. MUST be null terminated. +/// - `request_timeout_ms` -- [in] timeout in milliseconds to use for the request. This won't take +/// the path build into account so if the path build takes forever then this request will never +/// timeout. +/// - `request_and_path_build_timeout_ms` -- [in] timeout in milliseconds to use for the request and +/// path build (if required). This value takes presedence over `request_timeout_ms` if provided, +/// the request itself will be given a timeout of this value subtracting however long it took to +/// build the path. A value of `0` will be ignored and `request_timeout_ms` will be used instead. +/// - `callback` -- [in] callback to be called with the result of the request. +/// - `ctx` -- [in, optional] Pointer to an optional context to pass through to the callback. Set +/// to NULL if unused. +LIBSESSION_EXPORT void network_upload_to_server( + network_object* network, + const network_server_destination server, + const unsigned char* data, + size_t data_len, + const char* file_name, + int64_t request_timeout_ms, + int64_t request_and_path_build_timeout_ms, + network_onion_response_callback_t callback, + void* ctx); + +/// API: network/network_download_from_server +/// +/// Downloads a file from a server. +/// +/// Inputs: +/// - `network` -- [in] Pointer to the network object. +/// - `server` -- [in] struct containing information about file to be downloaded. +/// - `request_timeout_ms` -- [in] timeout in milliseconds to use for the request. This won't take +/// the path build into account so if the path build takes forever then this request will never +/// timeout. +/// - `request_and_path_build_timeout_ms` -- [in] timeout in milliseconds to use for the request and +/// path build (if required). This value takes presedence over `request_timeout_ms` if provided, +/// the request itself will be given a timeout of this value subtracting however long it took to +/// build the path. A value of `0` will be ignored and `request_timeout_ms` will be used instead. +/// - `callback` -- [in] callback to be called with the result of the request. +/// - `ctx` -- [in, optional] Pointer to an optional context to pass through to the callback. Set +/// to NULL if unused. +LIBSESSION_EXPORT void network_download_from_server( + network_object* network, + const network_server_destination server, + int64_t request_timeout_ms, + int64_t request_and_path_build_timeout_ms, + network_onion_response_callback_t callback, + void* ctx); + +/// API: network/network_get_client_version +/// +/// Retrieves the version information for the given platform. +/// +/// Inputs: +/// - `network` -- [in] Pointer to the network object. +/// - `platform` -- [in] the platform to retrieve the client version for. +/// - `ed25519_secret` -- [in] the users ed25519 secret key (used for blinded auth - 64 bytes). +/// - `request_timeout_ms` -- [in] timeout in milliseconds to use for the request. This won't take +/// the path build into account so if the path build takes forever then this request will never +/// timeout. +/// - `request_and_path_build_timeout_ms` -- [in] timeout in milliseconds to use for the request and +/// path build (if required). This value takes presedence over `request_timeout_ms` if provided, +/// the request itself will be given a timeout of this value subtracting however long it took to +/// build the path. A value of `0` will be ignored and `request_timeout_ms` will be used instead. +/// - `callback` -- [in] callback to be called with the result of the request. +/// - `ctx` -- [in, optional] Pointer to an optional context to pass through to the callback. Set +/// to NULL if unused. +LIBSESSION_EXPORT void network_get_client_version( + network_object* network, + CLIENT_PLATFORM platform, + const unsigned char* ed25519_secret, /* 64 bytes */ + int64_t request_timeout_ms, + int64_t request_and_path_build_timeout_ms, + network_onion_response_callback_t callback, + void* ctx); + +#ifdef __cplusplus +} +#endif diff --git a/include/session/network.hpp b/include/session/network.hpp new file mode 100644 index 00000000..aa0d1576 --- /dev/null +++ b/include/session/network.hpp @@ -0,0 +1,754 @@ +#pragma once + +#include +#include + +#include "onionreq/builder.hpp" +#include "onionreq/key_types.hpp" +#include "platform.hpp" +#include "session/random.hpp" +#include "types.hpp" + +namespace session::network { + +namespace fs = std::filesystem; + +using network_response_callback_t = std::function> headers, + std::optional response)>; + +enum class ConnectionStatus { + unknown, + connecting, + connected, + disconnected, +}; + +enum class PathType { + standard, + upload, + download, +}; + +using swarm_id_t = uint64_t; +constexpr swarm_id_t INVALID_SWARM_ID = std::numeric_limits::max(); + +struct service_node : public oxen::quic::RemoteAddress { + public: + std::vector storage_server_version; + swarm_id_t swarm_id; + + service_node() = delete; + + template + service_node( + std::string_view remote_pk, + std::vector storage_server_version, + swarm_id_t swarm_id, + Opt&&... opts) : + oxen::quic::RemoteAddress{remote_pk, std::forward(opts)...}, + storage_server_version{storage_server_version}, + swarm_id{swarm_id} {} + + template + service_node( + ustring_view remote_pk, + std::vector storage_server_version, + swarm_id_t swarm_id, + Opt&&... opts) : + oxen::quic::RemoteAddress{remote_pk, std::forward(opts)...}, + storage_server_version{storage_server_version}, + swarm_id{swarm_id} {} + + service_node(const service_node& obj) : + oxen::quic::RemoteAddress{obj}, + storage_server_version{obj.storage_server_version}, + swarm_id{obj.swarm_id} {} + service_node& operator=(const service_node& obj) { + storage_server_version = obj.storage_server_version; + swarm_id = obj.swarm_id; + oxen::quic::RemoteAddress::operator=(obj); + _copy_internals(obj); + return *this; + } + + bool operator==(const service_node& other) const { + return static_cast(*this) == + static_cast(other) && + storage_server_version == other.storage_server_version && swarm_id == other.swarm_id; + } +}; + +struct connection_info { + service_node node; + std::shared_ptr pending_requests; + std::shared_ptr conn; + std::shared_ptr stream; + + bool is_valid() const { return conn && stream && !stream->is_closing(); }; + bool has_pending_requests() const { return (pending_requests && (*pending_requests) > 0); }; + + void add_pending_request() { + if (!pending_requests) + pending_requests = std::make_shared(0); + (*pending_requests)++; + }; + + // This is weird but since we are modifying the shared_ptr we aren't mutating + // the object so it can be a const function + void remove_pending_request() const { + if (!pending_requests) + return; + (*pending_requests)--; + }; +}; + +struct onion_path { + std::string id; + connection_info conn_info; + std::vector nodes; + uint8_t failure_count; + + bool is_valid() const { return !nodes.empty() && conn_info.is_valid(); }; + bool has_pending_requests() const { return conn_info.has_pending_requests(); } + size_t num_pending_requests() const { + if (!conn_info.pending_requests) + return 0; + return (*conn_info.pending_requests); + } + + std::string to_string() const; + + bool contains_node(const service_node& sn) const; + + bool operator==(const onion_path& other) const { + // The `conn_info` and failure/timeout counts can be reset for a path in a number + // of situations so just use the nodes to determine if the paths match + return nodes == other.nodes; + } +}; + +namespace detail { + swarm_id_t pubkey_to_swarm_space(const session::onionreq::x25519_pubkey& pk); + std::vector>> generate_swarms( + std::vector nodes); + + std::optional node_for_destination(onionreq::network_destination destination); + + session::onionreq::x25519_pubkey pubkey_for_destination( + onionreq::network_destination destination); + +} // namespace detail + +struct request_info { + static request_info make( + onionreq::network_destination _dest, + std::optional _original_body, + std::optional _swarm_pk, + std::chrono::milliseconds _request_timeout, + std::optional _request_and_path_build_timeout = std::nullopt, + PathType _type = PathType::standard, + std::optional _req_id = std::nullopt, + std::optional endpoint = "onion_req", + std::optional _body = std::nullopt); + + enum class RetryReason { + none, + decryption_failure, + redirect, + redirect_swarm_refresh, + }; + + std::string request_id; + session::onionreq::network_destination destination; + std::string endpoint; + std::optional body; + std::optional original_body; + std::optional swarm_pubkey; + PathType path_type; + std::chrono::milliseconds request_timeout; + std::optional request_and_path_build_timeout; + std::chrono::system_clock::time_point creation_time = std::chrono::system_clock::now(); + + /// The reason we are retrying the request (if it's a retry). Generally only used for internal + /// purposes (like receiving a `421`) in order to prevent subsequent retries. + std::optional retry_reason{}; + + bool node_destination{detail::node_for_destination(destination).has_value()}; +}; + +class Network { + private: + const bool use_testnet; + const bool should_cache_to_disk; + const bool single_path_mode; + const fs::path cache_path; + + // Disk thread state + std::mutex snode_cache_mutex; // This guards all the below: + std::condition_variable snode_cache_cv; + bool has_pending_disk_write = false; + bool shut_down_disk_thread = false; + bool need_write = false; + bool need_clear_cache = false; + + // Values persisted to disk + std::optional seed_node_cache_size; + std::vector snode_cache; + std::chrono::system_clock::time_point last_snode_cache_update{}; + + std::thread disk_write_thread; + + // General values + bool destroyed = false; + bool suspended = false; + ConnectionStatus status; + oxen::quic::Network net; + std::shared_ptr endpoint; + std::unordered_map> paths; + std::vector> paths_pending_drop; + std::vector unused_nodes; + std::unordered_map snode_failure_counts; + std::vector>> all_swarms; + std::unordered_map>> swarm_cache; + + // Snode refresh state + int snode_cache_refresh_failure_count; + int in_progress_snode_cache_refresh_count; + std::optional current_snode_cache_refresh_request_id; + std::vector> after_snode_cache_refresh; + std::optional> unused_snode_refresh_nodes; + std::shared_ptr>> snode_refresh_results; + + // First hop state + int connection_failures = 0; + std::deque unused_connections; + std::unordered_map in_progress_connections; + + // Path build state + int path_build_failures = 0; + std::deque path_build_queue; + std::unordered_map in_progress_path_builds; + + // Request state + bool has_scheduled_resume_queues = false; + std::optional request_timeout_id; + std::chrono::system_clock::time_point last_resume_queues_timestamp{}; + std::unordered_map>> + request_queue; + + public: + friend class TestNetwork; + friend class TestNetworkWrapper; + + // Hook to be notified whenever the network connection status changes. + std::function status_changed; + + // Hook to be notified whenever the onion request paths are updated. + std::function> paths)> paths_changed; + + // Constructs a new network with the given cache path and a flag indicating whether it should + // use testnet or mainnet, all requests should be made via a single Network instance. + Network(std::optional cache_path, + bool use_testnet, + bool single_path_mode, + bool pre_build_paths); + virtual ~Network(); + + /// API: network/suspend + /// + /// Suspends the network preventing any further requests from creating new connections and + /// paths. This function also calls the `close_connections` function. + void suspend(); + + /// API: network/resume + /// + /// Resumes the network allowing new requests to creating new connections and paths. + void resume(); + + /// API: network/close_connections + /// + /// Closes any currently active connections. + void close_connections(); + + /// API: network/clear_cache + /// + /// Clears the cached from memory and from disk (if a cache path was provided during + /// initialization). + void clear_cache(); + + /// API: network/snode_cache_size + /// + /// Retrieves the current size of the snode cache from memory (if a cache doesn't exist or + /// hasn't been loaded then this will return 0). + size_t snode_cache_size(); + + /// API: network/get_swarm + /// + /// Retrieves the swarm for the given pubkey. If there is already an entry in the cache for the + /// swarm then that will be returned, otherwise a network request will be made to retrieve the + /// swarm and save it to the cache. + /// + /// Inputs: + /// - 'swarm_pubkey' - [in] public key for the swarm. + /// - 'callback' - [in] callback to be called with the retrieved swarm (in the case of an error + /// the callback will be called with an empty list). + void get_swarm( + session::onionreq::x25519_pubkey swarm_pubkey, + std::function swarm)> callback); + + /// API: network/get_random_nodes + /// + /// Retrieves a number of random nodes from the snode pool. If the are no nodes in the pool a + /// new pool will be populated and the nodes will be retrieved from that. + /// + /// Inputs: + /// - 'count' - [in] the number of nodes to retrieve. + /// - 'callback' - [in] callback to be called with the retrieved nodes (in the case of an error + /// the callback will be called with an empty list). + void get_random_nodes( + uint16_t count, std::function nodes)> callback); + + /// API: network/send_onion_request + /// + /// Sends a request via onion routing to the provided service node or server destination. + /// + /// Inputs: + /// - `destination` -- [in] service node or server destination information. + /// - `body` -- [in] data to send to the specified destination. + /// - `swarm_pubkey` -- [in, optional] pubkey for the swarm the request is associated with. + /// Should be NULL if the request is not associated with a swarm. + /// - `handle_response` -- [in] callback to be called with the result of the request. + /// - `request_timeout` -- [in] timeout in milliseconds to use for the request. This won't take + /// the path build into account so if the path build takes forever then this request will never + /// timeout. + /// - `request_and_path_build_timeout` -- [in] timeout in milliseconds to use for the request + /// and path build (if required). This value takes presedence over `request_timeout` if + /// provided, the request itself will be given a timeout of this value subtracting however long + /// it took to build the path. + /// - 'type' - [in] the type of paths to send the request across. + void send_onion_request( + onionreq::network_destination destination, + std::optional body, + std::optional swarm_pubkey, + network_response_callback_t handle_response, + std::chrono::milliseconds request_timeout, + std::optional request_and_path_build_timeout = std::nullopt, + PathType type = PathType::standard); + + /// API: network/upload_file_to_server + /// + /// Uploads a file to a given server destination. + /// + /// Inputs: + /// - 'data' - [in] the data to be uploaded to a server. + /// - `server` -- [in] the server destination to upload the file to. + /// - `file_name` -- [in, optional] optional name to use for the file. + /// - `request_timeout` -- [in] timeout in milliseconds to use for the request. This won't take + /// the path build into account so if the path build takes forever then this request will never + /// timeout. + /// - `request_and_path_build_timeout` -- [in] timeout in milliseconds to use for the request + /// and path build (if required). This value takes presedence over `request_timeout` if + /// provided, the request itself will be given a timeout of this value subtracting however long + /// it took to build the path. + /// - `handle_response` -- [in] callback to be called with the result of the request. + void upload_file_to_server( + ustring data, + onionreq::ServerDestination server, + std::optional file_name, + network_response_callback_t handle_response, + std::chrono::milliseconds request_timeout, + std::optional request_and_path_build_timeout = std::nullopt); + + /// API: network/download_file + /// + /// Download a file from a given server destination. + /// + /// Inputs: + /// - `server` -- [in] the server destination to download the file from. + /// - `request_timeout` -- [in] timeout in milliseconds to use for the request. This won't take + /// the path build into account so if the path build takes forever then this request will never + /// timeout. + /// - `request_and_path_build_timeout` -- [in] timeout in milliseconds to use for the request + /// and path build (if required). This value takes presedence over `request_timeout` if + /// provided, the request itself will be given a timeout of this value subtracting however long + /// it took to build the path. + /// - `handle_response` -- [in] callback to be called with the result of the request. + void download_file( + onionreq::ServerDestination server, + network_response_callback_t handle_response, + std::chrono::milliseconds request_timeout, + std::optional request_and_path_build_timeout = std::nullopt); + + /// API: network/download_file + /// + /// Convenience function to download a file from a given url and x25519 pubkey combination. + /// Calls through to the above `download_file` function after constructing a server destination + /// from the provided values. + /// + /// Inputs: + /// - `download_url` -- [in] the url to download the file from. + /// - `x25519_pubkey` -- [in] the server destination to download the file from. + /// - `timeout` -- [in] timeout in milliseconds to use for the request. + /// - `request_timeout` -- [in] timeout in milliseconds to use for the request. This won't take + /// the path build into account so if the path build takes forever then this request will never + /// timeout. + /// - `request_and_path_build_timeout` -- [in] timeout in milliseconds to use for the request + /// and path build (if required). This value takes presedence over `request_timeout` if + /// provided, the request itself will be given a timeout of this value subtracting however long + /// it took to build the path. + /// - `handle_response` -- [in] callback to be called with the result of the request. + void download_file( + std::string_view download_url, + onionreq::x25519_pubkey x25519_pubkey, + network_response_callback_t handle_response, + std::chrono::milliseconds request_timeout, + std::optional request_and_path_build_timeout = std::nullopt); + + /// API: network/get_client_version + /// + /// Retrieves the version information for the given platform. + /// + /// Inputs: + /// - `platform` -- [in] the platform to retrieve the client version for. + /// - `seckey` -- [in] the users ed25519 secret key (to generated blinded auth). + /// - `request_timeout` -- [in] timeout in milliseconds to use for the request. This won't take + /// the path build into account so if the path build takes forever then this request will never + /// timeout. + /// - `request_and_path_build_timeout` -- [in] timeout in milliseconds to use for the request + /// and path build (if required). This value takes presedence over `request_timeout` if + /// provided, the request itself will be given a timeout of this value subtracting however long + /// it took to build the path. + /// - `handle_response` -- [in] callback to be called with the result of the request. + void get_client_version( + Platform platform, + onionreq::ed25519_seckey seckey, + network_response_callback_t handle_response, + std::chrono::milliseconds request_timeout, + std::optional request_and_path_build_timeout = std::nullopt); + + private: + /// API: network/all_path_ips + /// + /// Internal function to retrieve all of the node ips current used in paths + std::vector all_path_ips() const { + std::vector result; + + for (const auto& [path_type, paths_for_type] : paths) + for (const auto& path : paths_for_type) + for (const auto& node : path.nodes) + result.emplace_back(node.to_ipv4()); + + return result; + }; + + /// API: network/update_disk_cache_throttled + /// + /// Function which can be used to notify the disk write thread that a write can be performed. + /// This function has a very basic throttling mechanism where it triggers the write a small + /// delay after it is called, any subsequent calls to the function within the same period will + /// be ignored. This is done to avoid excessive disk writes which probably aren't needed for + /// the cached network data. + virtual void update_disk_cache_throttled(bool force_immediate_write = false); + + /// API: network/disk_write_thread_loop + /// + /// Body of the disk writer which runs until signalled to stop. This is intended to run in its + /// own thread. The thread monitors a number of private variables and persists the snode pool + /// and swarm caches to disk if a `cache_path` was provided during initialization. + void disk_write_thread_loop(); + + /// API: network/load_cache_from_disk + /// + /// Loads the snode pool and swarm caches from disk if a `cache_path` was provided and cached + /// data exists. + void load_cache_from_disk(); + + /// API: network/_close_connections + /// + /// Triggered via the close_connections function but actually contains the logic to clear out + /// paths, requests and connections. This function is not thread safe so should should be + /// called with that in mind. + void _close_connections(); + + /// API: network/update_status + /// + /// Internal function to update the connection status and trigger the `status_changed` hook if + /// provided, this method ignores invalid or unchanged status changes. + /// + /// Inputs: + /// - 'updated_status' - [in] the updated connection status. + void update_status(ConnectionStatus updated_status); + + /// API: network/retry_delay + /// + /// A function which generates an exponential delay to wait before retrying a request/action + /// based on the provided failure count. + /// + /// Inputs: + /// - 'num_failures' - [in] the number of times the request has already failed. + /// - 'max_delay' - [in] the maximum amount of time to delay for. + virtual std::chrono::milliseconds retry_delay( + int num_failures, + std::chrono::milliseconds max_delay = std::chrono::milliseconds{5000}); + + /// API: network/get_endpoint + /// + /// Retrieves or creates a new endpoint pointer. + std::shared_ptr get_endpoint(); + + /// API: network/min_snode_cache_size + /// + /// When talking to testnet it's occassionally possible for the cache size to be smaller than + /// the `min_snode_cache_count` value (which would result in an endless loop re-fetching the + /// node cache) so instead this function will return the smaller of the two if we've done a + /// fetch from a seed node. + size_t min_snode_cache_size() const; + + /// API: network/get_unused_nodes + /// + /// Retrieves a list of all nodes in the cache which are currently unused (ie. not present in an + /// exising or pending path, connection or request). + /// + /// Outputs: + /// - The list of unused nodes. + std::vector get_unused_nodes(); + + /// API: network/establish_connection + /// + /// Establishes a connection to the target node and triggers the callback once the connection is + /// established (or closed in case it fails). + /// + /// Inputs: + /// - 'id' - [in] id for the request or path build which triggered the call. + /// - `target` -- [in] the target service node to connect to. + /// - `timeout` -- [in, optional] optional timeout for the request, if NULL the + /// `quic::DEFAULT_HANDSHAKE_TIMEOUT` will be used. + /// - `callback` -- [in] callback to be called with connection info once the connection is + /// established or fails. + void establish_connection( + std::string id, + service_node target, + std::optional timeout, + std::function error)> callback); + + /// API: network/establish_and_store_connection + /// + /// Establishes a connection to a random unused node and stores it in the `unused_connections` + /// list. + /// + /// Inputs: + /// - 'path_id' - [in] id for the path build which triggered the call. + virtual void establish_and_store_connection(std::string path_id); + + /// API: network/refresh_snode_cache_complete + /// + /// This function will be called from either `refresh_snode_cache` or + /// `refresh_snode_cache_from_seed_nodes` and will actually update the state and persist the + /// updated cache to disk. + /// + /// Inputs: + /// - 'nodes' - [in] the nodes to use as the updated cache. + void refresh_snode_cache_complete(std::vector nodes); + + /// API: network/refresh_snode_cache_from_seed_nodes + /// + /// This function refreshes the snode cache for a random seed node. Unlike the + /// `refresh_snode_cache` function this will update the cache with the response from a single + /// seed node since it's a trusted source. + /// + /// Inputs: + /// - 'request_id' - [in] id for an existing refresh_snode_cache request. + /// - 'reset_unused_nodes' - [in] flag to indicate whether this should reset the unused nodes + /// before kicking off the request. + virtual void refresh_snode_cache_from_seed_nodes( + std::string request_id, bool reset_unused_nodes); + + /// API: network/refresh_snode_cache + /// + /// This function refreshes the snode cache. If the current cache is to small (or not present) + /// this will trigger the above `refresh_snode_cache_from_seed_nodes` function, otherwise it + /// will randomly pick a number of nodes from the existing cache and refresh the cache from the + /// intersection of the results. + /// + /// Inputs: + /// - 'existing_request_id' - [in, optional] id for an existing refresh_snode_cache request. + virtual void refresh_snode_cache(std::optional existing_request_id = std::nullopt); + + /// API: network/build_path + /// + /// Build a new onion request path for the specified type. If there are no existing connections + /// this will open a new connection to a random service nodes in the snode cache. + /// + /// Inputs: + /// - 'path_id' - [in] id for the new path. + /// - `path_type` -- [in] the type of path to build. + virtual void build_path(std::string path_id, PathType path_type); + + /// API: network/find_valid_path + /// + /// Find a random path from the provided paths which is valid for the provided request. Note: + /// if the Network is setup in `single_path_mode` then the path returned may include the + /// destination for the request. + /// + /// Inputs: + /// - `info` -- [in] request to select a path for. + /// - `paths` -- [in] paths to select from. + /// + /// Outputs: + /// - The possible path, if found. + virtual std::optional find_valid_path( + const request_info info, const std::vector paths); + + /// API: network/build_path_if_needed + /// + /// Triggers a path build for the specified type if the total current or pending paths is below + /// the minimum threshold for the given type. Note: This may result in more paths than the + /// minimum threshold being built in order to avoid a situation where a request may never get + /// sent due to it's destination being present in the existing path(s) for the type. + /// + /// Inputs: + /// - `path_type` -- [in] the type of path to be built. + /// - `found_path` -- [in] flag indicating whether a valid path was found by calling + /// `find_valid_path` above. + virtual void build_path_if_needed(PathType path_type, bool found_valid_path); + + /// API: network/get_service_nodes + /// + /// Retrieves all or a random subset of service nodes from the given node. + /// + /// Inputs: + /// - 'request_id' - [in] id for the request which triggered the call. + /// - `conn_info` -- [in] the connection info to retrieve service nodes from. + /// - `limit` -- [in, optional] the number of service nodes to retrieve. + /// - `callback` -- [in] callback to be triggered once we receive nodes. NOTE: If an error + /// occurs an empty list and an error will be provided. + void get_service_nodes( + std::string request_id, + connection_info conn_info, + std::optional limit, + std::function nodes, std::optional error)> + callback); + + /// API: network/check_request_queue_timeouts + /// + /// Checks if any of the requests in the request queue have timed out (and fails them if so). + /// + /// Inputs: + /// - 'request_timeout_id' - [in] id for the timeout loop to prevent multiple loops from being + /// scheduled. + virtual void check_request_queue_timeouts( + std::optional request_timeout_id = std::nullopt); + + /// API: network/send_request + /// + /// Send a request via the network. + /// + /// Inputs: + /// - `info` -- [in] wrapper around all of the information required to send a request. + /// - `conn` -- [in] connection information used to send the request. + /// - `handle_response` -- [in] callback to be called with the result of the request. + void send_request( + request_info info, connection_info conn, network_response_callback_t handle_response); + + /// API: network/_send_onion_request + /// + /// Internal function invoked by ::send_onion_request after request_info construction + virtual void _send_onion_request( + request_info info, network_response_callback_t handle_response); + + /// API: network/process_v3_onion_response + /// + /// Processes a v3 onion request response. + /// + /// Inputs: + /// - `builder` -- [in] the builder that was used to build the onion request. + /// - `response` -- [in] the response data returned from the destination. + /// + /// Outputs: + /// - A tuple containing the status code, headers and body of the decrypted onion request + /// response. + std::tuple< + int16_t, + std::vector>, + std::optional> + process_v3_onion_response(session::onionreq::Builder builder, std::string response); + + /// API: network/process_v4_onion_response + /// + /// Processes a v4 onion request response. + /// + /// Inputs: + /// - `builder` -- [in] the builder that was used to build the onion request. + /// - `response` -- [in] the response data returned from the destination. + /// + /// Outputs: + /// - A tuple containing the status code, headers and body of the decrypted onion request + /// response. + std::tuple< + int16_t, + std::vector>, + std::optional> + process_v4_onion_response(session::onionreq::Builder builder, std::string response); + + /// API: network/validate_response + /// + /// Processes a quic response to extract the status code and body or throw if it errored or + /// received a non-successful status code. + /// + /// Inputs: + /// - `resp` -- [in] the quic response. + /// - `is_bencoded` -- [in] flag indicating whether the response will be bencoded or JSON. + /// + /// Returns: + /// - `std::pair` -- the status code and response body (for a bencoded + /// response this is just the direct response body from quic as it simplifies consuming the + /// response elsewhere). + std::pair validate_response(oxen::quic::message resp, bool is_bencoded); + + /// API: network/drop_path_when_empty + /// + /// Flags a path to be dropped once all pending requests have finished. + /// + /// Inputs: + /// - `id` -- [in] id the request or path which triggered the path drop (if the id is a path_id + /// then the drop was triggered by the connection being dropped). + /// - `path_type` -- [in] the type of path to build. + /// - `path` -- [in] the path to be dropped. + void drop_path_when_empty(std::string id, PathType path_type, onion_path path); + + /// API: network/clear_empty_pending_path_drops + /// + /// Iterates through all paths flagged to be dropped and actually drops any which are no longer + /// valid or have no more pending requests. + void clear_empty_pending_path_drops(); + + /// API: network/handle_errors + /// + /// Processes a non-success response to automatically perform any standard operations based on + /// the errors returned from the service node network (ie. updating the service node cache, + /// dropping nodes and/or onion request paths). + /// + /// Inputs: + /// - `info` -- [in] the information for the request that was made. + /// - `conn_info` -- [in] the connection info for the request that failed. + /// - `timeout` -- [in, optional] flag indicating whether the request timed out. + /// - `status_code` -- [in] the status code returned from the network. + /// - `headers` -- [in] the response headers returned from the network. + /// - `response` -- [in, optional] response data returned from the network. + /// - `handle_response` -- [in, optional] callback to be called with updated response + /// information after processing the error. + virtual void handle_errors( + request_info info, + connection_info conn_info, + bool timeout, + int16_t status_code, + std::vector> headers, + std::optional response, + std::optional handle_response); +}; + +} // namespace session::network diff --git a/include/session/onionreq/builder.h b/include/session/onionreq/builder.h index 3efd7c5a..4320276e 100644 --- a/include/session/onionreq/builder.h +++ b/include/session/onionreq/builder.h @@ -6,6 +6,7 @@ extern "C" { #include #include +#include #include "../export.h" @@ -44,14 +45,6 @@ LIBSESSION_EXPORT void onion_request_builder_free(onion_request_builder_object* /// /// Wrapper around session::onionreq::Builder::onion_request_builder_set_enc_type. /// -/// Declaration: -/// ```cpp -/// void onion_request_builder_set_enc_type( -/// [in] onion_request_builder_object* builder -/// [in] ENCRYPT_TYPE enc_type -/// ); -/// ``` -/// /// Inputs: /// - `builder` -- [in] Pointer to the builder object /// - `enc_type` -- [in] The encryption type to use in the onion request @@ -63,56 +56,50 @@ LIBSESSION_EXPORT void onion_request_builder_set_enc_type( /// Wrapper around session::onionreq::Builder::set_snode_destination. ed25519_pubkey and /// x25519_pubkey are both hex strings and must both be exactly 64 characters. /// -/// Declaration: -/// ```cpp -/// void onion_request_builder_set_snode_destination( -/// [in] onion_request_builder_object* builder -/// [in] const char* ed25519_pubkey, -/// [in] const char* x25519_pubkey -/// ); -/// ``` -/// /// Inputs: /// - `builder` -- [in] Pointer to the builder object +/// - `ip` -- [in] The IP address for the snode destination +/// - `quic_port` -- [in] The Quic port request for the snode destination /// - `ed25519_pubkey` -- [in] The ed25519 public key for the snode destination -/// - `x25519_pubkey` -- [in] The x25519 public key for the snode destination LIBSESSION_EXPORT void onion_request_builder_set_snode_destination( onion_request_builder_object* builder, - const char* ed25519_pubkey, - const char* x25519_pubkey); + const uint8_t ip[4], + const uint16_t quic_port, + const char* ed25519_pubkey); /// API: onion_request_builder_set_server_destination /// /// Wrapper around session::onionreq::Builder::set_server_destination. x25519_pubkey /// is a hex string and must both be exactly 64 characters. /// -/// Declaration: -/// ```cpp -/// void onion_request_builder_set_server_destination( -/// [in] onion_request_builder_object* builder -/// [in] const char* host, -/// [in] const char* target, -/// [in] const char* protocol, -/// [in] uint16_t port, -/// [in] const char* x25519_pubkey -/// ); -/// ``` -/// /// Inputs: /// - `builder` -- [in] Pointer to the builder object -/// - `host` -- [in] The host for the server destination -/// - `target` -- [in] The target (endpoint) for the server destination -/// - `protocol` -- [in] The protocol to use for the -/// - `port` -- [in] The host for the server destination -/// - `x25519_pubkey` -- [in] The x25519 public key for the snode destination +/// - `protocol` -- [in] The protocol to use +/// - `host` -- [in] The server host +/// - `endpoint` -- [in] The endpoint to call +/// - `method` -- [in] The HTTP method to use +/// - `port` -- [in] The port to use +/// - `x25519_pubkey` -- [in] The x25519 public key for server LIBSESSION_EXPORT void onion_request_builder_set_server_destination( onion_request_builder_object* builder, - const char* host, - const char* target, const char* protocol, + const char* host, + const char* endpoint, + const char* method, uint16_t port, const char* x25519_pubkey); +/// API: onion_request_builder_set_destination_pubkey +/// +/// Wrapper around session::onionreq::Builder::set_destination_pubkey. +/// +/// Inputs: +/// - `builder` -- [in] Pointer to the builder object +/// - `x25519_pubkey` -- [in] The x25519 public key for server (Hex string of exactly 64 +/// characters). +LIBSESSION_EXPORT void onion_request_builder_set_destination_pubkey( + onion_request_builder_object* builder, const char* x25519_pubkey); + /// API: onion_request_builder_add_hop /// /// Wrapper around session::onionreq::Builder::add_hop. ed25519_pubkey and diff --git a/include/session/onionreq/builder.hpp b/include/session/onionreq/builder.hpp index 0ebad565..2065ff37 100644 --- a/include/session/onionreq/builder.hpp +++ b/include/session/onionreq/builder.hpp @@ -1,12 +1,52 @@ #pragma once +#include #include #include +#include #include "key_types.hpp" +namespace session::network { +struct service_node; +struct request_info; +} // namespace session::network + namespace session::onionreq { +struct ServerDestination { + std::string protocol; + std::string host; + std::string endpoint; + session::onionreq::x25519_pubkey x25519_pubkey; + std::optional port; + std::optional>> headers; + std::string method; + + ServerDestination( + std::string protocol, + std::string host, + std::string endpoint, + session::onionreq::x25519_pubkey x25519_pubkey, + std::optional port = std::nullopt, + std::optional>> headers = std::nullopt, + std::string method = "GET") : + protocol{std::move(protocol)}, + host{std::move(host)}, + endpoint{std::move(endpoint)}, + x25519_pubkey{std::move(x25519_pubkey)}, + port{std::move(port)}, + headers{std::move(headers)}, + method{std::move(method)} {} +}; + +using network_destination = std::variant; + +namespace detail { + + session::onionreq::x25519_pubkey pubkey_for_destination(network_destination destination); +} + enum class EncryptType { aes_gcm, xchacha20, @@ -26,7 +66,16 @@ inline constexpr std::string_view to_string(EncryptType type) { // Builder class for preparing onion request payloads. class Builder { + Builder(const network_destination& destination, + const std::vector& nodes, + const EncryptType enc_type_); + public: + static Builder make( + const network_destination& destination, + const std::vector& nodes, + const EncryptType enc_type_ = EncryptType::xchacha20); + EncryptType enc_type; std::optional destination_x25519_public_key = std::nullopt; std::optional final_hop_x25519_keypair = std::nullopt; @@ -35,33 +84,12 @@ class Builder { void set_enc_type(EncryptType enc_type_) { enc_type = enc_type_; } - void set_snode_destination(ed25519_pubkey ed25519_public_key, x25519_pubkey x25519_public_key) { - destination_x25519_public_key.reset(); - ed25519_public_key_.reset(); - destination_x25519_public_key.emplace(x25519_public_key); - ed25519_public_key_.emplace(ed25519_public_key); - } - - void set_server_destination( - std::string host, - std::string target, - std::string protocol, - std::optional port, - x25519_pubkey x25519_public_key) { - destination_x25519_public_key.reset(); - - host_.emplace(host); - target_.emplace(target); - protocol_.emplace(protocol); - - if (port) - port_.emplace(*port); - - destination_x25519_public_key.emplace(x25519_public_key); - } - + void set_destination(network_destination destination); + void set_destination_pubkey(session::onionreq::x25519_pubkey x25519_pubkey); + void add_hop(ustring_view remote_key); void add_hop(std::pair keys) { hops_.push_back(keys); } + void generate(network::request_info& info); ustring build(ustring payload); private: @@ -74,9 +102,14 @@ class Builder { // Proxied request values std::optional host_ = std::nullopt; - std::optional target_ = std::nullopt; + std::optional endpoint_ = std::nullopt; std::optional protocol_ = std::nullopt; + std::optional method_ = std::nullopt; std::optional port_ = std::nullopt; + std::optional>> headers_ = std::nullopt; + std::optional>> query_params_ = std::nullopt; + + ustring _generate_payload(std::optional body) const; }; } // namespace session::onionreq diff --git a/include/session/onionreq/hop_encryption.hpp b/include/session/onionreq/hop_encryption.hpp index fcb18136..ae57c6b5 100644 --- a/include/session/onionreq/hop_encryption.hpp +++ b/include/session/onionreq/hop_encryption.hpp @@ -16,6 +16,9 @@ class HopEncryption { public_key_{std::move(public_key)}, server_{server} {} + // Returns true if the response is long enough to be a valid response. + static bool response_long_enough(EncryptType type, size_t response_size); + // Encrypts `plaintext` message using encryption `type`. `pubkey` is the recipients public key. // `reply` should be false for a client-to-snode message, and true on a returning // snode-to-client message. diff --git a/include/session/onionreq/key_types.hpp b/include/session/onionreq/key_types.hpp index 0f776da2..5621364b 100644 --- a/include/session/onionreq/key_types.hpp +++ b/include/session/onionreq/key_types.hpp @@ -92,6 +92,7 @@ using x25519_keypair = std::pair; legacy_pubkey parse_legacy_pubkey(std::string_view pubkey_in); ed25519_pubkey parse_ed25519_pubkey(std::string_view pubkey_in); x25519_pubkey parse_x25519_pubkey(std::string_view pubkey_in); +x25519_pubkey compute_x25519_pubkey(ustring_view ed25519_pk); } // namespace session::onionreq diff --git a/include/session/onionreq/parser.hpp b/include/session/onionreq/parser.hpp index 8d2d290e..d1e7208b 100644 --- a/include/session/onionreq/parser.hpp +++ b/include/session/onionreq/parser.hpp @@ -1,7 +1,8 @@ -#include +#pragma once -#include "session/onionreq/hop_encryption.hpp" -#include "session/types.hpp" +#include + +#include "hop_encryption.hpp" namespace session::onionreq { diff --git a/include/session/onionreq/response_parser.hpp b/include/session/onionreq/response_parser.hpp index 9f0c764d..ade15e49 100644 --- a/include/session/onionreq/response_parser.hpp +++ b/include/session/onionreq/response_parser.hpp @@ -7,6 +7,9 @@ namespace session::onionreq { +constexpr auto decryption_failed_error = + "Decryption failed (both XChaCha20-Poly1305 and AES256-GCM)"sv; + class ResponseParser { public: /// Constructs a parser, parsing the given request sent to us. Throws if parsing or decryption @@ -20,6 +23,8 @@ class ResponseParser { x25519_keypair_{std::move(x25519_keypair)}, enc_type_{enc_type} {} + static bool response_long_enough(EncryptType enc_type, size_t response_size); + ustring decrypt(ustring ciphertext) const; private: diff --git a/include/session/random.hpp b/include/session/random.hpp index 72fc8b89..4d07139b 100644 --- a/include/session/random.hpp +++ b/include/session/random.hpp @@ -1,6 +1,30 @@ #pragma once -#include "types.hpp" +#include + +#include + +#include "util.hpp" + +namespace session { +/// rng type that uses llarp::randint(), which is cryptographically secure +struct CSRNG { + using result_type = uint64_t; + + static constexpr uint64_t min() { return std::numeric_limits::min(); }; + + static constexpr uint64_t max() { return std::numeric_limits::max(); }; + + uint64_t operator()() { + uint64_t i; + randombytes((uint8_t*)&i, sizeof(i)); + return i; + }; +}; + +extern CSRNG csrng; + +} // namespace session namespace session::random { @@ -15,4 +39,15 @@ namespace session::random { /// - random bytes of the specified length. ustring random(size_t size); +/// API: random/random_base32 +/// +/// Return a random base32 string with the given length. +/// +/// Inputs: +/// - `size` -- the number of characters to be generated. +/// +/// Outputs: +/// - random base32 string of the specified length. +std::string random_base32(size_t size); + } // namespace session::random diff --git a/include/session/session_encrypt.h b/include/session/session_encrypt.h index b91628e3..56b5f90f 100644 --- a/include/session/session_encrypt.h +++ b/include/session/session_encrypt.h @@ -5,6 +5,7 @@ extern "C" { #endif #include +#include #include "export.h" @@ -170,12 +171,11 @@ LIBSESSION_EXPORT bool session_decrypt_for_blinded_recipient( /// This function attempts to decrypt an ONS response. /// /// Inputs: -/// - `lowercase_name_in` -- [in] Pointer to a buffer containing the lowercase name used to trigger -/// the response. -/// - `name_len` -- [in] Length of `name_in`. +/// - `lowercase_name_in` -- [in] Pointer to a NULL-terminated buffer containing the lowercase name +/// used to trigger the response. /// - `ciphertext_in` -- [in] Pointer to a data buffer containing the encrypted data. /// - `ciphertext_len` -- [in] Length of `ciphertext_in`. -/// - `nonce_in` -- [in] Pointer to a data buffer containing the nonce (24 bytes). +/// - `nonce_in` -- [in, optional] Pointer to a data buffer containing the nonce (24 bytes) or NULL. /// - `session_id_out` -- [out] pointer to a buffer of at least 67 bytes where the null-terminated, /// hex-encoded session_id will be written if decryption was successful. /// @@ -183,10 +183,9 @@ LIBSESSION_EXPORT bool session_decrypt_for_blinded_recipient( /// - `bool` -- True if the session ID was successfully decrypted, false if decryption failed. LIBSESSION_EXPORT bool session_decrypt_ons_response( const char* lowercase_name_in, - size_t name_len, const unsigned char* ciphertext_in, size_t ciphertext_len, - const unsigned char* nonce_in, /* 24 bytes */ + const unsigned char* nonce_in, /* 24 bytes or NULL */ char* session_id_out /* 67 byte output buffer */); /// API: crypto/session_decrypt_push_notification @@ -214,6 +213,22 @@ LIBSESSION_EXPORT bool session_decrypt_push_notification( unsigned char** plaintext_out, size_t* plaintext_len); +/// API: crypto/compute_message_hash +/// +/// Computes the hash for a message. +/// +/// Inputs: +/// - `pubkey_hex_in` -- the pubkey as a 67 character hex string that the message will be stored in. +/// NULL terminated. +/// - `ns` -- the namespace that the message will be stored in. +/// - `data` -- the base64 encoded message data that will be stored for the message. NULL +/// terminated. +/// +/// Outputs: +/// - `std::string` -- a deterministic hash for the message. +LIBSESSION_EXPORT bool session_compute_message_hash( + const char* pubkey_hex_in, int16_t ns, const char* base64_data_in, char* hash_out); + #ifdef __cplusplus } #endif diff --git a/include/session/session_encrypt.hpp b/include/session/session_encrypt.hpp index 6a185a73..8b3e6c9d 100644 --- a/include/session/session_encrypt.hpp +++ b/include/session/session_encrypt.hpp @@ -1,5 +1,7 @@ #pragma once +#include + #include "types.hpp" // Helper functions for the "Session Protocol" encryption mechanism. This is the encryption used @@ -235,13 +237,15 @@ std::pair decrypt_from_blinded_recipient( /// Inputs: /// - `lowercase_name` -- the lowercase name which was looked to up to retrieve this response. /// - `ciphertext` -- ciphertext returned from the server. -/// - `nonce` -- the nonce returned from the server +/// - `nonce` -- the nonce returned from the server if provided. /// /// Outputs: /// - `std::string` -- the session ID (in hex) returned from the server, *if* the server returned /// a session ID. Throws on error/failure. std::string decrypt_ons_response( - std::string_view lowercase_name, ustring_view ciphertext, ustring_view nonce); + std::string_view lowercase_name, + ustring_view ciphertext, + std::optional nonce); /// API: crypto/decrypt_push_notification /// @@ -257,4 +261,18 @@ std::string decrypt_ons_response( /// successful. Throws on error/failure. ustring decrypt_push_notification(ustring_view payload, ustring_view enc_key); +/// API: crypto/compute_message_hash +/// +/// Computes the hash for a message. +/// +/// Inputs: +/// - `pubkey_hex` -- the pubkey as a 66 character hex string that the message will be stored in. +/// - `ns` -- the namespace that the message will be stored in. +/// - `data` -- the base64 encoded message data that will be stored for the message. +/// +/// Outputs: +/// - `std::string` -- a deterministic hash for the message. +std::string compute_message_hash( + const std::string_view pubkey_hex, int16_t ns, std::string_view data); + } // namespace session diff --git a/include/session/util.hpp b/include/session/util.hpp index 22513ff9..b6825236 100644 --- a/include/session/util.hpp +++ b/include/session/util.hpp @@ -1,5 +1,7 @@ #pragma once +#include + #include #include #include @@ -7,6 +9,7 @@ #include #include #include +#include #include #include @@ -40,6 +43,18 @@ inline const char* from_unsigned(const unsigned char* x) { inline char* from_unsigned(unsigned char* x) { return reinterpret_cast(x); } +// Helper to switch from basic_string_view to basic_string_view. Both CFrom and CTo +// must be primitive, one-byte types. +template +inline std::basic_string_view convert_sv(std::basic_string_view from) { + return {reinterpret_cast(from.data()), from.size()}; +} +// Same as above, but with a const basic_string& argument (to allow deduction of CFrom when +// using a basic_string). +template +inline std::basic_string_view convert_sv(const std::basic_string& from) { + return {reinterpret_cast(from.data()), from.size()}; +} // Helper function to switch between basic_string_view and ustring_view inline ustring_view to_unsigned_sv(std::string_view v) { return {to_unsigned(v.data()), v.size()}; @@ -78,38 +93,10 @@ inline bool string_iequal(std::string_view s1, std::string_view s2) { }); } -// C++20 starts_/ends_with backport -inline constexpr bool starts_with(std::string_view str, std::string_view prefix) { - return str.size() >= prefix.size() && str.substr(prefix.size()) == prefix; -} - -inline constexpr bool end_with(std::string_view str, std::string_view suffix) { - return str.size() >= suffix.size() && str.substr(str.size() - suffix.size()) == suffix; -} - using uc32 = std::array; using uc33 = std::array; using uc64 = std::array; -template -using string_view_char_type = std::conditional_t< - std::is_convertible_v, - char, - std::conditional_t< - std::is_convertible_v>, - unsigned char, - std::conditional_t< - std::is_convertible_v>, - std::byte, - void>>>; - -template -constexpr bool is_char_array = false; -template -inline constexpr bool is_char_array> = - std::is_same_v || std::is_same_v || - std::is_same_v; - /// Takes a container of string-like binary values and returns a vector of ustring_views viewing /// those values. This can be used on a container of any type with a `.data()` and a `.size()` /// where `.data()` is a one-byte value pointer; std::string, std::string_view, ustring, @@ -146,6 +133,29 @@ std::vector to_view_vector(const Container& c) { return to_view_vector(c.begin(), c.end()); } +/// Splits a string on some delimiter string and returns a vector of string_view's pointing into the +/// pieces of the original string. The pieces are valid only as long as the original string remains +/// valid. Leading and trailing empty substrings are not removed. If delim is empty you get back a +/// vector of string_views each viewing one character. If `trim` is true then leading and trailing +/// empty values will be suppressed. +/// +/// auto v = split("ab--c----de", "--"); // v is {"ab", "c", "", "de"} +/// auto v = split("abc", ""); // v is {"a", "b", "c"} +/// auto v = split("abc", "c"); // v is {"ab", ""} +/// auto v = split("abc", "c", true); // v is {"ab"} +/// auto v = split("-a--b--", "-"); // v is {"", "a", "", "b", "", ""} +/// auto v = split("-a--b--", "-", true); // v is {"a", "", "b"} +/// +std::vector split( + std::string_view str, std::string_view delim, bool trim = false); + +/// Returns protocol, host, port, path. Port can be empty; throws on unparseable values. protocol +/// and host get normalized to lower-case. Port will be null if not present in the URL, or if set +/// to the default for the protocol. Path can be empty (a single optional `/` after the domain will +/// be ignored). +std::tuple, std::optional> parse_url( + std::string_view url); + /// Truncates a utf-8 encoded string to at most `n` bytes long, but with care as to not truncate in /// the middle of a unicode codepoint. If the `n` length would shorten the string such that it /// terminates in the middle of a utf-8 encoded unicode codepoint then the string is shortened diff --git a/proto/CMakeLists.txt b/proto/CMakeLists.txt index e7a05a5c..324b116b 100644 --- a/proto/CMakeLists.txt +++ b/proto/CMakeLists.txt @@ -12,6 +12,14 @@ set_target_properties( OUTPUT_NAME session-protos SOVERSION ${LIBSESSION_LIBVERSION}) +# -Wunused-parameter triggers in the protobuf dependency +message(STATUS "Disabling -Werror for proto unused-parameter") + +if (MSVC) +else() + target_compile_options(protos PUBLIC -Wno-error=unused-parameter) +endif() + libsession_static_bundle(protos) add_library(libsession::protos ALIAS protos) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 4dbc11a8..f5f2c285 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,7 +1,8 @@ add_library(common INTERFACE) target_link_libraries(common INTERFACE - oxenc::oxenc) + oxenc::oxenc + oxen::logging) target_include_directories(common INTERFACE ../include) if(WARNINGS_AS_ERRORS) @@ -34,16 +35,12 @@ macro(add_libsession_util_library name) endmacro() -if(NOT BUILD_STATIC_DEPS) - find_package(PkgConfig REQUIRED) - - if(NOT TARGET nettle) - pkg_check_modules(NETTLE nettle IMPORTED_TARGET REQUIRED) - add_library(nettle INTERFACE IMPORTED) - target_link_libraries(nettle INTERFACE PkgConfig::NETTLE) - endif() -endif() +add_libsession_util_library(util + file.cpp + logging.cpp + util.cpp +) add_libsession_util_library(crypto blinding.cpp @@ -54,7 +51,6 @@ add_libsession_util_library(crypto random.cpp session_encrypt.cpp sodium_array.cpp - util.cpp xed25519.cpp ) @@ -79,9 +75,15 @@ add_libsession_util_library(config -target_link_libraries(crypto +target_link_libraries(util PUBLIC common + oxen::logging +) + +target_link_libraries(crypto + PUBLIC + util PRIVATE libsodium::sodium-internal ) @@ -89,7 +91,6 @@ target_link_libraries(crypto target_link_libraries(config PUBLIC crypto - common libsession::protos PRIVATE libsodium::sodium-internal @@ -103,20 +104,20 @@ if(ENABLE_ONIONREQ) onionreq/key_types.cpp onionreq/parser.cpp onionreq/response_parser.cpp + network.cpp ) target_link_libraries(onionreq PUBLIC crypto - common + quic PRIVATE nlohmann_json::nlohmann_json libsodium::sodium-internal - nettle + nettle::nettle ) endif() - if(WARNINGS_AS_ERRORS AND NOT USE_LTO AND CMAKE_C_COMPILER_ID STREQUAL "GNU" AND CMAKE_C_COMPILER_VERSION MATCHES "^11\\.") # GCC 11 has an overzealous (and false) stringop-overread warning, but only when LTO is off. # Um, yeah. diff --git a/src/config.cpp b/src/config.cpp index 8073256b..2972d9fb 100644 --- a/src/config.cpp +++ b/src/config.cpp @@ -92,7 +92,6 @@ namespace { auto oldit = old.begin(), newit = new_.begin(); while (oldit != old.end() || newit != new_.end()) { - bool is_new = false; if (oldit == old.end() || (newit != new_.end() && newit->first < oldit->first)) { // newit is a new item; fall through to handle below @@ -117,7 +116,6 @@ namespace { if (o.index() != n.index()) { // The fundamental type (scalar, dict, set) changed, so we'll treat this as a // new value (which implicitly deletes a value of a wrong type when merging). - is_new = true; ++oldit; // fall through to handler below @@ -431,11 +429,9 @@ namespace { void verify_config_sig( oxenc::bt_dict_consumer dict, - ustring_view config_msg, const ConfigMessage::verify_callable& verifier, std::optional>* verified_signature, bool trust_signature) { - ustring_view to_verify, sig; if (dict.skip_until("~")) { dict.consume_signature([&](ustring_view to_verify, ustring_view sig) { if (sig.size() != 64) @@ -558,7 +554,7 @@ ConfigMessage::ConfigMessage( load_unknowns(unknown_, dict, "=", "~"); - verify_config_sig(dict, serialized, verifier, &verified_signature_, trust_signature); + verify_config_sig(dict, verifier, &verified_signature_, trust_signature); } catch (const oxenc::bt_deserialize_invalid& err) { throw config_parse_error{"Failed to parse config file: "s + err.what()}; } @@ -592,12 +588,12 @@ ConfigMessage::ConfigMessage( // prune out redundant messages (i.e. messages already included in another message's diff, and // duplicates) - for (int i = 0; i < configs.size(); i++) { + for (size_t i = 0; i < configs.size(); i++) { auto& [conf, redundant] = configs[i]; if (conf.seqno() > max_seqno) max_seqno = conf.seqno(); - for (int j = 0; !redundant && j < configs.size(); j++) { + for (size_t j = 0; !redundant && j < configs.size(); j++) { if (i == j) continue; const auto& conf2 = configs[j].first; @@ -619,7 +615,7 @@ ConfigMessage::ConfigMessage( if (curr_confs == 1) { // We have just one non-redundant config left after all that, so we become it directly as-is - for (int i = 0; i < configs.size(); i++) { + for (size_t i = 0; i < configs.size(); i++) { if (!configs[i].second) { *this = std::move(configs[i].first); unmerged_ = i; @@ -647,7 +643,7 @@ ConfigMessage::ConfigMessage( return; } - unmerged_ = -1; + unmerged_ = std::nullopt; // Clear any redundant messages. (we do it *here* rather than above because, in the // single-good-config case, above, we need the index of the good config for `unmerged_`). diff --git a/src/config/base.cpp b/src/config/base.cpp index 3a1a7c7c..de30d1e9 100644 --- a/src/config/base.cpp +++ b/src/config/base.cpp @@ -8,6 +8,8 @@ #include #include +#include +#include #include #include #include @@ -20,9 +22,14 @@ #include "session/util.hpp" using namespace std::literals; +using namespace oxen::log::literals; namespace session::config { +namespace log = oxen::log; + +auto cat = log::Cat("config"); + void ConfigBase::set_state(ConfigState s) { if (s == ConfigState::Dirty && is_readonly()) throw std::runtime_error{"Unable to make changes to a read-only config object"}; @@ -86,6 +93,10 @@ std::vector ConfigBase::merge( ustring_view{_keys.front().data(), _keys.front().size()}, unwrapped, storage_namespace()); + log::warning( + cat, + "Found double wraped message in namespace {}", + static_cast(storage_namespace())); parsed.emplace_back(h, keep_alive.emplace_back(std::move(unwrapped2))); } catch (...) { parsed.emplace_back(h, keep_alive.emplace_back(std::move(unwrapped))); @@ -146,17 +157,21 @@ std::vector ConfigBase::_merge( plaintexts.emplace_back(hash, decrypt(conf, key(i), encryption_domain())); decrypted = true; } catch (const decrypt_error&) { - log(LogLevel::debug, - "Failed to decrypt message " + std::to_string(ci) + " using key " + - std::to_string(i)); + log::debug(cat, "Failed to decrypt message {} using key {}", ci, i); } } if (!decrypted) - log(LogLevel::warning, "Failed to decrypt message " + std::to_string(ci)); + log::warning( + cat, + "Failed to decrypt message {} for namespace {}", + ci, + static_cast(storage_namespace())); } - log(LogLevel::debug, - "successfully decrypted " + std::to_string(plaintexts.size()) + " of " + - std::to_string(configs.size()) + " incoming messages"); + log::debug( + cat, + "successfully decrypted {} of {} incoming messages", + plaintexts.size(), + configs.size()); for (auto& [hash, plain] : plaintexts) { // Remove prefix padding: @@ -165,13 +180,13 @@ std::vector ConfigBase::_merge( plain.resize(plain.size() - p); } if (plain.empty()) { - log(LogLevel::error, "Invalid config message: contains no data"); + log::error(cat, "Invalid config message: contains no data"); continue; } // TODO FIXME (see above) if (plain[0] == 'm') { - log(LogLevel::warning, "multi-part messages not yet supported!"); + log::warning(cat, "multi-part messages not yet supported!"); continue; } @@ -182,17 +197,18 @@ std::vector ConfigBase::_merge( decompressed && !decompressed->empty()) plain = std::move(*decompressed); else { - log(LogLevel::warning, "Invalid config message: decompression failed"); + log::warning(cat, "Invalid config message: decompression failed"); continue; } } if (plain[0] != 'd') - log(LogLevel::error, - "invalid/unsupported config message with type " + - (plain[0] >= 0x20 && plain[0] <= 0x7e - ? "'" + std::string{from_unsigned_sv(plain.substr(0, 1))} + "'" - : "0x" + oxenc::to_hex(plain.begin(), plain.begin() + 1))); + log::error( + cat, + "invalid/unsupported config message with type {}", + (plain[0] >= 0x20 && plain[0] <= 0x7e + ? "'{}'"_format(static_cast(plain[0])) + : "0x{:02x}"_format(plain[0]))); all_hashes.emplace_back(hash); all_confs.emplace_back(plain); @@ -207,7 +223,7 @@ std::vector ConfigBase::_merge( _config->signer, config_lags(), [&](size_t i, const config_error& e) { - log(LogLevel::warning, e.what()); + log::warning(cat, "{}", e.what()); assert(i > 0); // i == 0 means we can't deserialize our own serialization bad_confs.insert(i); }); @@ -217,9 +233,13 @@ std::vector ConfigBase::_merge( // might be our current config, or might be one single one of the new incoming messages). // - confs that failed to parse (we can't understand them, so leave them behind as they may be // some future message). - int superconf = new_conf->unmerged_index(); // -1 if we had to merge - for (int i = 0; i < all_hashes.size(); i++) { - if (i != superconf && !bad_confs.count(i) && !all_hashes[i].empty()) + std::optional superconf = new_conf->unmerged_index(); // nullopt if we had to merge + std::string_view superconf_hash = + superconf && *superconf < all_hashes.size() ? all_hashes[*superconf] : ""; + + for (size_t i = 0; i < all_hashes.size(); i++) { + if (i != superconf && !bad_confs.count(i) && !all_hashes[i].empty() && + superconf_hash != all_hashes[i]) _old_hashes.emplace(all_hashes[i]); } @@ -248,12 +268,8 @@ std::vector ConfigBase::_merge( assert(((old_seqno == 0 && mine.empty()) || _config->unmerged_index() >= 1) && _config->unmerged_index() < all_hashes.size()); set_state(ConfigState::Clean); - _curr_hash = all_hashes[_config->unmerged_index()]; + _curr_hash = all_hashes[*_config->unmerged_index()]; } - } else { - // the merging affect nothing (if it had seqno would have been incremented), so don't - // pointlessly replace the inner config object. - assert(new_conf->unmerged_index() == 0); } std::vector good_hashes; @@ -272,6 +288,17 @@ std::vector ConfigBase::current_hashes() const { return hashes; } +std::vector ConfigBase::old_hashes() { + std::vector hashes; + if (!is_dirty()) { + for (auto& old : _old_hashes) + hashes.push_back(std::move(old)); + _old_hashes.clear(); + } + + return hashes; +} + bool ConfigBase::needs_push() const { return !is_clean(); } @@ -352,7 +379,7 @@ ustring ConfigBase::make_dump() const { d.append("$", data_sv); d.append("(", _curr_hash); - d.append_list(")").append(_old_hashes.begin(), _old_hashes.end()); + d.append_list(")").extend(_old_hashes.begin(), _old_hashes.end()); extra_data(d.append_dict("+")); @@ -647,13 +674,15 @@ LIBSESSION_EXPORT config_string_list* config_merge( const unsigned char** configs, const size_t* lengths, size_t count) { - auto& config = *unbox(conf); - std::vector> confs; - confs.reserve(count); - for (size_t i = 0; i < count; i++) - confs.emplace_back(msg_hashes[i], ustring_view{configs[i], lengths[i]}); - - return make_string_list(config.merge(confs)); + return wrap_exceptions(conf, [&] { + auto& config = *unbox(conf); + std::vector> confs; + confs.reserve(count); + for (size_t i = 0; i < count; i++) + confs.emplace_back(msg_hashes[i], ustring_view{configs[i], lengths[i]}); + + return make_string_list(config.merge(confs)); + }); } LIBSESSION_EXPORT bool config_needs_push(const config_object* conf) { @@ -661,39 +690,41 @@ LIBSESSION_EXPORT bool config_needs_push(const config_object* conf) { } LIBSESSION_EXPORT config_push_data* config_push(config_object* conf) { - auto& config = *unbox(conf); - auto [seqno, data, obs] = config.push(); - - // We need to do one alloc here that holds everything: - // - the returned struct - // - pointers to the obsolete message hash strings - // - the data - // - the message hash strings - size_t buffer_size = sizeof(config_push_data) + obs.size() * sizeof(char*) + data.size(); - for (auto& o : obs) - buffer_size += o.size(); - buffer_size += obs.size(); // obs msg hash string NULL terminators - - auto* ret = static_cast(std::malloc(buffer_size)); - - ret->seqno = seqno; - - static_assert(alignof(config_push_data) >= alignof(char*)); - ret->obsolete = reinterpret_cast(ret + 1); - ret->obsolete_len = obs.size(); - - ret->config = reinterpret_cast(ret->obsolete + ret->obsolete_len); - ret->config_len = data.size(); - - std::memcpy(ret->config, data.data(), data.size()); - char* obsptr = reinterpret_cast(ret->config + ret->config_len); - for (size_t i = 0; i < obs.size(); i++) { - std::memcpy(obsptr, obs[i].c_str(), obs[i].size() + 1); - ret->obsolete[i] = obsptr; - obsptr += obs[i].size() + 1; - } + return wrap_exceptions(conf, [&] { + auto& config = *unbox(conf); + auto [seqno, data, obs] = config.push(); + + // We need to do one alloc here that holds everything: + // - the returned struct + // - pointers to the obsolete message hash strings + // - the data + // - the message hash strings + size_t buffer_size = sizeof(config_push_data) + obs.size() * sizeof(char*) + data.size(); + for (auto& o : obs) + buffer_size += o.size(); + buffer_size += obs.size(); // obs msg hash string NULL terminators + + auto* ret = static_cast(std::malloc(buffer_size)); + + ret->seqno = seqno; + + static_assert(alignof(config_push_data) >= alignof(char*)); + ret->obsolete = reinterpret_cast(ret + 1); + ret->obsolete_len = obs.size(); + + ret->config = reinterpret_cast(ret->obsolete + ret->obsolete_len); + ret->config_len = data.size(); + + std::memcpy(ret->config, data.data(), data.size()); + char* obsptr = reinterpret_cast(ret->config + ret->config_len); + for (size_t i = 0; i < obs.size(); i++) { + std::memcpy(obsptr, obs[i].c_str(), obs[i].size() + 1); + ret->obsolete[i] = obsptr; + obsptr += obs[i].size() + 1; + } - return ret; + return ret; + }); } LIBSESSION_EXPORT void config_confirm_pushed( @@ -701,12 +732,18 @@ LIBSESSION_EXPORT void config_confirm_pushed( unbox(conf)->confirm_pushed(seqno, msg_hash); } -LIBSESSION_EXPORT void config_dump(config_object* conf, unsigned char** out, size_t* outlen) { - assert(out && outlen); - auto data = unbox(conf)->dump(); - *outlen = data.size(); - *out = static_cast(std::malloc(data.size())); - std::memcpy(*out, data.data(), data.size()); +LIBSESSION_EXPORT bool config_dump(config_object* conf, unsigned char** out, size_t* outlen) { + return wrap_exceptions( + conf, + [&] { + assert(out && outlen); + auto data = unbox(conf)->dump(); + *outlen = data.size(); + *out = static_cast(std::malloc(data.size())); + std::memcpy(*out, data.data(), data.size()); + return true; + }, + false); } LIBSESSION_EXPORT bool config_needs_dump(const config_object* conf) { @@ -717,10 +754,15 @@ LIBSESSION_EXPORT config_string_list* config_current_hashes(const config_object* return make_string_list(unbox(conf)->current_hashes()); } +LIBSESSION_EXPORT config_string_list* config_old_hashes(config_object* conf) { + return make_string_list(unbox(conf)->old_hashes()); +} + LIBSESSION_EXPORT unsigned char* config_get_keys(const config_object* conf, size_t* len) { const auto keys = unbox(conf)->get_keys(); - assert(std::count_if(keys.begin(), keys.end(), [](const auto& k) { return k.size() == 32; }) == - keys.size()); + assert(static_cast(std::count_if(keys.begin(), keys.end(), [](const auto& k) { + return k.size() == 32; + })) == keys.size()); assert(len); *len = keys.size(); if (keys.empty()) @@ -735,11 +777,24 @@ LIBSESSION_EXPORT unsigned char* config_get_keys(const config_object* conf, size return buf; } -LIBSESSION_EXPORT void config_add_key(config_object* conf, const unsigned char* key) { - unbox(conf)->add_key({key, 32}); +LIBSESSION_EXPORT bool config_add_key(config_object* conf, const unsigned char* key) { + return wrap_exceptions( + conf, + [&] { + unbox(conf)->add_key({key, 32}); + return true; + }, + false); } -LIBSESSION_EXPORT void config_add_key_low_prio(config_object* conf, const unsigned char* key) { - unbox(conf)->add_key({key, 32}, /*high_priority=*/false); + +LIBSESSION_EXPORT bool config_add_key_low_prio(config_object* conf, const unsigned char* key) { + return wrap_exceptions( + conf, + [&] { + unbox(conf)->add_key({key, 32}, /*high_priority=*/false); + return true; + }, + false); } LIBSESSION_EXPORT int config_clear_keys(config_object* conf) { return unbox(conf)->clear_keys(); @@ -751,7 +806,11 @@ LIBSESSION_EXPORT int config_key_count(const config_object* conf) { return unbox(conf)->key_count(); } LIBSESSION_EXPORT bool config_has_key(const config_object* conf, const unsigned char* key) { - return unbox(conf)->has_key({key, 32}); + try { + return unbox(conf)->has_key({key, 32}); + } catch (...) { + return false; + } } LIBSESSION_EXPORT const unsigned char* config_key(const config_object* conf, size_t i) { return unbox(conf)->key(i).data(); @@ -761,12 +820,24 @@ LIBSESSION_EXPORT const char* config_encryption_domain(const config_object* conf return unbox(conf)->encryption_domain(); } -LIBSESSION_EXPORT void config_set_sig_keys(config_object* conf, const unsigned char* secret) { - unbox(conf)->set_sig_keys({secret, 64}); +LIBSESSION_EXPORT bool config_set_sig_keys(config_object* conf, const unsigned char* secret) { + return wrap_exceptions( + conf, + [&] { + unbox(conf)->set_sig_keys({secret, 64}); + return true; + }, + false); } -LIBSESSION_EXPORT void config_set_sig_pubkey(config_object* conf, const unsigned char* pubkey) { - unbox(conf)->set_sig_pubkey({pubkey, 32}); +LIBSESSION_EXPORT bool config_set_sig_pubkey(config_object* conf, const unsigned char* pubkey) { + return wrap_exceptions( + conf, + [&] { + unbox(conf)->set_sig_pubkey({pubkey, 32}); + return true; + }, + false); } LIBSESSION_EXPORT const unsigned char* config_get_sig_pubkey(const config_object* conf) { @@ -780,14 +851,4 @@ LIBSESSION_EXPORT void config_clear_sig_keys(config_object* conf) { unbox(conf)->clear_sig_keys(); } -LIBSESSION_EXPORT void config_set_logger( - config_object* conf, void (*callback)(config_log_level, const char*, void*), void* ctx) { - if (!callback) - unbox(conf)->logger = nullptr; - else - unbox(conf)->logger = [callback, ctx](LogLevel lvl, std::string msg) { - callback(static_cast(static_cast(lvl)), msg.c_str(), ctx); - }; -} - } // extern "C" diff --git a/src/config/community.cpp b/src/config/community.cpp index 4c1a32e5..8a4f1351 100644 --- a/src/config/community.cpp +++ b/src/config/community.cpp @@ -88,67 +88,6 @@ std::string community::full_url( return url; } -// returns protocol, host, port. Port can be empty; throws on unparseable values. protocol and -// host get normalized to lower-case. Port will be 0 if not present in the URL, or if set to -// the default for the protocol. The URL must not include a path (though a single optional `/` -// after the domain is accepted and ignored). -std::tuple parse_url(std::string_view url) { - std::tuple result{}; - auto& [proto, host, port] = result; - if (auto pos = url.find("://"); pos != std::string::npos) { - auto proto_name = url.substr(0, pos); - url.remove_prefix(proto_name.size() + 3); - if (string_iequal(proto_name, "http")) - proto = "http://"; - else if (string_iequal(proto_name, "https")) - proto = "https://"; - } - if (proto.empty()) - throw std::invalid_argument{"Invalid community URL: invalid/missing protocol://"}; - - bool next_allow_dot = false; - bool has_dot = false; - while (!url.empty()) { - auto c = url.front(); - if ((c >= '0' && c <= '9') || (c >= 'a' && c <= 'z') || c == '-') { - host += c; - next_allow_dot = true; - } else if (c >= 'A' && c <= 'Z') { - host += c + ('a' - 'A'); - next_allow_dot = true; - } else if (next_allow_dot && c == '.') { - host += '.'; - has_dot = true; - next_allow_dot = false; - } else { - break; - } - url.remove_prefix(1); - } - if (host.size() < 4 || !has_dot || host.back() == '.') - throw std::invalid_argument{"Invalid community URL: invalid hostname"}; - - if (!url.empty() && url.front() == ':') { - url.remove_prefix(1); - if (auto [p, ec] = std::from_chars(url.data(), url.data() + url.size(), port); - ec == std::errc{}) - url.remove_prefix(p - url.data()); - else - throw std::invalid_argument{"Invalid community URL: invalid port"}; - if ((port == 80 && proto == "http://") || (port == 443 && proto == "https://")) - port = 0; - } - - if (!url.empty() && url.front() == '/') - url.remove_prefix(1); - - // We don't (currently) allow a /path in a community URL - if (!url.empty()) - throw std::invalid_argument{"Invalid community URL: found unexpected trailing value"}; - - return result; -} - void community::canonicalize_url(std::string& url) { if (auto new_url = canonical_url(url); new_url != url) url = std::move(new_url); @@ -168,14 +107,17 @@ void community::canonicalize_room(std::string& room) { } std::string community::canonical_url(std::string_view url) { - const auto& [proto, host, port] = parse_url(url); + const auto& [proto, host, port, path] = parse_url(url); std::string result; result += proto; result += host; - if (port != 0) { + if (port) { result += ':'; - result += std::to_string(port); + result += std::to_string(*port); } + // We don't (currently) allow a /path in a community URL + if (path) + throw std::invalid_argument{"Invalid community URL: found unexpected trailing value"}; if (result.size() > BASE_URL_MAX_LENGTH) throw std::invalid_argument{"Invalid community URL: base URL is too long"}; return result; @@ -274,6 +216,5 @@ LIBSESSION_C_API void community_make_full_url( auto full = session::config::community::full_url(base_url, room, session::ustring_view{pubkey, 32}); assert(full.size() <= COMMUNITY_FULL_URL_MAX_LENGTH); - size_t pos = 0; std::memcpy(full_url, full.data(), full.size() + 1); } diff --git a/src/config/contacts.cpp b/src/config/contacts.cpp index 1e0437a0..e23f213a 100644 --- a/src/config/contacts.cpp +++ b/src/config/contacts.cpp @@ -41,7 +41,7 @@ contact_info::contact_info(std::string sid) : session_id{std::move(sid)} { void contact_info::set_name(std::string n) { if (n.size() > MAX_NAME_LENGTH) - name = std::move(utf8_truncate(std::move(n), MAX_NAME_LENGTH)); + name = utf8_truncate(std::move(n), MAX_NAME_LENGTH); else name = std::move(n); } @@ -181,17 +181,16 @@ std::optional Contacts::get(std::string_view pubkey_hex) const { LIBSESSION_C_API bool contacts_get( config_object* conf, contacts_contact* contact, const char* session_id) { - try { - conf->last_error = nullptr; - if (auto c = unbox(conf)->get(session_id)) { - c->into(*contact); - return true; - } - } catch (const std::exception& e) { - copy_c_str(conf->_error_buf, e.what()); - conf->last_error = conf->_error_buf; - } - return false; + return wrap_exceptions( + conf, + [&] { + if (auto c = unbox(conf)->get(session_id)) { + c->into(*contact); + return true; + } + return false; + }, + false); } contact_info Contacts::get_or_construct(std::string_view pubkey_hex) const { @@ -203,15 +202,13 @@ contact_info Contacts::get_or_construct(std::string_view pubkey_hex) const { LIBSESSION_C_API bool contacts_get_or_construct( config_object* conf, contacts_contact* contact, const char* session_id) { - try { - conf->last_error = nullptr; - unbox(conf)->get_or_construct(session_id).into(*contact); - return true; - } catch (const std::exception& e) { - copy_c_str(conf->_error_buf, e.what()); - conf->last_error = conf->_error_buf; - return false; - } + return wrap_exceptions( + conf, + [&] { + unbox(conf)->get_or_construct(session_id).into(*contact); + return true; + }, + false); } void Contacts::set(const contact_info& contact) { @@ -252,8 +249,14 @@ void Contacts::set(const contact_info& contact) { set_positive_int(info["j"], contact.created); } -LIBSESSION_C_API void contacts_set(config_object* conf, const contacts_contact* contact) { - unbox(conf)->set(contact_info{*contact}); +LIBSESSION_C_API bool contacts_set(config_object* conf, const contacts_contact* contact) { + return wrap_exceptions( + conf, + [&] { + unbox(conf)->set(contact_info{*contact}); + return true; + }, + false); } void Contacts::set_name(std::string_view session_id, std::string name) { diff --git a/src/config/convo_info_volatile.cpp b/src/config/convo_info_volatile.cpp index 421a8666..3355f1dc 100644 --- a/src/config/convo_info_volatile.cpp +++ b/src/config/convo_info_volatile.cpp @@ -235,10 +235,9 @@ void ConvoInfoVolatile::set_base(const convo::base& c, DictFieldProxy& info) { } void ConvoInfoVolatile::prune_stale(std::chrono::milliseconds prune) { - const int64_t cutoff = - std::chrono::duration_cast( - (std::chrono::system_clock::now() - PRUNE_HIGH).time_since_epoch()) - .count(); + const int64_t cutoff = std::chrono::duration_cast( + (std::chrono::system_clock::now() - prune).time_since_epoch()) + .count(); std::vector stale; for (auto it = begin_1to1(); it != end(); ++it) @@ -497,30 +496,27 @@ int convo_info_volatile_init( LIBSESSION_C_API bool convo_info_volatile_get_1to1( config_object* conf, convo_info_volatile_1to1* convo, const char* session_id) { - try { - conf->last_error = nullptr; - if (auto c = unbox(conf)->get_1to1(session_id)) { - c->into(*convo); - return true; - } - } catch (const std::exception& e) { - copy_c_str(conf->_error_buf, e.what()); - conf->last_error = conf->_error_buf; - } - return false; + return wrap_exceptions( + conf, + [&] { + if (auto c = unbox(conf)->get_1to1(session_id)) { + c->into(*convo); + return true; + } + return false; + }, + false); } LIBSESSION_C_API bool convo_info_volatile_get_or_construct_1to1( config_object* conf, convo_info_volatile_1to1* convo, const char* session_id) { - try { - conf->last_error = nullptr; - unbox(conf)->get_or_construct_1to1(session_id).into(*convo); - return true; - } catch (const std::exception& e) { - copy_c_str(conf->_error_buf, e.what()); - conf->last_error = conf->_error_buf; - return false; - } + return wrap_exceptions( + conf, + [&] { + unbox(conf)->get_or_construct_1to1(session_id).into(*convo); + return true; + }, + false); } LIBSESSION_C_API bool convo_info_volatile_get_community( @@ -528,17 +524,16 @@ LIBSESSION_C_API bool convo_info_volatile_get_community( convo_info_volatile_community* og, const char* base_url, const char* room) { - try { - conf->last_error = nullptr; - if (auto c = unbox(conf)->get_community(base_url, room)) { - c->into(*og); - return true; - } - } catch (const std::exception& e) { - copy_c_str(conf->_error_buf, e.what()); - conf->last_error = conf->_error_buf; - } - return false; + return wrap_exceptions( + conf, + [&] { + if (auto c = unbox(conf)->get_community(base_url, room)) { + c->into(*og); + return true; + } + return false; + }, + false); } LIBSESSION_C_API bool convo_info_volatile_get_or_construct_community( config_object* conf, @@ -546,121 +541,129 @@ LIBSESSION_C_API bool convo_info_volatile_get_or_construct_community( const char* base_url, const char* room, unsigned const char* pubkey) { - try { - conf->last_error = nullptr; - unbox(conf) - ->get_or_construct_community(base_url, room, ustring_view{pubkey, 32}) - .into(*convo); - return true; - } catch (const std::exception& e) { - copy_c_str(conf->_error_buf, e.what()); - conf->last_error = conf->_error_buf; - return false; - } + return wrap_exceptions( + conf, + [&] { + unbox(conf) + ->get_or_construct_community(base_url, room, ustring_view{pubkey, 32}) + .into(*convo); + return true; + }, + false); } LIBSESSION_C_API bool convo_info_volatile_get_group( config_object* conf, convo_info_volatile_group* convo, const char* id) { - try { - conf->last_error = nullptr; - if (auto c = unbox(conf)->get_group(id)) { - c->into(*convo); - return true; - } - } catch (const std::exception& e) { - copy_c_str(conf->_error_buf, e.what()); - conf->last_error = conf->_error_buf; - } - return false; + return wrap_exceptions( + conf, + [&] { + if (auto c = unbox(conf)->get_group(id)) { + c->into(*convo); + return true; + } + return false; + }, + false); } LIBSESSION_C_API bool convo_info_volatile_get_or_construct_group( config_object* conf, convo_info_volatile_group* convo, const char* id) { - try { - conf->last_error = nullptr; - unbox(conf)->get_or_construct_group(id).into(*convo); - return true; - } catch (const std::exception& e) { - copy_c_str(conf->_error_buf, e.what()); - conf->last_error = conf->_error_buf; - return false; - } + return wrap_exceptions( + conf, + [&] { + unbox(conf)->get_or_construct_group(id).into(*convo); + return true; + }, + false); } LIBSESSION_C_API bool convo_info_volatile_get_legacy_group( config_object* conf, convo_info_volatile_legacy_group* convo, const char* id) { - try { - conf->last_error = nullptr; - if (auto c = unbox(conf)->get_legacy_group(id)) { - c->into(*convo); - return true; - } - } catch (const std::exception& e) { - copy_c_str(conf->_error_buf, e.what()); - conf->last_error = conf->_error_buf; - } - return false; + return wrap_exceptions( + conf, + [&] { + if (auto c = unbox(conf)->get_legacy_group(id)) { + c->into(*convo); + return true; + } + return false; + }, + false); } LIBSESSION_C_API bool convo_info_volatile_get_or_construct_legacy_group( config_object* conf, convo_info_volatile_legacy_group* convo, const char* id) { - try { - conf->last_error = nullptr; - unbox(conf)->get_or_construct_legacy_group(id).into(*convo); - return true; - } catch (const std::exception& e) { - copy_c_str(conf->_error_buf, e.what()); - conf->last_error = conf->_error_buf; - return false; - } + return wrap_exceptions( + conf, + [&] { + unbox(conf)->get_or_construct_legacy_group(id).into(*convo); + return true; + }, + false); } -LIBSESSION_C_API void convo_info_volatile_set_1to1( +LIBSESSION_C_API bool convo_info_volatile_set_1to1( config_object* conf, const convo_info_volatile_1to1* convo) { - unbox(conf)->set(convo::one_to_one{*convo}); -} -LIBSESSION_C_API void convo_info_volatile_set_community( + return wrap_exceptions( + conf, + [&] { + unbox(conf)->set(convo::one_to_one{*convo}); + return true; + }, + false); +} +LIBSESSION_C_API bool convo_info_volatile_set_community( config_object* conf, const convo_info_volatile_community* convo) { - unbox(conf)->set(convo::community{*convo}); -} -LIBSESSION_C_API void convo_info_volatile_set_group( + return wrap_exceptions( + conf, + [&] { + unbox(conf)->set(convo::community{*convo}); + return true; + }, + false); +} +LIBSESSION_C_API bool convo_info_volatile_set_group( config_object* conf, const convo_info_volatile_group* convo) { - unbox(conf)->set(convo::group{*convo}); -} -LIBSESSION_C_API void convo_info_volatile_set_legacy_group( + return wrap_exceptions( + conf, + [&] { + unbox(conf)->set(convo::group{*convo}); + return true; + }, + false); +} +LIBSESSION_C_API bool convo_info_volatile_set_legacy_group( config_object* conf, const convo_info_volatile_legacy_group* convo) { - unbox(conf)->set(convo::legacy_group{*convo}); + return wrap_exceptions( + conf, + [&] { + unbox(conf)->set(convo::legacy_group{*convo}); + return true; + }, + false); } LIBSESSION_C_API bool convo_info_volatile_erase_1to1(config_object* conf, const char* session_id) { - try { - return unbox(conf)->erase_1to1(session_id); - } catch (...) { - return false; - } + return wrap_exceptions( + conf, [&] { return unbox(conf)->erase_1to1(session_id); }, false); } LIBSESSION_C_API bool convo_info_volatile_erase_community( config_object* conf, const char* base_url, const char* room) { - try { - return unbox(conf)->erase_community(base_url, room); - } catch (...) { - return false; - } + return wrap_exceptions( + conf, + [&] { return unbox(conf)->erase_community(base_url, room); }, + false); } LIBSESSION_C_API bool convo_info_volatile_erase_group(config_object* conf, const char* group_id) { - try { - return unbox(conf)->erase_group(group_id); - } catch (...) { - return false; - } + return wrap_exceptions( + conf, [&] { return unbox(conf)->erase_group(group_id); }, false); } LIBSESSION_C_API bool convo_info_volatile_erase_legacy_group( config_object* conf, const char* group_id) { - try { - return unbox(conf)->erase_legacy_group(group_id); - } catch (...) { - return false; - } + return wrap_exceptions( + conf, + [&] { return unbox(conf)->erase_legacy_group(group_id); }, + false); } LIBSESSION_C_API size_t convo_info_volatile_size(const config_object* conf) { diff --git a/src/config/groups/info.cpp b/src/config/groups/info.cpp index e9b57a92..4fbb2d86 100644 --- a/src/config/groups/info.cpp +++ b/src/config/groups/info.cpp @@ -171,12 +171,13 @@ LIBSESSION_C_API const char* groups_info_get_name(const config_object* conf) { /// Outputs: /// - `int` -- Returns 0 on success, non-zero on error LIBSESSION_C_API int groups_info_set_name(config_object* conf, const char* name) { - try { - unbox(conf)->set_name(name); - } catch (const std::exception& e) { - return set_error(conf, SESSION_ERR_BAD_VALUE, e); - } - return 0; + return wrap_exceptions( + conf, + [&] { + unbox(conf)->set_name(name); + return 0; + }, + static_cast(SESSION_ERR_BAD_VALUE)); } /// API: groups_info/groups_info_get_description @@ -209,12 +210,13 @@ LIBSESSION_C_API const char* groups_info_get_description(const config_object* co /// Outputs: /// - `int` -- Returns 0 on success, non-zero on error LIBSESSION_C_API int groups_info_set_description(config_object* conf, const char* description) { - try { - unbox(conf)->set_description(description); - } catch (const std::exception& e) { - return set_error(conf, SESSION_ERR_BAD_VALUE, e); - } - return 0; + return wrap_exceptions( + conf, + [&] { + unbox(conf)->set_description(description); + return 0; + }, + static_cast(SESSION_ERR_BAD_VALUE)); } /// API: groups_info/groups_info_get_pic @@ -256,13 +258,13 @@ LIBSESSION_C_API int groups_info_set_pic(config_object* conf, user_profile_pic p if (!url.empty()) key = {pic.key, 32}; - try { - unbox(conf)->set_profile_pic(url, key); - } catch (const std::exception& e) { - return set_error(conf, SESSION_ERR_BAD_VALUE, e); - } - - return 0; + return wrap_exceptions( + conf, + [&] { + unbox(conf)->set_profile_pic(url, key); + return 0; + }, + static_cast(SESSION_ERR_BAD_VALUE)); } /// API: groups_info/groups_info_get_expiry_timer diff --git a/src/config/groups/keys.cpp b/src/config/groups/keys.cpp index 6b9f584a..76c75969 100644 --- a/src/config/groups/keys.cpp +++ b/src/config/groups/keys.cpp @@ -596,7 +596,6 @@ ustring Keys::swarm_make_subaccount(std::string_view session_id, bool write, boo auto X = session_id_pk(session_id); auto& c = _sign_sk; - auto& C = *_sign_pk; auto k = subaccount_blind_factor(X); @@ -634,9 +633,6 @@ ustring Keys::swarm_subaccount_token(std::string_view session_id, bool write, bo // Similar to the above, but we only care about getting flags || kT auto X = session_id_pk(session_id); - auto& c = _sign_sk; - auto& C = *_sign_pk; - auto k = subaccount_blind_factor(X); // T = |S| @@ -1090,7 +1086,7 @@ bool Keys::load_key_message( } } - verify_config_sig(d, data, verifier_); + verify_config_sig(d, verifier_); // If this is our pending config or this has a later generation than our pending config then // drop our pending status. @@ -1402,12 +1398,39 @@ const groups::Keys& unbox(const config_group_keys* conf) { return *static_cast(conf->internals); } -void set_error(config_group_keys* conf, std::string_view e) { - if (e.size() > 255) - e.remove_suffix(e.size() - 255); - std::memcpy(conf->_error_buf, e.data(), e.size()); - conf->_error_buf[e.size()] = 0; - conf->last_error = conf->_error_buf; +// Wraps a labmda and, if an exception is thrown, sets an error message in the internals.error +// string and updates the last_error pointer in the outer (C) config_object struct to point at it. +// +// No return value: accepts void and pointer returns; pointer returns will become nullptr on error +template +decltype(auto) wrap_exceptions(config_group_keys* conf, Call&& f) { + using Ret = std::invoke_result_t; + + try { + conf->last_error = nullptr; + return std::invoke(std::forward(f)); + } catch (const std::exception& e) { + copy_c_str(conf->_error_buf, e.what()); + conf->last_error = conf->_error_buf; + } + if constexpr (std::is_pointer_v) + return nullptr; + else + static_assert(std::is_void_v, "Don't know how to return an error value!"); +} + +// Same as above but accepts callbacks with value returns on errors: returns `f()` on success, +// `error_return` on exception +template +Ret wrap_exceptions(config_group_keys* conf, Call&& f, Ret error_return) { + try { + conf->last_error = nullptr; + return std::invoke(std::forward(f)); + } catch (const std::exception& e) { + copy_c_str(conf->_error_buf, e.what()); + conf->last_error = conf->_error_buf; + } + return error_return; } } // namespace @@ -1483,16 +1506,16 @@ LIBSESSION_C_API bool groups_keys_load_admin_key( const unsigned char* secret, config_object* info, config_object* members) { - try { - unbox(conf).load_admin_key( - ustring_view{secret, 32}, - *unbox(info), - *unbox(members)); - } catch (const std::exception& e) { - set_error(conf, e.what()); - return false; - } - return true; + return wrap_exceptions( + conf, + [&] { + unbox(conf).load_admin_key( + ustring_view{secret, 32}, + *unbox(info), + *unbox(members)); + return true; + }, + false); } LIBSESSION_C_API bool groups_keys_rekey( @@ -1504,17 +1527,18 @@ LIBSESSION_C_API bool groups_keys_rekey( assert(info && members); auto& keys = unbox(conf); ustring_view to_push; - try { - to_push = keys.rekey(*unbox(info), *unbox(members)); - } catch (const std::exception& e) { - set_error(conf, e.what()); - return false; - } - if (out && outlen) { - *out = to_push.data(); - *outlen = to_push.size(); - } - return true; + + return wrap_exceptions( + conf, + [&] { + to_push = keys.rekey(*unbox(info), *unbox(members)); + if (out && outlen) { + *out = to_push.data(); + *outlen = to_push.size(); + } + return true; + }, + false); } LIBSESSION_C_API bool groups_keys_pending_config( @@ -1537,18 +1561,18 @@ LIBSESSION_C_API bool groups_keys_load_message( config_object* info, config_object* members) { assert(data && info && members); - try { - unbox(conf).load_key_message( - msg_hash, - ustring_view{data, datalen}, - timestamp_ms, - *unbox(info), - *unbox(members)); - } catch (const std::exception& e) { - set_error(conf, e.what()); - return false; - } - return true; + return wrap_exceptions( + conf, + [&] { + unbox(conf).load_key_message( + msg_hash, + ustring_view{data, datalen}, + timestamp_ms, + *unbox(info), + *unbox(members)); + return true; + }, + false); } LIBSESSION_C_API config_string_list* groups_keys_current_hashes(const config_group_keys* conf) { @@ -1601,18 +1625,18 @@ LIBSESSION_C_API bool groups_keys_decrypt_message( size_t* plaintext_len) { assert(ciphertext_in && plaintext_out && plaintext_len); - try { - auto [sid, plaintext] = - unbox(conf).decrypt_message(ustring_view{ciphertext_in, ciphertext_len}); - std::memcpy(session_id, sid.c_str(), sid.size() + 1); - *plaintext_out = static_cast(std::malloc(plaintext.size())); - std::memcpy(*plaintext_out, plaintext.data(), plaintext.size()); - *plaintext_len = plaintext.size(); - return true; - } catch (const std::exception& e) { - set_error(conf, e.what()); - } - return false; + return wrap_exceptions( + conf, + [&] { + auto [sid, plaintext] = + unbox(conf).decrypt_message(ustring_view{ciphertext_in, ciphertext_len}); + std::memcpy(session_id, sid.c_str(), sid.size() + 1); + *plaintext_out = static_cast(std::malloc(plaintext.size())); + std::memcpy(*plaintext_out, plaintext.data(), plaintext.size()); + *plaintext_len = plaintext.size(); + return true; + }, + false); } LIBSESSION_C_API bool groups_keys_key_supplement( @@ -1626,16 +1650,17 @@ LIBSESSION_C_API bool groups_keys_key_supplement( std::vector session_ids; for (size_t i = 0; i < sids_len; i++) session_ids.emplace_back(sids[i]); - try { - auto msg = unbox(conf).key_supplement(session_ids); - *message = static_cast(malloc(msg.size())); - *message_len = msg.size(); - std::memcpy(*message, msg.data(), msg.size()); - return true; - } catch (const std::exception& e) { - set_error(conf, e.what()); - return false; - } + + return wrap_exceptions( + conf, + [&] { + auto msg = unbox(conf).key_supplement(session_ids); + *message = static_cast(malloc(msg.size())); + *message_len = msg.size(); + std::memcpy(*message, msg.data(), msg.size()); + return true; + }, + false); } LIBSESSION_EXPORT int groups_keys_current_generation(config_group_keys* conf) { @@ -1649,15 +1674,15 @@ LIBSESSION_C_API bool groups_keys_swarm_make_subaccount_flags( bool del, unsigned char* sign_value) { assert(sign_value); - try { - auto val = unbox(conf).swarm_make_subaccount(session_id, write, del); - assert(val.size() == 100); - std::memcpy(sign_value, val.data(), val.size()); - return true; - } catch (const std::exception& e) { - set_error(conf, e.what()); - return false; - } + return wrap_exceptions( + conf, + [&] { + auto val = unbox(conf).swarm_make_subaccount(session_id, write, del); + assert(val.size() == 100); + std::memcpy(sign_value, val.data(), val.size()); + return true; + }, + false); } LIBSESSION_C_API bool groups_keys_swarm_make_subaccount( @@ -1687,10 +1712,14 @@ LIBSESSION_C_API bool groups_keys_swarm_verify_subaccount( const char* group_id, const unsigned char* session_ed25519_secretkey, const unsigned char* signing_value) { - return groups::Keys::swarm_verify_subaccount( - group_id, - ustring_view{session_ed25519_secretkey, 64}, - ustring_view{signing_value, 100}); + try { + return groups::Keys::swarm_verify_subaccount( + group_id, + ustring_view{session_ed25519_secretkey, 64}, + ustring_view{signing_value, 100}); + } catch (...) { + return false; + } } LIBSESSION_C_API bool groups_keys_swarm_subaccount_sign( @@ -1703,20 +1732,23 @@ LIBSESSION_C_API bool groups_keys_swarm_subaccount_sign( char* subaccount_sig, char* signature) { assert(msg && signing_value && subaccount && subaccount_sig && signature); - try { - auto auth = unbox(conf).swarm_subaccount_sign( - ustring_view{msg, msg_len}, ustring_view{signing_value, 100}); - assert(auth.subaccount.size() == 48); - assert(auth.subaccount_sig.size() == 88); - assert(auth.signature.size() == 88); - std::memcpy(subaccount, auth.subaccount.c_str(), auth.subaccount.size() + 1); - std::memcpy(subaccount_sig, auth.subaccount_sig.c_str(), auth.subaccount_sig.size() + 1); - std::memcpy(signature, auth.signature.c_str(), auth.signature.size() + 1); - return true; - } catch (const std::exception& e) { - set_error(conf, e.what()); - return false; - } + return wrap_exceptions( + conf, + [&] { + auto auth = unbox(conf).swarm_subaccount_sign( + ustring_view{msg, msg_len}, ustring_view{signing_value, 100}); + assert(auth.subaccount.size() == 48); + assert(auth.subaccount_sig.size() == 88); + assert(auth.signature.size() == 88); + std::memcpy(subaccount, auth.subaccount.c_str(), auth.subaccount.size() + 1); + std::memcpy( + subaccount_sig, + auth.subaccount_sig.c_str(), + auth.subaccount_sig.size() + 1); + std::memcpy(signature, auth.signature.c_str(), auth.signature.size() + 1); + return true; + }, + false); } LIBSESSION_C_API bool groups_keys_swarm_subaccount_sign_binary( @@ -1729,20 +1761,20 @@ LIBSESSION_C_API bool groups_keys_swarm_subaccount_sign_binary( unsigned char* subaccount_sig, unsigned char* signature) { assert(msg && signing_value && subaccount && subaccount_sig && signature); - try { - auto auth = unbox(conf).swarm_subaccount_sign( - ustring_view{msg, msg_len}, ustring_view{signing_value, 100}, true); - assert(auth.subaccount.size() == 36); - assert(auth.subaccount_sig.size() == 64); - assert(auth.signature.size() == 64); - std::memcpy(subaccount, auth.subaccount.data(), 36); - std::memcpy(subaccount_sig, auth.subaccount_sig.data(), 64); - std::memcpy(signature, auth.signature.data(), 64); - return true; - } catch (const std::exception& e) { - set_error(conf, e.what()); - return false; - } + return wrap_exceptions( + conf, + [&] { + auto auth = unbox(conf).swarm_subaccount_sign( + ustring_view{msg, msg_len}, ustring_view{signing_value, 100}, true); + assert(auth.subaccount.size() == 36); + assert(auth.subaccount_sig.size() == 64); + assert(auth.signature.size() == 64); + std::memcpy(subaccount, auth.subaccount.data(), 36); + std::memcpy(subaccount_sig, auth.subaccount_sig.data(), 64); + std::memcpy(signature, auth.signature.data(), 64); + return true; + }, + false); } LIBSESSION_C_API bool groups_keys_swarm_subaccount_token_flags( @@ -1751,15 +1783,15 @@ LIBSESSION_C_API bool groups_keys_swarm_subaccount_token_flags( bool write, bool del, unsigned char* token) { - try { - auto tok = unbox(conf).swarm_subaccount_token(session_id, write, del); - assert(tok.size() == 36); - std::memcpy(token, tok.data(), 36); - return true; - } catch (const std::exception& e) { - set_error(conf, e.what()); - return false; - } + return wrap_exceptions( + conf, + [&] { + auto tok = unbox(conf).swarm_subaccount_token(session_id, write, del); + assert(tok.size() == 36); + std::memcpy(token, tok.data(), 36); + return true; + }, + false); } LIBSESSION_C_API bool groups_keys_swarm_subaccount_token( diff --git a/src/config/groups/members.cpp b/src/config/groups/members.cpp index b3df32d1..0aadcbfc 100644 --- a/src/config/groups/members.cpp +++ b/src/config/groups/members.cpp @@ -18,7 +18,7 @@ void Members::extra_data(oxenc::bt_dict_producer&& extra) const { if (pending_send_ids.empty()) return; - extra.append_list("pending_send_ids").append(pending_send_ids.begin(), pending_send_ids.end()); + extra.append_list("pending_send_ids").extend(pending_send_ids.begin(), pending_send_ids.end()); } void Members::load_extra_data(oxenc::bt_dict_consumer&& extra) { @@ -193,7 +193,7 @@ member::member(const config_group_member& m) : session_id{m.session_id, 66} { ? m.invited : 0; promotion_status = (m.promoted == STATUS_SENT || m.promoted == STATUS_FAILED || - m.invited == STATUS_NOT_SENT) + m.promoted == STATUS_NOT_SENT) ? m.promoted : 0; removed_status = (m.removed == REMOVED_MEMBER || m.removed == REMOVED_MEMBER_AND_MESSAGES) @@ -293,34 +293,37 @@ LIBSESSION_C_API int groups_members_init( LIBSESSION_C_API bool groups_members_get( config_object* conf, config_group_member* member, const char* session_id) { - try { - conf->last_error = nullptr; - if (auto c = unbox(conf)->get(session_id)) { - c->into(*member); - return true; - } - } catch (const std::exception& e) { - copy_c_str(conf->_error_buf, e.what()); - conf->last_error = conf->_error_buf; - } - return false; + return wrap_exceptions( + conf, + [&] { + if (auto c = unbox(conf)->get(session_id)) { + c->into(*member); + return true; + } + return false; + }, + false); } LIBSESSION_C_API bool groups_members_get_or_construct( config_object* conf, config_group_member* member, const char* session_id) { - try { - conf->last_error = nullptr; - unbox(conf)->get_or_construct(session_id).into(*member); - return true; - } catch (const std::exception& e) { - copy_c_str(conf->_error_buf, e.what()); - conf->last_error = conf->_error_buf; - return false; - } -} - -LIBSESSION_C_API void groups_members_set(config_object* conf, const config_group_member* member) { - unbox(conf)->set(groups::member{*member}); + return wrap_exceptions( + conf, + [&] { + unbox(conf)->get_or_construct(session_id).into(*member); + return true; + }, + false); +} + +LIBSESSION_C_API bool groups_members_set(config_object* conf, const config_group_member* member) { + return wrap_exceptions( + conf, + [&] { + unbox(conf)->set(groups::member{*member}); + return true; + }, + false); } LIBSESSION_C_API GROUP_MEMBER_STATUS diff --git a/src/config/internal.cpp b/src/config/internal.cpp index c5b4421c..9ad61eba 100644 --- a/src/config/internal.cpp +++ b/src/config/internal.cpp @@ -211,7 +211,7 @@ std::optional zstd_decompress(ustring_view data, size_t max_size) { ZSTD_initDStream(zds); ZSTD_inBuffer input{/*.src=*/data.data(), /*.size=*/data.size(), /*.pos=*/0}; std::array out_buf; - ZSTD_outBuffer output{/*.dst=*/out_buf.data(), /*.size=*/out_buf.size()}; + ZSTD_outBuffer output{/*.dst=*/out_buf.data(), /*.size=*/out_buf.size(), /*.pos=*/0}; ustring decompressed; diff --git a/src/config/internal.hpp b/src/config/internal.hpp index 8872ea4c..c338cbc3 100644 --- a/src/config/internal.hpp +++ b/src/config/internal.hpp @@ -74,14 +74,6 @@ template return c_wrapper_init_generic(conf, error, ed25519_pubkey, ed25519_secretkey, dump); } -template -void copy_c_str(char (&dest)[N], std::string_view src) { - if (src.size() >= N) - src.remove_suffix(src.size() - N - 1); - std::memcpy(dest, src.data(), src.size()); - dest[src.size()] = 0; -} - // Copies a container of std::strings into a self-contained malloc'ed config_string_list for // returning to C code with the strings and pointers of the string list in the same malloced space, // hanging off the end (so that everything, including string values, is freed by a single `free()`). diff --git a/src/config/protos.cpp b/src/config/protos.cpp index affcc83c..ef1bcbdc 100644 --- a/src/config/protos.cpp +++ b/src/config/protos.cpp @@ -138,7 +138,6 @@ ustring unwrap_config(ustring_view ed25519_sk, ustring_view data, config::Namesp if (!req.ParseFromArray(data.data(), data.size())) throw std::runtime_error{"Failed to parse WebSocketMessage"}; - const auto& msg_type = req.type(); if (req.type() != WebSocketProtos::WebSocketMessage_Type_REQUEST) throw std::runtime_error{"Error: received invalid WebSocketRequest"}; diff --git a/src/config/user_groups.cpp b/src/config/user_groups.cpp index a0487e65..031f4fde 100644 --- a/src/config/user_groups.cpp +++ b/src/config/user_groups.cpp @@ -310,7 +310,7 @@ std::optional UserGroups::get_community( og.load(*info_dict); if (!pubkey.empty()) og.set_pubkey(pubkey); - return std::move(og); + return og; } return std::nullopt; } @@ -647,17 +647,16 @@ int user_groups_init( LIBSESSION_C_API bool user_groups_get_community( config_object* conf, ugroups_community_info* comm, const char* base_url, const char* room) { - try { - conf->last_error = nullptr; - if (auto c = unbox(conf)->get_community(base_url, room)) { - c->into(*comm); - return true; - } - } catch (const std::exception& e) { - copy_c_str(conf->_error_buf, e.what()); - conf->last_error = conf->_error_buf; - } - return false; + return wrap_exceptions( + conf, + [&] { + if (auto c = unbox(conf)->get_community(base_url, room)) { + c->into(*comm); + return true; + } + return false; + }, + false); } LIBSESSION_C_API bool user_groups_get_or_construct_community( config_object* conf, @@ -665,41 +664,38 @@ LIBSESSION_C_API bool user_groups_get_or_construct_community( const char* base_url, const char* room, unsigned const char* pubkey) { - try { - conf->last_error = nullptr; - unbox(conf) - ->get_or_construct_community(base_url, room, ustring_view{pubkey, 32}) - .into(*comm); - return true; - } catch (const std::exception& e) { - copy_c_str(conf->_error_buf, e.what()); - conf->last_error = conf->_error_buf; - return false; - } + return wrap_exceptions( + conf, + [&] { + unbox(conf) + ->get_or_construct_community(base_url, room, ustring_view{pubkey, 32}) + .into(*comm); + return true; + }, + false); } LIBSESSION_C_API bool user_groups_get_group( config_object* conf, ugroups_group_info* group, const char* group_id) { - try { - conf->last_error = nullptr; - if (auto g = unbox(conf)->get_group(group_id)) { - g->into(*group); - return true; - } - } catch (const std::exception& e) { - set_error(conf, e.what()); - } - return false; + return wrap_exceptions( + conf, + [&] { + if (auto g = unbox(conf)->get_group(group_id)) { + g->into(*group); + return true; + } + return false; + }, + false); } LIBSESSION_C_API bool user_groups_get_or_construct_group( config_object* conf, ugroups_group_info* group, const char* group_id) { - try { - conf->last_error = nullptr; - unbox(conf)->get_or_construct_group(group_id).into(*group); - return true; - } catch (const std::exception& e) { - set_error(conf, e.what()); - return false; - } + return wrap_exceptions( + conf, + [&] { + unbox(conf)->get_or_construct_group(group_id).into(*group); + return true; + }, + false); } LIBSESSION_C_API void ugroups_legacy_group_free(ugroups_legacy_group_info* group) { @@ -711,50 +707,59 @@ LIBSESSION_C_API void ugroups_legacy_group_free(ugroups_legacy_group_info* group LIBSESSION_C_API ugroups_legacy_group_info* user_groups_get_legacy_group( config_object* conf, const char* id) { - try { - conf->last_error = nullptr; + return wrap_exceptions(conf, [&] { auto group = std::make_unique(); group->_internal = nullptr; if (auto c = unbox(conf)->get_legacy_group(id)) { std::move(c)->into(*group); return group.release(); } - } catch (const std::exception& e) { - copy_c_str(conf->_error_buf, e.what()); - conf->last_error = conf->_error_buf; - } - return nullptr; + return static_cast(nullptr); + }); } LIBSESSION_C_API ugroups_legacy_group_info* user_groups_get_or_construct_legacy_group( config_object* conf, const char* id) { - try { - conf->last_error = nullptr; + return wrap_exceptions(conf, [&] { auto group = std::make_unique(); group->_internal = nullptr; unbox(conf)->get_or_construct_legacy_group(id).into(*group); return group.release(); - } catch (const std::exception& e) { - copy_c_str(conf->_error_buf, e.what()); - conf->last_error = conf->_error_buf; - return nullptr; - } + }); } LIBSESSION_C_API void user_groups_set_community( config_object* conf, const ugroups_community_info* comm) { unbox(conf)->set(community_info{*comm}); } -LIBSESSION_C_API void user_groups_set_group(config_object* conf, const ugroups_group_info* group) { - unbox(conf)->set(group_info{*group}); +LIBSESSION_C_API bool user_groups_set_group(config_object* conf, const ugroups_group_info* group) { + return wrap_exceptions( + conf, + [&] { + unbox(conf)->set(group_info{*group}); + return true; + }, + false); } -LIBSESSION_C_API void user_groups_set_legacy_group( +LIBSESSION_C_API bool user_groups_set_legacy_group( config_object* conf, const ugroups_legacy_group_info* group) { - unbox(conf)->set(legacy_group_info{*group}); + return wrap_exceptions( + conf, + [&] { + unbox(conf)->set(legacy_group_info{*group}); + return true; + }, + false); } -LIBSESSION_C_API void user_groups_set_free_legacy_group( +LIBSESSION_C_API bool user_groups_set_free_legacy_group( config_object* conf, ugroups_legacy_group_info* group) { - unbox(conf)->set(legacy_group_info{std::move(*group)}); + return wrap_exceptions( + conf, + [&] { + unbox(conf)->set(legacy_group_info{std::move(*group)}); + return true; + }, + false); } LIBSESSION_C_API bool user_groups_erase_community( diff --git a/src/config/user_profile.cpp b/src/config/user_profile.cpp index 96906b74..f01cded6 100644 --- a/src/config/user_profile.cpp +++ b/src/config/user_profile.cpp @@ -48,12 +48,13 @@ void UserProfile::set_name_truncated(std::string new_name) { set_name(utf8_truncate(std::move(new_name), contact_info::MAX_NAME_LENGTH)); } LIBSESSION_C_API int user_profile_set_name(config_object* conf, const char* name) { - try { - unbox(conf)->set_name(name); - } catch (const std::exception& e) { - return set_error(conf, SESSION_ERR_BAD_VALUE, e); - } - return 0; + return wrap_exceptions( + conf, + [&] { + unbox(conf)->set_name(name); + return 0; + }, + static_cast(SESSION_ERR_BAD_VALUE)); } profile_pic UserProfile::get_profile_pic() const { @@ -90,13 +91,13 @@ LIBSESSION_C_API int user_profile_set_pic(config_object* conf, user_profile_pic if (!url.empty()) key = {pic.key, 32}; - try { - unbox(conf)->set_profile_pic(url, key); - } catch (const std::exception& e) { - return set_error(conf, SESSION_ERR_BAD_VALUE, e); - } - - return 0; + return wrap_exceptions( + conf, + [&] { + unbox(conf)->set_profile_pic(url, key); + return 0; + }, + static_cast(SESSION_ERR_BAD_VALUE)); } void UserProfile::set_nts_priority(int priority) { diff --git a/src/curve25519.cpp b/src/curve25519.cpp index 81870cc3..1c8f11b3 100644 --- a/src/curve25519.cpp +++ b/src/curve25519.cpp @@ -10,9 +10,9 @@ namespace session::curve25519 { -std::pair, std::array> curve25519_key_pair() { +std::pair, std::array> curve25519_key_pair() { std::array curve_pk; - std::array curve_sk; + std::array curve_sk; crypto_box_keypair(curve_pk.data(), curve_sk.data()); return {curve_pk, curve_sk}; diff --git a/src/file.cpp b/src/file.cpp new file mode 100644 index 00000000..8f0ee656 --- /dev/null +++ b/src/file.cpp @@ -0,0 +1,38 @@ +#include +#include + +namespace session { + +std::ofstream open_for_writing(const fs::path& filename) { + std::ofstream out; + out.exceptions(std::ios_base::failbit | std::ios_base::badbit); + out.open(filename, std::ios_base::binary | std::ios_base::out | std::ios_base::trunc); + out.exceptions(std::ios_base::badbit); + return out; +} + +std::ifstream open_for_reading(const fs::path& filename) { + std::ifstream in; + in.exceptions(std::ios_base::failbit | std::ios_base::badbit); + in.open(filename, std::ios::binary | std::ios::in); + in.exceptions(std::ios_base::badbit); + return in; +} + +std::string read_whole_file(const fs::path& filename) { + auto in = open_for_reading(filename); + std::string contents; + in.seekg(0, std::ios::end); + auto size = in.tellg(); + in.seekg(0, std::ios::beg); + contents.resize(size); + in.read(contents.data(), size); + return contents; +} + +void write_whole_file(const fs::path& filename, std::string_view contents) { + auto out = open_for_writing(filename); + out.write(contents.data(), static_cast(contents.size())); +} + +} // namespace session diff --git a/src/hash.cpp b/src/hash.cpp index 09968b57..0b87599c 100644 --- a/src/hash.cpp +++ b/src/hash.cpp @@ -14,9 +14,11 @@ ustring hash(const size_t size, ustring_view msg, std::optional ke if (key && key->size() > crypto_generichash_blake2b_BYTES_MAX) throw std::invalid_argument{"Invalid key: expected less than 65 bytes"}; + auto result_code = 0; ustring result; result.resize(size); - crypto_generichash_blake2b( + + result_code = crypto_generichash_blake2b( result.data(), size, msg.data(), @@ -24,6 +26,9 @@ ustring hash(const size_t size, ustring_view msg, std::optional ke key ? key->data() : nullptr, key ? key->size() : 0); + if (result_code != 0) + throw std::runtime_error{"Hash generation failed"}; + return result; } diff --git a/src/logging.cpp b/src/logging.cpp new file mode 100644 index 00000000..779b1f66 --- /dev/null +++ b/src/logging.cpp @@ -0,0 +1,106 @@ +#include "session/logging.hpp" + +#include + +#include +#include +#include + +#include "oxen/log/level.hpp" +#include "session/export.h" + +namespace session { + +namespace log = oxen::log; + +LogLevel::LogLevel(spdlog::level::level_enum lvl) : level{static_cast(lvl)} {} + +spdlog::level::level_enum LogLevel::spdlog_level() const { + return static_cast(level); +} + +std::string_view LogLevel::to_string() const { + return log::to_string(spdlog_level()); +} + +void add_logger(std::function cb) { + log::add_sink(std::make_shared(std::move(cb))); +} +void add_logger( + std::function cb) { + log::add_sink(std::make_shared(std::move(cb))); +} + +void manual_log(std::string_view msg) { + log::info(oxen::log::Cat("manual"), "{}", msg); +} + +void logger_reset_level(LogLevel level) { + log::reset_level(level.spdlog_level()); +} +void logger_set_level_default(LogLevel level) { + log::set_level_default(level.spdlog_level()); +} +LogLevel logger_get_level_default() { + return log::get_level_default(); +} +void logger_set_level(std::string cat_name, LogLevel level) { + log::set_level(std::move(cat_name), level.spdlog_level()); +} +LogLevel logger_get_level(std::string cat_name) { + return log::get_level(std::move(cat_name)); +} + +void clear_loggers() { + log::clear_sinks(); +} + +} // namespace session + +extern "C" { + +LIBSESSION_C_API void session_add_logger_simple(void (*callback)(const char* msg, size_t msglen)) { + assert(callback); + session::add_logger( + [cb = std::move(callback)](std::string_view msg) { cb(msg.data(), msg.size()); }); +} + +LIBSESSION_C_API void session_add_logger_full(void (*callback)( + const char* msg, size_t msglen, const char* cat, size_t cat_len, LOG_LEVEL level)) { + assert(callback); + session::add_logger( + [cb = std::move(callback)]( + std::string_view msg, std::string_view category, session::LogLevel level) { + cb(msg.data(), + msg.size(), + category.data(), + category.size(), + static_cast(level.level)); + }); +} + +LIBSESSION_C_API void session_logger_reset_level(LOG_LEVEL level) { + oxen::log::reset_level(static_cast(level)); +} +LIBSESSION_C_API void session_logger_set_level_default(LOG_LEVEL level) { + oxen::log::set_level_default(static_cast(level)); +} +LIBSESSION_C_API LOG_LEVEL session_logger_get_level_default() { + return static_cast(oxen::log::get_level_default()); +} +LIBSESSION_C_API void session_logger_set_level(const char* cat_name, LOG_LEVEL level) { + oxen::log::set_level(cat_name, static_cast(level)); +} +LIBSESSION_C_API LOG_LEVEL session_logger_get_level(const char* cat_name) { + return static_cast(oxen::log::get_level(cat_name)); +} + +LIBSESSION_C_API void session_manual_log(const char* msg) { + session::manual_log(msg); +} + +LIBSESSION_C_API void session_clear_loggers() { + session::clear_loggers(); +} + +} // extern "C" \ No newline at end of file diff --git a/src/network.cpp b/src/network.cpp new file mode 100644 index 00000000..37b60da8 --- /dev/null +++ b/src/network.cpp @@ -0,0 +1,3262 @@ +#include "session/network.hpp" + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "session/blinding.hpp" +#include "session/ed25519.hpp" +#include "session/export.h" +#include "session/file.hpp" +#include "session/network.h" +#include "session/onionreq/builder.h" +#include "session/onionreq/builder.hpp" +#include "session/onionreq/key_types.hpp" +#include "session/onionreq/response_parser.hpp" +#include "session/util.hpp" + +using namespace oxen; +using namespace session::onionreq; +using namespace std::literals; +using namespace oxen::log::literals; + +namespace session::network { + +namespace { + + inline auto cat = log::Cat("network"); + + class load_cache_exception : public std::runtime_error { + public: + load_cache_exception(std::string message) : std::runtime_error(message) {} + }; + class status_code_exception : public std::runtime_error { + public: + int16_t status_code; + std::vector> headers; + + status_code_exception( + int16_t status_code, + std::vector> headers, + std::string message) : + std::runtime_error(message), status_code{status_code}, headers{headers} {} + }; + + constexpr int16_t error_network_suspended = -10001; + constexpr int16_t error_building_onion_request = -10002; + constexpr int16_t error_path_build_timeout = -10003; + + const std::pair content_type_plain_text = { + "Content-Type", "text/plain; charset=UTF-8"}; + const std::pair content_type_json = { + "Content-Type", "application/json"}; + + // The amount of time the snode cache can be used before it needs to be refreshed/ + constexpr auto snode_cache_expiration_duration = 2h; + + // The smallest size the snode cache can get to before we need to fetch more. + constexpr size_t min_snode_cache_count = 12; + + // The number of snodes to use to refresh the cache. + constexpr int num_snodes_to_refresh_cache_from = 3; + + // The number of snodes (including the guard snode) in a path. + constexpr uint8_t path_size = 3; + + // The number of times a path can fail before it's replaced. + constexpr uint16_t path_failure_threshold = 3; + + // The number of times a snode can fail before it's replaced. + constexpr uint16_t snode_failure_threshold = 3; + + // The frequency to check if queued requests have timed out due to a pending path build + constexpr auto queued_request_path_build_timeout_frequency = 250ms; + + const fs::path default_cache_path{u8"."}, file_testnet{u8"testnet"}, + file_snode_pool{u8"snode_pool"}; + const std::vector legacy_files{ + u8"snode_pool_updated", u8"swarm", u8"snode_failure_counts"}; + + constexpr auto node_not_found_prefix = "502 Bad Gateway\n\nNext node not found: "sv; + constexpr auto node_not_found_prefix_no_status = "Next node not found: "sv; + constexpr auto ALPN = "oxenstorage"sv; + constexpr auto ONION = "onion_req"; + + enum class PathSelectionBehaviour { + random, + new_or_least_busy, + }; + + std::string path_type_name(PathType path_type, bool single_path_mode) { + if (single_path_mode) + return "single_path"; + + switch (path_type) { + case PathType::standard: return "standard"; + case PathType::upload: return "upload"; + case PathType::download: return "download"; + } + return "standard"; // Default + } + + // The mininum number of paths we want to maintain + uint8_t min_path_count(PathType path_type, bool single_path_mode) { + if (single_path_mode) + return 1; + + switch (path_type) { + case PathType::standard: return 2; + case PathType::upload: return 2; + case PathType::download: return 2; + } + return 2; // Default + } + + PathSelectionBehaviour path_selection_behaviour(PathType path_type) { + switch (path_type) { + case PathType::standard: return PathSelectionBehaviour::random; + case PathType::upload: return PathSelectionBehaviour::new_or_least_busy; + case PathType::download: return PathSelectionBehaviour::new_or_least_busy; + } + return PathSelectionBehaviour::random; // Default + } + + /// Converts a string such as "1.2.3" to a vector of ints {1,2,3}. Throws if something + /// in/around the .'s isn't parseable as an integer. + std::vector parse_version(std::string_view vers, bool trim_trailing_zero = true) { + auto v_s = session::split(vers, "."); + std::vector result; + for (const auto& piece : v_s) + if (!quic::parse_int(piece, result.emplace_back())) + throw std::invalid_argument{"Invalid version"}; + + // Remove any trailing `0` values (but ensure we at least end up with a "0" version) + if (trim_trailing_zero) + while (result.size() > 1 && result.back() == 0) + result.pop_back(); + + return result; + } + + service_node node_from_json(nlohmann::json json) { + auto pk_ed = json["pubkey_ed25519"].get(); + if (pk_ed.size() != 64 || !oxenc::is_hex(pk_ed)) + throw std::invalid_argument{ + "Invalid service node json: pubkey_ed25519 is not a valid, hex pubkey"}; + + // When parsing a node from JSON it'll generally be from the 'get_swarm` endpoint or a 421 + // error neither of which contain the `storage_server_version` - luckily we don't need the + // version for these two cases so can just default it to `0` + std::vector storage_server_version = {0}; + if (json.contains("storage_server_version")) { + if (json["storage_server_version"].is_array()) { + if (json["storage_server_version"].size() > 0) { + // Convert the version to a string and parse it back into a version code to + // ensure the version formats remain consistent throughout + storage_server_version = json["storage_server_version"].get>(); + storage_server_version = + parse_version("{}"_format(fmt::join(storage_server_version, "."))); + } + } else + storage_server_version = + parse_version(json["storage_server_version"].get()); + } + + std::string ip; + if (json.contains("public_ip")) + ip = json["public_ip"].get(); + else + ip = json["ip"].get(); + + uint16_t port; + if (json.contains("storage_lmq_port")) + port = json["storage_lmq_port"].get(); + else + port = json["port_omq"].get(); + + swarm_id_t swarm_id = INVALID_SWARM_ID; + if (json.contains("swarm_id")) + swarm_id = json["swarm_id"].get(); + + return {oxenc::from_hex(pk_ed), storage_server_version, swarm_id, ip, port}; + } + + service_node node_from_disk(std::string_view str, bool can_ignore_version = false) { + // Format is "{ip}|{port}|{version}|{ed_pubkey}|{swarm_id}" + auto parts = split(str, "|"); + if (parts.size() != 5) + throw std::invalid_argument("Invalid service node serialisation: {}"_format(str)); + if (parts[3].size() != 64 || !oxenc::is_hex(parts[3])) + throw std::invalid_argument{ + "Invalid service node serialisation: pubkey is not hex or has wrong size"}; + + uint16_t port; + if (!quic::parse_int(parts[1], port)) + throw std::invalid_argument{"Invalid service node serialization: invalid port"}; + + std::vector storage_server_version = parse_version(parts[2]); + if (!can_ignore_version && storage_server_version == std::vector{0}) + throw std::invalid_argument{"Invalid service node serialization: invalid version"}; + + swarm_id_t swarm_id = INVALID_SWARM_ID; + quic::parse_int(parts[4], swarm_id); + + return { + oxenc::from_hex(parts[3]), // ed25519_pubkey + storage_server_version, // storage_server_version + swarm_id, // swarm_id + std::string(parts[0]), // ip + port, // port + }; + } + + const std::vector seed_nodes_testnet{ + node_from_disk("144.76.164.202|35400|2.8.0|" + "decaf007f26d3d6f9b845ad031ffdf6d04638c25bb10b8fffbbe99135303c4b9|"sv)}; + const std::vector seed_nodes_mainnet{ + node_from_disk("144.76.164.202|20200|2.8.0|" + "1f000f09a7b07828dcb72af7cd16857050c10c02bd58afb0e38111fb6cda1fef|"sv), + node_from_disk("88.99.102.229|20201|2.8.0|" + "1f101f0acee4db6f31aaa8b4df134e85ca8a4878efaef7f971e88ab144c1a7ce|"sv), + node_from_disk("195.16.73.17|20202|2.8.0|" + "1f202f00f4d2d4acc01e20773999a291cf3e3136c325474d159814e06199919f|"sv), + node_from_disk("104.194.11.120|20203|2.8.0|" + "1f303f1d7523c46fa5398826740d13282d26b5de90fbae5749442f66afb6d78b|"sv), + node_from_disk("104.194.8.115|20204|2.8.0|" + "1f604f1c858a121a681d8f9b470ef72e6946ee1b9c5ad15a35e16b50c28db7b0|"sv)}; + constexpr auto file_server = "filev2.getsession.org"sv; + constexpr auto file_server_pubkey = + "da21e1d886c6fbaea313f75298bd64aab03a97ce985b46bb2dad9f2089c8ee59"sv; + + std::string node_to_disk(service_node node) { + // Format is "{ip}|{port}|{version}|{ed_pubkey}|{swarm_id}" + auto ed25519_pubkey_hex = oxenc::to_hex(node.view_remote_key()); + + return fmt::format( + "{}|{}|{}|{}|{}", + node.host(), + node.port(), + "{}"_format(fmt::join(node.storage_server_version, ".")), + ed25519_pubkey_hex, + node.swarm_id); + } + + session::onionreq::x25519_pubkey compute_xpk(ustring_view ed25519_pk) { + std::array xpk; + if (0 != crypto_sign_ed25519_pk_to_curve25519(xpk.data(), ed25519_pk.data())) + throw std::runtime_error{ + "An error occured while attempting to convert Ed25519 pubkey to X25519; " + "is the pubkey valid?"}; + return session::onionreq::x25519_pubkey::from_bytes({xpk.data(), 32}); + } + + std::string consume_string(oxenc::bt_dict_consumer dict, std::string_view key) { + if (!dict.skip_until(key)) + throw std::invalid_argument{ + "Unable to find entry in dict for key '" + std::string(key) + "'"}; + return dict.consume_string(); + } + + template + auto consume_integer(oxenc::bt_dict_consumer dict, std::string_view key) { + if (!dict.skip_until(key)) + throw std::invalid_argument{ + "Unable to find entry in dict for key '" + std::string(key) + "'"}; + return dict.next_integer().second; + } +} // namespace + +namespace detail { + swarm_id_t pubkey_to_swarm_space(const session::onionreq::x25519_pubkey& pk) { + swarm_id_t res = 0; + for (size_t i = 0; i < 4; i++) { + swarm_id_t buf; + std::memcpy(&buf, pk.data() + i * 8, 8); + res ^= buf; + } + oxenc::big_to_host_inplace(res); + + return res; + } + + std::vector>> generate_swarms( + std::vector nodes) { + std::vector>> result; + std::unordered_map> _grouped_nodes; + + for (const auto& node : nodes) + _grouped_nodes[node.swarm_id].push_back(node); + + for (auto& [swarm_id, nodes] : _grouped_nodes) + result.emplace_back(swarm_id, std::move(nodes)); + + std::sort(result.begin(), result.end(), [](const auto& a, const auto& b) { + return a.first < b.first; + }); + return result; + } + + std::optional node_for_destination(network_destination destination) { + if (auto* dest = std::get_if(&destination)) + return *dest; + + return std::nullopt; + } + + nlohmann::json get_service_nodes_params(std::optional limit) { + nlohmann::json params{ + {"active_only", true}, + {"fields", + {{"public_ip", true}, + {"pubkey_ed25519", true}, + {"storage_lmq_port", true}, + {"storage_server_version", true}, + {"swarm_id", true}}}}; + + if (limit) + params["limit"] = *limit; + + return params; + } + + std::vector process_get_service_nodes_response( + oxenc::bt_list_consumer result_bencode) { + std::vector result; + result_bencode.skip_value(); // Skip the status code (already validated) + auto response_dict = result_bencode.consume_dict_consumer(); + response_dict.skip_until("result"); + + auto result_dict = response_dict.consume_dict_consumer(); + result_dict.skip_until("service_node_states"); + + // Process the node list + auto node = result_dict.consume_list_consumer(); + + while (!node.is_finished()) { + auto node_consumer = node.consume_dict_consumer(); + auto pubkey_ed25519 = oxenc::from_hex(consume_string(node_consumer, "pubkey_ed25519")); + auto public_ip = consume_string(node_consumer, "public_ip"); + auto storage_lmq_port = consume_integer(node_consumer, "storage_lmq_port"); + + std::vector storage_server_version; + node_consumer.skip_until("storage_server_version"); + auto version_consumer = node_consumer.consume_list_consumer(); + auto swarm_id = consume_integer(node_consumer, "swarm_id"); + + while (!version_consumer.is_finished()) { + storage_server_version.emplace_back(version_consumer.consume_integer()); + } + + result.emplace_back( + pubkey_ed25519, storage_server_version, swarm_id, public_ip, storage_lmq_port); + } + + return result; + } + + std::vector process_get_service_nodes_response(nlohmann::json response_json) { + if (!response_json.contains("result") || !response_json["result"].is_object()) + throw std::runtime_error{"JSON missing result field."}; + + nlohmann::json result_json = response_json["result"]; + if (!result_json.contains("service_node_states") || + !result_json["service_node_states"].is_array()) + throw std::runtime_error{"JSON missing service_node_states field."}; + + std::vector result; + for (auto& snode : result_json["service_node_states"]) + result.emplace_back(node_from_json(snode)); + + return result; + } + + void log_retry_result_if_needed(request_info info, bool single_path_mode) { + if (!info.retry_reason) + return; + + // For debugging purposes if the error was a redirect retry then + // we want to log that the retry was successful as this will + // help identify how often we are receiving incorrect errors + auto reason = "unknown retry"; + + switch (*info.retry_reason) { + case request_info::RetryReason::none: reason = "unknown retry"; break; + case request_info::RetryReason::redirect: reason = "421 retry"; break; + case request_info::RetryReason::decryption_failure: reason = "decryption error"; break; + case request_info::RetryReason::redirect_swarm_refresh: + reason = "421 swarm refresh retry"; + break; + } + + log::info( + cat, + "Received valid response after {} in request {} for {}.", + reason, + info.request_id, + path_type_name(info.path_type, single_path_mode)); + } + + std::vector convert_service_nodes( + std::vector nodes) { + std::vector converted_nodes; + for (auto& node : nodes) { + auto ed25519_pubkey_hex = oxenc::to_hex(node.view_remote_key()); + auto ipv4 = node.to_ipv4(); + network_service_node converted_node; + converted_node.ip[0] = (ipv4.addr >> 24) & 0xFF; + converted_node.ip[1] = (ipv4.addr >> 16) & 0xFF; + converted_node.ip[2] = (ipv4.addr >> 8) & 0xFF; + converted_node.ip[3] = ipv4.addr & 0xFF; + strncpy(converted_node.ed25519_pubkey_hex, ed25519_pubkey_hex.c_str(), 64); + converted_node.ed25519_pubkey_hex[64] = '\0'; // Ensure null termination + converted_node.quic_port = node.port(); + converted_nodes.push_back(converted_node); + } + + return converted_nodes; + } + + ServerDestination convert_server_destination(const network_server_destination server) { + std::optional>> headers; + if (server.headers_size > 0) { + headers = std::vector>{}; + + for (size_t i = 0; i < server.headers_size; i++) + headers->emplace_back(server.headers[i], server.header_values[i]); + } + + return ServerDestination{ + server.protocol, + server.host, + server.endpoint, + x25519_pubkey::from_hex({server.x25519_pubkey, 64}), + server.port, + headers, + server.method}; + } +} // namespace detail + +request_info request_info::make( + onionreq::network_destination _dest, + std::optional _original_body, + std::optional _swarm_pk, + std::chrono::milliseconds _request_timeout, + std::optional _request_and_path_build_timeout, + PathType _type, + std::optional _req_id, + std::optional _ep, + std::optional _body) { + return request_info{ + _req_id.value_or("R-{}"_format(random::random_base32(4))), + std::move(_dest), + _ep.value_or(ONION), + std::move(_body), + std::move(_original_body), + std::move(_swarm_pk), + _type, + _request_timeout, + _request_and_path_build_timeout}; +} + +std::string onion_path::to_string() const { + std::vector node_descriptions; + std::transform( + nodes.begin(), + nodes.end(), + std::back_inserter(node_descriptions), + [](const service_node& node) { return node.to_string(); }); + + return "{}"_format(fmt::join(node_descriptions, ", ")); +} + +bool onion_path::contains_node(const service_node& sn) const { + for (auto& n : nodes) { + if (n == sn) + return true; + } + + return false; +} + +// MARK: Initialization + +Network::Network( + std::optional cache_path, + bool use_testnet, + bool single_path_mode, + bool pre_build_paths) : + use_testnet{use_testnet}, + should_cache_to_disk{cache_path}, + single_path_mode{single_path_mode}, + cache_path{cache_path.value_or(default_cache_path)} { + // Load the cache from disk and start the disk write thread + if (should_cache_to_disk) { + load_cache_from_disk(); + disk_write_thread = std::thread{&Network::disk_write_thread_loop, this}; + } + + // Kick off a separate thread to build paths (may as well kick this off early) + if (pre_build_paths) + for (int i = 0; i < min_path_count(PathType::standard, single_path_mode); ++i) { + auto path_id = "P-{}"_format(random::random_base32(4)); + in_progress_path_builds[path_id] = PathType::standard; + net.call_soon([this, path_id] { build_path(path_id, PathType::standard); }); + } +} + +Network::~Network() { + // We need to explicitly close the connections at the start of the destructor to prevent ban + // memory errors due to complex logic with the quic::Network instance + destroyed = true; + _close_connections(); + + { + std::lock_guard lock{snode_cache_mutex}; + shut_down_disk_thread = true; + } + update_disk_cache_throttled(true); + if (disk_write_thread.joinable()) + disk_write_thread.join(); +} + +// MARK: Cache Management + +void Network::load_cache_from_disk() { + try { + // If the cache is for the wrong network then delete everything + auto testnet_stub = cache_path / file_testnet; + if (use_testnet != fs::exists(testnet_stub) && fs::exists(testnet_stub)) + fs::remove_all(cache_path); + + // Remove any legacy files (don't want to leave old data around) + for (const auto& path : legacy_files) { + auto path_to_remove = cache_path / path; + fs::remove_all(path_to_remove); + } + + // If we are using testnet then create a file to indicate that + if (use_testnet) + write_whole_file(testnet_stub); + + // Load the snode pool + if (auto pool_path = cache_path / file_snode_pool; fs::exists(pool_path)) { + auto ftime = fs::last_write_time(pool_path); + last_snode_cache_update = + std::chrono::time_point_cast( + ftime - fs::file_time_type::clock::now() + + std::chrono::system_clock::now()); + + auto file = open_for_reading(pool_path); + std::vector loaded_cache; + std::string line; + auto invalid_entries = 0; + + while (std::getline(file, line)) { + try { + loaded_cache.push_back(node_from_disk(line)); + } catch (...) { + ++invalid_entries; + } + } + + if (invalid_entries > 0) + log::warning(cat, "Skipped {} invalid entries in snode cache.", invalid_entries); + + snode_cache = loaded_cache; + all_swarms = detail::generate_swarms(loaded_cache); + } + + log::info( + cat, + "Loaded cache of {} snodes, {} swarms.", + snode_cache.size(), + all_swarms.size()); + } catch (const std::exception& e) { + log::error(cat, "Failed to load snode cache, will rebuild ({}).", e.what()); + + if (fs::exists(cache_path)) + fs::remove_all(cache_path); + } +} + +void Network::update_disk_cache_throttled(bool force_immediate_write) { + // If we are forcing an immediate write then just notify the disk write thread and reset the + // pending write flag + if (force_immediate_write) { + snode_cache_cv.notify_one(); + has_pending_disk_write = false; + return; + } + + if (has_pending_disk_write) + return; + + has_pending_disk_write = true; + net.call_later(1s, [this]() { + snode_cache_cv.notify_one(); + has_pending_disk_write = false; + }); +} + +void Network::disk_write_thread_loop() { + std::unique_lock lock{snode_cache_mutex}; + while (true) { + snode_cache_cv.wait( + lock, [this] { return need_write || need_clear_cache || shut_down_disk_thread; }); + + if (need_write) { + // Make a local copy so that we can release the lock and not + // worry about other threads wanting to change things + auto snode_cache_write = snode_cache; + + lock.unlock(); + { + try { + // Create the cache directories if needed + fs::create_directories(cache_path); + + // If we are using testnet then create a file to indicate that + if (use_testnet) { + auto testnet_stub = cache_path / file_testnet; + write_whole_file(testnet_stub); + } + + // Save the snode pool to disk + auto pool_path = cache_path / file_snode_pool, pool_tmp = pool_path; + pool_tmp += u8"_new"; + + { + std::stringstream ss; + for (auto& snode : snode_cache_write) + ss << node_to_disk(snode) << '\n'; + + std::ofstream file(pool_tmp, std::ios::binary); + file << ss.rdbuf(); + } + + fs::rename(pool_tmp, pool_path); + need_write = false; + + log::debug(cat, "Finished writing snode cache to disk."); + } catch (const std::exception& e) { + log::error(cat, "Failed to write snode cache: {}", e.what()); + } + } + lock.lock(); + } + if (need_clear_cache) { + snode_cache = {}; + + lock.unlock(); + if (fs::exists(cache_path)) + fs::remove_all(cache_path); + lock.lock(); + need_clear_cache = false; + } + if (shut_down_disk_thread) + return; + } +} + +void Network::clear_cache() { + net.call([this]() mutable { + { + std::lock_guard lock{snode_cache_mutex}; + need_clear_cache = true; + } + update_disk_cache_throttled(true); + }); +} + +size_t Network::snode_cache_size() { + return net.call_get([this]() -> size_t { return snode_cache.size(); }); +} + +// MARK: Connection + +void Network::suspend() { + net.call([this]() mutable { + suspended = true; + close_connections(); + log::info(cat, "Suspended."); + }); +} + +void Network::resume() { + net.call([this]() mutable { + suspended = false; + log::info(cat, "Resumed."); + }); +} + +void Network::close_connections() { + net.call([this]() mutable { _close_connections(); }); +} + +void Network::_close_connections() { + // Explicitly reset the endpoint to close all connections + endpoint.reset(); + + // Cancel any pending requests (they can't succeed once the connection is closed) + for (const auto& [path_type, path_type_requests] : request_queue) + for (const auto& [info, callback] : path_type_requests) + callback( + false, + false, + error_network_suspended, + {content_type_plain_text}, + "Network is suspended."); + + // Clear all storage of requests, paths and connections so that we are in a fresh state on + // relaunch + request_queue.clear(); + paths.clear(); + path_build_queue.clear(); + paths_pending_drop.clear(); + unused_connections.clear(); + in_progress_connections.clear(); + snode_refresh_results.reset(); + current_snode_cache_refresh_request_id = std::nullopt; + + update_status(ConnectionStatus::disconnected); + log::info(cat, "Closed all connections."); +} + +void Network::update_status(ConnectionStatus updated_status) { + // Ignore updates which don't change the status + if (status == updated_status) + return; + + // If we are already 'connected' then ignore 'connecting' status changes (if we drop one path + // and build another in the background this can happen) + if (status == ConnectionStatus::connected && updated_status == ConnectionStatus::connecting) + return; + + // Store the updated status + status = updated_status; + + if (!status_changed) + return; + + status_changed(updated_status); +} + +std::chrono::milliseconds Network::retry_delay( + int num_failures, std::chrono::milliseconds max_delay) { + return std::chrono::milliseconds(std::min( + max_delay.count(), + static_cast(100 * std::pow(2, num_failures)))); +} + +std::shared_ptr Network::get_endpoint() { + return net.call_get([this]() mutable { + if (!endpoint) + endpoint = net.endpoint(quic::Address{"0.0.0.0", 0}, quic::opt::alpns{ALPN}); + + return endpoint; + }); +} + +// MARK: Request Queues and Path Building + +size_t Network::min_snode_cache_size() const { + if (!seed_node_cache_size) + return min_snode_cache_count; + + // If the seed node cache size is somehow smaller than `min_snode_cache_count` (ie. Testnet + // having issues) then the minimum size should be the full cache size (minus enough to build a + // path) or at least the size of a path + auto min_path_size = static_cast(path_size); + return std::min( + std::max(min_path_size, *seed_node_cache_size - min_path_size), min_snode_cache_count); +} + +std::vector Network::get_unused_nodes() { + if (snode_cache.size() < min_snode_cache_size()) + return {}; + + // Exclude any IPs that are already in use from existing paths + std::vector node_ips_to_exlude = all_path_ips(); + + // Exclude unused connections + for (const auto& conn_info : unused_connections) + node_ips_to_exlude.emplace_back(conn_info.node.to_ipv4()); + + // Exclude in progress connections + for (const auto& [request_id, node] : in_progress_connections) + node_ips_to_exlude.emplace_back(node.to_ipv4()); + + // Exclude pending requests + for (const auto& [path_type, path_type_requests] : request_queue) + for (const auto& [info, callback] : path_type_requests) + if (auto* dest = std::get_if(&info.destination)) + node_ips_to_exlude.emplace_back(dest->to_ipv4()); + + // Exclude any nodes which have surpassed the failure threshold + for (const auto& [node_string, failure_count] : snode_failure_counts) + if (failure_count >= snode_failure_threshold) { + size_t colon_pos = node_string.find(':'); + + if (colon_pos != std::string::npos) + node_ips_to_exlude.emplace_back(quic::ipv4{node_string.substr(0, colon_pos)}); + else + node_ips_to_exlude.emplace_back(quic::ipv4{node_string}); + } + + // Populate the unused nodes with any nodes in the cache which shouldn't be excluded + std::vector result; + + if (node_ips_to_exlude.empty()) + result = snode_cache; + else + std::copy_if( + snode_cache.begin(), + snode_cache.end(), + std::back_inserter(result), + [&node_ips_to_exlude](const auto& node) { + return std::find( + node_ips_to_exlude.begin(), + node_ips_to_exlude.end(), + node.to_ipv4()) == node_ips_to_exlude.end(); + }); + + // Shuffle the `result` so anything that uses it would get random nodes + std::shuffle(result.begin(), result.end(), csrng); + + return result; +} + +void Network::establish_connection( + std::string id, + service_node target, + std::optional timeout, + std::function error)> callback) { + log::trace(cat, "{} called for {}.", __PRETTY_FUNCTION__, id); + auto currently_suspended = net.call_get([this]() -> bool { return suspended; }); + + // If the network is currently suspended then don't try to open a connection + if (currently_suspended) + return callback( + {target, std::make_shared(0), nullptr, nullptr}, "Network is suspended."); + + auto conn_key_pair = ed25519::ed25519_key_pair(); + auto creds = quic::GNUTLSCreds::make_from_ed_seckey(from_unsigned_sv(conn_key_pair.second)); + auto cb_called = std::make_shared(); + auto cb = std::make_shared)>>( + std::move(callback)); + auto conn_promise = std::promise>(); + auto conn_future = conn_promise.get_future().share(); + auto handshake_timeout = + timeout ? std::optional{quic::opt::handshake_timeout{ + std::chrono::duration_cast(*timeout)}} + : std::nullopt; + + auto c = get_endpoint()->connect( + target, + creds, + quic::opt::keep_alive{10s}, + handshake_timeout, + [this, id, target, cb, cb_called, conn_future](quic::connection_interface&) mutable { + log::trace(cat, "Connection established for {}.", id); + + // Just in case, call it within a `net.call` + net.call([&] { + std::call_once(*cb_called, [&]() { + if (cb) { + auto conn = conn_future.get(); + (*cb)({target, + std::make_shared(0), + conn, + conn->open_stream()}, + std::nullopt); + cb.reset(); + } + }); + }); + }, + [this, target, id, cb, cb_called, conn_future]( + quic::connection_interface& conn, uint64_t error_code) mutable { + if (error_code == static_cast(NGTCP2_ERR_HANDSHAKE_TIMEOUT)) + log::info( + cat, + "Unable to establish connection to {} for {}.", + target.to_string(), + id); + else + log::info(cat, "Connection to {} closed for {}.", target.to_string(), id); + + // Just in case, call it within a `net.call` + net.call([&] { + // Trigger the callback first before updating the paths in case this was + // triggered when try to establish a connection + std::call_once(*cb_called, [&]() { + if (cb) { + (*cb)({target, std::make_shared(0), nullptr, nullptr}, + std::nullopt); + cb.reset(); + } + }); + + // If the Network instance has been `destroyed` (ie. it's destructor has been + // called) then don't do any of the following logic as it'll likely result in + // undefined behaviours and crashes + if (destroyed) + return; + + // Remove the connection from `unused_connection` if present + std::erase_if(unused_connections, [&conn, &target](auto& unused_conn) { + return (unused_conn.node == target && unused_conn.conn && + unused_conn.conn->reference_id() == conn.reference_id()); + }); + + // If this connection is being used in an existing path then we should drop it + // (as the path is no longer valid) + for (const auto& [path_type, paths_for_type] : paths) { + for (const auto& path : paths_for_type) { + if (!path.nodes.empty() && path.nodes.front() == target && + path.conn_info.conn && + conn.reference_id() == path.conn_info.conn->reference_id()) { + drop_path_when_empty(id, path_type, path); + break; + } + } + } + + // Since a connection was closed we should also clear any pending path drops + // in case this connection was one of those + clear_empty_pending_path_drops(); + + // If the connection failed with a handshake timeout then the node is + // unreachable, either due to a device network issue or because the node + // is down so set the failure count to the failure threshold so it won't + // be used for subsequent requests + if (error_code == static_cast(NGTCP2_ERR_HANDSHAKE_TIMEOUT)) + snode_failure_counts[target.to_string()] = snode_failure_threshold; + }); + }); + + conn_promise.set_value(c); +} + +void Network::establish_and_store_connection(std::string path_id) { + // If we are suspended then don't try to establish a new connection + if (suspended) + return; + + // If we haven't set a connection status yet then do so now + if (status == ConnectionStatus::unknown) + update_status(ConnectionStatus::connecting); + + // Re-populate the unused nodes if it ends up being empty + if (unused_nodes.empty()) + unused_nodes = get_unused_nodes(); + + // If there aren't enough unused nodes then trigger a cache refresh + if (unused_nodes.size() < min_snode_cache_size()) { + log::trace( + cat, + "Unable to establish new connection due to lack of unused nodes, refreshing snode " + "cache ({}).", + path_id); + return net.call_soon([this, path_id]() { refresh_snode_cache(path_id); }); + } + + // Otherwise check if it's been too long since the last cache update and, if so, trigger a + // refresh + auto cache_lifetime = std::chrono::duration_cast( + std::chrono::system_clock::now() - last_snode_cache_update); + + if (cache_lifetime < 0s || cache_lifetime > snode_cache_expiration_duration) + net.call_soon([this]() { refresh_snode_cache(); }); + + // If there are no in progress connections then reset the failure count + if (in_progress_connections.empty()) + connection_failures = 0; + + // Grab a node from the `unused_nodes` list to establish a connection to + auto target_node = unused_nodes.back(); + unused_nodes.pop_back(); + + // Try to establish a new connection to the target (this has a 3s handshake timeout as we + // wouldn't want to use any nodes which take longer than that anyway) + log::info(cat, "Establishing connection to {} for {}.", target_node.to_string(), path_id); + in_progress_connections.emplace(path_id, target_node); + + establish_connection( + path_id, + target_node, + 3s, + [this, target_node, path_id](connection_info info, std::optional) { + // If we failed to get a connection then try again after a delay (may as well try + // indefinitely because there is no way to recover from this issue) + if (!info.is_valid()) { + connection_failures++; + auto connection_retry_delay = retry_delay(connection_failures); + log::error( + cat, + "Failed to connect to {}, will try another after {}ms.", + target_node.to_string(), + connection_retry_delay.count()); + return net.call_later(connection_retry_delay, [this, path_id]() { + establish_and_store_connection(path_id); + }); + } + + // We were able to connect to the node so add it to the unused_connections queue + log::info(cat, "Connection to {} valid for {}.", target_node.to_string(), path_id); + unused_connections.emplace_back(info); + + // Kick off the next pending path build since we now have a valid connection + if (!path_build_queue.empty()) { + in_progress_path_builds[path_id] = path_build_queue.front(); + net.call_soon([this, path_type = path_build_queue.front(), path_id]() { + build_path(path_id, path_type); + }); + path_build_queue.pop_front(); + } + + // If there are still pending path builds but no in progress connections then kick + // off enough additional connections for remaining builds (this shouldn't happen but + // better to be safe and avoid a situation where a path build gets orphaned) + if (!path_build_queue.empty() && in_progress_connections.empty()) + for ([[maybe_unused]] const auto& _ : path_build_queue) + net.call_soon([this]() { + auto conn_id = "EC-{}"_format(random::random_base32(4)); + establish_and_store_connection(conn_id); + }); + }); +} + +void Network::refresh_snode_cache_complete(std::vector nodes) { + // Shuffle the nodes so we don't have a specific order + std::shuffle(nodes.begin(), nodes.end(), csrng); + + // Update the disk cache if the snode pool was updated + { + std::lock_guard lock{snode_cache_mutex}; + snode_cache = nodes; + last_snode_cache_update = std::chrono::system_clock::now(); + need_write = true; + } + update_disk_cache_throttled(); + + // Reset the cache refresh state + current_snode_cache_refresh_request_id = std::nullopt; + snode_cache_refresh_failure_count = 0; + in_progress_snode_cache_refresh_count = 0; + unused_snode_refresh_nodes = std::nullopt; + snode_refresh_results.reset(); + + // Reset the snode failure counts (assume if the snode refresh includes + // nodes then they are valid) + snode_failure_counts.clear(); + + // Since we've updated the snode cache the swarm cache could be invalid + // so we need to regenerate it (the resulting `all_swarms` needs to be + // stored in ascending order as it is required for the logic to find the + // appropriate swarm for a given pubkey) + all_swarms.clear(); + swarm_cache.clear(); + all_swarms = detail::generate_swarms(nodes); + + // Run any post-refresh processes + for (const auto& callback : after_snode_cache_refresh) + net.call_soon([cb = std::move(callback)]() { cb(); }); + after_snode_cache_refresh.clear(); + + // Resume any queued path builds + for (const auto& path_type : path_build_queue) { + auto path_id = "P-{}"_format(random::random_base32(4)); + in_progress_path_builds[path_id] = path_type; + net.call_soon([this, path_type, path_id]() { build_path(path_id, path_type); }); + } + path_build_queue.clear(); +} + +void Network::refresh_snode_cache_from_seed_nodes(std::string request_id, bool reset_unused_nodes) { + if (suspended) { + log::info(cat, "Ignoring snode cache refresh as network is suspended ({}).", request_id); + return; + } + + // Only allow a single cache refresh at a time (this gets cleared in `_close_connections` so if + // it happens to loop after going to, and returning from, the background a subsequent refresh + // won't be blocked) + if (current_snode_cache_refresh_request_id && + current_snode_cache_refresh_request_id != request_id) { + log::info( + cat, + "Snode cache refresh from seed node {} ignored as it doesn't match the current " + "refresh id ({}).", + request_id, + current_snode_cache_refresh_request_id.value_or("NULL")); + return; + } + + // If the unused nodes is empty then reset it (if we are refreshing from seed nodes it means the + // local cache is not usable so we are just going to have to call this endlessly until it works) + if (reset_unused_nodes || !unused_snode_refresh_nodes || unused_snode_refresh_nodes->empty()) { + log::info( + cat, + "Existing cache is insufficient, refreshing from seed nodes ({}).", + request_id); + + // Shuffle to ensure we pick random nodes to fetch from + unused_snode_refresh_nodes = (use_testnet ? seed_nodes_testnet : seed_nodes_mainnet); + std::shuffle(unused_snode_refresh_nodes->begin(), unused_snode_refresh_nodes->end(), csrng); + } + + auto target_node = unused_snode_refresh_nodes->back(); + unused_snode_refresh_nodes->pop_back(); + + establish_connection( + request_id, + target_node, + 3s, + [this, request_id](connection_info info, std::optional) { + // If we failed to get a connection then try again after a delay (may as well try + // indefinitely because there is no way to recover from this issue) + if (!info.is_valid()) { + snode_cache_refresh_failure_count++; + auto cache_refresh_retry_delay = retry_delay(snode_cache_refresh_failure_count); + log::error( + cat, + "Failed to connect to seed node to refresh snode cache, will retry " + "after {}ms ({}).", + cache_refresh_retry_delay.count(), + request_id); + return net.call_later(cache_refresh_retry_delay, [this, request_id]() { + refresh_snode_cache_from_seed_nodes(request_id, false); + }); + } + + get_service_nodes( + request_id, + info, + std::nullopt, + [this, request_id]( + std::vector nodes, std::optional error) { + // If we got no nodes then we will need to try again + if (nodes.empty()) { + snode_cache_refresh_failure_count++; + auto cache_refresh_retry_delay = + retry_delay(snode_cache_refresh_failure_count); + log::error( + cat, + "Failed to retrieve nodes from seed node to refresh cache " + "due to error: {}, will retry after {}ms ({}).", + error.value_or("Unknown Error"), + cache_refresh_retry_delay.count(), + request_id); + return net.call_later( + cache_refresh_retry_delay, [this, request_id]() { + refresh_snode_cache_from_seed_nodes(request_id, false); + }); + } + + log::info( + cat, + "Refreshing snode cache from seed nodes completed with {} " + "nodes ({}).", + nodes.size(), + request_id); + seed_node_cache_size = nodes.size(); + refresh_snode_cache_complete(nodes); + }); + }); +} + +void Network::refresh_snode_cache(std::optional existing_request_id) { + auto request_id = existing_request_id.value_or("RSC-{}"_format(random::random_base32(4))); + + if (suspended) { + log::info(cat, "Ignoring snode cache refresh as network is suspended ({}).", request_id); + return; + } + + // Only allow a single cache refresh at a time (this gets cleared in `_close_connections` so if + // it happens to loop after going to, and returning from, the background a subsequent refresh + // won't be blocked) + if (current_snode_cache_refresh_request_id && + current_snode_cache_refresh_request_id != request_id) { + log::info( + cat, + "Snode cache refresh {} ignored due to in progress refresh ({}).", + request_id, + current_snode_cache_refresh_request_id.value_or("NULL")); + return; + } + + // We are starting a new cache refresh so store an identifier for it (we also initialise + // `snode_refresh_results` so we can use it to track the results from the different requests) + if (!current_snode_cache_refresh_request_id) { + log::info(cat, "Refreshing snode cache ({}).", request_id); + current_snode_cache_refresh_request_id = request_id; + snode_refresh_results = std::make_shared>>(); + } + + // If we don't have enough nodes in the unused nodes then refresh it + if (unused_nodes.size() < min_snode_cache_size()) + unused_nodes = get_unused_nodes(); + + // If we still don't have enough nodes in the unused nodes it likely means we didn't + // have enough nodes in the cache so instead just fetch from the seed nodes (which is + // a trusted source so we can update the cache from a single response) + if (unused_nodes.size() < min_snode_cache_size()) + return refresh_snode_cache_from_seed_nodes(request_id, true); + + // Target an unused node and increment the in progress refresh counter + auto target_node = unused_nodes.back(); + unused_nodes.pop_back(); + in_progress_snode_cache_refresh_count++; + + // If there are still more concurrent refresh_snode_cache requests we want to trigger then + // trigger the next one to run in the next run loop + if (in_progress_snode_cache_refresh_count < num_snodes_to_refresh_cache_from) + net.call_soon([this, request_id]() { refresh_snode_cache(request_id); }); + + // Prepare and send the request to retrieve service nodes + nlohmann::json payload{ + {"method", "oxend_request"}, + {"params", + {{"endpoint", "get_service_nodes"}, + {"params", detail::get_service_nodes_params(std::nullopt)}}}, + }; + auto info = request_info::make( + target_node, + ustring{quic::to_usv(payload.dump())}, + std::nullopt, + quic::DEFAULT_TIMEOUT, + std::nullopt, + PathType::standard, + request_id); + _send_onion_request( + info, + [this, request_id]( + bool success, + bool timeout, + int16_t, + std::vector>, + std::optional response) { + // If the 'snode_refresh_results' value doesn't exist it means we have already + // completed/cancelled this snode cache refresh and have somehow gotten into an + // invalid state, so just ignore this request + if (!snode_refresh_results) { + log::warning( + cat, + "Ignoring snode cache response after cache update already completed " + "({}).", + request_id); + return; + } + + try { + if (!success || timeout || !response) + throw std::runtime_error{response.value_or("Unknown error.")}; + + nlohmann::json response_json = nlohmann::json::parse(*response); + std::vector result = + detail::process_get_service_nodes_response(response_json); + snode_refresh_results->emplace_back(result); + + // Update the in progress request count + in_progress_snode_cache_refresh_count--; + } catch (const std::exception& e) { + // The request failed so increment the failure counter and retry after a short + // delay + snode_cache_refresh_failure_count++; + + auto cache_refresh_retry_delay = retry_delay(snode_cache_refresh_failure_count); + log::error( + cat, + "Failed to retrieve nodes from one target when refreshing cache due to " + "error: {} Will try another target after {}ms ({}).", + e.what(), + cache_refresh_retry_delay.count(), + request_id); + return net.call_later(cache_refresh_retry_delay, [this, request_id]() { + refresh_snode_cache(request_id); + }); + } + + // If we haven't received all results then do nothing + if (snode_refresh_results->size() != num_snodes_to_refresh_cache_from) { + log::info( + cat, + "Received snode cache refresh result {}/{} ({}).", + snode_refresh_results->size(), + num_snodes_to_refresh_cache_from, + request_id); + return; + } + + auto any_nodes_request_failed = std::any_of( + snode_refresh_results->begin(), + snode_refresh_results->end(), + [](const auto& n) { return n.empty(); }); + + // If the current cache is still usable just send a warning and don't bother + // retrying + if (any_nodes_request_failed) { + log::warning(cat, "Failed to refresh snode cache ({}).", request_id); + current_snode_cache_refresh_request_id = std::nullopt; + snode_cache_refresh_failure_count = 0; + in_progress_snode_cache_refresh_count = 0; + snode_refresh_results.reset(); + return; + } + + // Sort the vectors (so make it easier to find the intersection) + for (auto& nodes : *snode_refresh_results) + std::stable_sort(nodes.begin(), nodes.end()); + + auto nodes = (*snode_refresh_results)[0]; + + // If we triggered multiple requests then get the intersection of all vectors + if (snode_refresh_results->size() > 1) { + for (size_t i = 1; i < snode_refresh_results->size(); ++i) { + std::vector temp; + std::set_intersection( + nodes.begin(), + nodes.end(), + (*snode_refresh_results)[i].begin(), + (*snode_refresh_results)[i].end(), + std::back_inserter(temp), + [](const auto& a, const auto& b) { return a == b; }); + nodes = std::move(temp); + } + } + + log::info( + cat, + "Refreshing snode cache completed with {} nodes ({}).", + nodes.size(), + request_id); + refresh_snode_cache_complete(nodes); + }); +} + +void Network::build_path(std::string path_id, PathType path_type) { + if (suspended) { + log::info(cat, "Ignoring build_path call as network is suspended."); + return; + } + + auto path_name = path_type_name(path_type, single_path_mode); + + // If we don't have an unused connection for the first hop then enqueue the path build and + // establish a new connection + if (unused_connections.empty()) { + log::info( + cat, + "No unused connections available to build {} path, creating new connection for {}.", + path_name, + path_id); + path_build_queue.emplace_back(path_type); + in_progress_path_builds.erase(path_id); + return net.call_soon([this, path_id]() { establish_and_store_connection(path_id); }); + } + + // Reset the unused nodes list if it's too small + if (unused_nodes.size() < path_size) + unused_nodes = get_unused_nodes(); + + // If we still don't have enough unused nodes then we need to refresh the cache + if (unused_nodes.size() < path_size) { + log::info( + cat, "Re-queing {} path build due to insufficient nodes ({}).", path_name, path_id); + path_build_failures = 0; + path_build_queue.emplace_back(path_type); + in_progress_path_builds.erase(path_id); + return net.call_soon([this]() { refresh_snode_cache(); }); + } + + // Build the path + log::info(cat, "Building {} path ({}).", path_name, path_id); + in_progress_path_builds[path_id] = path_type; + + auto conn_info = std::move(unused_connections.front()); + unused_connections.pop_front(); + std::vector path_nodes = {conn_info.node}; + + while (path_nodes.size() < path_size) { + if (unused_nodes.empty()) { + // Log the error and try build again after a slight delay + log::info( + cat, + "Unable to build {} path due to lack of suitable unused nodes ({}).", + path_name, + path_id); + + // Delay the next path build attempt based on the error we received + path_build_failures++; + unused_connections.push_front(std::move(conn_info)); + auto delay = retry_delay(path_build_failures); + net.call_later(delay, [this, path_id, path_type]() { build_path(path_id, path_type); }); + return; + } + + // Grab the next unused node to continue building the path + auto node = unused_nodes.back(); + unused_nodes.pop_back(); + + // Ensure we don't put two nodes with the same IP into the same path + auto snode_with_ip_it = std::find_if( + path_nodes.begin(), path_nodes.end(), [&node](const auto& existing_node) { + return existing_node.to_ipv4() == node.to_ipv4(); + }); + + if (snode_with_ip_it == path_nodes.end()) + path_nodes.push_back(node); + } + + // Store the new path + auto path = onion_path{path_id, std::move(conn_info), path_nodes, 0}; + paths[path_type].emplace_back(path); + in_progress_path_builds.erase(path_id); + + // Log that a path was built + log::info( + cat, + "Built new onion request path [{}], now have {} {} path(s) ({}).", + path.to_string(), + paths[path_type].size(), + path_name, + path_id); + + // If the connection info is valid and it's a standard path then update the + // connection status to connected + if (path_type == PathType::standard) { + update_status(ConnectionStatus::connected); + + // If a paths_changed callback was provided then call it + if (paths_changed) { + std::vector> raw_paths; + for (const auto& path : paths[path_type]) + raw_paths.emplace_back(path.nodes); + + paths_changed(raw_paths); + } + } + + // Remove the nodes from unused_nodes which have the same IPs as nodes in + // the final path + std::vector path_ips; + for (const auto& node : path_nodes) + path_ips.emplace_back(node.to_ipv4()); + + std::erase_if(unused_nodes, [&path_ips](const auto& node) { + return std::find(path_ips.begin(), path_ips.end(), node.to_ipv4()) != path_ips.end(); + }); + + // If there are pending requests which this path is valid for then resume them + std::erase_if(request_queue[path_type], [this, &path](const auto& request) { + if (!find_valid_path(request.first, {path})) + return false; + + net.call_soon([this, info = request.first, cb = std::move(request.second)]() { + _send_onion_request(std::move(info), std::move(cb)); + }); + return true; + }); + + // If there are still pending requests and there are no pending path builds for them then kick + // off a subsequent path build in an effort to resume the remaining requests + if (!request_queue[path_type].empty()) { + auto additional_path_id = "P-{}"_format(random::random_base32(4)); + in_progress_path_builds[additional_path_id] = path_type; + net.call_soon([this, path_type, additional_path_id] { + build_path(additional_path_id, path_type); + }); + } else + request_queue.erase(path_type); +} + +std::optional Network::find_valid_path( + const request_info info, const std::vector paths) { + if (paths.empty()) + return std::nullopt; + + // Only include paths with valid connections as options + std::vector possible_paths; + std::copy_if( + paths.begin(), paths.end(), std::back_inserter(possible_paths), [&](const auto& path) { + return path.is_valid(); + }); + + // If the request destination is a node then only select a path that doesn't include the IP of + // the destination + if (auto target = detail::node_for_destination(info.destination)) { + std::vector ip_excluded_paths; + std::copy_if( + possible_paths.begin(), + possible_paths.end(), + std::back_inserter(ip_excluded_paths), + [&](const onion_path& p) { return not p.contains_node(*target); }); + + if (single_path_mode && ip_excluded_paths.empty()) + log::warning( + cat, + "Path should have been excluded due to matching IP for {} but network is in " + "single path mode.", + info.request_id); + else + possible_paths = ip_excluded_paths; + } + + if (possible_paths.empty()) + return std::nullopt; + + // Randomise the possible paths (if all paths are equal for the PathSelectionBehaviour then we + // want a random one to be selected) + std::shuffle(possible_paths.begin(), possible_paths.end(), csrng); + + // Select from the possible paths based on the desired behaviour + auto behaviour = path_selection_behaviour(info.path_type); + switch (behaviour) { + case PathSelectionBehaviour::new_or_least_busy: { + auto min_num_paths = min_path_count(info.path_type, single_path_mode); + std::sort( + possible_paths.begin(), possible_paths.end(), [](const auto& a, const auto& b) { + return a.num_pending_requests() < b.num_pending_requests(); + }); + + // If we have already have the min number of paths for this path type, or there is + // a path with no pending requests then return the first path + if (paths.size() >= min_num_paths || possible_paths.front().num_pending_requests() == 0) + return possible_paths.front(); + + // Otherwise we want to build a new path (for this PathSelectionBehaviour the assuption + // is that it'd be faster to build a new path and send the request along that rather + // than use an existing path) + return std::nullopt; + } + + // Random is the default behaviour + case PathSelectionBehaviour::random: return possible_paths.front(); + default: return possible_paths.front(); + } +}; + +void Network::build_path_if_needed(PathType path_type, bool found_path) { + const auto current_paths = paths[path_type]; + + // In `single_path_mode` we never build additional paths + if (current_paths.size() > 0 && single_path_mode) + return; + + // We only want to enqueue a new path build if: + // - We don't have the minimum number of paths for the specified type + // - We don't have any pending builds + // - The current paths are unsuitable for the request + auto min_paths = min_path_count(path_type, single_path_mode); + + // If we have enough existing paths and found a valid path then no need to build more paths + if (found_path && current_paths.size() >= min_paths) + return; + + // Get the number pending paths + auto queued = std::count(path_build_queue.begin(), path_build_queue.end(), path_type); + auto in_progress = std::count_if( + in_progress_path_builds.begin(), + in_progress_path_builds.end(), + [&path_type](const auto& build) { return build.second == path_type; }); + auto pending_paths = (queued + in_progress); + + // If we don't have enough current + pending paths, or the request couldn't be sent then + // kick off a new path build + if ((current_paths.size() + pending_paths) < min_paths || (!found_path && pending_paths == 0)) { + auto path_id = "P-{}"_format(random::random_base32(4)); + build_path(path_id, path_type); + } +} + +// MARK: Direct Requests + +void Network::get_service_nodes( + std::string request_id, + connection_info conn_info, + std::optional limit, + std::function nodes, std::optional error)> + callback) { + log::trace(cat, "{} called for {}.", __PRETTY_FUNCTION__, request_id); + + if (!conn_info.is_valid()) + return callback({}, "Connection is not valid."); + + oxenc::bt_dict_producer payload; + payload.append("endpoint", "get_service_nodes"); + payload.append("params", detail::get_service_nodes_params(limit).dump()); + + conn_info.add_pending_request(); + conn_info.stream->command( + "oxend_request", + payload.view(), + [this, request_id, conn_info, cb = std::move(callback)](quic::message resp) { + log::trace(cat, "{} got response for {}.", __PRETTY_FUNCTION__, request_id); + std::vector result; + conn_info.remove_pending_request(); + + try { + auto [status_code, body] = validate_response(resp, true); + oxenc::bt_list_consumer result_bencode{body}; + result = detail::process_get_service_nodes_response(result_bencode); + } catch (const std::exception& e) { + return cb({}, e.what()); + } + + // Output the result + cb(result, std::nullopt); + + // After completing a request we should try to clear any pending path drops (just in + // case this request was the final one on a pending path drop) + if (!conn_info.has_pending_requests()) + clear_empty_pending_path_drops(); + }); +} + +// MARK: Swarm Management + +void Network::get_swarm( + session::onionreq::x25519_pubkey swarm_pubkey, + std::function swarm)> callback) { + log::trace(cat, "{} called for {}.", __PRETTY_FUNCTION__, swarm_pubkey.hex()); + + net.call([this, swarm_pubkey, cb = std::move(callback)]() { + // If we have a cached swarm then return it + auto cached_swarm = swarm_cache[swarm_pubkey.hex()]; + if (!cached_swarm.second.empty()) + return cb(cached_swarm.first, cached_swarm.second); + + // If we have no snode cache or no swarms then we need to rebuild the cache (which will also + // rebuild the swarms) and run this request again + if (snode_cache.empty() || all_swarms.empty()) { + after_snode_cache_refresh.emplace_back([this, swarm_pubkey, cb = std::move(cb)]() { + get_swarm(swarm_pubkey, std::move(cb)); + }); + return net.call_soon([this]() { refresh_snode_cache(); }); + } + + // If there is only a single swarm then return it + if (all_swarms.size() == 1) + return cb(all_swarms.front().first, all_swarms.front().second); + + // Generate a swarm_id for the pubkey + const swarm_id_t swarm_id = detail::pubkey_to_swarm_space(swarm_pubkey); + + // Find the right boundary, i.e. first swarm with swarm_id >= res + auto right_it = std::lower_bound( + all_swarms.begin(), all_swarms.end(), swarm_id, [](const auto& s, uint64_t v) { + return s.first < v; + }); + + if (right_it == all_swarms.end()) + // res is > the top swarm_id, meaning it is big and in the wrapping space between last + // and first elements. + right_it = all_swarms.begin(); + + // Our "left" is the one just before that (with wraparound, if right is the first swarm) + auto left_it = std::prev(right_it == all_swarms.begin() ? all_swarms.end() : right_it); + + uint64_t dright = right_it->first - swarm_id; + uint64_t dleft = swarm_id - left_it->first; + auto swarm = &*(dright < dleft ? right_it : left_it); + + // Update the cache with the result + log::info( + cat, + "Found swarm with {} nodes for {}, adding to cache.", + swarm->second.size(), + swarm_pubkey.hex()); + swarm_cache[swarm_pubkey.hex()] = *swarm; + cb(swarm->first, swarm->second); + }); +} + +// MARK: Node Retrieval + +void Network::get_random_nodes( + uint16_t count, std::function nodes)> callback) { + auto request_id = "R-{}"_format(random::random_base32(4)); + log::trace(cat, "{} called for {}.", __PRETTY_FUNCTION__, request_id); + + net.call([this, request_id, count, cb = std::move(callback)]() mutable { + // If we don't have sufficient unused nodes then regenerate it + if (unused_nodes.size() < count) + unused_nodes = get_unused_nodes(); + + // If we still don't have sufficient nodes then we need to refresh the snode cache + if (unused_nodes.size() < count) { + after_snode_cache_refresh.emplace_back( + [this, count, cb = std::move(cb)]() { get_random_nodes(count, cb); }); + return net.call_soon([this]() { refresh_snode_cache(); }); + } + + // Otherwise callback with the requested random number of nodes + auto random_nodes = + std::vector(unused_nodes.begin(), unused_nodes.begin() + count); + unused_nodes.erase(unused_nodes.begin(), unused_nodes.begin() + count); + cb(random_nodes); + }); +} + +// MARK: Request Handling + +void Network::check_request_queue_timeouts(std::optional request_timeout_id_) { + // If the network is suspended (or destroyed) then don't bother checking for timeouts + if (suspended || destroyed) + return; + + // If there is an existing timeout checking loop then we don't want to start a second + if (request_timeout_id != request_timeout_id_) + return; + + // If there wasn't an existing loop id then set it here + if (!request_timeout_id) + request_timeout_id = "RT-{}"_format(random::random_base32(4)); + + // Timeout and remove any pending requests which should timeout based on path build time + auto has_remaining_timeout_requests = false; + auto time_now = std::chrono::system_clock::now(); + + for (auto& [path_type, requests_for_path] : request_queue) + std::erase_if( + requests_for_path, + [&has_remaining_timeout_requests, &time_now](const auto& request) { + // If the request doesn't have a path build timeout then ignore it + if (!request.first.request_and_path_build_timeout) + return false; + + auto duration = std::chrono::duration_cast( + time_now - request.first.creation_time); + + if (duration > *request.first.request_and_path_build_timeout) { + request.second( + false, + true, + error_path_build_timeout, + {content_type_plain_text}, + "Timed out waiting for path build."); + return true; + } + + has_remaining_timeout_requests = true; + return false; + }); + + // If there are no more timeout requests then stop looping here + if (!has_remaining_timeout_requests) { + request_timeout_id = std::nullopt; + return; + } + + // Otherwise schedule the next check + net.call_later(queued_request_path_build_timeout_frequency, [this]() { + check_request_queue_timeouts(request_timeout_id); + }); +} + +void Network::send_request( + request_info info, connection_info conn_info, network_response_callback_t handle_response) { + log::trace(cat, "{} called for {}.", __PRETTY_FUNCTION__, info.request_id); + + if (!conn_info.is_valid()) + return handle_response( + false, false, -1, {content_type_plain_text}, "Network is unreachable."); + + quic::bstring_view payload{}; + + if (info.body) + payload = convert_sv(*info.body); + + // Calculate the remaining timeout + std::chrono::milliseconds timeout = info.request_timeout; + + if (info.request_and_path_build_timeout) { + auto elapsed_time = std::chrono::duration_cast( + std::chrono::system_clock::now() - info.creation_time); + + timeout = *info.request_and_path_build_timeout - elapsed_time; + + // If the timeout was somehow negative then just fail the request (no point continuing if + // we have already timed out) + if (timeout < std::chrono::milliseconds(0)) + return handle_response( + false, + true, + error_path_build_timeout, + {content_type_plain_text}, + "Path Build Timed Out."); + } + + conn_info.add_pending_request(); + conn_info.stream->command( + info.endpoint, + payload, + timeout, + [this, info, conn_info, cb = std::move(handle_response)](quic::message resp) { + log::trace(cat, "{} got response for {}.", __PRETTY_FUNCTION__, info.request_id); + + std::pair result; + auto& [status_code, body] = result; + conn_info.remove_pending_request(); + + try { + result = validate_response(resp, false); + } catch (const status_code_exception& e) { + return handle_errors( + info, + conn_info, + resp.timed_out, + e.status_code, + e.headers, + e.what(), + cb); + } catch (const std::exception& e) { + return handle_errors( + info, + conn_info, + resp.timed_out, + -1, + {content_type_plain_text}, + e.what(), + cb); + } + + cb(true, false, status_code, {}, body); + + // After completing a request we should try to clear any pending path drops (just in + // case this request was the final one on a pending path drop) + if (!conn_info.has_pending_requests()) + clear_empty_pending_path_drops(); + }); +} + +void Network::send_onion_request( + onionreq::network_destination destination, + std::optional body, + std::optional swarm_pubkey, + network_response_callback_t handle_response, + std::chrono::milliseconds request_timeout, + std::optional request_and_path_build_timeout, + PathType type) { + _send_onion_request( + request_info::make( + std::move(destination), + std::move(body), + std::move(swarm_pubkey), + request_timeout, + request_and_path_build_timeout, + type), + std::move(handle_response)); +} + +void Network::_send_onion_request(request_info info, network_response_callback_t handle_response) { + auto path_name = path_type_name(info.path_type, single_path_mode); + log::trace(cat, "{} called for {} path ({}).", __PRETTY_FUNCTION__, path_name, info.request_id); + + // Try to retrieve a valid path for this request, if we can't get one then add the request to + // the queue to be run once a path for it has successfully been built + auto path = net.call_get([this, info]() { + auto result = find_valid_path(info, paths[info.path_type]); + net.call_soon([this, path_type = info.path_type, found_path = result.has_value()]() { + build_path_if_needed(path_type, found_path); + }); + return result; + }); + + if (!path) { + return net.call([this, info = std::move(info), cb = std::move(handle_response)]() { + // If the network is suspended then fail immediately + if (suspended) + return cb( + false, + false, + error_network_suspended, + {content_type_plain_text}, + "Network is suspended."); + + request_queue[info.path_type].emplace_back(std::move(info), std::move(cb)); + + // If the request has a path_build_timeout then start the timeout check loop + if (info.request_and_path_build_timeout) + net.call_later(queued_request_path_build_timeout_frequency, [this]() { + check_request_queue_timeouts(); + }); + }); + } + + log::trace(cat, "{} got {} path for {}.", __PRETTY_FUNCTION__, path_name, info.request_id); + + // Construct the onion request + auto builder = Builder::make(info.destination, path->nodes); + try { + builder.generate(info); + } catch (const std::exception& e) { + log::warning(cat, "Builder exception: {}", e.what()); + return handle_response( + false, false, error_building_onion_request, {content_type_plain_text}, e.what()); + } + + // Actually send the request + send_request( + info, + path->conn_info, + [this, + builder = std::move(builder), + info, + path = *path, + cb = std::move(handle_response)]( + bool success, + bool timeout, + int16_t status_code, + std::vector> headers, + std::optional response) { + log::trace(cat, "{} got response for {}.", __PRETTY_FUNCTION__, info.request_id); + + // If the request was reported as a failure or a timeout then we + // will have already handled the errors so just trigger the callback + if (!success || timeout) + return cb(success, timeout, status_code, headers, response); + + try { + // Ensure the response is long enough to be processed, if not + // then handle it as an error + if (!ResponseParser::response_long_enough(builder.enc_type, response->size())) + throw status_code_exception{ + status_code, + {content_type_plain_text}, + "Response is too short to be an onion request response: " + + *response}; + + // Otherwise, process the onion request response + std::tuple< + int16_t, + std::vector>, + std::optional> + processed_response; + + // The SnodeDestination runs via V3 onion requests and the + // ServerDestination runs via V4 + if (std::holds_alternative(info.destination)) + processed_response = process_v3_onion_response(builder, *response); + else if (std::holds_alternative(info.destination)) + processed_response = process_v4_onion_response(builder, *response); + + // If we got a non 2xx status code, return the error + auto& [processed_status_code, processed_headers, processed_body] = + processed_response; + if (processed_status_code < 200 || processed_status_code > 299) + throw status_code_exception{ + processed_status_code, + {content_type_plain_text}, + processed_body.value_or("Request returned " + "non-success status " + "code.")}; + + // For debugging purposes we want to add a log if this was a successful request + // after we did an automatic retry + detail::log_retry_result_if_needed(info, single_path_mode); + + // Try process the body in case it was a batch request which + // failed + std::optional results; + if (processed_body) { + try { + auto processed_body_json = nlohmann::json::parse(*processed_body); + + // If it wasn't a batch/sequence request then assume it + // was successful and return no error + if (processed_body_json.contains("results")) + results = processed_body_json["results"]; + } catch (...) { + } + } + + // If there was no 'results' array then it wasn't a batch + // request so we can stop here and return + if (!results) + return cb( + true, + false, + processed_status_code, + processed_headers, + processed_body); + + // Otherwise we want to check if all of the results have the + // same status code and, if so, handle that failure case + // (default the 'error_body' to the 'processed_body' in case we + // don't get an explicit error) + int16_t single_status_code = -1; + std::vector> single_headers = { + content_type_plain_text}; + std::optional error_body = processed_body; + for (const auto& result : results->items()) { + if (result.value().contains("code") && result.value()["code"].is_number() && + (single_status_code == -1 || + result.value()["code"].get() != single_status_code)) + single_status_code = result.value()["code"].get(); + else { + // Either there was no code, or the code was different + // from a former code in which case there wasn't an + // individual detectable error (ie. it needs specific + // handling) so return no error + single_status_code = 200; + break; + } + + if (result.value().contains("headers")) { + single_headers = {}; + auto header_vals = result.value()["headers"]; + + for (auto it = header_vals.begin(); it != header_vals.end(); ++it) + single_headers.emplace_back(it.key(), it.value()); + } + + if (result.value().contains("body") && result.value()["body"].is_string()) + error_body = result.value()["body"].get(); + } + + // If all results contained the same error then handle it as a + // single error + if (single_status_code < 200 || single_status_code > 299) + throw status_code_exception{ + single_status_code, + single_headers, + error_body.value_or("Sub-request returned " + "non-success status code.")}; + + // Otherwise some requests succeeded and others failed so + // succeed with the processed data + return cb( + true, false, processed_status_code, processed_headers, processed_body); + } catch (const status_code_exception& e) { + handle_errors( + info, path.conn_info, false, e.status_code, e.headers, e.what(), cb); + } catch (const std::exception& e) { + handle_errors( + info, + path.conn_info, + false, + -1, + {content_type_plain_text}, + e.what(), + cb); + } + }); +} + +void Network::upload_file_to_server( + ustring data, + onionreq::ServerDestination server, + std::optional file_name, + network_response_callback_t handle_response, + std::chrono::milliseconds request_timeout, + std::optional request_and_path_build_timeout) { + std::vector> headers; + std::unordered_set existing_keys; + + if (server.headers) + for (auto& [key, value] : *server.headers) { + headers.emplace_back(key, value); + existing_keys.insert(key); + } + + // Add the required headers if they weren't provided + if (existing_keys.find("Content-Disposition") == existing_keys.end()) + headers.emplace_back( + "Content-Disposition", + (file_name ? "attachment; filename=\"{}\""_format(*file_name) : "attachment")); + + if (existing_keys.find("Content-Type") == existing_keys.end()) + headers.emplace_back("Content-Type", "application/octet-stream"); + + send_onion_request( + ServerDestination{ + server.protocol, + server.host, + server.endpoint, + server.x25519_pubkey, + server.port, + headers, + server.method}, + data, + std::nullopt, + handle_response, + request_timeout, + request_and_path_build_timeout, + PathType::upload); +} + +void Network::download_file( + std::string_view download_url, + session::onionreq::x25519_pubkey x25519_pubkey, + network_response_callback_t handle_response, + std::chrono::milliseconds request_timeout, + std::optional request_and_path_build_timeout) { + const auto& [proto, host, port, path] = parse_url(download_url); + + if (!path) + throw std::invalid_argument{"Invalid URL provided: Missing path"}; + + download_file( + ServerDestination{proto, host, *path, x25519_pubkey, port, std::nullopt, "GET"}, + handle_response, + request_timeout, + request_and_path_build_timeout); +} + +void Network::download_file( + onionreq::ServerDestination server, + network_response_callback_t handle_response, + std::chrono::milliseconds request_timeout, + std::optional request_and_path_build_timeout) { + send_onion_request( + server, + std::nullopt, + std::nullopt, + handle_response, + request_timeout, + request_and_path_build_timeout, + PathType::download); +} + +void Network::get_client_version( + Platform platform, + onionreq::ed25519_seckey seckey, + network_response_callback_t handle_response, + std::chrono::milliseconds request_timeout, + std::optional request_and_path_build_timeout) { + std::string endpoint; + + switch (platform) { + case Platform::android: endpoint = "/session_version?platform=android"; break; + case Platform::desktop: endpoint = "/session_version?platform=desktop"; break; + case Platform::ios: endpoint = "/session_version?platform=ios"; break; + } + + // Generate the auth signature + auto blinded_keys = blind_version_key_pair(to_unsigned_sv(seckey.view())); + auto timestamp = std::chrono::duration_cast( + (std::chrono::system_clock::now()).time_since_epoch()) + .count(); + auto signature = blind_version_sign(to_unsigned_sv(seckey.view()), platform, timestamp); + auto pubkey = x25519_pubkey::from_hex(file_server_pubkey); + std::string blinded_pk_hex; + blinded_pk_hex.reserve(66); + blinded_pk_hex += "07"; + oxenc::to_hex( + blinded_keys.first.begin(), + blinded_keys.first.end(), + std::back_inserter(blinded_pk_hex)); + + auto headers = std::vector>{}; + headers.emplace_back("X-FS-Pubkey", blinded_pk_hex); + headers.emplace_back("X-FS-Timestamp", "{}"_format(timestamp)); + headers.emplace_back("X-FS-Signature", oxenc::to_base64(signature)); + + send_onion_request( + ServerDestination{ + "http", std::string(file_server), endpoint, pubkey, 80, headers, "GET"}, + std::nullopt, + pubkey, + handle_response, + request_timeout, + request_and_path_build_timeout, + PathType::standard); +} + +// MARK: Response Handling + +std::tuple>, std::optional> +Network::process_v3_onion_response(Builder builder, std::string response) { + std::string base64_iv_and_ciphertext; + try { + nlohmann::json response_json = nlohmann::json::parse(response); + + if (!response_json.contains("result") || !response_json["result"].is_string()) + throw std::runtime_error{"JSON missing result field."}; + + base64_iv_and_ciphertext = response_json["result"].get(); + } catch (...) { + base64_iv_and_ciphertext = response; + } + + if (!oxenc::is_base64(base64_iv_and_ciphertext)) + throw std::runtime_error{"Invalid base64 encoded IV and ciphertext."}; + + ustring iv_and_ciphertext; + oxenc::from_base64( + base64_iv_and_ciphertext.begin(), + base64_iv_and_ciphertext.end(), + std::back_inserter(iv_and_ciphertext)); + auto parser = ResponseParser(builder); + auto result = parser.decrypt(iv_and_ciphertext); + auto result_json = nlohmann::json::parse(result); + int16_t status_code; + std::vector> headers; + std::string body; + + if (result_json.contains("status_code") && result_json["status_code"].is_number()) + status_code = result_json["status_code"].get(); + else if (result_json.contains("status") && result_json["status"].is_number()) + status_code = result_json["status"].get(); + else + throw std::runtime_error{"Invalid JSON response, missing required status_code field."}; + + if (result_json.contains("headers")) { + auto header_vals = result_json["headers"]; + + for (auto it = header_vals.begin(); it != header_vals.end(); ++it) + headers.emplace_back(it.key(), it.value()); + } + + if (result_json.contains("body") && result_json["body"].is_string()) + body = result_json["body"].get(); + else + body = result_json.dump(); + + return {status_code, headers, body}; +} + +std::tuple>, std::optional> +Network::process_v4_onion_response(Builder builder, std::string response) { + ustring response_data{to_unsigned(response.data()), response.size()}; + auto parser = ResponseParser(builder); + auto result = parser.decrypt(response_data); + + // Process the bencoded response + oxenc::bt_list_consumer result_bencode{result}; + + if (result_bencode.is_finished() || !result_bencode.is_string()) + throw std::runtime_error{"Invalid bencoded response"}; + + auto response_info_string = result_bencode.consume_string(); + int16_t status_code; + std::vector> headers; + nlohmann::json response_info_json = nlohmann::json::parse(response_info_string); + + if (response_info_json.contains("code") && response_info_json["code"].is_number()) + status_code = response_info_json["code"].get(); + else + throw std::runtime_error{"Invalid JSON response, missing required code field."}; + + if (response_info_json.contains("headers")) { + auto header_vals = response_info_json["headers"]; + + for (auto it = header_vals.begin(); it != header_vals.end(); ++it) + headers.emplace_back(it.key(), it.value()); + } + + if (result_bencode.is_finished()) + return {status_code, headers, std::nullopt}; + + return {status_code, headers, result_bencode.consume_string()}; +} + +// MARK: Error Handling + +std::pair Network::validate_response(quic::message resp, bool is_bencoded) { + std::string body = resp.body_str(); + + if (resp.timed_out) + throw std::runtime_error{"Timed out"}; + if (resp.is_error()) + throw std::runtime_error{body.empty() ? "Unknown error" : body}; + + if (is_bencoded) { + // Process the bencoded response + oxenc::bt_list_consumer result_bencode{body}; + + if (result_bencode.is_finished() || !result_bencode.is_integer()) + throw std::runtime_error{"Invalid bencoded response"}; + + // If we have a status code that is not in the 2xx range, return the error + auto status_code = result_bencode.consume_integer(); + + if (status_code < 200 || status_code > 299) { + if (result_bencode.is_finished() || !result_bencode.is_string()) + throw status_code_exception{ + status_code, + {content_type_plain_text}, + "Request failed with status code: " + std::to_string(status_code)}; + + throw status_code_exception{ + status_code, {content_type_plain_text}, result_bencode.consume_string()}; + } + + // Can't convert the data to a string so just return the response body itself + return {status_code, body}; + } + + // Default to a 200 success if the response is empty but didn't timeout or error + int16_t status_code = 200; + std::pair content_type; + std::string response_string; + + try { + nlohmann::json response_json = nlohmann::json::parse(body); + content_type = content_type_json; + + if (response_json.is_array() && response_json.size() == 2) { + status_code = response_json[0].get(); + response_string = response_json[1].dump(); + } else + response_string = body; + } catch (...) { + response_string = body; + content_type = content_type_plain_text; + } + + if (status_code < 200 || status_code > 299) + throw status_code_exception{status_code, {content_type}, response_string}; + + return {status_code, response_string}; +} + +void Network::drop_path_when_empty(std::string id, PathType path_type, onion_path path) { + paths_pending_drop.emplace_back(path, path_type); + paths[path_type].erase( + std::remove(paths[path_type].begin(), paths[path_type].end(), path), + paths[path_type].end()); + + std::string reason; + if (id == path.id) + reason = "connection being closed"; + else + reason = "failure threshold passed with {} failure"_format(id); + + log::info( + cat, + "Flagging path {} [{}] to be dropped due to {}, now have {} {} paths(s).", + path.id, + path.to_string(), + reason, + paths[path_type].size(), + path_type_name(path_type, single_path_mode)); + + // Clear any paths which are waiting to be dropped + clear_empty_pending_path_drops(); +} + +void Network::clear_empty_pending_path_drops() { + auto remaining_standard_paths = 0; + std::erase_if(paths_pending_drop, [this, &remaining_standard_paths](const auto& path_info) { + // If the path is no longer valid then we can drop it + if (!path_info.first.has_pending_requests()) { + log::info( + cat, + "Removing flagged {} path {} that {}: [{}].", + path_type_name(path_info.second, single_path_mode), + path_info.first.id, + (path_info.first.is_valid() ? "has no remaining requests" + : "is no longer valid"), + path_info.first.to_string()); + return true; + } + remaining_standard_paths++; + return false; + }); + + // Update the network status if we've removed all standard paths + if (remaining_standard_paths == 0 && paths[PathType::standard].empty()) + update_status(ConnectionStatus::disconnected); +} + +void Network::handle_errors( + request_info info, + connection_info conn_info, + bool timeout_, + int16_t status_code_, + std::vector> headers_, + std::optional response, + std::optional handle_response) { + bool timeout = timeout_; + auto status_code = status_code_; + auto headers = headers_; + auto path_name = path_type_name(info.path_type, single_path_mode); + + // There is an issue which can occur where we get invalid data back and are unable to decrypt + // it, if we do see this behaviour then we want to retry the request on the off chance it + // resolves itself + // + // When testing this case the retry always resulted in a 421 error, if that occurs we want to go + // through the standard 421 behaviour (which, in this case, would involve a 3rd retry against + // another node in the swarm to confirm the redirect) + if (!info.retry_reason && response && *response == session::onionreq::decryption_failed_error) { + log::info( + cat, + "Received decryption failure in request {} on {} path, retrying.", + info.request_id, + path_name); + auto updated_info = info; + updated_info.retry_reason = request_info::RetryReason::decryption_failure; + return net.call_soon([this, updated_info, cb = std::move(*handle_response)]() { + _send_onion_request(updated_info, std::move(cb)); + }); + } + + // A number of server errors can return HTML data but no status code, we want to extract those + // cases so they can be handled properly below + if (status_code == -1 && response) { + const std::unordered_map> response_map = { + {"400 Bad Request", {400, false}}, + {"403 Forbidden", {403, false}}, + {"500 Internal Server Error", {500, false}}, + {"502 Bad Gateway", {502, false}}, + {"503 Service Unavailable", {503, false}}, + {"504 Gateway Timeout", {504, true}}, + }; + + for (const auto& [prefix, result] : response_map) { + if (response->starts_with(prefix)) { + status_code = result.first; + timeout = (timeout || result.second); + } + } + } + + // In trace mode log all error info + log::trace( + cat, + "Received network error in request {} on {} path, status_code: {}, timeout: {}, " + "response: {}", + info.request_id, + path_name, + status_code, + timeout, + response.value_or("(No Response)")); + + // A timeout could be caused because the destination is unreachable rather than the the path + // (eg. if a user has an old SOGS which is no longer running on their device they will get a + // timeout) so if we timed out while sending a proxied request we assume something is wrong on + // the server side and don't update the path/snode state + if (!info.node_destination && timeout) { + if (handle_response) + return (*handle_response)(false, true, status_code, headers, response); + return; + } + + switch (status_code) { + // A 404 or a 400 is likely due to a bad/missing SOGS or file so + // shouldn't mark a path or snode as invalid + case 400: + case 404: + if (handle_response) + return (*handle_response)(false, false, status_code, headers, response); + return; + + // The user's clock is out of sync with the service node network (a + // snode will return 406, but V4 onion requests returns a 425) + case 406: + case 425: + if (handle_response) + return (*handle_response)(false, false, status_code, headers, response); + return; + + // The snode is reporting that it isn't associated with the given public key anymore. If + // this is the first 421 then we want to try another node in the swarm (just in case it + // was reported incorrectly). If this is the second occurrence of the 421 then the + // client needs to update the swarm (if the response contains updated swarm data), or + // increment the path failure count. + case 421: + try { + // If there is no response handler or no swarm information was provided then we + // should just replace the swarm + auto target = detail::node_for_destination(info.destination); + + if (!handle_response || !info.swarm_pubkey || !target) + throw std::invalid_argument{"Unable to handle redirect."}; + + switch (info.retry_reason.value_or(request_info::RetryReason::none)) { + // If this was the first 421 then we want to retry using another node in the + // swarm to get confirmation that we should switch to a different swarm + case request_info::RetryReason::none: + case request_info::RetryReason::decryption_failure: { + auto cached_swarm = swarm_cache[info.swarm_pubkey->hex()]; + + if (cached_swarm.second.empty()) + throw std::invalid_argument{ + "Unable to handle redirect due to lack of swarm."}; + + std::vector swarm_copy; + std::copy_if( + cached_swarm.second.begin(), + cached_swarm.second.end(), + std::back_inserter(swarm_copy), + [&target = *target](const auto& node) { return node != target; }); + std::shuffle(swarm_copy.begin(), swarm_copy.end(), csrng); + + if (swarm_copy.empty()) + throw std::invalid_argument{"No other nodes in the swarm."}; + + log::info( + cat, + "Received 421 error in request {} on {} path, retrying once before " + "updating swarm.", + info.request_id, + path_name); + auto updated_info = info; + updated_info.destination = swarm_copy.front(); + updated_info.retry_reason = request_info::RetryReason::redirect; + return net.call_soon( + [this, updated_info, cb = std::move(*handle_response)]() { + _send_onion_request(updated_info, std::move(cb)); + }); + } + + // If we got a second 421 then it's likely that our cached swarm is out of date + // so we need to refresh our snode cache, regenerate our swarm and try one more + // time + case request_info::RetryReason::redirect: + log::info( + cat, + "Received second 421 error in request {} on {} path, refreshing " + "snode cache before trying one final time.", + info.request_id, + path_name); + after_snode_cache_refresh.emplace_back([this, + swarm_pubkey = info.swarm_pubkey, + info, + status_code, + headers, + response, + cb = std::move( + *handle_response)]() { + get_swarm( + *swarm_pubkey, + [this, + info, + status_code, + headers, + response, + cb = std::move(cb)]( + swarm_id_t, std::vector swarm) { + auto target = + detail::node_for_destination(info.destination); + + std::vector swarm_copy; + std::copy_if( + swarm.begin(), + swarm.end(), + std::back_inserter(swarm_copy), + [&target = *target](const auto& node) { + return node != target; + }); + std::shuffle(swarm_copy.begin(), swarm_copy.end(), csrng); + + // If there are no nodes in the swarm then don't bother + // trying again + if (swarm_copy.empty()) { + log::info( + cat, + "Second 421 retry for request {} resulted in " + "another 421 and had no other nodes in the " + "swarm.", + info.request_id); + return cb(false, false, status_code, headers, response); + } + + auto updated_info = info; + updated_info.retry_reason = + request_info::RetryReason::redirect_swarm_refresh; + updated_info.destination = swarm_copy.front(); + net.call_soon([this, updated_info, cb = std::move(cb)]() { + _send_onion_request(updated_info, std::move(cb)); + }); + }); + }); + return net.call_soon([this, request_id = info.request_id]() { + refresh_snode_cache(request_id); + }); + + // If we got a 421 after refreshing the swarm then there is some bigger issue + // (ie. our local swarm generation logic differs from the server or we are + // getting invalid swarm ids back when updating our cache) so the best we can + // do is handle this like any other error + case request_info::RetryReason::redirect_swarm_refresh: + log::info( + cat, + "Received another 421 for request {} after refreshing the snode " + "cache, failing request.", + info.request_id); + break; + + default: break; // Unhandled case should just behave like any other error + } + } catch (...) { + } + + // If we weren't able to retry or redirect the swarm then handle this like any other + // error + break; + + case 500: + case 504: + // If we are making a proxied request to a server then assume 500 errors are occurring + // on the server rather than in the service node network and don't update the path/snode + // state + if (!info.node_destination) { + if (handle_response) + return (*handle_response)(false, timeout, status_code, headers, response); + return; + } + break; + + default: break; + } + + // Retrieve the path for the connection_info (no paths share the same guard node so we can use + // that to find it) + std::optional path; + auto is_active_path = true; + + auto path_it = std::find_if( + paths[info.path_type].begin(), + paths[info.path_type].end(), + [guard_node = conn_info.node](const auto& path) { + return !path.nodes.empty() && path.nodes.front() == guard_node; + }); + + // Try to retrieve the path this request was on, if it's not in an active or pending drop path + // then log a warning (as this shouldn't be possible) and call the callback + if (path_it != paths[info.path_type].end()) + path = *path_it; + else { + auto path_pending_drop_it = std::find_if( + paths_pending_drop.begin(), + paths_pending_drop.end(), + [guard_node = conn_info.node](const auto& path_info) { + return !path_info.first.nodes.empty() && + path_info.first.nodes.front() == guard_node; + }); + + if (path_pending_drop_it == paths_pending_drop.end()) { + log::warning( + cat, + "Request {} failed but {} path with guard {} already dropped.", + info.request_id, + path_name, + conn_info.node.to_string()); + + if (handle_response) + (*handle_response)(false, timeout, status_code, headers, response); + return; + } + path = path_pending_drop_it->first; + is_active_path = false; + } + + // Update the failure counts and paths + auto updated_path = *path; + bool found_invalid_node = false; + + if (response) { + std::optional ed25519PublicKey; + + // Check if the response has one of the 'node_not_found' prefixes + if (response->starts_with(node_not_found_prefix)) + ed25519PublicKey = {response->data() + node_not_found_prefix.size()}; + else if (response->starts_with(node_not_found_prefix_no_status)) + ed25519PublicKey = {response->data() + node_not_found_prefix_no_status.size()}; + + // If we found a result then try to extract the pubkey and process it + if (ed25519PublicKey && ed25519PublicKey->size() == 64 && + oxenc::is_hex(*ed25519PublicKey)) { + session::onionreq::ed25519_pubkey edpk = + session::onionreq::ed25519_pubkey::from_hex(*ed25519PublicKey); + auto edpk_view = to_unsigned_sv(edpk.view()); + + auto snode_it = std::find_if( + updated_path.nodes.begin(), + updated_path.nodes.end(), + [&edpk_view](const auto& node) { return node.view_remote_key() == edpk_view; }); + + if (snode_it != updated_path.nodes.end()) { + found_invalid_node = true; + + // If we get an explicit node failure then we should just immediately drop it and + // try to repair the existing path by replacing the bad node with another one + snode_failure_counts[snode_it->to_string()] = snode_failure_threshold; + + try { + // If the node that's gone bad is the guard node then we just have to + // drop the path + if (snode_it == updated_path.nodes.begin()) + throw std::runtime_error{"Cannot recover if guard node is bad"}; + + if (unused_nodes.empty()) + throw std::runtime_error{"No remaining nodes"}; + + auto target_node = unused_nodes.back(); + unused_nodes.pop_back(); + + std::replace( + updated_path.nodes.begin(), + updated_path.nodes.end(), + *snode_it, + target_node); + log::info( + cat, + "Found bad node ({}) in {} path, replacing node ({}).", + *ed25519PublicKey, + path_name, + updated_path.id); + } catch (...) { + // There aren't enough unused nodes remaining so we need to drop the + // path + updated_path.failure_count = path_failure_threshold; + log::info( + cat, + "Unable to replace bad node ({}) in {} path ({}).", + *ed25519PublicKey, + path_name, + updated_path.id); + } + } + } + } + + // If we didn't find the specific node or the paths connection was closed then increment the + // path failure count + if (!found_invalid_node || !updated_path.conn_info.is_valid()) { + updated_path.failure_count += 1; + + // If the path has failed too many times we want to drop the guard snode (marking it as + // invalid) and increment the failure count of each node in the path) + if (updated_path.failure_count >= path_failure_threshold) { + for (auto& it : updated_path.nodes) + ++snode_failure_counts[it.to_string()]; + + // Set the failure count of the guard node to match the threshold so we don't use it + // again until we refresh the cache + snode_failure_counts[updated_path.nodes[0].to_string()] = snode_failure_threshold; + } else if (updated_path.nodes.size() < path_size) + // triggered when trying to establish a new path and, as such, we should increase + // the failure count of the guard node since it is probably invalid + ++snode_failure_counts[updated_path.nodes[0].to_string()]; + } + + // Drop the path if invalid (and currently an active path) + if (is_active_path) { + if (updated_path.failure_count >= path_failure_threshold) + drop_path_when_empty(info.request_id, info.path_type, *path_it); + else + std::replace( + paths[info.path_type].begin(), + paths[info.path_type].end(), + *path_it, + updated_path); + } + + if (handle_response) + (*handle_response)(false, timeout, status_code, headers, response); +} + +} // namespace session::network + +// MARK: C API + +namespace { + +inline session::network::Network& unbox(network_object* network_) { + assert(network_ && network_->internals); + return *static_cast(network_->internals); +} + +inline bool set_error(char* error, const std::exception& e) { + if (!error) + return false; + + std::string msg = e.what(); + if (msg.size() > 255) + msg.resize(255); + std::memcpy(error, msg.c_str(), msg.size() + 1); + return false; +} + +} // namespace + +extern "C" { + +using namespace session; +using namespace session::network; + +LIBSESSION_C_API bool network_init( + network_object** network, + const char* cache_path_, + bool use_testnet, + bool single_path_mode, + bool pre_build_paths, + char* error) { + try { + std::optional cache_path; + if (cache_path_) + cache_path = cache_path_; + + auto n = std::make_unique( + cache_path, use_testnet, single_path_mode, pre_build_paths); + auto n_object = std::make_unique(); + + n_object->internals = n.release(); + *network = n_object.release(); + return true; + } catch (const std::exception& e) { + return set_error(error, e); + } +} + +LIBSESSION_C_API void network_free(network_object* network) { + delete static_cast(network->internals); + delete network; +} + +LIBSESSION_C_API void network_suspend(network_object* network) { + unbox(network).suspend(); +} + +LIBSESSION_C_API void network_resume(network_object* network) { + unbox(network).resume(); +} + +LIBSESSION_C_API void network_close_connections(network_object* network) { + unbox(network).close_connections(); +} + +LIBSESSION_C_API void network_clear_cache(network_object* network) { + unbox(network).clear_cache(); +} + +LIBSESSION_C_API size_t network_get_snode_cache_size(network_object* network) { + return unbox(network).snode_cache_size(); +} + +LIBSESSION_C_API void network_set_status_changed_callback( + network_object* network, void (*callback)(CONNECTION_STATUS status, void* ctx), void* ctx) { + if (!callback) + unbox(network).status_changed = nullptr; + else + unbox(network).status_changed = [cb = std::move(callback), ctx](ConnectionStatus status) { + cb(static_cast(status), ctx); + }; +} + +LIBSESSION_C_API void network_set_paths_changed_callback( + network_object* network, + void (*callback)(onion_request_path* paths, size_t paths_len, void* ctx), + void* ctx) { + if (!callback) + unbox(network).paths_changed = nullptr; + else + unbox(network).paths_changed = [cb = std::move(callback), + ctx](std::vector> paths) { + size_t paths_mem_size = 0; + for (auto& nodes : paths) + paths_mem_size += + sizeof(onion_request_path) + (sizeof(network_service_node) * nodes.size()); + + // Allocate the memory for the onion_request_paths* array + auto* c_paths_array = static_cast(std::malloc(paths_mem_size)); + for (size_t i = 0; i < paths.size(); ++i) { + auto c_nodes = network::detail::convert_service_nodes(paths[i]); + + // Allocate memory that persists outside the loop + size_t node_array_size = sizeof(network_service_node) * c_nodes.size(); + auto* c_nodes_array = + static_cast(std::malloc(node_array_size)); + std::copy(c_nodes.begin(), c_nodes.end(), c_nodes_array); + new (c_paths_array + i) onion_request_path{c_nodes_array, c_nodes.size()}; + } + + cb(c_paths_array, paths.size(), ctx); + }; +} + +LIBSESSION_C_API void network_get_swarm( + network_object* network, + const char* swarm_pubkey_hex, + void (*callback)(network_service_node* nodes, size_t nodes_len, void*), + void* ctx) { + assert(swarm_pubkey_hex && callback); + unbox(network).get_swarm( + x25519_pubkey::from_hex({swarm_pubkey_hex, 64}), + [cb = std::move(callback), ctx](swarm_id_t, std::vector nodes) { + auto c_nodes = network::detail::convert_service_nodes(nodes); + cb(c_nodes.data(), c_nodes.size(), ctx); + }); +} + +LIBSESSION_C_API void network_get_random_nodes( + network_object* network, + uint16_t count, + void (*callback)(network_service_node*, size_t, void*), + void* ctx) { + assert(callback); + unbox(network).get_random_nodes( + count, [cb = std::move(callback), ctx](std::vector nodes) { + auto c_nodes = network::detail::convert_service_nodes(nodes); + cb(c_nodes.data(), c_nodes.size(), ctx); + }); +} + +LIBSESSION_C_API void network_send_onion_request_to_snode_destination( + network_object* network, + const network_service_node node, + const unsigned char* body_, + size_t body_size, + const char* swarm_pubkey_hex, + int64_t request_timeout_ms, + int64_t request_and_path_build_timeout_ms, + network_onion_response_callback_t callback, + void* ctx) { + assert(callback); + + try { + std::optional body; + if (body_size > 0) + body = {body_, body_size}; + + std::optional swarm_pubkey; + if (swarm_pubkey_hex) + swarm_pubkey = x25519_pubkey::from_hex({swarm_pubkey_hex, 64}); + + std::optional request_and_path_build_timeout; + if (request_and_path_build_timeout_ms > 0) + request_and_path_build_timeout = + std::chrono::milliseconds{request_and_path_build_timeout_ms}; + + std::array ip; + std::memcpy(ip.data(), node.ip, ip.size()); + + unbox(network).send_onion_request( + service_node{ + oxenc::from_hex({node.ed25519_pubkey_hex, 64}), + {0}, + INVALID_SWARM_ID, + "{}"_format(fmt::join(ip, ".")), + node.quic_port}, + body, + swarm_pubkey, + [cb = std::move(callback), ctx]( + bool success, + bool timeout, + int status_code, + std::vector> headers, + std::optional response) { + std::vector cHeaders; + std::vector cHeaderValues; + cHeaders.reserve(headers.size()); + cHeaderValues.reserve(headers.size()); + + for (const auto& [header, value] : headers) { + cHeaders.push_back(header.c_str()); + cHeaderValues.push_back(value.c_str()); + } + + if (response) + cb(success, + timeout, + status_code, + cHeaders.data(), + cHeaderValues.data(), + headers.size(), + (*response).c_str(), + (*response).size(), + ctx); + else + cb(success, + timeout, + status_code, + cHeaders.data(), + cHeaderValues.data(), + headers.size(), + nullptr, + 0, + ctx); + }, + std::chrono::milliseconds{request_timeout_ms}, + request_and_path_build_timeout); + } catch (const std::exception& e) { + callback(false, false, -1, nullptr, nullptr, 0, e.what(), std::strlen(e.what()), ctx); + } +} + +LIBSESSION_C_API void network_send_onion_request_to_server_destination( + network_object* network, + const network_server_destination server, + const unsigned char* body_, + size_t body_size, + int64_t request_timeout_ms, + int64_t request_and_path_build_timeout_ms, + network_onion_response_callback_t callback, + void* ctx) { + assert(server.method && server.protocol && server.host && server.endpoint && + server.x25519_pubkey && callback); + + try { + std::optional body; + if (body_size > 0) + body = {body_, body_size}; + + std::optional request_and_path_build_timeout; + if (request_and_path_build_timeout_ms > 0) + request_and_path_build_timeout = + std::chrono::milliseconds{request_and_path_build_timeout_ms}; + + unbox(network).send_onion_request( + network::detail::convert_server_destination(server), + body, + std::nullopt, + [cb = std::move(callback), ctx]( + bool success, + bool timeout, + int status_code, + std::vector> headers, + std::optional response) { + std::vector cHeaders; + std::vector cHeaderValues; + cHeaders.reserve(headers.size()); + cHeaderValues.reserve(headers.size()); + + for (const auto& [header, value] : headers) { + cHeaders.push_back(header.c_str()); + cHeaderValues.push_back(value.c_str()); + } + + if (response) + cb(success, + timeout, + status_code, + cHeaders.data(), + cHeaderValues.data(), + headers.size(), + (*response).c_str(), + (*response).size(), + ctx); + else + cb(success, + timeout, + status_code, + cHeaders.data(), + cHeaderValues.data(), + headers.size(), + nullptr, + 0, + ctx); + }, + std::chrono::milliseconds{request_timeout_ms}, + request_and_path_build_timeout); + } catch (const std::exception& e) { + callback(false, false, -1, nullptr, nullptr, 0, e.what(), std::strlen(e.what()), ctx); + } +} + +LIBSESSION_C_API void network_upload_to_server( + network_object* network, + const network_server_destination server, + const unsigned char* data, + size_t data_len, + const char* file_name_, + int64_t request_timeout_ms, + int64_t request_and_path_build_timeout_ms, + network_onion_response_callback_t callback, + void* ctx) { + assert(data && server.method && server.protocol && server.host && server.endpoint && + server.x25519_pubkey && callback); + + try { + std::optional file_name; + if (file_name_) + file_name = file_name_; + + std::optional request_and_path_build_timeout; + if (request_and_path_build_timeout_ms > 0) + request_and_path_build_timeout = + std::chrono::milliseconds{request_and_path_build_timeout_ms}; + + unbox(network).upload_file_to_server( + {data, data_len}, + network::detail::convert_server_destination(server), + file_name, + [cb = std::move(callback), ctx]( + bool success, + bool timeout, + int status_code, + std::vector> headers, + std::optional response) { + std::vector cHeaders; + std::vector cHeaderValues; + cHeaders.reserve(headers.size()); + cHeaderValues.reserve(headers.size()); + + for (const auto& [header, value] : headers) { + cHeaders.push_back(header.c_str()); + cHeaderValues.push_back(value.c_str()); + } + + if (response) + cb(success, + timeout, + status_code, + cHeaders.data(), + cHeaderValues.data(), + headers.size(), + (*response).c_str(), + (*response).size(), + ctx); + else + cb(success, + timeout, + status_code, + cHeaders.data(), + cHeaderValues.data(), + headers.size(), + nullptr, + 0, + ctx); + }, + std::chrono::milliseconds{request_timeout_ms}, + request_and_path_build_timeout); + } catch (const std::exception& e) { + callback(false, false, -1, nullptr, nullptr, 0, e.what(), std::strlen(e.what()), ctx); + } +} + +LIBSESSION_C_API void network_download_from_server( + network_object* network, + const network_server_destination server, + int64_t request_timeout_ms, + int64_t request_and_path_build_timeout_ms, + network_onion_response_callback_t callback, + void* ctx) { + assert(server.method && server.protocol && server.host && server.endpoint && + server.x25519_pubkey && callback); + + try { + std::optional request_and_path_build_timeout; + if (request_and_path_build_timeout_ms > 0) + request_and_path_build_timeout = + std::chrono::milliseconds{request_and_path_build_timeout_ms}; + + unbox(network).download_file( + network::detail::convert_server_destination(server), + [cb = std::move(callback), ctx]( + bool success, + bool timeout, + int status_code, + std::vector> headers, + std::optional response) { + std::vector cHeaders; + std::vector cHeaderValues; + cHeaders.reserve(headers.size()); + cHeaderValues.reserve(headers.size()); + + for (const auto& [header, value] : headers) { + cHeaders.push_back(header.c_str()); + cHeaderValues.push_back(value.c_str()); + } + + if (response) + cb(success, + timeout, + status_code, + cHeaders.data(), + cHeaderValues.data(), + headers.size(), + (*response).c_str(), + (*response).size(), + ctx); + else + cb(success, + timeout, + status_code, + cHeaders.data(), + cHeaderValues.data(), + headers.size(), + nullptr, + 0, + ctx); + }, + std::chrono::milliseconds{request_timeout_ms}, + request_and_path_build_timeout); + } catch (const std::exception& e) { + callback(false, false, -1, nullptr, nullptr, 0, e.what(), std::strlen(e.what()), ctx); + } +} + +LIBSESSION_C_API void network_get_client_version( + network_object* network, + CLIENT_PLATFORM platform, + const unsigned char* ed25519_secret, + int64_t request_timeout_ms, + int64_t request_and_path_build_timeout_ms, + network_onion_response_callback_t callback, + void* ctx) { + assert(platform && callback); + + try { + std::optional request_and_path_build_timeout; + if (request_and_path_build_timeout_ms > 0) + request_and_path_build_timeout = + std::chrono::milliseconds{request_and_path_build_timeout_ms}; + + unbox(network).get_client_version( + static_cast(platform), + onionreq::ed25519_seckey::from_bytes({ed25519_secret, 64}), + [cb = std::move(callback), ctx]( + bool success, + bool timeout, + int status_code, + std::vector> headers, + std::optional response) { + std::vector cHeaders; + std::vector cHeaderValues; + cHeaders.reserve(headers.size()); + cHeaderValues.reserve(headers.size()); + + for (const auto& [header, value] : headers) { + cHeaders.push_back(header.c_str()); + cHeaderValues.push_back(value.c_str()); + } + + if (response) + cb(success, + timeout, + status_code, + cHeaders.data(), + cHeaderValues.data(), + headers.size(), + (*response).c_str(), + (*response).size(), + ctx); + else + cb(success, + timeout, + status_code, + cHeaders.data(), + cHeaderValues.data(), + headers.size(), + nullptr, + 0, + ctx); + }, + std::chrono::milliseconds{request_timeout_ms}, + request_and_path_build_timeout); + } catch (const std::exception& e) { + callback(false, false, -1, nullptr, nullptr, 0, e.what(), std::strlen(e.what()), ctx); + } +} + +} // extern "C" diff --git a/src/onionreq/builder.cpp b/src/onionreq/builder.cpp index 27e39bf0..9b9ade16 100644 --- a/src/onionreq/builder.cpp +++ b/src/onionreq/builder.cpp @@ -1,8 +1,11 @@ #include "session/onionreq/builder.hpp" +#include #include +#include #include #include +#include #include #include #include @@ -12,19 +15,38 @@ #include #include -#include +#include #include #include +#include +#include #include "session/export.h" +#include "session/network.hpp" #include "session/onionreq/builder.h" #include "session/onionreq/hop_encryption.hpp" #include "session/onionreq/key_types.hpp" #include "session/util.hpp" #include "session/xed25519.hpp" +using namespace std::literals; +using namespace oxen::log::literals; +using session::ustring_view; + namespace session::onionreq { +namespace detail { + session::onionreq::x25519_pubkey pubkey_for_destination(network_destination destination) { + if (auto* dest = std::get_if(&destination)) + return compute_x25519_pubkey(dest->view_remote_key()); + + if (auto* dest = std::get_if(&destination)) + return dest->x25519_pubkey; + + throw std::runtime_error{"Invalid destination."}; + } +} // namespace detail + namespace { ustring encode_size(uint32_t s) { @@ -43,6 +65,95 @@ EncryptType parse_enc_type(std::string_view enc_type) { throw std::runtime_error{"Invalid encryption type " + std::string{enc_type}}; } +Builder Builder::make( + const network_destination& destination, + const std::vector& nodes, + const EncryptType enc_type_) { + return Builder{destination, nodes, enc_type_}; +} + +Builder::Builder( + const network_destination& destination, + const std::vector& nodes, + const EncryptType enc_type_) : + enc_type{enc_type_}, + destination_x25519_public_key{detail::pubkey_for_destination(destination)} { + set_destination(destination); + for (auto& n : nodes) + add_hop(n.view_remote_key()); +} + +void Builder::add_hop(ustring_view remote_key) { + hops_.push_back({ed25519_pubkey::from_bytes(remote_key), compute_x25519_pubkey(remote_key)}); +} + +void Builder::set_destination(network_destination destination) { + ed25519_public_key_.reset(); + + if (auto* dest = std::get_if(&destination)) + ed25519_public_key_.emplace(ed25519_pubkey::from_bytes(dest->view_remote_key())); + else if (auto* dest = std::get_if(&destination)) { + host_.emplace(dest->host); + endpoint_.emplace(dest->endpoint); + method_.emplace(dest->method); + + // Remove the '://' from the protocol if it was given + size_t pos = dest->protocol.find("://"); + if (pos != std::string::npos) + protocol_.emplace(dest->protocol.substr(0, pos)); + else + protocol_.emplace(dest->protocol); + + if (dest->port) + port_.emplace(*dest->port); + + if (dest->headers) + headers_.emplace(*dest->headers); + } else + throw std::invalid_argument{"Invalid destination type."}; +} + +void Builder::set_destination_pubkey(session::onionreq::x25519_pubkey x25519_pubkey) { + destination_x25519_public_key.reset(); + destination_x25519_public_key.emplace(x25519_pubkey); +} + +void Builder::generate(network::request_info& info) { + info.body = build(_generate_payload(info.original_body)); +} + +ustring Builder::_generate_payload(std::optional body) const { + // If we don't have the data required for a server request, then assume it's targeting a + // service node and, therefore, the `body` is the payload + if (!host_ || !endpoint_ || !protocol_ || !method_ || !destination_x25519_public_key) + return body.value_or(ustring{}); + + // Otherwise generate the payload for a server request + auto headers_json = nlohmann::json::object(); + + if (headers_) + for (const auto& [key, value] : *headers_) { + // Some platforms might automatically add this header, but we don't want to include it + if (key != "User-Agent") + headers_json[key] = value; + } + + if (body && !headers_json.contains("Content-Type")) + headers_json["Content-Type"] = "application/json"; + + // Structure the request information + nlohmann::json request_info{ + {"method", *method_}, {"endpoint", *endpoint_}, {"headers", headers_json}}; + std::vector payload{request_info.dump()}; + + // If we were given a body, add it to the payload + if (body.has_value()) + payload.emplace_back(from_unsigned_sv(*body)); + + auto result = oxenc::bt_serialize(payload); + return {to_unsigned(result.data()), result.size()}; +} + ustring Builder::build(ustring payload) { ustring blob; @@ -94,19 +205,19 @@ ustring Builder::build(ustring payload) { // The data we send to the destination differs depending on whether the destination is a // server or a service node - if (host_ && target_ && protocol_ && destination_x25519_public_key) { + if (host_ && protocol_ && destination_x25519_public_key) { final_route = { - {"host", host_.value()}, - {"target", target_.value()}, + {"host", *host_}, + {"target", "/oxen/v4/lsrpc"}, // All servers support V4 onion requests {"method", "POST"}, - {"protocol", protocol_.value()}, - {"port", port_.value_or(protocol_.value() == "https" ? 443 : 80)}, + {"protocol", *protocol_}, + {"port", port_.value_or(*protocol_ == "https" ? 443 : 80)}, {"ephemeral_key", A.hex()}, // The x25519 ephemeral_key here is the key for the // *next* hop to use {"enc_type", to_string(enc_type)}, }; - blob = e.encrypt(enc_type, payload.data(), *destination_x25519_public_key); + blob = e.encrypt(enc_type, payload, *destination_x25519_public_key); } else if (ed25519_public_key_ && destination_x25519_public_key) { nlohmann::json control{{"headers", ""}}; final_route = { @@ -121,7 +232,13 @@ ustring Builder::build(ustring payload) { data += to_unsigned_sv(control.dump()); blob = e.encrypt(enc_type, data, *destination_x25519_public_key); } else { - throw std::runtime_error{"Destination not set"}; + if (!destination_x25519_public_key.has_value()) + throw std::runtime_error{"Destination not set: No destination x25519 public key"}; + if (!ed25519_public_key_.has_value()) + throw std::runtime_error{"Destination not set: No destination ed25519 public key"}; + throw std::runtime_error{ + "Destination not set: " + host_.value_or("N/A") + ", " + + protocol_.value_or("N/A")}; } // Save these because we need them again to decrypt the final response: @@ -208,29 +325,47 @@ LIBSESSION_C_API void onion_request_builder_set_enc_type( LIBSESSION_C_API void onion_request_builder_set_snode_destination( onion_request_builder_object* builder, - const char* ed25519_pubkey, - const char* x25519_pubkey) { - assert(builder && ed25519_pubkey && x25519_pubkey); - - unbox(builder).set_snode_destination( - session::onionreq::ed25519_pubkey::from_hex({ed25519_pubkey, 64}), - session::onionreq::x25519_pubkey::from_hex({x25519_pubkey, 64})); + const uint8_t ip[4], + const uint16_t quic_port, + const char* ed25519_pubkey) { + assert(builder && ip && ed25519_pubkey); + + std::array target_ip; + std::memcpy(target_ip.data(), ip, target_ip.size()); + + unbox(builder).set_destination(session::network::service_node( + oxenc::from_hex({ed25519_pubkey, 64}), + {0}, + session::network::INVALID_SWARM_ID, + "{}"_format(fmt::join(target_ip, ".")), + quic_port)); } LIBSESSION_C_API void onion_request_builder_set_server_destination( onion_request_builder_object* builder, - const char* host, - const char* target, const char* protocol, + const char* host, + const char* endpoint, + const char* method, uint16_t port, const char* x25519_pubkey) { - assert(builder && host && target && protocol && x25519_pubkey); + assert(builder && protocol && host && endpoint && protocol && x25519_pubkey); - unbox(builder).set_server_destination( - host, - target, + unbox(builder).set_destination(session::onionreq::ServerDestination{ protocol, + host, + endpoint, + session::onionreq::x25519_pubkey::from_hex({x25519_pubkey, 64}), port, + std::nullopt, + method}); +} + +LIBSESSION_C_API void onion_request_builder_set_destination_pubkey( + onion_request_builder_object* builder, const char* x25519_pubkey) { + assert(builder && x25519_pubkey); + + unbox(builder).set_destination_pubkey( session::onionreq::x25519_pubkey::from_hex({x25519_pubkey, 64})); } diff --git a/src/onionreq/hop_encryption.cpp b/src/onionreq/hop_encryption.cpp index d15ef641..deb48c14 100644 --- a/src/onionreq/hop_encryption.cpp +++ b/src/onionreq/hop_encryption.cpp @@ -80,6 +80,15 @@ namespace { } // namespace +bool HopEncryption::response_long_enough(EncryptType type, size_t response_size) { + switch (type) { + case EncryptType::xchacha20: + return (response_size >= crypto_aead_xchacha20poly1305_ietf_ABYTES); + case EncryptType::aes_gcm: return (response_size >= GCM_IV_SIZE + GCM_DIGEST_SIZE); + } + return false; +} + ustring HopEncryption::encrypt( EncryptType type, ustring plaintext, const x25519_pubkey& pubkey) const { switch (type) { @@ -131,8 +140,9 @@ ustring HopEncryption::encrypt_aesgcm(ustring plaintext, const x25519_pubkey& pu ustring HopEncryption::decrypt_aesgcm(ustring ciphertext_, const x25519_pubkey& pubKey) const { ustring_view ciphertext = {ciphertext_.data(), ciphertext_.size()}; - if (ciphertext.size() < GCM_IV_SIZE + GCM_DIGEST_SIZE) - throw std::runtime_error{"ciphertext data is too short"}; + if (!response_long_enough(EncryptType::aes_gcm, ciphertext_.size())) + throw std::invalid_argument{ + "Ciphertext data is too short: " + std::string(from_unsigned(ciphertext_.data()))}; auto key = derive_symmetric_key(private_key_, pubKey); @@ -198,8 +208,10 @@ ustring HopEncryption::decrypt_xchacha20(ustring ciphertext_, const x25519_pubke // Extract nonce from the beginning of the ciphertext: auto nonce = ciphertext.substr(0, crypto_aead_xchacha20poly1305_ietf_NPUBBYTES); ciphertext.remove_prefix(nonce.size()); - if (ciphertext.size() < crypto_aead_xchacha20poly1305_ietf_ABYTES) - throw std::runtime_error{"Invalid ciphertext: too short"}; + + if (!response_long_enough(EncryptType::xchacha20, ciphertext_.size())) + throw std::invalid_argument{ + "Ciphertext data is too short: " + std::string(from_unsigned(ciphertext_.data()))}; const auto key = xchacha20_shared_key(public_key_, private_key_, pubKey, !server_); diff --git a/src/onionreq/key_types.cpp b/src/onionreq/key_types.cpp index f234f16e..15c8aef2 100644 --- a/src/onionreq/key_types.cpp +++ b/src/onionreq/key_types.cpp @@ -81,5 +81,13 @@ ed25519_pubkey parse_ed25519_pubkey(std::string_view pubkey_in) { x25519_pubkey parse_x25519_pubkey(std::string_view pubkey_in) { return parse_pubkey(pubkey_in); } +x25519_pubkey compute_x25519_pubkey(ustring_view ed25519_pk) { + std::array xpk; + if (0 != crypto_sign_ed25519_pk_to_curve25519(xpk.data(), ed25519_pk.data())) + throw std::runtime_error{ + "An error occured while attempting to convert Ed25519 pubkey to X25519; " + "is the pubkey valid?"}; + return x25519_pubkey::from_bytes({xpk.data(), 32}); +} } // namespace session::onionreq diff --git a/src/onionreq/response_parser.cpp b/src/onionreq/response_parser.cpp index 2ace3237..bbddcc6c 100644 --- a/src/onionreq/response_parser.cpp +++ b/src/onionreq/response_parser.cpp @@ -23,6 +23,10 @@ ResponseParser::ResponseParser(session::onionreq::Builder builder) { x25519_keypair_ = builder.final_hop_x25519_keypair.value(); } +bool ResponseParser::response_long_enough(EncryptType enc_type, size_t response_size) { + return HopEncryption::response_long_enough(enc_type, response_size); +} + ustring ResponseParser::decrypt(ustring ciphertext) const { HopEncryption d{x25519_keypair_.second, x25519_keypair_.first, false}; @@ -32,13 +36,17 @@ ustring ResponseParser::decrypt(ustring ciphertext) const { try { return d.decrypt(enc_type_, ciphertext, destination_x25519_public_key_); } catch (const std::exception& e) { - if (enc_type_ == session::onionreq::EncryptType::xchacha20) - return d.decrypt( - session::onionreq::EncryptType::aes_gcm, - ciphertext, - destination_x25519_public_key_); - else - throw e; + if (enc_type_ == session::onionreq::EncryptType::xchacha20) { + try { + return d.decrypt( + session::onionreq::EncryptType::aes_gcm, + ciphertext, + destination_x25519_public_key_); + } catch (...) { + throw std::runtime_error{std::string(decryption_failed_error)}; + } + } else + throw; } } diff --git a/src/random.cpp b/src/random.cpp index 693dd489..caeff02b 100644 --- a/src/random.cpp +++ b/src/random.cpp @@ -2,11 +2,20 @@ #include +#include + #include "session/export.h" #include "session/util.hpp" +namespace session { +// make this once, and only once, and use it where needed +CSRNG csrng = CSRNG{}; +} // namespace session + namespace session::random { +constexpr char base32_charset[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZ234567"; + ustring random(size_t size) { ustring result; result.resize(size); @@ -15,6 +24,17 @@ ustring random(size_t size) { return result; } +std::string random_base32(size_t size) { + std::string result; + result.reserve(size); + auto n_chars = sizeof(base32_charset) - 1; + + for (size_t i = 0; i < size; ++i) + result.push_back(base32_charset[csrng() % n_chars]); + + return result; +} + } // namespace session::random extern "C" { diff --git a/src/session_encrypt.cpp b/src/session_encrypt.cpp index 922ea93b..1da385d5 100644 --- a/src/session_encrypt.cpp +++ b/src/session_encrypt.cpp @@ -1,19 +1,24 @@ #include "session/session_encrypt.hpp" +#include #include #include #include #include #include +#include #include +#include #include #include +#include #include #include #include #include #include +#include #include #include "session/blinding.hpp" @@ -23,6 +28,39 @@ using namespace std::literals; namespace session { +namespace detail { + inline int64_t to_epoch_ms(std::chrono::system_clock::time_point t) { + return std::chrono::duration_cast(t.time_since_epoch()).count(); + } + + // detail::to_hashable takes either an integral type, system_clock::time_point, or a string + // type and converts it to a string_view by writing an integer value (using std::to_chars) + // into the buffer space (which should be at least 20 bytes), and returning a string_view + // into the written buffer space. For strings/string_views the string_view is returned + // directly from the argument. system_clock::time_points are converted into integral + // milliseconds since epoch then treated as an integer value. + template , int> = 0> + std::string_view to_hashable(const T& val, char*& buffer) { + std::ostringstream ss; + ss << val; + + std::string str = ss.str(); + std::copy(str.begin(), str.end(), buffer); + std::string_view s(buffer, str.length()); + buffer += str.length(); + return s; + } + inline std::string_view to_hashable( + const std::chrono::system_clock::time_point& val, char*& buffer) { + return to_hashable(to_epoch_ms(val), buffer); + } + template , int> = 0> + std::string_view to_hashable(const T& value, char*&) { + return value; + } + +} // namespace detail + // Version tag we prepend to encrypted-for-blinded-user messages. This is here so we can detect if // some future version changes the format (and if not even try to load it). inline constexpr unsigned char BLINDED_ENCRYPT_VERSION = 0; @@ -508,10 +546,44 @@ std::pair decrypt_from_blinded_recipient( } std::string decrypt_ons_response( - std::string_view lowercase_name, ustring_view ciphertext, ustring_view nonce) { + std::string_view lowercase_name, + ustring_view ciphertext, + std::optional nonce) { + // Handle old Argon2-based encryption used before HF16 + if (!nonce) { + if (ciphertext.size() < crypto_secretbox_MACBYTES) + throw std::invalid_argument{"Invalid ciphertext: expected to be greater than 16 bytes"}; + + uc32 key; + std::array salt = {0}; + + if (0 != crypto_pwhash( + key.data(), + key.size(), + lowercase_name.data(), + lowercase_name.size(), + salt.data(), + crypto_pwhash_OPSLIMIT_MODERATE, + crypto_pwhash_MEMLIMIT_MODERATE, + crypto_pwhash_ALG_ARGON2ID13)) + throw std::runtime_error{"Failed to generate key"}; + + ustring msg; + msg.resize(ciphertext.size() - crypto_secretbox_MACBYTES); + std::array nonce = {0}; + + if (0 != + crypto_secretbox_open_easy( + msg.data(), ciphertext.data(), ciphertext.size(), nonce.data(), key.data())) + throw std::runtime_error{"Failed to decrypt"}; + + std::string session_id = oxenc::to_hex(msg.begin(), msg.end()); + return session_id; + } + if (ciphertext.size() < crypto_aead_xchacha20poly1305_ietf_ABYTES) throw std::invalid_argument{"Invalid ciphertext: expected to be greater than 16 bytes"}; - if (nonce.size() != crypto_aead_xchacha20poly1305_ietf_NPUBBYTES) + if (nonce->size() != crypto_aead_xchacha20poly1305_ietf_NPUBBYTES) throw std::invalid_argument{"Invalid nonce: expected to be 24 bytes"}; // Hash the ONS name using BLAKE2b @@ -543,7 +615,7 @@ std::string decrypt_ons_response( ciphertext.size(), nullptr, 0, - nonce.data(), + nonce->data(), key.data())) throw std::runtime_error{"Failed to decrypt"}; @@ -590,6 +662,66 @@ ustring decrypt_push_notification(ustring_view payload, ustring_view enc_key) { return buf; } +template +std::string compute_hash(Func hasher, const T&... args) { + // Allocate a buffer of 20 bytes per integral value (which is the largest the any integral + // value can be when stringified). + std::array< + char, + (0 + ... + + (std::is_integral_v || std::is_same_v + ? 20 + : 0))> + buffer; + auto* b = buffer.data(); + return hasher({detail::to_hashable(args, b)...}); +} + +std::string compute_hash_blake2b_b64(std::vector parts) { + constexpr size_t HASH_SIZE = 32; + crypto_generichash_state state; + crypto_generichash_init(&state, nullptr, 0, HASH_SIZE); + for (const auto& s : parts) + crypto_generichash_update( + &state, reinterpret_cast(s.data()), s.size()); + std::array hash; + crypto_generichash_final(&state, hash.data(), HASH_SIZE); + + std::string b64hash = oxenc::to_base64(hash.begin(), hash.end()); + // Trim padding: + while (!b64hash.empty() && b64hash.back() == '=') + b64hash.pop_back(); + return b64hash; +} + +std::string compute_message_hash( + const std::string_view pubkey_hex, int16_t ns, std::string_view data) { + if (pubkey_hex.size() != 66) + throw std::invalid_argument{ + "Invalid pubkey_hex: Expecting 66 character hex-encoded pubkey"}; + + // This function is based on the `computeMessageHash` function on the storage-server used to + // generate a message hash: + // https://github.com/oxen-io/oxen-storage-server/blob/dev/oxenss/rpc/request_handler.cpp + auto pubkey = oxenc::from_hex(pubkey_hex.substr(2)); + uint8_t netid_raw; + oxenc::from_hex(pubkey_hex.begin(), pubkey_hex.begin() + 2, &netid_raw); + char netid = static_cast(netid_raw); + + std::array ns_buf; + char* ns_buf_ptr = ns_buf.data(); + std::string_view ns_for_hash = ns != 0 ? detail::to_hashable(ns, ns_buf_ptr) : ""sv; + + auto decoded_data = oxenc::from_base64(data); + + return compute_hash( + compute_hash_blake2b_b64, + std::string_view{&netid, 1}, + pubkey, + ns_for_hash, + decoded_data); +} + } // namespace session using namespace session; @@ -718,16 +850,17 @@ LIBSESSION_C_API bool session_decrypt_for_blinded_recipient( LIBSESSION_C_API bool session_decrypt_ons_response( const char* name_in, - size_t name_len, const unsigned char* ciphertext_in, size_t ciphertext_len, const unsigned char* nonce_in, char* session_id_out) { try { + std::optional nonce; + if (nonce_in) + nonce = ustring{nonce_in, crypto_aead_xchacha20poly1305_ietf_NPUBBYTES}; + auto session_id = session::decrypt_ons_response( - std::string_view{name_in, name_len}, - ustring_view{ciphertext_in, ciphertext_len}, - ustring_view{nonce_in, crypto_aead_xchacha20poly1305_ietf_NPUBBYTES}); + name_in, ustring_view{ciphertext_in, ciphertext_len}, nonce); std::memcpy(session_id_out, session_id.c_str(), session_id.size() + 1); return true; @@ -754,3 +887,15 @@ LIBSESSION_C_API bool session_decrypt_push_notification( return false; } } + +LIBSESSION_C_API bool session_compute_message_hash( + const char* pubkey_hex_in, int16_t ns, const char* base64_data_in, char* hash_out) { + try { + auto hash = session::compute_message_hash(pubkey_hex_in, ns, base64_data_in); + + std::memcpy(hash_out, hash.c_str(), hash.size() + 1); + return true; + } catch (...) { + return false; + } +} diff --git a/src/util.cpp b/src/util.cpp index 123aacfe..7669d0e1 100644 --- a/src/util.cpp +++ b/src/util.cpp @@ -1,7 +1,90 @@ -#include - +#include #include namespace session { +std::vector split(std::string_view str, const std::string_view delim, bool trim) { + std::vector results; + // Special case for empty delimiter: splits on each character boundary: + if (delim.empty()) { + results.reserve(str.size()); + for (size_t i = 0; i < str.size(); i++) + results.emplace_back(str.data() + i, 1); + return results; + } + + for (size_t pos = str.find(delim); pos != std::string_view::npos; pos = str.find(delim)) { + if (!trim || !results.empty() || pos > 0) + results.push_back(str.substr(0, pos)); + str.remove_prefix(pos + delim.size()); + } + if (!trim || str.size()) + results.push_back(str); + else + while (!results.empty() && results.back().empty()) + results.pop_back(); + return results; +} + +std::tuple, std::optional> parse_url( + std::string_view url) { + std::tuple, std::optional> + result{}; + auto& [proto, host, port, path] = result; + if (auto pos = url.find("://"); pos != std::string::npos) { + auto proto_name = url.substr(0, pos); + url.remove_prefix(proto_name.size() + 3); + if (string_iequal(proto_name, "http")) + proto = "http://"; + else if (string_iequal(proto_name, "https")) + proto = "https://"; + } + if (proto.empty()) + throw std::invalid_argument{"Invalid URL: invalid/missing protocol://"}; + + bool next_allow_dot = false; + bool has_dot = false; + while (!url.empty()) { + auto c = url.front(); + if ((c >= '0' && c <= '9') || (c >= 'a' && c <= 'z') || c == '-') { + host += c; + next_allow_dot = true; + } else if (c >= 'A' && c <= 'Z') { + host += c + ('a' - 'A'); + next_allow_dot = true; + } else if (next_allow_dot && c == '.') { + host += '.'; + has_dot = true; + next_allow_dot = false; + } else { + break; + } + url.remove_prefix(1); + } + if (host.size() < 4 || !has_dot || host.back() == '.') + throw std::invalid_argument{"Invalid URL: invalid hostname"}; + + if (!url.empty() && url.front() == ':') { + url.remove_prefix(1); + uint16_t target_port; + if (auto [p, ec] = std::from_chars(url.data(), url.data() + url.size(), target_port); + ec == std::errc{}) + url.remove_prefix(p - url.data()); + else + throw std::invalid_argument{"Invalid URL: invalid port"}; + if (!(target_port == 80 && proto == "http://") && !(target_port == 443 && proto == "https:/" + "/")) + port = target_port; + } + + if (url.size() > 1 && url.front() == '/') + path = url; + else if (!url.empty() && url.front() == '/') { + url.remove_prefix(1); + path = std::nullopt; + } + + return result; +} + } // namespace session diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index bffd82f2..d359c9ce 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -1,6 +1,10 @@ add_subdirectory(Catch2) -add_executable(testAll +if(CMAKE_BUILD_TYPE STREQUAL "Release") + add_definitions(-DRELEASE_BUILD) +endif() + +set(LIB_SESSION_UTESTS_SOURCES test_blinding.cpp test_bt_merge.cpp test_bugs.cpp @@ -17,26 +21,44 @@ add_executable(testAll test_group_info.cpp test_group_members.cpp test_hash.cpp + test_logging.cpp test_multi_encrypt.cpp - test_onionreq.cpp test_proto.cpp test_random.cpp test_session_encrypt.cpp test_xed25519.cpp ) +if (ENABLE_ONIONREQ) + list(APPEND LIB_SESSION_UTESTS_SOURCES test_network.cpp) + list(APPEND LIB_SESSION_UTESTS_SOURCES test_onionreq.cpp) +endif() + +add_executable(testAll ${LIB_SESSION_UTESTS_SOURCES}) + +add_library(Catch2Wrapper INTERFACE) +target_link_libraries(Catch2Wrapper INTERFACE Catch2::Catch2WithMain) + +if(MSVC) + target_compile_options(Catch2Wrapper INTERFACE /W0) +else() + target_compile_options(Catch2Wrapper INTERFACE -w) +endif() target_link_libraries(testAll PRIVATE libsession::config libsodium::sodium-internal - Catch2::Catch2WithMain -) + nlohmann_json::nlohmann_json + oxen::logging + Catch2::Catch2WithMain) if (ENABLE_ONIONREQ) target_link_libraries(testAll PRIVATE libsession::onionreq) endif() -add_custom_target(check COMMAND testAll) +if(NOT TARGET check) + add_custom_target(check COMMAND testAll) +endif() add_executable(swarm-auth-test EXCLUDE_FROM_ALL swarm-auth-test.cpp) target_link_libraries(swarm-auth-test PRIVATE config) @@ -44,6 +66,6 @@ target_link_libraries(swarm-auth-test PRIVATE config) if(STATIC_BUNDLE) add_executable(static-bundle-test static_bundle.cpp) target_include_directories(static-bundle-test PUBLIC ../include) - target_link_libraries(static-bundle-test PRIVATE "${PROJECT_BINARY_DIR}/libsession-util.a" oxenc::oxenc) + target_link_libraries(static-bundle-test PRIVATE "${PROJECT_BINARY_DIR}/libsession-util.a" oxenc::oxenc quic) add_dependencies(static-bundle-test session-util) endif() diff --git a/tests/catch2_bt_format.hpp b/tests/catch2_bt_format.hpp index fa87bae1..ef6942c2 100644 --- a/tests/catch2_bt_format.hpp +++ b/tests/catch2_bt_format.hpp @@ -24,7 +24,7 @@ struct StringMaker { inline std::string StringMaker::convert(const oxenc::bt_value& value) { return var::visit( [](const auto& x) { - return StringMaker>{}.convert(x); + return StringMaker>{}.convert(x); }, static_cast(value)); } diff --git a/tests/test_blinding.cpp b/tests/test_blinding.cpp index 5804bbcb..07d033f8 100644 --- a/tests/test_blinding.cpp +++ b/tests/test_blinding.cpp @@ -26,9 +26,6 @@ constexpr std::array seed2{ 0x45, 0x44, 0xc1, 0xc5, 0x08, 0x9c, 0x40, 0x41, 0x4b, 0xbd, 0xa1, 0xff, 0xdd, 0xe8, 0xaa, 0xb2, 0x61, 0x7f, 0xe9, 0x37, 0xee, 0x74, 0xa5, 0xee, 0x81}; -constexpr ustring_view pub1{seed1.data() + 32, 32}; -constexpr ustring_view pub2{seed2.data() + 32, 32}; - constexpr std::array xpub1{ 0xfe, 0x94, 0xb7, 0xad, 0x4b, 0x7f, 0x1c, 0xc1, 0xbb, 0x92, 0x67, 0x1f, 0x1f, 0x0d, 0x24, 0x3f, 0x22, 0x6e, 0x11, 0x5b, 0x33, 0x77, @@ -40,12 +37,6 @@ constexpr std::array xpub2{ 0x78, 0x81, 0x96, 0x2c, 0x72, 0x36, 0x99, 0x15, 0x20, 0x73, }; -constexpr std::array pub2_abs{ - 0x35, 0x70, 0xb6, 0x9a, 0x47, 0xdc, 0x09, 0x45, 0x44, 0xc1, 0xc5, - 0x08, 0x9c, 0x40, 0x41, 0x4b, 0xbd, 0xa1, 0xff, 0xdd, 0xe8, 0xaa, - 0xb2, 0x61, 0x7f, 0xe9, 0x37, 0xee, 0x74, 0xa5, 0xee, 0x01, -}; - const std::string session_id1 = "05" + oxenc::to_hex(xpub1.begin(), xpub1.end()); const std::string session_id2 = "05" + oxenc::to_hex(xpub2.begin(), xpub2.end()); @@ -359,8 +350,8 @@ TEST_CASE("Communities session id blinded id matching", "[blinding][matching]") CHECK(session_id_matches_blinded_id(session_id2, b25_5, server_pks[4])); CHECK(session_id_matches_blinded_id(session_id1, b25_6, server_pks[5])); - auto invalid_session_id = "9" + session_id1.substr(1, 65); - auto invalid_blinded_id = "9" + b15_1.substr(1, 65); + auto invalid_session_id = "9"s + session_id1.substr(1); + auto invalid_blinded_id = "9"s + b15_1.substr(1); auto invalid_server_pk = server_pks[0].substr(0, 60); CHECK_THROWS(session_id_matches_blinded_id(invalid_session_id, b15_1, server_pks[0])); CHECK_THROWS(session_id_matches_blinded_id(session_id1, invalid_blinded_id, server_pks[0])); diff --git a/tests/test_config_convo_info_volatile.cpp b/tests/test_config_convo_info_volatile.cpp index 85395a5a..bdac72e4 100644 --- a/tests/test_config_convo_info_volatile.cpp +++ b/tests/test_config_convo_info_volatile.cpp @@ -295,7 +295,7 @@ TEST_CASE("Conversations (C API)", "[config][conversations][c]") { "bad-url", "room", "0000000000000000000000000000000000000000000000000000000000000000"_hexbytes.data())); - CHECK(conf->last_error == "Invalid community URL: invalid/missing protocol://"sv); + CHECK(conf->last_error == "Invalid URL: invalid/missing protocol://"sv); CHECK_FALSE(convo_info_volatile_get_or_construct_community( conf, &og, @@ -492,9 +492,6 @@ TEST_CASE("Conversation pruning", "[config][conversations][pruning]") { auto pk = some_pubkey(x); return "05" + oxenc::to_hex(pk.begin(), pk.end()); }; - auto some_og_url = [&](unsigned char x) -> std::string { - return "https://example.com/r/room"s + std::to_string(x); - }; const auto now = std::chrono::system_clock::now() - 1ms; auto unix_timestamp = [&now](int days_ago) -> int64_t { return std::chrono::duration_cast( diff --git a/tests/test_config_user_groups.cpp b/tests/test_config_user_groups.cpp index 18145aa4..39d39f25 100644 --- a/tests/test_config_user_groups.cpp +++ b/tests/test_config_user_groups.cpp @@ -445,10 +445,6 @@ TEST_CASE("User Groups -- (non-legacy) groups", "[config][groups][new]") { constexpr auto definitely_real_id = "035000000000000000000000000000000000000000000000000000000000000000"sv; - int64_t now = std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count(); - CHECK_FALSE(groups.get_group(definitely_real_id)); CHECK(groups.empty()); diff --git a/tests/test_config_userprofile.cpp b/tests/test_config_userprofile.cpp index 2efb60f9..6ae899dd 100644 --- a/tests/test_config_userprofile.cpp +++ b/tests/test_config_userprofile.cpp @@ -13,14 +13,6 @@ using namespace std::literals; using namespace oxenc::literals; -void log_msg(config_log_level lvl, const char* msg, void*) { - INFO((lvl == LOG_LEVEL_ERROR ? "ERROR" - : lvl == LOG_LEVEL_WARNING ? "Warning" - : lvl == LOG_LEVEL_INFO ? "Info" - : "debug") - << ": " << msg); -} - auto empty_extra_data = "1:+de"; TEST_CASE("UserProfile", "[config][user_profile]") { @@ -91,8 +83,6 @@ TEST_CASE("user profile C API", "[config][user_profile][c]") { rc = user_profile_init(&conf, ed_sk.data(), NULL, 0, err); REQUIRE(rc == 0); - config_set_logger(conf, log_msg, NULL); - // We don't need to push anything, since this is an empty config CHECK_FALSE(config_needs_push(conf)); // And we haven't changed anything so don't need to dump to db @@ -148,8 +138,8 @@ TEST_CASE("user profile C API", "[config][user_profile][c]") { CHECK(name == "Kallie"sv); pic = user_profile_get_pic(conf); - REQUIRE(pic.url); - REQUIRE(pic.key); + REQUIRE(pic.url != ""s); + REQUIRE(pic.key != to_usv(""s)); CHECK(pic.url == "http://example.org/omg-pic-123.bmp"sv); CHECK(ustring_view{pic.key, 32} == "secret78901234567890123456789012"_bytes); @@ -255,7 +245,6 @@ TEST_CASE("user profile C API", "[config][user_profile][c]") { // Start with an empty config, as above: config_object* conf2; REQUIRE(user_profile_init(&conf2, ed_sk.data(), NULL, 0, err) == 0); - config_set_logger(conf2, log_msg, NULL); CHECK_FALSE(config_needs_dump(conf2)); // Now imagine we just pulled down the encrypted string from the swarm; we merge it into conf2: @@ -368,15 +357,31 @@ TEST_CASE("user profile C API", "[config][user_profile][c]") { // Since only one of them set a profile pic there should be no conflict there: pic = user_profile_get_pic(conf); +#if defined(__APPLE__) || defined(__clang__) || defined(__llvm__) REQUIRE(pic.url); +#else + REQUIRE(pic.url != nullptr); +#endif CHECK(pic.url == "http://new.example.com/pic"sv); +#if defined(__APPLE__) || defined(__clang__) || defined(__llvm__) REQUIRE(pic.key); +#else + REQUIRE(pic.key != nullptr); +#endif CHECK(to_hex(ustring_view{pic.key, 32}) == "7177657274007975696f31323334353637383930313233343536373839303132"); pic = user_profile_get_pic(conf2); +#if defined(__APPLE__) || defined(__clang__) || defined(__llvm__) REQUIRE(pic.url); +#else + REQUIRE(pic.url != nullptr); +#endif CHECK(pic.url == "http://new.example.com/pic"sv); +#if defined(__APPLE__) || defined(__clang__) || defined(__llvm__) REQUIRE(pic.key); +#else + REQUIRE(pic.key != nullptr); +#endif CHECK(to_hex(ustring_view{pic.key, 32}) == "7177657274007975696f31323334353637383930313233343536373839303132"); diff --git a/tests/test_configdata.cpp b/tests/test_configdata.cpp index 0a2b1eed..433be62f 100644 --- a/tests/test_configdata.cpp +++ b/tests/test_configdata.cpp @@ -53,7 +53,7 @@ TEST_CASE("config data dict encoding", "[config][data][dict]") { d["D"] = config::dict{{"x", 1}, {"y", 2}}; d["d"] = config::dict{{"e", config::dict{{"f", config::dict{{"g", ""}}}}}}; - static_assert(oxenc::detail::is_bt_input_dict_container); + static_assert(oxenc::bt_input_dict_container); CHECK(oxenc::bt_serialize(d) == "d1:B1:x1:Dd1:xi1e1:yi2ee1:ai23e1:cli-3ei4e1:11:2e1:dd1:ed1:fd1:g0:eeee"); @@ -202,12 +202,12 @@ TEST_CASE("config message serialization", "[config][serialization]") { "e")); // clang-format on - const auto hash0 = "d65738bba88b0f3455cef20fe09a7b4b10f25f9db82be24a6ce1bd06da197526"_hex; + const std::string hash0{"d65738bba88b0f3455cef20fe09a7b4b10f25f9db82be24a6ce1bd06da197526"_hex}; CHECK(view_hex(m.hash()) == oxenc::to_hex(hash0)); auto m1 = m.increment(); m1.data().erase("foo"); - const auto hash1 = "5b30b4abf4cba71db25dbc0d977cc25df1d0a8a87cad7f561cdec2b8caf65f5e"_hex; + const std::string hash1{"5b30b4abf4cba71db25dbc0d977cc25df1d0a8a87cad7f561cdec2b8caf65f5e"_hex}; CHECK(view_hex(m1.hash()) == oxenc::to_hex(hash1)); auto m2 = m1.increment(); @@ -220,7 +220,7 @@ TEST_CASE("config message serialization", "[config][serialization]") { s(d(m2.data()["bar"])[""]).erase("b"); s(d(m2.data()["bar"])[""]).insert(42); // already present - const auto hash2 = "027552203cf669070d3ecbeecfa65c65497d59aa4da490e0f68f8131ce081320"_hex; + const std::string hash2{"027552203cf669070d3ecbeecfa65c65497d59aa4da490e0f68f8131ce081320"_hex}; CHECK(view_hex(m2.hash()) == oxenc::to_hex(hash2)); // clang-format off @@ -266,16 +266,18 @@ TEST_CASE("config message serialization", "[config][serialization]") { "3:foo" "0:" "e" "e")); + // clang-format on auto m5 = m2.increment().increment().increment(); - const auto hash3 = "b83871ea06587f9254cdf2b2af8daff19bd7fb550fb90d5f8f9f546464c08bc5"_hex; - const auto hash4 = "c30e2cfa7ec93c64a1ab6420c9bccfb63da8e4c2940ed6509ffb64f3f0131860"_hex; - const auto hash5 = "3234eb7da8cf4b79b9eec2a144247279d10f6f118184f82429a42c5996bea60c"_hex; + const std::string hash3{"b83871ea06587f9254cdf2b2af8daff19bd7fb550fb90d5f8f9f546464c08bc5"_hex}, + hash4{"c30e2cfa7ec93c64a1ab6420c9bccfb63da8e4c2940ed6509ffb64f3f0131860"_hex}, + hash5{"3234eb7da8cf4b79b9eec2a144247279d10f6f118184f82429a42c5996bea60c"_hex}; CHECK(view_hex(m2.increment().hash()) == oxenc::to_hex(hash3)); CHECK(view_hex(m2.increment().increment().hash()) == oxenc::to_hex(hash4)); CHECK(view_hex(m5.hash()) == oxenc::to_hex(hash5)); + // clang-format off CHECK(printable(m5.serialize()) == printable( "d" "1:#" "i15e" diff --git a/tests/test_curve25519.cpp b/tests/test_curve25519.cpp index 3acf6aa4..275e8c7b 100644 --- a/tests/test_curve25519.cpp +++ b/tests/test_curve25519.cpp @@ -12,7 +12,7 @@ TEST_CASE("X25519 key pair generation", "[curve25519][keypair]") { auto kp2 = session::curve25519::curve25519_key_pair(); CHECK(kp1.first.size() == 32); - CHECK(kp1.second.size() == 64); + CHECK(kp1.second.size() == 32); CHECK(kp1.first != kp2.first); CHECK(kp1.second != kp2.second); } diff --git a/tests/test_group_info.cpp b/tests/test_group_info.cpp index df70900a..9d9482db 100644 --- a/tests/test_group_info.cpp +++ b/tests/test_group_info.cpp @@ -13,8 +13,6 @@ using namespace std::literals; using namespace oxenc::literals; -static constexpr int64_t created_ts = 1680064059; - using namespace session::config; TEST_CASE("Group Info settings", "[config][groups][info]") { diff --git a/tests/test_group_keys.cpp b/tests/test_group_keys.cpp index 8a49e964..580772dc 100644 --- a/tests/test_group_keys.cpp +++ b/tests/test_group_keys.cpp @@ -23,8 +23,6 @@ using namespace std::literals; using namespace oxenc::literals; -static constexpr int64_t created_ts = 1680064059; - using namespace session::config; static std::array sk_from_seed(ustring_view seed) { @@ -198,7 +196,7 @@ TEST_CASE("Group Keys - C++ API", "[config][groups][keys][cpp]") { mem_configs.clear(); // add non-admin members, re-key, distribute - for (int i = 0; i < members.size(); ++i) { + for (size_t i = 0; i < members.size(); ++i) { auto m = admin1.members.get_or_construct(members[i].session_id); m.admin = false; m.name = "Member" + std::to_string(i); @@ -314,7 +312,7 @@ TEST_CASE("Group Keys - C++ API", "[config][groups][keys][cpp]") { std::unordered_set{{"keyhash1"s, "keyhash2"s, "keyhash3"s, "keyhash4"s}}); } - for (int i = 0; i < members.size(); i++) { + for (size_t i = 0; i < members.size(); i++) { auto& m = members[i]; bool found_key = m.keys.load_key_message( "keyhash4", new_keys_config2, get_timestamp_ms(), m.info, m.members); @@ -512,7 +510,7 @@ TEST_CASE("Group Keys - C++ API", "[config][groups][keys][cpp]") { CHECK(a.keys.current_hashes() == std::unordered_set{{"keyhash6"s, "keyhash7"s}}); } - for (int i = 0; i < members.size(); i++) { + for (size_t i = 0; i < members.size(); i++) { auto& m = members[i]; CHECK(m.keys.load_key_message( "keyhash6", @@ -729,11 +727,13 @@ TEST_CASE("Group Keys - C API", "[config][groups][keys][c]") { get_timestamp_ms(), m.info, m.members)); - config_string_list* hashes; - REQUIRE_THROWS( - hashes = config_merge(m.info, merge_hash1, &merge_data1[0], &merge_size1[0], 1)); - REQUIRE_THROWS( - hashes = config_merge(m.members, merge_hash1, &merge_data1[1], &merge_size1[1], 1)); + [[maybe_unused]] config_string_list* hashes; + hashes = config_merge(m.info, merge_hash1, &merge_data1[0], &merge_size1[0], 1); + REQUIRE(m.info->last_error == "Cannot merge configs without any decryption keys"sv); + m.info->last_error = nullptr; + hashes = config_merge(m.members, merge_hash1, &merge_data1[1], &merge_size1[1], 1); + REQUIRE(m.members->last_error == "Cannot merge configs without any decryption keys"sv); + m.members->last_error = nullptr; REQUIRE(groups_members_size(m.members) == 0); } @@ -741,7 +741,7 @@ TEST_CASE("Group Keys - C API", "[config][groups][keys][c]") { free(new_info_config1); free(new_mem_config1); - for (int i = 0; i < members.size(); ++i) { + for (size_t i = 0; i < members.size(); ++i) { config_group_member new_mem; REQUIRE(groups_members_get_or_construct( @@ -881,7 +881,7 @@ TEST_CASE("Group Keys - swarm authentication", "[config][groups][keys][swarm]") member.info.id, to_usv(member.secret_key), auth_data)); // Try flipping a bit in each position of the auth data and make sure it fails to validate: - for (int i = 0; i < auth_data.size(); i++) { + for (size_t i = 0; i < auth_data.size(); i++) { for (int b = 0; b < 8; b++) { if (i == 35 && b == 7) // This is the sign bit of k, which can be flipped but gets // flipped back when dealing with the missing X->Ed conversion diff --git a/tests/test_group_members.cpp b/tests/test_group_members.cpp index 60a6e678..d5b566ff 100644 --- a/tests/test_group_members.cpp +++ b/tests/test_group_members.cpp @@ -13,8 +13,6 @@ using namespace std::literals; using namespace oxenc::literals; -static constexpr int64_t created_ts = 1680064059; - using namespace session::config; constexpr bool is_prime100(int i) { diff --git a/tests/test_logging.cpp b/tests/test_logging.cpp new file mode 100644 index 00000000..ba7b1638 --- /dev/null +++ b/tests/test_logging.cpp @@ -0,0 +1,116 @@ +#include + +#include +#include +#include +#include +#include +#include + +using namespace session; +using namespace oxen; +using namespace oxen::log::literals; + +std::regex timestamp_re{R"(\[\d{4}-\d\d-\d\d \d\d:\d\d:\d\d\] \[\+[\d.hms]+\])"}; +// Clears timestamps out of a log statement for testing reproducibility +std::string fixup_log(std::string_view log) { + std::string fixed; + std::regex_replace( + std::back_inserter(fixed), + log.begin(), + log.end(), + timestamp_re, + "[] []", + std::regex_constants::format_first_only); + return fixed; +} + +std::vector simple_logs; +std::vector full_logs; // "cat|level|msg" + +TEST_CASE("Logging callbacks", "[logging]") { + oxen::log::clear_sinks(); + simple_logs.clear(); + full_logs.clear(); + session::logger_reset_level(LogLevel::info); + + SECTION("C++ lambdas") { + session::add_logger([&](std::string_view msg) { simple_logs.emplace_back(msg); }); + session::add_logger([&](auto msg, auto cat, auto level) { + full_logs.push_back("{}|{}|{}"_format(cat, level.to_string(), msg)); + }); + } + SECTION("C function pointers") { + session_add_logger_simple( + [](const char* msg, size_t msglen) { simple_logs.emplace_back(msg, msglen); }); + session_add_logger_full([](const char* msg, + size_t msglen, + const char* cat, + size_t cat_len, + LOG_LEVEL level) { + full_logs.push_back("{}|{}|{}"_format( + std::string{cat, cat_len}, + oxen::log::to_string(static_cast(level)), + std::string{msg, msglen})); + }); + } + + log::critical(log::Cat("test.a"), "abc {}", 21 * 2); +#if defined(__APPLE__) && defined(__clang__) && (__clang_major__ <= 15) +#else + int line0 = __LINE__ - 3; +#endif + log::info(log::Cat("test.b"), "hi"); +#if defined(__APPLE__) && defined(__clang__) && (__clang_major__ <= 15) +#else + int line1 = __LINE__ - 3; +#endif + + oxen::log::clear_sinks(); + + REQUIRE(simple_logs.size() == 2); + REQUIRE(full_logs.size() == 2); + +#if defined(__APPLE__) && defined(__clang__) && (__clang_major__ <= 15) + CHECK(fixup_log(simple_logs[0]) == + "[] [] [test.a:critical|log.hpp:177] abc 42\n"); + CHECK(fixup_log(simple_logs[1]) == "[] [] [test.b:info|log.hpp:98] hi\n"); + CHECK(fixup_log(full_logs[0]) == + "test.a|critical|[] [] [test.a:critical|log.hpp:177] abc 42\n"); + CHECK(fixup_log(full_logs[1]) == + "test.b|info|[] [] [test.b:info|log.hpp:98] hi\n"); +#else + CHECK(fixup_log(simple_logs[0]) == + "[] [] [test.a:critical|tests/test_logging.cpp:{}] abc 42\n"_format( + line0)); + CHECK(fixup_log(simple_logs[1]) == + "[] [] [test.b:info|tests/test_logging.cpp:{}] hi\n"_format(line1)); + CHECK(fixup_log(full_logs[0]) == + "test.a|critical|[] [] [test.a:critical|tests/test_logging.cpp:{}] abc 42\n"_format( + line0)); + CHECK(fixup_log(full_logs[1]) == + "test.b|info|[] [] [test.b:info|tests/test_logging.cpp:{}] hi\n"_format( + line1)); +#endif +} + +TEST_CASE("Logging callbacks with quic::Network", "[logging][network]") { + oxen::log::clear_sinks(); + simple_logs.clear(); + session::logger_set_level("quic", LogLevel::debug); + + session::add_logger([&](std::string_view msg) { simple_logs.emplace_back(msg); }); + + { quic::Network net; } + + oxen::log::clear_sinks(); + + CHECK(simple_logs.size() >= 2); + // CHECK(simple_logs == std::vector{"uncomment me to fail showing all log lines"}); +#if defined(__APPLE__) && defined(__clang__) && defined(RELEASE_BUILD) + CHECK(simple_logs.front().find("Started libevent") != std::string::npos); +#else + CHECK(simple_logs.front().find("Starting libevent") != std::string::npos); +#endif + CHECK(simple_logs.back().find("Loop shutdown complete") != std::string::npos); +} diff --git a/tests/test_multi_encrypt.cpp b/tests/test_multi_encrypt.cpp index 64e72683..c84cb1dc 100644 --- a/tests/test_multi_encrypt.cpp +++ b/tests/test_multi_encrypt.cpp @@ -1,3 +1,4 @@ +#include #include #include @@ -35,7 +36,7 @@ TEST_CASE("Multi-recipient encryption", "[encrypt][multi]") { "0123456789abcdef333333333333333300000000000000000000000000000000"_hexbytes}; std::array x_keys; - for (int i = 0; i < seeds.size(); i++) + for (size_t i = 0; i < seeds.size(); i++) x_keys[i] = to_x_keys(seeds[i]); CHECK(oxenc::to_hex(to_usv(x_keys[0].second)) == @@ -199,7 +200,7 @@ TEST_CASE("Multi-recipient encryption, simpler interface", "[encrypt][multi][sim "0123456789abcdef333333333333333300000000000000000000000000000000"_hexbytes}; std::array x_keys; - for (int i = 0; i < seeds.size(); i++) + for (size_t i = 0; i < seeds.size(); i++) x_keys[i] = to_x_keys(seeds[i]); CHECK(oxenc::to_hex(to_usv(x_keys[0].second)) == @@ -292,12 +293,15 @@ TEST_CASE("Multi-recipient encryption, simpler interface", "[encrypt][multi][sim "test suite", nonce); - CHECK(printable(encrypted) == - printable( - "d1:#24:" + "32ab4bb45d6df5cc14e1c330fb1a8b68ea3826a8c2213a49"_hex + "1:el" + - "21:" + "e64937e5ea201b84f4e88a976dad900d91caaf6a17"_hex + - "21:" + "bcb642c49c6da03f70cdaab2ed6666721318afd631"_hex + - "21:" + "1ecee2215d226817edfdb097f05037eb799309103a"_hex + "ee")); + CHECK(printable(encrypted) == printable(fmt::format( + "d" + "1:#24:{}" + "1:el21:{}21:{}21:{}e" + "e", + "32ab4bb45d6df5cc14e1c330fb1a8b68ea3826a8c2213a49"_hex, + "e64937e5ea201b84f4e88a976dad900d91caaf6a17"_hex, + "bcb642c49c6da03f70cdaab2ed6666721318afd631"_hex, + "1ecee2215d226817edfdb097f05037eb799309103a"_hex))); m1 = session::decrypt_for_multiple_simple( encrypted, diff --git a/tests/test_network.cpp b/tests/test_network.cpp new file mode 100644 index 00000000..f038fb4a --- /dev/null +++ b/tests/test_network.cpp @@ -0,0 +1,1467 @@ +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "utils.hpp" + +using namespace session; +using namespace session::onionreq; +using namespace session::network; + +namespace { +struct Result { + bool success; + bool timeout; + int16_t status_code; + std::vector> headers; + std::optional response; +}; + +service_node test_node(const ustring ed_pk, const uint16_t index, const bool unique_ip = true) { + return service_node{ + ed_pk, + {2, 8, 0}, + INVALID_SWARM_ID, + (unique_ip ? fmt::format("0.0.0.{}", index) : "1.1.1.1"), + index}; +} + +std::optional node_for_destination(network_destination destination) { + if (auto* dest = std::get_if(&destination)) + return *dest; + + return std::nullopt; +} + +} // namespace + +namespace session::network { +class TestNetwork : public Network { + public: + std::unordered_map call_counts; + std::vector calls_to_ignore; + std::chrono::milliseconds retry_delay_value = 0ms; + std::optional> find_valid_path_response; + std::optional last_request_info; + + TestNetwork( + std::optional cache_path, + bool use_testnet, + bool single_path_mode, + bool pre_build_paths) : + Network{cache_path, use_testnet, single_path_mode, pre_build_paths} { + paths_changed = [this](std::vector>) { + call_counts["paths_changed"]++; + }; + } + + void set_suspended(bool suspended_) { suspended = suspended_; } + + bool get_suspended() { return suspended; } + + ConnectionStatus get_status() { return status; } + + void set_snode_cache(std::vector cache) { + // Need to set the `last_snode_cache_update` to `10s` ago because otherwise it'll be + // considered invalid when checking the cache validity + snode_cache = cache; + last_snode_cache_update = (std::chrono::system_clock::now() - 10s); + } + + void set_unused_connections(std::deque unused_connections_) { + unused_connections = unused_connections_; + } + + void set_in_progress_connections( + std::unordered_map in_progress_connections_) { + in_progress_connections = in_progress_connections_; + } + + void add_path(PathType path_type, std::vector nodes) { + paths[path_type].emplace_back( + onion_path{"Test", {nodes[0], nullptr, nullptr, nullptr}, nodes, 0}); + } + + void set_paths(PathType path_type, std::vector paths_) { + paths[path_type] = paths_; + } + + std::vector get_paths(PathType path_type) { return paths[path_type]; } + + void set_all_swarms(std::vector>> all_swarms_) { + all_swarms = all_swarms_; + } + + void set_swarm( + session::onionreq::x25519_pubkey swarm_pubkey, + swarm_id_t swarm_id, + std::vector swarm) { + swarm_cache[swarm_pubkey.hex()] = {swarm_id, swarm}; + } + + std::pair> get_cached_swarm( + session::onionreq::x25519_pubkey swarm_pubkey) { + return swarm_cache[swarm_pubkey.hex()]; + } + + swarm_id_t get_swarm_id(std::string swarm_pubkey_hex) { + if (swarm_pubkey_hex.size() == 66) + swarm_pubkey_hex = swarm_pubkey_hex.substr(2); + + auto pk = x25519_pubkey::from_hex(swarm_pubkey_hex); + std::promise prom; + get_swarm(pk, [&prom](swarm_id_t result, std::vector) { + prom.set_value(result); + }); + return prom.get_future().get(); + } + + void set_failure_count(service_node node, uint8_t failure_count) { + snode_failure_counts[node.to_string()] = failure_count; + } + + uint8_t get_failure_count(service_node node) { + return snode_failure_counts.try_emplace(node.to_string(), 0).first->second; + } + + uint8_t get_failure_count(PathType path_type, onion_path path) { + auto current_paths = paths[path_type]; + auto target_path = std::find_if( + current_paths.begin(), current_paths.end(), [&path](const auto& path_it) { + return path_it.nodes[0] == path.nodes[0]; + }); + + if (target_path != current_paths.end()) + return target_path->failure_count; + + return 0; + } + + void set_path_build_queue(std::deque path_build_queue_) { + path_build_queue = path_build_queue_; + } + + std::deque get_path_build_queue() { return path_build_queue; } + + void set_path_build_failures(int path_build_failures_) { + path_build_failures = path_build_failures_; + } + + int get_path_build_failures() { return path_build_failures; } + + void set_unused_nodes(std::vector unused_nodes_) { unused_nodes = unused_nodes_; } + + std::vector get_unused_nodes() { return Network::get_unused_nodes(); } + + std::vector get_unused_nodes_value() { return unused_nodes; } + + void add_pending_request(PathType path_type, request_info info) { + request_queue[path_type].emplace_back( + std::move(info), + [](bool, + bool, + int16_t, + std::vector>, + std::optional) {}); + } + + // Overridden Functions + + std::chrono::milliseconds retry_delay(int, std::chrono::milliseconds) override { + return retry_delay_value; + } + + void update_disk_cache_throttled(bool force_immediate_write) override { + const auto func_name = "update_disk_cache_throttled"; + + if (check_should_ignore_and_log_call(func_name)) + return; + + Network::update_disk_cache_throttled(force_immediate_write); + } + + void establish_and_store_connection(std::string request_id) override { + const auto func_name = "establish_and_store_connection"; + + if (check_should_ignore_and_log_call(func_name)) + return; + + Network::establish_and_store_connection(request_id); + } + + void refresh_snode_cache(std::optional existing_request_id) override { + const auto func_name = "refresh_snode_cache"; + + if (check_should_ignore_and_log_call(func_name)) + return; + + Network::refresh_snode_cache(existing_request_id); + } + + void build_path(std::string path_id, PathType path_type) override { + const auto func_name = "build_path"; + + if (check_should_ignore_and_log_call(func_name)) + return; + + Network::build_path(path_id, path_type); + } + + std::optional find_valid_path( + request_info info, std::vector paths) override { + const auto func_name = "find_valid_path"; + + if (check_should_ignore_and_log_call(func_name)) + return std::nullopt; + + if (find_valid_path_response) + return *find_valid_path_response; + + return Network::find_valid_path(info, paths); + } + + void check_request_queue_timeouts(std::optional request_timeout_id) override { + const auto func_name = "check_request_queue_timeouts"; + + if (check_should_ignore_and_log_call(func_name)) + return; + + Network::check_request_queue_timeouts(request_timeout_id); + } + + void _send_onion_request( + request_info info, network_response_callback_t handle_response) override { + const auto func_name = "_send_onion_request"; + last_request_info = info; + + if (check_should_ignore_and_log_call(func_name)) + return; + + Network::_send_onion_request(std::move(info), std::move(handle_response)); + } + + // Exposing Private Functions + + void establish_connection( + std::string request_id, + service_node target, + std::optional timeout, + std::function error)> callback) { + Network::establish_connection(request_id, target, timeout, std::move(callback)); + } + + void build_path_if_needed(PathType path_type, bool found_valid_path) override { + return Network::build_path_if_needed(path_type, found_valid_path); + } + + void send_request( + request_info info, connection_info conn, network_response_callback_t handle_response) { + Network::send_request(info, conn, std::move(handle_response)); + } + + void handle_errors( + request_info info, + connection_info conn_info, + bool timeout, + int16_t status_code, + std::vector> headers, + std::optional response, + std::optional handle_response) override { + call_counts["handle_errors"]++; + Network::handle_errors( + info, + conn_info, + timeout, + status_code, + headers, + response, + std::move(handle_response)); + } + + // Mocking Functions + + template + void ignore_calls_to(Strings&&... __args) { + (calls_to_ignore.emplace_back(std::forward(__args)), ...); + } + + bool check_should_ignore_and_log_call(std::string func_name) { + call_counts[func_name]++; + + return std::find(calls_to_ignore.begin(), calls_to_ignore.end(), func_name) != + calls_to_ignore.end(); + } + + void reset_calls() { return call_counts.clear(); } + bool called(std::string func_name, int times = 1) { return (call_counts[func_name] >= times); } + + bool did_not_call(std::string func_name) { return !call_counts.contains(func_name); } +}; +} // namespace session::network + +TEST_CASE("Network Url Parsing", "[network][parse_url]") { + auto [proto1, host1, port1, path1] = parse_url("HTTPS://example.com/test"); + auto [proto2, host2, port2, path2] = parse_url("http://example2.com:1234/test/123456"); + auto [proto3, host3, port3, path3] = parse_url("https://example3.com"); + auto [proto4, host4, port4, path4] = parse_url("https://example4.com/test?value=test"); + + CHECK(proto1 == "https://"); + CHECK(proto2 == "http://"); + CHECK(proto3 == "https://"); + CHECK(proto4 == "https://"); + CHECK(host1 == "example.com"); + CHECK(host2 == "example2.com"); + CHECK(host3 == "example3.com"); + CHECK(host4 == "example4.com"); + CHECK(port1.value_or(9999) == 9999); + CHECK(port2.value_or(9999) == 1234); + CHECK(port3.value_or(9999) == 9999); + CHECK(port4.value_or(9999) == 9999); + CHECK(path1.value_or("NULL") == "/test"); + CHECK(path2.value_or("NULL") == "/test/123456"); + CHECK(path3.value_or("NULL") == "NULL"); + CHECK(path4.value_or("NULL") == "/test?value=test"); +} + +TEST_CASE("Network error handling", "[network]") { + auto ed_pk = "4cb76fdc6d32278e3f83dbf608360ecc6b65727934b85d2fb86862ff98c46ab7"_hexbytes; + auto ed_pk2 = "5ea34e72bb044654a6a23675690ef5ffaaf1656b02f93fb76655f9cbdbe89876"_hexbytes; + auto ed_sk = + "4cb76fdc6d32278e3f83dbf608360ecc6b65727934b85d2fb86862ff98c46ab78862834829a" + "87e0afadfed763fa8785e893dbde7f2c001ff1071aa55005c347f"_hexbytes; + auto x_pk_hex = "d2ad010eeb72d72e561d9de7bd7b6989af77dcabffa03a5111a6c859ae5c3a72"; + auto target = test_node(ed_pk, 0); + auto target2 = test_node(ed_pk2, 1); + auto target3 = test_node(ed_pk2, 2); + auto target4 = test_node(ed_pk2, 3); + auto path = + onion_path{"Test", {target, nullptr, nullptr, nullptr}, {target, target2, target3}, 0}; + auto mock_request = request_info{ + "AAAA", + target, + "test", + std::nullopt, + std::nullopt, + std::nullopt, + PathType::standard, + 0ms, + std::nullopt, + std::chrono::system_clock::now(), + std::nullopt, + true}; + Result result; + std::optional network; + + // Check the handling of the codes which make no changes + auto codes_with_no_changes = {400, 404, 406, 425}; + + for (auto code : codes_with_no_changes) { + network.emplace(std::nullopt, true, true, false); + network->set_suspended(true); // Make no requests in this test + network->ignore_calls_to("_send_onion_request", "update_disk_cache_throttled"); + network->set_paths(PathType::standard, {path}); + network->handle_errors( + mock_request, + {target, nullptr, nullptr, nullptr}, + false, + code, + {}, + std::nullopt, + [&result]( + bool success, + bool timeout, + int16_t status_code, + std::vector> headers, + std::optional response) { + result = {success, timeout, status_code, headers, response}; + }); + + CHECK_FALSE(result.success); + CHECK_FALSE(result.timeout); + CHECK(result.status_code == code); + CHECK_FALSE(result.response.has_value()); + CHECK(network->get_failure_count(target) == 0); + CHECK(network->get_failure_count(target2) == 0); + CHECK(network->get_failure_count(target3) == 0); + CHECK(network->get_failure_count(PathType::standard, path) == 0); + } + + // Check general error handling (first failure) + network.emplace(std::nullopt, true, true, false); + network->set_suspended(true); // Make no requests in this test + network->ignore_calls_to("_send_onion_request", "update_disk_cache_throttled"); + network->set_paths(PathType::standard, {path}); + network->handle_errors( + mock_request, + {target, nullptr, nullptr, nullptr}, + false, + 500, + {}, + std::nullopt, + [&result]( + bool success, + bool timeout, + int16_t status_code, + std::vector> headers, + std::optional response) { + result = {success, timeout, status_code, headers, response}; + }); + CHECK_FALSE(result.success); + CHECK_FALSE(result.timeout); + CHECK(result.status_code == 500); + CHECK_FALSE(result.response.has_value()); + CHECK(network->get_failure_count(target) == 0); + CHECK(network->get_failure_count(target2) == 0); + CHECK(network->get_failure_count(target3) == 0); + CHECK(network->get_failure_count(PathType::standard, path) == 1); + + // // Check general error handling with no response (too many path failures) + path = onion_path{"Test", {target, nullptr, nullptr, nullptr}, {target, target2, target3}, 9}; + network.emplace(std::nullopt, true, true, false); + network->set_suspended(true); // Make no requests in this test + network->ignore_calls_to("_send_onion_request", "update_disk_cache_throttled"); + network->set_paths(PathType::standard, {path}); + network->handle_errors( + mock_request, + {target, nullptr, nullptr, nullptr}, + false, + 500, + {}, + std::nullopt, + [&result]( + bool success, + bool timeout, + int16_t status_code, + std::vector> headers, + std::optional response) { + result = {success, timeout, status_code, headers, response}; + }); + + CHECK_FALSE(result.success); + CHECK_FALSE(result.timeout); + CHECK(result.status_code == 500); + CHECK_FALSE(result.response.has_value()); + CHECK(network->get_failure_count(target) == 3); // Guard node dropped + CHECK(network->get_failure_count(target2) == 1); // Other nodes incremented + CHECK(network->get_failure_count(target3) == 1); // Other nodes incremented + CHECK(network->get_failure_count(PathType::standard, path) == 0); // Path dropped and reset + + // // Check general error handling with a path and specific node failure + path = onion_path{"Test", {target, nullptr, nullptr, nullptr}, {target, target2, target3}, 0}; + auto response = std::string{"Next node not found: "} + ed25519_pubkey::from_bytes(ed_pk2).hex(); + network.emplace(std::nullopt, true, true, false); + network->set_suspended(true); // Make no requests in this test + network->ignore_calls_to("_send_onion_request", "update_disk_cache_throttled"); + network->set_snode_cache({target, target2, target3, target4}); + network->set_unused_nodes({target4}); + network->set_paths(PathType::standard, {path}); + network->handle_errors( + mock_request, + {target, nullptr, nullptr, nullptr}, + false, + 500, + {}, + response, + [&result]( + bool success, + bool timeout, + int16_t status_code, + std::vector> headers, + std::optional response) { + result = {success, timeout, status_code, headers, response}; + }); + + CHECK_FALSE(result.success); + CHECK_FALSE(result.timeout); + CHECK(result.status_code == 500); + CHECK(result.response == response); + CHECK(network->get_failure_count(target) == 0); + CHECK(network->get_failure_count(target2) == 3); // Node will have been dropped + CHECK(network->get_failure_count(target3) == 0); + CHECK(network->get_paths(PathType::standard).front().nodes[1] != target2); + CHECK(network->get_failure_count(PathType::standard, path) == + 1); // Incremented because conn_info is invalid + + // Check a 421 with no swarm data throws (no good way to handle this case) + network.emplace(std::nullopt, true, true, false); + network->set_suspended(true); // Make no requests in this test + network->ignore_calls_to("_send_onion_request", "update_disk_cache_throttled"); + network->set_paths(PathType::standard, {path}); + network->handle_errors( + mock_request, + {target, nullptr, nullptr, nullptr}, + false, + 421, + {}, + std::nullopt, + [&result]( + bool success, + bool timeout, + int16_t status_code, + std::vector> headers, + std::optional response) { + result = {success, timeout, status_code, headers, response}; + }); + CHECK_FALSE(result.success); + CHECK_FALSE(result.timeout); + CHECK(result.status_code == 421); + CHECK(network->get_failure_count(target) == 0); + CHECK(network->get_failure_count(target2) == 0); + CHECK(network->get_failure_count(target3) == 0); + CHECK(network->get_failure_count(PathType::standard, path) == 1); + + // Check a non redirect 421 triggers a retry using a different node + auto mock_request2 = request_info{ + "BBBB", + target, + "test", + std::nullopt, + std::nullopt, + x25519_pubkey::from_hex(x_pk_hex), + PathType::standard, + 0ms, + std::nullopt, + std::chrono::system_clock::now(), + std::nullopt, + true}; + network.emplace(std::nullopt, true, true, false); + network->set_suspended(true); // Make no requests in this test + network->ignore_calls_to("_send_onion_request", "update_disk_cache_throttled"); + network->set_swarm(x25519_pubkey::from_hex(x_pk_hex), 1, {target, target2, target3}); + network->set_paths(PathType::standard, {path}); + network->reset_calls(); + network->handle_errors( + mock_request2, + {target, nullptr, nullptr, nullptr}, + false, + 421, + {}, + std::nullopt, + [](bool, + bool, + int16_t, + std::vector>, + std::optional) {}); + CHECK(EVENTUALLY(10ms, network->called("_send_onion_request"))); + REQUIRE(network->last_request_info.has_value()); + CHECK(node_for_destination(network->last_request_info->destination) != + node_for_destination(mock_request2.destination)); + + // Check that when a retry request of a 421 receives it's own 421 that it tries + // to update the snode cache + auto mock_request3 = request_info{ + "BBBB", + target, + "test", + std::nullopt, + std::nullopt, + x25519_pubkey::from_hex(x_pk_hex), + PathType::standard, + 0ms, + std::nullopt, + std::chrono::system_clock::now(), + request_info::RetryReason::redirect, + true}; + network.emplace(std::nullopt, true, true, false); + network->set_suspended(true); // Make no requests in this test + network->ignore_calls_to( + "_send_onion_request", "update_disk_cache_throttled", "refresh_snode_cache"); + network->set_paths(PathType::standard, {path}); + network->handle_errors( + mock_request3, + {target, nullptr, nullptr, nullptr}, + false, + 421, + {}, + std::nullopt, + [](bool, + bool, + int16_t, + std::vector>, + std::optional) {}); + CHECK(EVENTUALLY(10ms, network->called("refresh_snode_cache"))); + + // Check when the retry after refreshing the snode cache due to a 421 receives it's own 421 it + // is handled like any other error + auto mock_request4 = request_info{ + "BBBB", + target, + "test", + std::nullopt, + std::nullopt, + x25519_pubkey::from_hex(x_pk_hex), + PathType::standard, + 0ms, + std::nullopt, + std::chrono::system_clock::now(), + request_info::RetryReason::redirect_swarm_refresh, + true}; + network.emplace(std::nullopt, true, true, false); + network->set_suspended(true); // Make no requests in this test + network->ignore_calls_to("_send_onion_request", "update_disk_cache_throttled"); + network->set_paths(PathType::standard, {path}); + network->handle_errors( + mock_request4, + {target, nullptr, nullptr, nullptr}, + false, + 421, + {}, + std::nullopt, + [&result]( + bool success, + bool timeout, + int16_t status_code, + std::vector> headers, + std::optional response) { + result = {success, timeout, status_code, headers, response}; + }); + CHECK_FALSE(result.success); + CHECK_FALSE(result.timeout); + CHECK(result.status_code == 421); + CHECK(network->get_failure_count(target) == 0); + CHECK(network->get_failure_count(target2) == 0); + CHECK(network->get_failure_count(target3) == 0); + CHECK(network->get_failure_count(PathType::standard, path) == 1); + + // Check a timeout with a sever destination doesn't impact the failure counts + auto server = ServerDestination{ + "https", + "open.getsession.org", + "/rooms", + x25519_pubkey::from_hex("a03c383cf63c3c4efe67acc52112a6dd734b3a946b9545f488aaa93da79912" + "38"), + 443, + std::nullopt, + "GET"}; + auto mock_request5 = request_info{ + "CCCC", + server, + "test", + std::nullopt, + std::nullopt, + x25519_pubkey::from_hex(x_pk_hex), + PathType::standard, + 0ms, + std::nullopt, + std::chrono::system_clock::now(), + std::nullopt, + false}; + network.emplace(std::nullopt, true, true, false); + network->set_suspended(true); // Make no requests in this test + network->ignore_calls_to("_send_onion_request", "update_disk_cache_throttled"); + network->handle_errors( + mock_request5, + {target, nullptr, nullptr, nullptr}, + true, + -1, + {}, + "Test", + [&result]( + bool success, + bool timeout, + int16_t status_code, + std::vector> headers, + std::optional response) { + result = {success, timeout, status_code, headers, response}; + }); + CHECK_FALSE(result.success); + CHECK(result.timeout); + CHECK(result.status_code == -1); + CHECK(network->get_failure_count(target) == 0); + CHECK(network->get_failure_count(target2) == 0); + CHECK(network->get_failure_count(target3) == 0); + CHECK(network->get_failure_count(PathType::standard, path) == 0); + + // Check a server response starting with '500 Internal Server Error' is reported as a `500` + // error and doesn't affect the failure count + network.emplace(std::nullopt, true, true, false); + network->set_suspended(true); // Make no requests in this test + network->ignore_calls_to("_send_onion_request", "update_disk_cache_throttled"); + network->handle_errors( + mock_request4, + {target, nullptr, nullptr, nullptr}, + false, + -1, + {}, + "500 Internal Server Error", + [&result]( + bool success, + bool timeout, + int16_t status_code, + std::vector> headers, + std::optional response) { + result = {success, timeout, status_code, headers, response}; + }); + CHECK_FALSE(result.success); + CHECK_FALSE(result.timeout); + CHECK(result.status_code == 500); + CHECK(network->get_failure_count(target) == 0); + CHECK(network->get_failure_count(target2) == 0); + CHECK(network->get_failure_count(target3) == 0); + CHECK(network->get_failure_count(PathType::standard, path) == 0); +} + +TEST_CASE("Network Path Building", "[network][get_unused_nodes]") { + const auto ed_pk = "4cb76fdc6d32278e3f83dbf608360ecc6b65727934b85d2fb86862ff98c46ab7"_hexbytes; + std::optional network; + std::vector snode_cache; + std::vector unused_nodes; + for (uint16_t i = 0; i < 12; ++i) + snode_cache.emplace_back(test_node(ed_pk, i)); + auto invalid_info = connection_info{snode_cache[0], nullptr, nullptr, nullptr}; + auto path = + onion_path{"Test", invalid_info, {snode_cache[0], snode_cache[1], snode_cache[2]}, 0}; + + // Should shuffle the result + network.emplace(std::nullopt, true, false, false); + network->set_snode_cache(snode_cache); + CHECK(network->get_unused_nodes() != network->get_unused_nodes()); + + // Should contain the entire snode cache initially + network.emplace(std::nullopt, true, false, false); + network->set_snode_cache(snode_cache); + unused_nodes = network->get_unused_nodes(); + std::stable_sort(unused_nodes.begin(), unused_nodes.end()); + CHECK(unused_nodes == snode_cache); + + // Should exclude nodes used in paths + network.emplace(std::nullopt, true, false, false); + network->set_snode_cache(snode_cache); + network->set_paths(PathType::standard, {path}); + unused_nodes = network->get_unused_nodes(); + std::stable_sort(unused_nodes.begin(), unused_nodes.end()); + CHECK(unused_nodes == std::vector{snode_cache.begin() + 3, snode_cache.end()}); + + // Should exclude nodes in unused connections + network.emplace(std::nullopt, true, false, false); + network->set_snode_cache(snode_cache); + network->set_unused_connections({invalid_info}); + unused_nodes = network->get_unused_nodes(); + std::stable_sort(unused_nodes.begin(), unused_nodes.end()); + CHECK(unused_nodes == std::vector{snode_cache.begin() + 1, snode_cache.end()}); + + // Should exclude nodes in in-progress connections + network.emplace(std::nullopt, true, false, false); + network->set_snode_cache(snode_cache); + network->set_in_progress_connections({{"Test", snode_cache.front()}}); + unused_nodes = network->get_unused_nodes(); + std::stable_sort(unused_nodes.begin(), unused_nodes.end()); + CHECK(unused_nodes == std::vector{snode_cache.begin() + 1, snode_cache.end()}); + + // Should exclude nodes destinations in pending requests + network.emplace(std::nullopt, true, false, false); + network->set_snode_cache(snode_cache); + network->add_pending_request( + PathType::standard, + request_info::make( + snode_cache.front(), + std::nullopt, + std::nullopt, + 1s, + std::nullopt, + PathType::standard)); + unused_nodes = network->get_unused_nodes(); + std::stable_sort(unused_nodes.begin(), unused_nodes.end()); + CHECK(unused_nodes == std::vector{snode_cache.begin() + 1, snode_cache.end()}); + + // Should exclude nodes which have passed the failure threshold + network.emplace(std::nullopt, true, false, false); + network->set_snode_cache(snode_cache); + network->set_failure_count(snode_cache.front(), 10); + unused_nodes = network->get_unused_nodes(); + std::stable_sort(unused_nodes.begin(), unused_nodes.end()); + CHECK(unused_nodes == std::vector{snode_cache.begin() + 1, snode_cache.end()}); + + // Should exclude nodes which have the same IP if one was excluded + std::vector same_ip_snode_cache; + auto unique_node = service_node{ed_pk, {2, 8, 0}, INVALID_SWARM_ID, "0.0.0.20", uint16_t{20}}; + for (uint16_t i = 0; i < 11; ++i) + same_ip_snode_cache.emplace_back(test_node(ed_pk, i, false)); + same_ip_snode_cache.emplace_back(unique_node); + network.emplace(std::nullopt, true, false, false); + network->set_snode_cache(same_ip_snode_cache); + network->set_failure_count(same_ip_snode_cache.front(), 10); + unused_nodes = network->get_unused_nodes(); + REQUIRE(unused_nodes.size() == 1); + CHECK(unused_nodes.front() == unique_node); +} + +TEST_CASE("Network Path Building", "[network][build_path]") { + const auto ed_pk = "4cb76fdc6d32278e3f83dbf608360ecc6b65727934b85d2fb86862ff98c46ab7"_hexbytes; + std::optional network; + std::vector snode_cache; + for (uint16_t i = 0; i < 12; ++i) + snode_cache.emplace_back(test_node(ed_pk, i)); + auto invalid_info = connection_info{snode_cache[0], nullptr, nullptr, nullptr}; + + // Nothing should happen if the network is suspended + network.emplace(std::nullopt, true, false, false); + network->set_suspended(true); + network->build_path("Test1", PathType::standard); + CHECK(ALWAYS(10ms, network->did_not_call("establish_and_store_connection"))); + + // If there are no unused connections it puts the path build in the queue and calls + // establish_and_store_connection + network.emplace(std::nullopt, true, false, false); + network->ignore_calls_to("establish_and_store_connection"); + network->build_path("Test1", PathType::standard); + CHECK(network->get_path_build_queue() == std::deque{PathType::standard}); + CHECK(EVENTUALLY(10ms, network->called("establish_and_store_connection"))); + + // If the unused nodes are empty it refreshes them + network.emplace(std::nullopt, true, false, false); + network->set_snode_cache(snode_cache); + network->set_unused_connections({invalid_info}); + network->set_in_progress_connections({{"TestInProgress", snode_cache.front()}}); + network->build_path("Test1", PathType::standard); + CHECK(network->get_unused_nodes_value().size() == snode_cache.size() - 3); + CHECK(network->get_path_build_queue().empty()); + + // It should exclude nodes that are already in existing paths + network.emplace(std::nullopt, true, false, false); + network->set_snode_cache(snode_cache); + network->set_unused_connections({invalid_info}); + network->set_in_progress_connections({{"TestInProgress", snode_cache.front()}}); + network->add_path(PathType::standard, {snode_cache.begin() + 1, snode_cache.begin() + 1 + 3}); + network->build_path("Test1", PathType::standard); + CHECK(network->get_unused_nodes_value().size() == (snode_cache.size() - 3 - 3)); + CHECK(network->get_path_build_queue().empty()); + + // If there aren't enough unused nodes it resets the failure count, re-queues the path build and + // triggers a snode cache refresh + network.emplace(std::nullopt, true, false, false); + network->ignore_calls_to("refresh_snode_cache"); + network->set_snode_cache(snode_cache); + network->set_unused_connections({invalid_info}); + network->set_path_build_failures(10); + network->add_path(PathType::standard, snode_cache); + network->build_path("Test1", PathType::standard); + CHECK(network->get_path_build_failures() == 0); + CHECK(network->get_path_build_queue() == std::deque{PathType::standard}); + CHECK(EVENTUALLY(10ms, network->called("refresh_snode_cache"))); + + // If it can't build a path after excluding nodes with the same IP it increments the + // failure count and re-tries the path build after a small delay + network.emplace(std::nullopt, true, false, false); + network->set_snode_cache(snode_cache); + network->set_unused_connections({invalid_info}); + network->set_unused_nodes(std::vector{ + snode_cache[0], snode_cache[0], snode_cache[0], snode_cache[0]}); + network->build_path("Test1", PathType::standard); + network->ignore_calls_to("build_path"); // Ignore the 2nd loop + CHECK(network->get_path_build_failures() == 1); + CHECK(network->get_path_build_queue().empty()); + CHECK(EVENTUALLY(10ms, network->called("build_path", 2))); + + // It stores a successful non-standard path and kicks of queued requests but doesn't update the + // status or call the 'paths_changed' hook + network.emplace(std::nullopt, true, false, false); + network->find_valid_path_response = + onion_path{"Test", invalid_info, {snode_cache.begin(), snode_cache.begin() + 3}, 0}; + network->ignore_calls_to("_send_onion_request"); + network->set_snode_cache(snode_cache); + network->set_unused_connections({invalid_info}); + network->add_pending_request( + PathType::download, + request_info::make( + snode_cache.back(), + std::nullopt, + std::nullopt, + 1s, + std::nullopt, + PathType::download)); + network->build_path("Test1", PathType::download); + CHECK(EVENTUALLY(10ms, network->called("_send_onion_request"))); + CHECK(network->get_paths(PathType::download).size() == 1); + + // It stores a successful 'standard' path, updates the status, calls the 'paths_changed' hook + // and kicks of queued requests + network.emplace(std::nullopt, true, false, false); + network->find_valid_path_response = + onion_path{"Test", invalid_info, {snode_cache.begin(), snode_cache.begin() + 3}, 0}; + network->ignore_calls_to("_send_onion_request"); + network->set_snode_cache(snode_cache); + network->set_unused_connections({invalid_info}); + network->add_pending_request( + PathType::standard, + request_info::make( + snode_cache.back(), + std::nullopt, + std::nullopt, + 1s, + std::nullopt, + PathType::standard)); + network->build_path("Test1", PathType::standard); + CHECK(EVENTUALLY(10ms, network->called("_send_onion_request"))); + CHECK(network->get_paths(PathType::standard).size() == 1); + CHECK(network->get_status() == ConnectionStatus::connected); + CHECK(network->called("paths_changed")); +} + +TEST_CASE("Network Find Valid Path", "[network][find_valid_path]") { + auto ed_pk = "4cb76fdc6d32278e3f83dbf608360ecc6b65727934b85d2fb86862ff98c46ab7"_hexbytes; + auto target = test_node(ed_pk, 1); + auto test_service_node = service_node{ + "decaf007f26d3d6f9b845ad031ffdf6d04638c25bb10b8fffbbe99135303c4b9"_hexbytes, + {2, 8, 0}, + INVALID_SWARM_ID, + "144.76.164.202", + uint16_t{35400}}; + auto network = TestNetwork(std::nullopt, true, false, false); + auto info = request_info::make(target, std::nullopt, std::nullopt, 0ms); + auto invalid_path = onion_path{ + "Test", + {test_service_node, nullptr, nullptr, nullptr}, + {test_service_node}, + uint8_t{0}}; + + // It returns nothing when given no path options + CHECK_FALSE(network.find_valid_path(info, {}).has_value()); + + // It ignores invalid paths + CHECK_FALSE(network.find_valid_path(info, {invalid_path}).has_value()); + + // Need to get a valid path for subsequent tests + std::promise>> prom; + + network.establish_connection( + "Test", + test_service_node, + 3s, + [&prom](connection_info conn_info, std::optional error) { + prom.set_value({std::move(conn_info), error}); + }); + + // Wait for the result to be set + auto result = prom.get_future().get(); + REQUIRE(result.first.is_valid()); + auto valid_path = onion_path{ + "Test", + std::move(result.first), + std::vector{test_service_node}, + uint8_t{0}}; + + // It excludes paths which include the IP of the target + auto shared_ip_info = request_info::make(test_service_node, std::nullopt, std::nullopt, 0ms); + CHECK_FALSE(network.find_valid_path(shared_ip_info, {valid_path}).has_value()); + + // It returns a path when there is a valid one + CHECK(network.find_valid_path(info, {valid_path}).has_value()); + + // In 'single_path_mode' it does allow the path to include the IP of the target (so that + // requests can still be made) + auto network_single_path = TestNetwork(std::nullopt, true, true, false); + CHECK(network_single_path.find_valid_path(shared_ip_info, {valid_path}).has_value()); +} + +TEST_CASE("Network Enqueue Path Build", "[network][build_path_if_needed]") { + auto ed_pk = "4cb76fdc6d32278e3f83dbf608360ecc6b65727934b85d2fb86862ff98c46ab7"_hexbytes; + auto target = test_node(ed_pk, 0); + ; + std::optional network; + auto invalid_path = onion_path{ + "Test", connection_info{target, nullptr, nullptr, nullptr}, {target}, uint8_t{0}}; + + // It does not add additional path builds if there is already a path and it's in + // 'single_path_mode' + network.emplace(std::nullopt, true, true, false); + network->ignore_calls_to("establish_and_store_connection"); + network->set_paths(PathType::standard, {invalid_path}); + network->build_path_if_needed(PathType::standard, false); + CHECK(ALWAYS(10ms, network->did_not_call("establish_and_store_connection"))); + CHECK(network->get_path_build_queue().empty()); + + // Adds a path build to the queue + network.emplace(std::nullopt, true, false, false); + network->ignore_calls_to("establish_and_store_connection"); + network->set_paths(PathType::standard, {}); + network->build_path_if_needed(PathType::standard, false); + CHECK(EVENTUALLY(10ms, network->called("establish_and_store_connection"))); + CHECK(network->get_path_build_queue() == std::deque{PathType::standard}); + + // Can only add the correct number of 'standard' path builds to the queue + network.emplace(std::nullopt, true, false, false); + network->ignore_calls_to("establish_and_store_connection"); + network->build_path_if_needed(PathType::standard, false); + network->build_path_if_needed(PathType::standard, false); + CHECK(EVENTUALLY(10ms, network->called("establish_and_store_connection", 2))); + network->reset_calls(); // This triggers 'call_soon' so we need to wait until they are enqueued + network->build_path_if_needed(PathType::standard, false); + CHECK(ALWAYS(10ms, network->did_not_call("establish_and_store_connection"))); + CHECK(network->get_path_build_queue() == + std::deque{PathType::standard, PathType::standard}); + + // Can add additional 'standard' path builds if below the minimum threshold + network.emplace(std::nullopt, true, false, false); + network->ignore_calls_to("establish_and_store_connection"); + network->set_paths(PathType::standard, {invalid_path}); + network->build_path_if_needed(PathType::standard, false); + CHECK(EVENTUALLY(10ms, network->called("establish_and_store_connection"))); + CHECK(network->get_path_build_queue() == std::deque{PathType::standard}); + + // Can add more path builds if there are enough active paths of the same type, no pending paths + // and no `found_path` was provided + network.emplace(std::nullopt, true, false, false); + network->ignore_calls_to("establish_and_store_connection"); + network->set_paths(PathType::standard, {invalid_path, invalid_path}); + network->build_path_if_needed(PathType::standard, false); + CHECK(EVENTUALLY(10ms, network->called("establish_and_store_connection"))); + CHECK(network->get_path_build_queue() == std::deque{PathType::standard}); + + // Cannot add more path builds if there are already enough active paths of the same type and a + // `found_path` was provided + network.emplace(std::nullopt, true, false, false); + network->ignore_calls_to("establish_and_store_connection"); + network->set_paths(PathType::standard, {invalid_path, invalid_path}); + network->build_path_if_needed(PathType::standard, true); + CHECK(ALWAYS(10ms, network->did_not_call("establish_and_store_connection"))); + CHECK(network->get_path_build_queue().empty()); + + // Cannot add more path builds if there is already a build of the same type in the queue and the + // number of active and pending builds of the same type meet the limit + network.emplace(std::nullopt, true, false, false); + network->ignore_calls_to("establish_and_store_connection"); + network->set_paths(PathType::standard, {invalid_path}); + network->set_path_build_queue({PathType::standard}); + network->build_path_if_needed(PathType::standard, false); + CHECK(ALWAYS(10ms, network->did_not_call("establish_and_store_connection"))); + CHECK(network->get_path_build_queue() == std::deque{PathType::standard}); + + // Can only add the correct number of 'download' path builds to the queue + network.emplace(std::nullopt, true, false, false); + network->ignore_calls_to("establish_and_store_connection"); + network->build_path_if_needed(PathType::download, false); + network->build_path_if_needed(PathType::download, false); + CHECK(EVENTUALLY(10ms, network->called("establish_and_store_connection", 2))); + network->reset_calls(); // This triggers 'call_soon' so we need to wait until they are enqueued + network->build_path_if_needed(PathType::download, false); + CHECK(ALWAYS(10ms, network->did_not_call("establish_and_store_connection"))); + CHECK(network->get_path_build_queue() == + std::deque{PathType::download, PathType::download}); + + // Can only add the correct number of 'upload' path builds to the queue + network.emplace(std::nullopt, true, false, false); + network->ignore_calls_to("establish_and_store_connection"); + network->build_path_if_needed(PathType::upload, false); + network->build_path_if_needed(PathType::upload, false); + CHECK(EVENTUALLY(10ms, network->called("establish_and_store_connection", 2))); + network->reset_calls(); // This triggers 'call_soon' so we need to wait until they are enqueued + network->build_path_if_needed(PathType::upload, false); + CHECK(ALWAYS(10ms, network->did_not_call("establish_and_store_connection"))); + CHECK(network->get_path_build_queue() == + std::deque{PathType::upload, PathType::upload}); +} + +TEST_CASE("Network requests", "[network][establish_connection]") { + auto test_service_node = service_node{ + "decaf007f26d3d6f9b845ad031ffdf6d04638c25bb10b8fffbbe99135303c4b9"_hexbytes, + {2, 8, 0}, + INVALID_SWARM_ID, + "144.76.164.202", + uint16_t{35400}}; + auto network = TestNetwork(std::nullopt, true, true, false); + std::promise>> prom; + + network.establish_connection( + "Test", + test_service_node, + 3s, + [&prom](connection_info info, std::optional error) { + prom.set_value({info, error}); + }); + + // Wait for the result to be set + auto result = prom.get_future().get(); + + CHECK(result.first.is_valid()); + CHECK_FALSE(result.second.has_value()); +} + +TEST_CASE("Network requests", "[network][check_request_queue_timeouts]") { + auto test_service_node = service_node{ + "decaf007f26d3d6f9b845ad031ffdf6d04638c25bb10b8fffbbe99135303c4b9"_hexbytes, + {2, 8, 0}, + INVALID_SWARM_ID, + "144.76.164.202", + uint16_t{35400}}; + std::optional network; + std::promise prom; + + // Test that it doesn't start checking for timeouts when the request doesn't have + // a build paths timeout + network.emplace(std::nullopt, true, true, false); + network->send_onion_request( + test_service_node, + ustring{to_usv("{\"method\":\"info\",\"params\":{}}")}, + std::nullopt, + [](bool, + bool, + int16_t, + std::vector>, + std::optional) {}, + oxen::quic::DEFAULT_TIMEOUT, + std::nullopt); + CHECK(ALWAYS(300ms, network->did_not_call("check_request_queue_timeouts"))); + + // Test that it does start checking for timeouts when the request has a + // paths build timeout + network.emplace(std::nullopt, true, true, false); + network->ignore_calls_to("build_path"); + network->send_onion_request( + test_service_node, + ustring{to_usv("{\"method\":\"info\",\"params\":{}}")}, + std::nullopt, + [](bool, + bool, + int16_t, + std::vector>, + std::optional) {}, + oxen::quic::DEFAULT_TIMEOUT, + oxen::quic::DEFAULT_TIMEOUT); + CHECK(EVENTUALLY(300ms, network->called("check_request_queue_timeouts"))); + + // Test that it fails the request with a timeout if it has a build path timeout + // and the path build takes too long + network.emplace(std::nullopt, true, true, false); + network->ignore_calls_to("build_path"); + network->send_onion_request( + test_service_node, + ustring{to_usv("{\"method\":\"info\",\"params\":{}}")}, + std::nullopt, + [&prom](bool success, + bool timeout, + int16_t status_code, + std::vector> headers, + std::optional response) { + prom.set_value({success, timeout, status_code, headers, response}); + }, + oxen::quic::DEFAULT_TIMEOUT, + 100ms); + + // Wait for the result to be set + auto result = prom.get_future().get(); + + CHECK_FALSE(result.success); + CHECK(result.timeout); +} + +TEST_CASE("Network requests", "[network][send_request]") { + auto test_service_node = service_node{ + "decaf007f26d3d6f9b845ad031ffdf6d04638c25bb10b8fffbbe99135303c4b9"_hexbytes, + {2, 8, 0}, + INVALID_SWARM_ID, + "144.76.164.202", + uint16_t{35400}}; + auto network = TestNetwork(std::nullopt, true, true, false); + std::promise prom; + + network.establish_connection( + "Test", + test_service_node, + 3s, + [&prom, &network, test_service_node]( + connection_info info, std::optional error) { + if (!info.is_valid()) + return prom.set_value({false, false, -1, {}, error.value_or("Unknown Error")}); + + network.send_request( + request_info::make( + test_service_node, + ustring{to_usv("{}")}, + std::nullopt, + 3s, + std::nullopt, + PathType::standard, + std::nullopt, + "info"), + std::move(info), + [&prom](bool success, + bool timeout, + int16_t status_code, + std::vector> headers, + std::optional response) { + prom.set_value({success, timeout, status_code, headers, response}); + }); + }); + + // Wait for the result to be set + auto result = prom.get_future().get(); + + CHECK(result.success); + CHECK_FALSE(result.timeout); + CHECK(result.status_code == 200); + REQUIRE(result.response.has_value()); + INFO("*result.response is: " << *result.response); + REQUIRE_NOTHROW([&] { [[maybe_unused]] auto _ = nlohmann::json::parse(*result.response); }); + + auto response = nlohmann::json::parse(*result.response); + CHECK(response.contains("hf")); + CHECK(response.contains("t")); + CHECK(response.contains("version")); +} + +TEST_CASE("Network onion request", "[network][send_onion_request]") { + auto test_service_node = service_node{ + "decaf007f26d3d6f9b845ad031ffdf6d04638c25bb10b8fffbbe99135303c4b9"_hexbytes, + {2, 8, 0}, + INVALID_SWARM_ID, + "144.76.164.202", + uint16_t{35400}}; + auto network = Network(std::nullopt, true, true, false); + std::promise result_promise; + + network.send_onion_request( + test_service_node, + ustring{to_usv("{\"method\":\"info\",\"params\":{}}")}, + std::nullopt, + [&result_promise]( + bool success, + bool timeout, + int16_t status_code, + std::vector> headers, + std::optional response) { + result_promise.set_value({success, timeout, status_code, headers, response}); + }, + oxen::quic::DEFAULT_TIMEOUT, + oxen::quic::DEFAULT_TIMEOUT); + + // Wait for the result to be set + auto result = result_promise.get_future().get(); + + CHECK(result.success); + CHECK_FALSE(result.timeout); + CHECK(result.status_code == 200); + REQUIRE(result.response.has_value()); + INFO("*result.response is: " << *result.response); + REQUIRE_NOTHROW([&] { [[maybe_unused]] auto _ = nlohmann::json::parse(*result.response); }); + + auto response = nlohmann::json::parse(*result.response); + CHECK(response.contains("hf")); + CHECK(response.contains("t")); + CHECK(response.contains("version")); +} + +TEST_CASE("Network direct request C API", "[network][network_send_request]") { + network_object* network; + REQUIRE(network_init(&network, nullptr, true, true, false, nullptr)); + std::array target_ip = {144, 76, 164, 202}; + auto test_service_node = network_service_node{}; + test_service_node.quic_port = 35400; + std::copy(target_ip.begin(), target_ip.end(), test_service_node.ip); + std::strcpy( + test_service_node.ed25519_pubkey_hex, + "decaf007f26d3d6f9b845ad031ffdf6d04638c25bb10b8fffbbe99135303c4b9"); + auto body = ustring{to_usv("{\"method\":\"info\",\"params\":{}}")}; + auto result_promise = std::make_shared>(); + + network_send_onion_request_to_snode_destination( + network, + test_service_node, + body.data(), + body.size(), + nullptr, + std::chrono::milliseconds{oxen::quic::DEFAULT_TIMEOUT}.count(), + std::chrono::milliseconds{oxen::quic::DEFAULT_TIMEOUT}.count(), + [](bool success, + bool timeout, + int16_t status_code, + const char** headers, + const char** header_values, + size_t headers_size, + const char* c_response, + size_t response_size, + void* ctx) { + auto result_promise = static_cast*>(ctx); + auto response_str = std::string(c_response, response_size); + std::vector> header_pairs; + header_pairs.reserve(headers_size); + + for (size_t i = 0; i < headers_size; ++i) { + if (headers[i] == nullptr) + continue; // Skip null entries + if (header_values[i] == nullptr) + continue; // Skip null entries + + header_pairs.emplace_back(headers[i], header_values[i]); + } + + result_promise->set_value( + {success, timeout, status_code, header_pairs, response_str}); + }, + static_cast(result_promise.get())); + + // Wait for the result to be set + auto result = result_promise->get_future().get(); + + CHECK(result.success); + CHECK_FALSE(result.timeout); + CHECK(result.status_code == 200); + REQUIRE(result.response.has_value()); + INFO("*result.response is: " << *result.response); + REQUIRE_NOTHROW([&] { [[maybe_unused]] auto _ = nlohmann::json::parse(*result.response); }); + + auto response = nlohmann::json::parse(*result.response); + CHECK(response.contains("hf")); + CHECK(response.contains("t")); + CHECK(response.contains("version")); + network_free(network); +} + +TEST_CASE("Network swarm", "[network][detail][pubkey_to_swarm_space]") { + x25519_pubkey pk; + + pk = x25519_pubkey::from_hex( + "3506f4a71324b7dd114eddbf4e311f39dde243e1f2cb97c40db1961f70ebaae8"); + CHECK(session::network::detail::pubkey_to_swarm_space(pk) == 17589930838143112648ULL); + pk = x25519_pubkey::from_hex( + "cf27da303a50ac8c4b2d43d27259505c9bcd73fc21cf2a57902c3d050730b604"); + CHECK(session::network::detail::pubkey_to_swarm_space(pk) == 10370619079776428163ULL); + pk = x25519_pubkey::from_hex( + "d3511706b8b34f6e8411bf07bd22ba6b2435ca56846fbccf6eb1e166a6cd15cc"); + CHECK(session::network::detail::pubkey_to_swarm_space(pk) == 2144983569669512198ULL); + pk = x25519_pubkey::from_hex( + "0f06693428fca9102a451e3f28d9cc743d8ea60a89ab6aa69eb119470c11cbd3"); + CHECK(session::network::detail::pubkey_to_swarm_space(pk) == 9690840703409570833ULL); + pk = x25519_pubkey::from_hex( + "ffba630924aa1224bb930dde21c0d11bf004608f2812217f8ac812d6c7e3ad48"); + CHECK(session::network::detail::pubkey_to_swarm_space(pk) == 4532060000165252872ULL); + pk = x25519_pubkey::from_hex( + "eeeeeeeeeeeeeeee777777777777777711111111111111118888888888888888"); + CHECK(session::network::detail::pubkey_to_swarm_space(pk) == 0); + pk = x25519_pubkey::from_hex( + "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"); + CHECK(session::network::detail::pubkey_to_swarm_space(pk) == 0); + pk = x25519_pubkey::from_hex( + "fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe"); + CHECK(session::network::detail::pubkey_to_swarm_space(pk) == 1); + pk = x25519_pubkey::from_hex( + "ffffffffffffffffffffffffffffffffffffffffffffffff7fffffffffffffff"); + CHECK(session::network::detail::pubkey_to_swarm_space(pk) == 1ULL << 63); + pk = x25519_pubkey::from_hex( + "000000000000000000000000000000000000000000000000ffffffffffffffff"); + CHECK(session::network::detail::pubkey_to_swarm_space(pk) == (uint64_t)-1); + pk = x25519_pubkey::from_hex( + "0000000000000000000000000000000000000000000000000123456789abcdef"); + CHECK(session::network::detail::pubkey_to_swarm_space(pk) == 0x0123456789abcdefULL); +} + +TEST_CASE("Network swarm", "[network][get_swarm]") { + auto ed_pk = "4cb76fdc6d32278e3f83dbf608360ecc6b65727934b85d2fb86862ff98c46ab7"_hexbytes; + std::vector>> swarms = { + {100, {}}, {200, {}}, {300, {}}, {399, {}}, {498, {}}, {596, {}}, {694, {}}}; + auto network = TestNetwork(std::nullopt, true, true, false); + network.set_snode_cache({test_node(ed_pk, 0)}); + network.set_all_swarms(swarms); + + // Exact matches: + // 0x64 = 100, 0xc8 = 200, 0x1f2 = 498 + CHECK(network.get_swarm_id("05000000000000000000000000000000000000000000000000000000000000006" + "4") == 100); + CHECK(network.get_swarm_id("0500000000000000000000000000000000000000000000000000000000000000c" + "8") == 200); + CHECK(network.get_swarm_id("0500000000000000000000000000000000000000000000000000000000000001f" + "2") == 498); + + // Nearest + CHECK(network.get_swarm_id("05000000000000000000000000000000000000000000000000000000000000000" + "0") == 100); + CHECK(network.get_swarm_id("05000000000000000000000000000000000000000000000000000000000000000" + "1") == 100); + + // Nearest, with wraparound + // 0x8000... is closest to the top value + CHECK(network.get_swarm_id("05000000000000000000000000000000000000000000000000800000000000000" + "0") == 694); + + // 0xa000... is closest (via wraparound) to the smallest + CHECK(network.get_swarm_id("05000000000000000000000000000000000000000000000000a00000000000000" + "0") == 100); + + // This is the invalid swarm id for swarms, but should still work for a client + CHECK(network.get_swarm_id("05000000000000000000000000000000000000000000000000fffffffffffffff" + "f") == 100); + CHECK(network.get_swarm_id("05000000000000000000000000000000000000000000000000fffffffffffffff" + "e") == 100); + + // Midpoint tests; we prefer the lower value when exactly in the middle between two swarms. + // 0x96 = 150 + CHECK(network.get_swarm_id("05000000000000000000000000000000000000000000000000000000000000009" + "5") == 100); + CHECK(network.get_swarm_id("05000000000000000000000000000000000000000000000000000000000000009" + "6") == 100); + CHECK(network.get_swarm_id("05000000000000000000000000000000000000000000000000000000000000009" + "7") == 200); + + // 0xfa = 250 + CHECK(network.get_swarm_id("0500000000000000000000000000000000000000000000000000000000000000f" + "9") == 200); + CHECK(network.get_swarm_id("0500000000000000000000000000000000000000000000000000000000000000f" + "a") == 200); + CHECK(network.get_swarm_id("0500000000000000000000000000000000000000000000000000000000000000f" + "b") == 300); + + // 0x15d = 349 + CHECK(network.get_swarm_id("05000000000000000000000000000000000000000000000000000000000000015" + "d") == 300); + CHECK(network.get_swarm_id("05000000000000000000000000000000000000000000000000000000000000015" + "e") == 399); + + // 0x1c0 = 448 + CHECK(network.get_swarm_id("0500000000000000000000000000000000000000000000000000000000000001c" + "0") == 399); + CHECK(network.get_swarm_id("0500000000000000000000000000000000000000000000000000000000000001c" + "1") == 498); + + // 0x223 = 547 + CHECK(network.get_swarm_id("05000000000000000000000000000000000000000000000000000000000000022" + "2") == 498); + CHECK(network.get_swarm_id("05000000000000000000000000000000000000000000000000000000000000022" + "3") == 498); + CHECK(network.get_swarm_id("05000000000000000000000000000000000000000000000000000000000000022" + "4") == 596); + + // 0x285 = 645 + CHECK(network.get_swarm_id("05000000000000000000000000000000000000000000000000000000000000028" + "5") == 596); + CHECK(network.get_swarm_id("05000000000000000000000000000000000000000000000000000000000000028" + "6") == 694); + + // 0x800....d is the midpoint between 694 and 100 (the long way). We always round "down" (which + // in this case, means wrapping to the largest swarm). + CHECK(network.get_swarm_id("05000000000000000000000000000000000000000000000000800000000000018" + "c") == 694); + CHECK(network.get_swarm_id("05000000000000000000000000000000000000000000000000800000000000018" + "d") == 694); + CHECK(network.get_swarm_id("05000000000000000000000000000000000000000000000000800000000000018" + "e") == 100); + + // With a swarm at -20 the midpoint is now 40 (=0x28). When our value is the *low* value we + // prefer the *last* swarm in the case of a tie (while consistent with the general case of + // preferring the left edge, it means we're inconsistent with the other wraparound case, above. + // *sigh*). + swarms.push_back({(uint64_t)-20, {}}); + network.set_all_swarms(swarms); + CHECK(network.get_swarm_id("05000000000000000000000000000000000000000000000000000000000000002" + "7") == swarms.back().first); + CHECK(network.get_swarm_id("05000000000000000000000000000000000000000000000000000000000000002" + "8") == swarms.back().first); + CHECK(network.get_swarm_id("05000000000000000000000000000000000000000000000000000000000000002" + "9") == swarms.front().first); + + // The code used to have a broken edge case if we have a swarm at zero and a client at max-u64 + // because of an overflow in how the distance is calculated (the first swarm will be calculated + // as max-u64 away, rather than 1 away), and so the id always maps to the highest swarm (even + // though 0xfff...fe maps to the lowest swarm; the first check here, then, would fail. + swarms.insert(swarms.begin(), {0, {}}); + network.set_all_swarms(swarms); + CHECK(network.get_swarm_id("05000000000000000000000000000000000000000000000000fffffffffffffff" + "f") == 0); + CHECK(network.get_swarm_id("05000000000000000000000000000000000000000000000000fffffffffffffff" + "e") == 0); +} diff --git a/tests/test_onionreq.cpp b/tests/test_onionreq.cpp index a4c42dd3..0f892538 100644 --- a/tests/test_onionreq.cpp +++ b/tests/test_onionreq.cpp @@ -1,4 +1,5 @@ #include +#include #include #include diff --git a/tests/test_session_encrypt.cpp b/tests/test_session_encrypt.cpp index 8c286b69..a15ef826 100644 --- a/tests/test_session_encrypt.cpp +++ b/tests/test_session_encrypt.cpp @@ -1,3 +1,4 @@ +#include #include #include @@ -418,16 +419,40 @@ TEST_CASE("Session ONS response decryption", "[session-ons][decrypt]") { auto ciphertext = "3575802dd9bfea72672a208840f37ca289ceade5d3ffacabe2d231f109d204329fc33e28c33" "1580d9a8c9b8a64cacfec97"_hexbytes; + auto ciphertext_legacy = + "dbd4bc89bd2c9e5322fd9f4cadcaa66a0c38f15d0c927a86cc36e895fe1f3c532a3958d972563f52ca858e94eec22dc360"_hexbytes; auto nonce = "00112233445566778899aabbccddeeff00ffeeddccbbaa99"_hexbytes; - ustring sid_data = - "05d2ad010eeb72d72e561d9de7bd7b6989af77dcabffa03a5111a6c859ae5c3a72"_hexbytes; CHECK(decrypt_ons_response(name, ciphertext, nonce) == "05d2ad010eeb72d72e561d9de7bd7b6989af77dcabffa03a5111a6c859ae5c3a72"); + CHECK(decrypt_ons_response(name, ciphertext_legacy, std::nullopt) == + "05d2ad010eeb72d72e561d9de7bd7b6989af77dcabffa03a5111a6c859ae5c3a72"); CHECK_THROWS(decrypt_ons_response(name, to_unsigned_sv("invalid"), nonce)); CHECK_THROWS(decrypt_ons_response(name, ciphertext, to_unsigned_sv("invalid"))); } +TEST_CASE("Session ONS response decryption C API", "[session-ons][session_decrypt_ons_response]") { + using namespace session; + + auto name = "test\0"; + auto ciphertext = + "3575802dd9bfea72672a208840f37ca289ceade5d3ffacabe2d231f109d204329fc33e28c33" + "1580d9a8c9b8a64cacfec97"_hexbytes; + auto ciphertext_legacy = + "dbd4bc89bd2c9e5322fd9f4cadcaa66a0c38f15d0c927a86cc36e895fe1f3c532a3958d972563f52ca858e94eec22dc360"_hexbytes; + auto nonce = "00112233445566778899aabbccddeeff00ffeeddccbbaa99"_hexbytes; + + char ons1[67]; + CHECK(session_decrypt_ons_response( + name, ciphertext.data(), ciphertext.size(), nonce.data(), ons1)); + CHECK(ons1 == "05d2ad010eeb72d72e561d9de7bd7b6989af77dcabffa03a5111a6c859ae5c3a72"sv); + + char ons2[67]; + CHECK(session_decrypt_ons_response( + name, ciphertext_legacy.data(), ciphertext_legacy.size(), nullptr, ons2)); + CHECK(ons2 == "05d2ad010eeb72d72e561d9de7bd7b6989af77dcabffa03a5111a6c859ae5c3a72"sv); +} + TEST_CASE("Session push notification decryption", "[session-notification][decrypt]") { using namespace session; @@ -444,3 +469,19 @@ TEST_CASE("Session push notification decryption", "[session-notification][decryp CHECK_THROWS(decrypt_push_notification(to_unsigned_sv("invalid"), enc_key)); CHECK_THROWS(decrypt_push_notification(payload, to_unsigned_sv("invalid"))); } + +TEST_CASE("Session message hash", "[session][message-hash]") { + using namespace session; + + auto pubkey_hex = "0518981d2822aabc9ba8dbf83f2feac4c70eb737930bc4f254fa71e01f8464a049"sv; + auto base64_data1 = + "CAESpQMKA1BVVBIPL2FwaS92MS9tZXNzYWdlGoEDCAcSQjA1MTg5ODFkMjgyMmFhYmM5YmE4ZGJmODNmMmZlYWM0YzcwZWI3Mzc5MzBiYzRmMjU0ZmE3MWUwMWY4NDY0YTA0OSirlfaFnDI4AUKvAvYMh0I1qhBp9tDOzZhl7vMFuD7a9k/BLvPHMOkTrYsjGj2ri7T6AoJjVm/dDMsXlEP58VaGFSv+mcctCRstYox+3CchbQoVieBi2NGE1bqCeiZeLOMxQxleSZ94vzi7CoC8/NCLmTBzKvw0GBo77Tz37yPGxNLp2QO1xOuDVqM1/+4Sdj+JzMpfsZA8PDMmG3T1o8DJJ/EmwlxsmKM/eAjqtNpdF1G7wtZW5im9fiW11sQgG0/+5EsqxqEoo0xsi5TL6L9DN6zKhjXC9bu/QAfI5ZIpF5+9IHzKashPAjSswBZmlesjbFbNvNgBq4hSeXIxjtg7xDm/hfXao1WRa3TMHgfZs2bY+cNlDGqArjZT9q5XTVxsQYXq+mz/koh0qxiJktAC3C0ixs7CInORFiD18omD4oqX1/IB"sv; + auto base64_data2 = + "CAESpAMKA1BVVBIPL2FwaS92MS9tZXNzYWdlGoEDCAcSQjA1MTg5ODFkMjgyMmFhYmM5YmE4ZGJmODNmMmZlYWM0YzcwZWI3Mzc5MzBiYzRmMjU0ZmE3MWUwMWY4NDY0YTA0OSi/7peInDI4AUKvAsCTN9WMEkajMbC7EA6QOClzdXK3W6MTEElFotQ6PGNa2IKfYb+iu0MRC6ph+1hE5hzfay00v0UfB5Xen3dBgZ2drwToYhYb1zqRlIeesdwT0Yt6ct+Gn47PBL4oXOv7PJo3ys3jlq1t+xbAN/vum/8ART9xVhNIZ+3dOpS62z8pwSqusWECGw9dJDgFN6g0+2R85dco/HP9Z2SiGBaAJulKFUXKaT+jMHab3nPjoqke/lVG544iJAmNbI+KJr61YgtsbVfO02pje1RXeQtQacAtWpCYlin4fNtr6ANTs8aJDb1H1JFOG/r8PZHkPl1Fl/2cDppngZYJJo6/8IH9FpZS64le+mZy2BjP7UKfEx3ulmJIwpfqcqe9qvoTbGtljSf8wRylUkeo1E7Gg2WP8SDrgdXBwIaZp24="sv; + int16_t ns = -10; + + CHECK(compute_message_hash(pubkey_hex, ns, base64_data1) == + "xREbCx9GRzDiuU8GsEK7rR1InU6peC3vp10cBkTUDPg"); + CHECK(compute_message_hash(pubkey_hex, ns, base64_data2) == + "apKu8OMjrbU+YeVWpMSyrr1wHq51K3uKD8WM0F4E1cE"); +} diff --git a/tests/utils.hpp b/tests/utils.hpp index 76b145c1..9fbeee69 100644 --- a/tests/utils.hpp +++ b/tests/utils.hpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include "session/config/base.h" @@ -88,3 +89,69 @@ std::vector> view_vec(const std::vector> Validator> +auto eventually_impl(std::chrono::milliseconds timeout, Call&& f, Validator&& isValid) + -> std::invoke_result_t { + using ResultType = std::invoke_result_t; + + // If we already have a value then don't bother with the loop + if (auto result = f(); isValid(result)) + return result; + + auto start = std::chrono::steady_clock::now(); + auto sleep_duration = std::chrono::milliseconds{10}; + while (std::chrono::steady_clock::now() - start < timeout) { + std::this_thread::sleep_for(sleep_duration); + + if (auto result = f(); isValid(result)) + return result; + } + + return ResultType{}; +} + +template > Validator> +bool always_impl(std::chrono::milliseconds duration, Call&& f, Validator&& isValid) { + auto start = std::chrono::steady_clock::now(); + auto sleep_duration = std::chrono::milliseconds{10}; + while (std::chrono::steady_clock::now() - start < duration) { + if (auto result = f(); !isValid(result)) + return false; + std::this_thread::sleep_for(sleep_duration); + } + return true; +} + +template + requires std::is_same_v, bool> +bool eventually_impl(std::chrono::milliseconds timeout, Call&& f) { + return eventually_impl(timeout, f, [](bool result) { return result; }); +} + +template + requires std::is_same_v< + std::invoke_result_t, + std::vector::value_type>> +auto eventually_impl(std::chrono::milliseconds timeout, Call&& f) -> std::invoke_result_t { + using ResultType = std::invoke_result_t; + return eventually_impl(timeout, f, [](const ResultType& result) { return !result.empty(); }); +} + +template + requires std::is_same_v, bool> +bool always_impl(std::chrono::milliseconds duration, Call&& f) { + return always_impl(duration, f, [](bool result) { return result; }); +} + +template + requires std::is_same_v< + std::invoke_result_t, + std::vector::value_type>> +bool always_impl(std::chrono::milliseconds duration, Call&& f) { + using ResultType = std::invoke_result_t; + return always_impl(duration, f, [](const ResultType& result) { return !result.empty(); }); +} + +#define EVENTUALLY(timeout, ...) eventually_impl(timeout, [&]() { return (__VA_ARGS__); }) +#define ALWAYS(duration, ...) always_impl(duration, [&]() { return (__VA_ARGS__); }) diff --git a/utils/android.sh b/utils/android.sh index 98ccdfdd..a39a7b64 100755 --- a/utils/android.sh +++ b/utils/android.sh @@ -49,8 +49,6 @@ pkg="${archive%%.tar.xz}" mkdir -p "$pkg"/include cp -rv ../include/session "$pkg"/include/ -mkdir -p "$pkg"/include/oxenc -cp -v ../external/oxen-encoding/oxenc/*.h x86_64/external/oxen-encoding/oxenc/version.h "$pkg"/include/oxenc/ for abi in "${abis[@]}"; do mkdir -p "$pkg"/lib/$abi diff --git a/utils/ci/drone-format-verify.sh b/utils/ci/drone-format-verify.sh index 6a4d3041..829a3a21 100755 --- a/utils/ci/drone-format-verify.sh +++ b/utils/ci/drone-format-verify.sh @@ -1,6 +1,6 @@ #!/usr/bin/env bash test "x$IGNORE" != "x" && exit 0 repo=$(readlink -e $(dirname $0)/../../) -clang-format-15 -i $(find $repo/src $repo/include $repo/tests | grep -E '\.[hc](pp)?$') +clang-format-17 -i $(find $repo/src $repo/include $repo/tests | grep -E '\.[hc](pp)?$') jsonnetfmt -i $repo/.drone.jsonnet git --no-pager diff --exit-code --color || (echo -ne '\n\n\e[31;1mLint check failed; please run ./utils/format.sh\e[0m\n\n' ; exit 1) diff --git a/utils/format.sh b/utils/format.sh index 05a783f2..15ca8dcf 100755 --- a/utils/format.sh +++ b/utils/format.sh @@ -1,6 +1,6 @@ #!/usr/bin/env bash -CLANG_FORMAT_DESIRED_VERSION=15 +CLANG_FORMAT_DESIRED_VERSION=17 binary=$(command -v clang-format-$CLANG_FORMAT_DESIRED_VERSION 2>/dev/null) if [ $? -ne 0 ]; then diff --git a/utils/ios.sh b/utils/ios.sh index 3e00e63c..1ff2eede 100755 --- a/utils/ios.sh +++ b/utils/ios.sh @@ -22,6 +22,7 @@ VALID_DEVICE_ARCH_PLATFORMS=(OS64) OUTPUT_DIR="${TARGET_BUILD_DIR:-build-ios}" IPHONEOS_DEPLOYMENT_TARGET=${IPHONEOS_DEPLOYMENT_TARGET:-13} ENABLE_BITCODE=${ENABLE_BITCODE:-OFF} +CONFIGURATION=${CONFIGURATION:-App_Store_Release} SHOULD_ACHIVE=${2:-true} # Parameter 2 is a flag indicating whether we want to archive the result # We want to customise the env variable so can't just default the value @@ -97,6 +98,14 @@ if [ -z $PLATFORM_NAME ] || [ $PLATFORM_NAME = "iphoneos" ]; then fi # Build the individual architectures +submodule_check=ON +build_type="Release" + +if [ "$CONFIGURATION" == "Debug" || "$CONFIGURATION" == "Debug_Compile_LibSession" ]; then + submodule_check=OFF + build_type="Debug" +fi + for i in "${!TARGET_ARCHS[@]}"; do build="${BUILD_DIR}/${TARGET_ARCHS[$i]}" platform="${TARGET_PLATFORMS[$i]}" @@ -107,7 +116,10 @@ for i in "${!TARGET_ARCHS[@]}"; do -DPLATFORM=$platform \ -DDEPLOYMENT_TARGET=$IPHONEOS_DEPLOYMENT_TARGET \ -DENABLE_BITCODE=$ENABLE_BITCODE \ - -DENABLE_ONIONREQ=OFF \ + -DBUILD_STATIC_DEPS=ON \ + -DENABLE_VISIBILITY=ON \ + -DSUBMODULE_CHECK=$submodule_check \ + -DCMAKE_BUILD_TYPE=$build_type \ -DLOCAL_MIRROR=https://oxen.rocks/deps done @@ -142,42 +154,39 @@ rm -rf "${OUTPUT_DIR}/libsession-util.xcframework" if [ "${#TARGET_SIM_ARCHS}" -gt "0" ] && [ "${#TARGET_DEVICE_ARCHS}" -gt "0" ]; then xcodebuild -create-xcframework \ -library "${BUILD_DIR}/ios/libsession-util.a" \ + -headers "include" \ -library "${BUILD_DIR}/sim/libsession-util.a" \ + -headers "include" \ -output "${OUTPUT_DIR}/libsession-util.xcframework" elif [ "${#TARGET_DEVICE_ARCHS}" -gt "0" ]; then xcodebuild -create-xcframework \ -library "${BUILD_DIR}/ios/libsession-util.a" \ + -headers "include" \ -output "${OUTPUT_DIR}/libsession-util.xcframework" else xcodebuild -create-xcframework \ -library "${BUILD_DIR}/sim/libsession-util.a" \ + -headers "include" \ -output "${OUTPUT_DIR}/libsession-util.xcframework" fi -# Copy the headers over -cp -rv include/session "${OUTPUT_DIR}/libsession-util.xcframework" - # The 'module.modulemap' is needed for XCode to be able to find the headers -modmap="${OUTPUT_DIR}/libsession-util.xcframework/module.modulemap" +modmap="${OUTPUT_DIR}/module.modulemap" echo "module SessionUtil {" >"$modmap" echo " module capi {" >>"$modmap" for x in $(cd include && find session -name '*.h'); do echo " header \"$x\"" >>"$modmap" done echo -e " export *\n }" >>"$modmap" -if false; then - # If we include the cpp headers like this then Xcode will try to load them as C headers (which - # of course breaks) and doesn't provide any way to only load the ones you need (because this is - # Apple land, why would anything useful be available?). So we include the headers in the - # archive but can't let xcode discover them because it will do it wrong. - echo -e "\n module cppapi {" >>"$modmap" - for x in $(cd include && find session -name '*.hpp'); do - echo " header \"$x\"" >>"$modmap" - done - echo -e " export *\n }" >>"$modmap" -fi echo "}" >>"$modmap" +# Need to add the module.modulemap into each architecture directory in the xcframework +for dir in "${OUTPUT_DIR}/libsession-util.xcframework"/*/; do + cp "${modmap}" "${dir}/Headers/module.modulemap" +done + +rm -rf "${modmap}" + if [ $SHOULD_ACHIVE = true ]; then (cd "${OUTPUT_DIR}/.." && tar cvJf "${UNIQUE_NAME}.tar.xz" "${UNIQUE_NAME}") fi diff --git a/utils/static-bundle.sh b/utils/static-bundle.sh index 7b8ea49d..1b6ad0a9 100755 --- a/utils/static-bundle.sh +++ b/utils/static-bundle.sh @@ -72,8 +72,6 @@ fi mkdir -p "$pkg"/{lib,include} cp -v libsession-util.a "$pkg"/lib cp -rv "$projdir"/include/session "$pkg"/include -mkdir -p "$pkg"/include/oxenc -cp -v "$projdir"/external/oxen-encoding/oxenc/*.h external/oxen-encoding/oxenc/version.h "$pkg"/include/oxenc/ if [ -z "$zip" ]; then tar cvJf "$archive" "$pkg"