diff --git a/mlir/include/mlir/Tools/lsp-server-support/Transport.h b/mlir/include/mlir/Tools/lsp-server-support/Transport.h index ce742be7a941c..44c71058cf717 100644 --- a/mlir/include/mlir/Tools/lsp-server-support/Transport.h +++ b/mlir/include/mlir/Tools/lsp-server-support/Transport.h @@ -147,9 +147,15 @@ class MessageHandler { void (ThisT::*handler)(const Param &)) { notificationHandlers[method] = [method, handler, thisPtr](llvm::json::Value rawParams) { - llvm::Expected param = parse(rawParams, method, "request"); - if (!param) - return llvm::consumeError(param.takeError()); + llvm::Expected param = + parse(rawParams, method, "notification"); + if (!param) { + return llvm::consumeError( + llvm::handleErrors(param.takeError(), [](const LSPError &lspError) { + Logger::error("JSON parsing error: {0}", + lspError.message.c_str()); + })); + } (thisPtr->*handler)(*param); }; } diff --git a/mlir/lib/Tools/lsp-server-support/Transport.cpp b/mlir/lib/Tools/lsp-server-support/Transport.cpp index 64dea35614c07..339c5f3825165 100644 --- a/mlir/lib/Tools/lsp-server-support/Transport.cpp +++ b/mlir/lib/Tools/lsp-server-support/Transport.cpp @@ -51,12 +51,12 @@ class Reply { Reply::Reply(const llvm::json::Value &id, llvm::StringRef method, JSONTransport &transport, std::mutex &transportOutputMutex) - : id(id), transport(&transport), + : method(method), id(id), transport(&transport), transportOutputMutex(transportOutputMutex) {} Reply::Reply(Reply &&other) - : replied(other.replied.load()), id(std::move(other.id)), - transport(other.transport), + : method(other.method), replied(other.replied.load()), + id(std::move(other.id)), transport(other.transport), transportOutputMutex(other.transportOutputMutex) { other.transport = nullptr; } diff --git a/mlir/unittests/CMakeLists.txt b/mlir/unittests/CMakeLists.txt index 6fad249a0b2fb..6d8aa290e82f2 100644 --- a/mlir/unittests/CMakeLists.txt +++ b/mlir/unittests/CMakeLists.txt @@ -20,6 +20,7 @@ add_subdirectory(Support) add_subdirectory(Rewrite) add_subdirectory(TableGen) add_subdirectory(Target) +add_subdirectory(Tools) add_subdirectory(Transforms) if(MLIR_ENABLE_EXECUTION_ENGINE) diff --git a/mlir/unittests/Tools/CMakeLists.txt b/mlir/unittests/Tools/CMakeLists.txt new file mode 100644 index 0000000000000..a97588d928668 --- /dev/null +++ b/mlir/unittests/Tools/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(lsp-server-support) diff --git a/mlir/unittests/Tools/lsp-server-support/CMakeLists.txt b/mlir/unittests/Tools/lsp-server-support/CMakeLists.txt new file mode 100644 index 0000000000000..3aa8b9c4bc773 --- /dev/null +++ b/mlir/unittests/Tools/lsp-server-support/CMakeLists.txt @@ -0,0 +1,6 @@ +add_mlir_unittest(MLIRLspServerSupportTests + Transport.cpp +) +target_link_libraries(MLIRLspServerSupportTests + PRIVATE + MLIRLspServerSupportLib) diff --git a/mlir/unittests/Tools/lsp-server-support/Transport.cpp b/mlir/unittests/Tools/lsp-server-support/Transport.cpp new file mode 100644 index 0000000000000..a086964cd3660 --- /dev/null +++ b/mlir/unittests/Tools/lsp-server-support/Transport.cpp @@ -0,0 +1,134 @@ +//===- Transport.cpp - LSP JSON transport unit tests ----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Tools/lsp-server-support/Transport.h" +#include "mlir/Tools/lsp-server-support/Logging.h" +#include "mlir/Tools/lsp-server-support/Protocol.h" +#include "llvm/Support/FileSystem.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +using namespace mlir; +using namespace mlir::lsp; +using namespace testing; + +namespace { + +TEST(TransportTest, SendReply) { + std::string out; + llvm::raw_string_ostream os(out); + JSONTransport transport(nullptr, os); + MessageHandler handler(transport); + + transport.reply(1989, nullptr); + EXPECT_THAT(out, HasSubstr("\"id\":1989")); + EXPECT_THAT(out, HasSubstr("\"result\":null")); +} + +class TransportInputTest : public Test { + llvm::SmallVector inputPath; + std::FILE *in = nullptr; + std::string output = ""; + llvm::raw_string_ostream os; + std::optional transport = std::nullopt; + std::optional messageHandler = std::nullopt; + +protected: + TransportInputTest() : os(output) {} + + void SetUp() override { + std::error_code ec = + llvm::sys::fs::createTemporaryFile("lsp-unittest", "json", inputPath); + ASSERT_FALSE(ec) << "Could not create temporary file: " << ec.message(); + + in = std::fopen(inputPath.data(), "r"); + ASSERT_TRUE(in) << "Could not open temporary file: " + << std::strerror(errno); + transport.emplace(in, os, JSONStreamStyle::Delimited); + messageHandler.emplace(*transport); + } + + void TearDown() override { + EXPECT_EQ(std::fclose(in), 0) + << "Could not close temporary file FD: " << std::strerror(errno); + std::error_code ec = + llvm::sys::fs::remove(inputPath, /*IgnoreNonExisting=*/false); + EXPECT_FALSE(ec) << "Could not remove temporary file '" << inputPath.data() + << "': " << ec.message(); + } + + void writeInput(StringRef buffer) { + std::error_code ec; + llvm::raw_fd_ostream os(inputPath.data(), ec); + ASSERT_FALSE(ec) << "Could not write to '" << inputPath.data() + << "': " << ec.message(); + os << buffer; + os.close(); + } + + StringRef getOutput() const { return output; } + MessageHandler &getMessageHandler() { return *messageHandler; } + + void runTransport() { + bool gotEOF = false; + llvm::Error err = llvm::handleErrors( + transport->run(*messageHandler), [&](const llvm::ECError &ecErr) { + gotEOF = ecErr.convertToErrorCode() == std::errc::io_error; + }); + llvm::consumeError(std::move(err)); + EXPECT_TRUE(gotEOF); + } +}; + +TEST_F(TransportInputTest, RequestWithInvalidParams) { + struct Handler { + void onMethod(const TextDocumentItem ¶ms, + mlir::lsp::Callback callback) {} + } handler; + getMessageHandler().method("invalid-params-request", &handler, + &Handler::onMethod); + + writeInput("{\"jsonrpc\":\"2.0\",\"id\":92," + "\"method\":\"invalid-params-request\",\"params\":{}}\n"); + runTransport(); + EXPECT_THAT(getOutput(), HasSubstr("error")); + EXPECT_THAT(getOutput(), HasSubstr("missing value at (root).uri")); +} + +TEST_F(TransportInputTest, NotificationWithInvalidParams) { + // JSON parsing errors are only reported via error logging. As a result, this + // test can't make any expectations -- but it prints the output anyway, by way + // of demonstration. + Logger::setLogLevel(Logger::Level::Error); + + struct Handler { + void onNotification(const TextDocumentItem ¶ms) {} + } handler; + getMessageHandler().notification("invalid-params-notification", &handler, + &Handler::onNotification); + + writeInput("{\"jsonrpc\":\"2.0\",\"method\":\"invalid-params-notification\"," + "\"params\":{}}\n"); + runTransport(); +} + +TEST_F(TransportInputTest, MethodNotFound) { + writeInput("{\"jsonrpc\":\"2.0\",\"id\":29,\"method\":\"ack\"}\n"); + runTransport(); + EXPECT_THAT(getOutput(), HasSubstr("\"id\":29")); + EXPECT_THAT(getOutput(), HasSubstr("\"error\"")); + EXPECT_THAT(getOutput(), HasSubstr("\"message\":\"method not found: ack\"")); +} + +TEST_F(TransportInputTest, OutgoingNotification) { + auto notifyFn = getMessageHandler().outgoingNotification( + "outgoing-notification"); + notifyFn(CompletionList{}); + EXPECT_THAT(getOutput(), HasSubstr("\"method\":\"outgoing-notification\"")); +} +} // namespace