diff --git a/mlir/include/mlir/Tools/lsp-server-support/Transport.h b/mlir/include/mlir/Tools/lsp-server-support/Transport.h
index ce742be7a941c..44c71058cf717 100644
--- a/mlir/include/mlir/Tools/lsp-server-support/Transport.h
+++ b/mlir/include/mlir/Tools/lsp-server-support/Transport.h
@@ -147,9 +147,15 @@ class MessageHandler {
void (ThisT::*handler)(const Param &)) {
notificationHandlers[method] = [method, handler,
thisPtr](llvm::json::Value rawParams) {
- llvm::Expected param = parse(rawParams, method, "request");
- if (!param)
- return llvm::consumeError(param.takeError());
+ llvm::Expected param =
+ parse(rawParams, method, "notification");
+ if (!param) {
+ return llvm::consumeError(
+ llvm::handleErrors(param.takeError(), [](const LSPError &lspError) {
+ Logger::error("JSON parsing error: {0}",
+ lspError.message.c_str());
+ }));
+ }
(thisPtr->*handler)(*param);
};
}
diff --git a/mlir/lib/Tools/lsp-server-support/Transport.cpp b/mlir/lib/Tools/lsp-server-support/Transport.cpp
index 64dea35614c07..339c5f3825165 100644
--- a/mlir/lib/Tools/lsp-server-support/Transport.cpp
+++ b/mlir/lib/Tools/lsp-server-support/Transport.cpp
@@ -51,12 +51,12 @@ class Reply {
Reply::Reply(const llvm::json::Value &id, llvm::StringRef method,
JSONTransport &transport, std::mutex &transportOutputMutex)
- : id(id), transport(&transport),
+ : method(method), id(id), transport(&transport),
transportOutputMutex(transportOutputMutex) {}
Reply::Reply(Reply &&other)
- : replied(other.replied.load()), id(std::move(other.id)),
- transport(other.transport),
+ : method(other.method), replied(other.replied.load()),
+ id(std::move(other.id)), transport(other.transport),
transportOutputMutex(other.transportOutputMutex) {
other.transport = nullptr;
}
diff --git a/mlir/unittests/CMakeLists.txt b/mlir/unittests/CMakeLists.txt
index 6fad249a0b2fb..6d8aa290e82f2 100644
--- a/mlir/unittests/CMakeLists.txt
+++ b/mlir/unittests/CMakeLists.txt
@@ -20,6 +20,7 @@ add_subdirectory(Support)
add_subdirectory(Rewrite)
add_subdirectory(TableGen)
add_subdirectory(Target)
+add_subdirectory(Tools)
add_subdirectory(Transforms)
if(MLIR_ENABLE_EXECUTION_ENGINE)
diff --git a/mlir/unittests/Tools/CMakeLists.txt b/mlir/unittests/Tools/CMakeLists.txt
new file mode 100644
index 0000000000000..a97588d928668
--- /dev/null
+++ b/mlir/unittests/Tools/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(lsp-server-support)
diff --git a/mlir/unittests/Tools/lsp-server-support/CMakeLists.txt b/mlir/unittests/Tools/lsp-server-support/CMakeLists.txt
new file mode 100644
index 0000000000000..3aa8b9c4bc773
--- /dev/null
+++ b/mlir/unittests/Tools/lsp-server-support/CMakeLists.txt
@@ -0,0 +1,6 @@
+add_mlir_unittest(MLIRLspServerSupportTests
+ Transport.cpp
+)
+target_link_libraries(MLIRLspServerSupportTests
+ PRIVATE
+ MLIRLspServerSupportLib)
diff --git a/mlir/unittests/Tools/lsp-server-support/Transport.cpp b/mlir/unittests/Tools/lsp-server-support/Transport.cpp
new file mode 100644
index 0000000000000..a086964cd3660
--- /dev/null
+++ b/mlir/unittests/Tools/lsp-server-support/Transport.cpp
@@ -0,0 +1,134 @@
+//===- Transport.cpp - LSP JSON transport unit tests ----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Tools/lsp-server-support/Transport.h"
+#include "mlir/Tools/lsp-server-support/Logging.h"
+#include "mlir/Tools/lsp-server-support/Protocol.h"
+#include "llvm/Support/FileSystem.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+using namespace mlir::lsp;
+using namespace testing;
+
+namespace {
+
+TEST(TransportTest, SendReply) {
+ std::string out;
+ llvm::raw_string_ostream os(out);
+ JSONTransport transport(nullptr, os);
+ MessageHandler handler(transport);
+
+ transport.reply(1989, nullptr);
+ EXPECT_THAT(out, HasSubstr("\"id\":1989"));
+ EXPECT_THAT(out, HasSubstr("\"result\":null"));
+}
+
+class TransportInputTest : public Test {
+ llvm::SmallVector inputPath;
+ std::FILE *in = nullptr;
+ std::string output = "";
+ llvm::raw_string_ostream os;
+ std::optional transport = std::nullopt;
+ std::optional messageHandler = std::nullopt;
+
+protected:
+ TransportInputTest() : os(output) {}
+
+ void SetUp() override {
+ std::error_code ec =
+ llvm::sys::fs::createTemporaryFile("lsp-unittest", "json", inputPath);
+ ASSERT_FALSE(ec) << "Could not create temporary file: " << ec.message();
+
+ in = std::fopen(inputPath.data(), "r");
+ ASSERT_TRUE(in) << "Could not open temporary file: "
+ << std::strerror(errno);
+ transport.emplace(in, os, JSONStreamStyle::Delimited);
+ messageHandler.emplace(*transport);
+ }
+
+ void TearDown() override {
+ EXPECT_EQ(std::fclose(in), 0)
+ << "Could not close temporary file FD: " << std::strerror(errno);
+ std::error_code ec =
+ llvm::sys::fs::remove(inputPath, /*IgnoreNonExisting=*/false);
+ EXPECT_FALSE(ec) << "Could not remove temporary file '" << inputPath.data()
+ << "': " << ec.message();
+ }
+
+ void writeInput(StringRef buffer) {
+ std::error_code ec;
+ llvm::raw_fd_ostream os(inputPath.data(), ec);
+ ASSERT_FALSE(ec) << "Could not write to '" << inputPath.data()
+ << "': " << ec.message();
+ os << buffer;
+ os.close();
+ }
+
+ StringRef getOutput() const { return output; }
+ MessageHandler &getMessageHandler() { return *messageHandler; }
+
+ void runTransport() {
+ bool gotEOF = false;
+ llvm::Error err = llvm::handleErrors(
+ transport->run(*messageHandler), [&](const llvm::ECError &ecErr) {
+ gotEOF = ecErr.convertToErrorCode() == std::errc::io_error;
+ });
+ llvm::consumeError(std::move(err));
+ EXPECT_TRUE(gotEOF);
+ }
+};
+
+TEST_F(TransportInputTest, RequestWithInvalidParams) {
+ struct Handler {
+ void onMethod(const TextDocumentItem ¶ms,
+ mlir::lsp::Callback callback) {}
+ } handler;
+ getMessageHandler().method("invalid-params-request", &handler,
+ &Handler::onMethod);
+
+ writeInput("{\"jsonrpc\":\"2.0\",\"id\":92,"
+ "\"method\":\"invalid-params-request\",\"params\":{}}\n");
+ runTransport();
+ EXPECT_THAT(getOutput(), HasSubstr("error"));
+ EXPECT_THAT(getOutput(), HasSubstr("missing value at (root).uri"));
+}
+
+TEST_F(TransportInputTest, NotificationWithInvalidParams) {
+ // JSON parsing errors are only reported via error logging. As a result, this
+ // test can't make any expectations -- but it prints the output anyway, by way
+ // of demonstration.
+ Logger::setLogLevel(Logger::Level::Error);
+
+ struct Handler {
+ void onNotification(const TextDocumentItem ¶ms) {}
+ } handler;
+ getMessageHandler().notification("invalid-params-notification", &handler,
+ &Handler::onNotification);
+
+ writeInput("{\"jsonrpc\":\"2.0\",\"method\":\"invalid-params-notification\","
+ "\"params\":{}}\n");
+ runTransport();
+}
+
+TEST_F(TransportInputTest, MethodNotFound) {
+ writeInput("{\"jsonrpc\":\"2.0\",\"id\":29,\"method\":\"ack\"}\n");
+ runTransport();
+ EXPECT_THAT(getOutput(), HasSubstr("\"id\":29"));
+ EXPECT_THAT(getOutput(), HasSubstr("\"error\""));
+ EXPECT_THAT(getOutput(), HasSubstr("\"message\":\"method not found: ack\""));
+}
+
+TEST_F(TransportInputTest, OutgoingNotification) {
+ auto notifyFn = getMessageHandler().outgoingNotification(
+ "outgoing-notification");
+ notifyFn(CompletionList{});
+ EXPECT_THAT(getOutput(), HasSubstr("\"method\":\"outgoing-notification\""));
+}
+} // namespace