Skip to content

[mlir-lsp] Support outgoing requests #90078

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions mlir/include/mlir/Tools/lsp-server-support/Transport.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#ifndef MLIR_TOOLS_LSPSERVERSUPPORT_TRANSPORT_H
#define MLIR_TOOLS_LSPSERVERSUPPORT_TRANSPORT_H

#include "mlir/Support/DebugStringHelper.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Tools/lsp-server-support/Logging.h"
Expand Down Expand Up @@ -100,6 +101,18 @@ using Callback = llvm::unique_function<void(llvm::Expected<T>)>;
template <typename T>
using OutgoingNotification = llvm::unique_function<void(const T &)>;

/// An OutgoingRequest<T> is a function used for outgoing requests to send to
/// the client.
template <typename T>
using OutgoingRequest =
llvm::unique_function<void(const T &, llvm::json::Value id)>;

/// An `OutgoingRequestCallback` is invoked when an outgoing request to the
/// client receives a response in turn. It is passed the original request's ID,
/// as well as the result JSON.
using OutgoingRequestCallback =
std::function<void(llvm::json::Value, llvm::Expected<llvm::json::Value>)>;

/// A handler used to process the incoming transport messages.
class MessageHandler {
public:
Expand Down Expand Up @@ -170,6 +183,26 @@ class MessageHandler {
};
}

/// Create an OutgoingRequest function that, when called, sends a request with
/// the given method via the transport. Should the outgoing request be
/// met with a response, the response callback is invoked to handle that
/// response.
template <typename T>
OutgoingRequest<T> outgoingRequest(llvm::StringLiteral method,
OutgoingRequestCallback callback) {
return [&, method, callback](const T &params, llvm::json::Value id) {
{
std::lock_guard<std::mutex> lock(responseHandlersMutex);
responseHandlers.insert(
{debugString(id), std::make_pair(method.str(), callback)});
}

std::lock_guard<std::mutex> transportLock(transportOutputMutex);
Logger::info("--> {0}({1})", method, id);
transport.call(method, llvm::json::Value(params), id);
};
}

private:
template <typename HandlerT>
using HandlerMap = llvm::StringMap<llvm::unique_function<HandlerT>>;
Expand All @@ -178,6 +211,14 @@ class MessageHandler {
HandlerMap<void(llvm::json::Value, Callback<llvm::json::Value>)>
methodHandlers;

/// A pair of (1) the original request's method name, and (2) the callback
/// function to be invoked for responses.
using ResponseHandlerTy = std::pair<std::string, OutgoingRequestCallback>;
/// A mapping from request/response ID to response handler.
llvm::StringMap<ResponseHandlerTy> responseHandlers;
/// Mutex to guard insertion into the response handler map.
std::mutex responseHandlersMutex;

JSONTransport &transport;

/// Mutex to guard sending output messages to the transport.
Expand Down
38 changes: 23 additions & 15 deletions mlir/lib/Tools/lsp-server-support/Transport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,21 +117,29 @@ bool MessageHandler::onCall(llvm::StringRef method, llvm::json::Value params,

bool MessageHandler::onReply(llvm::json::Value id,
llvm::Expected<llvm::json::Value> result) {
// TODO: Add support for reply callbacks when support for outgoing messages is
// added. For now, we just log an error on any replies received.
Callback<llvm::json::Value> replyHandler =
[&id](llvm::Expected<llvm::json::Value> result) {
Logger::error(
"received a reply with ID {0}, but there was no such call", id);
if (!result)
llvm::consumeError(result.takeError());
};

// Log and run the reply handler.
if (result)
replyHandler(std::move(result));
else
replyHandler(result.takeError());
// Find the response handler in the mapping. If it exists, move it out of the
// mapping and erase it.
ResponseHandlerTy responseHandler;
{
std::lock_guard<std::mutex> responseHandlersLock(responseHandlersMutex);
auto it = responseHandlers.find(debugString(id));
if (it != responseHandlers.end()) {
responseHandler = std::move(it->second);
responseHandlers.erase(it);
}
}

// If we found a response handler, invoke it. Otherwise, log an error.
if (responseHandler.second) {
Logger::info("--> reply:{0}({1})", responseHandler.first, id);
responseHandler.second(std::move(id), std::move(result));
} else {
Logger::error(
"received a reply with ID {0}, but there was no such outgoing request",
id);
if (!result)
llvm::consumeError(result.takeError());
}
return true;
}

Expand Down
39 changes: 39 additions & 0 deletions mlir/unittests/Tools/lsp-server-support/Transport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,4 +131,43 @@ TEST_F(TransportInputTest, OutgoingNotification) {
notifyFn(CompletionList{});
EXPECT_THAT(getOutput(), HasSubstr("\"method\":\"outgoing-notification\""));
}

TEST_F(TransportInputTest, ResponseHandlerNotFound) {
// Unhandled responses are only reported via error logging. As a result, this
// test can't make any expectations -- but it prints the output anyway, by way
// of demonstration.
Logger::setLogLevel(Logger::Level::Error);
writeInput("{\"jsonrpc\":\"2.0\",\"id\":81,\"result\":null}\n");
runTransport();
}

TEST_F(TransportInputTest, OutgoingRequest) {
// Make some outgoing requests.
int responseCallbackInvoked = 0;
auto callFn = getMessageHandler().outgoingRequest<CompletionList>(
"outgoing-request",
[&responseCallbackInvoked](llvm::json::Value id,
llvm::Expected<llvm::json::Value> value) {
// Make expectations on the expected response.
EXPECT_EQ(id, 83);
ASSERT_TRUE((bool)value);
EXPECT_EQ(debugString(*value), "{\"foo\":6}");
responseCallbackInvoked += 1;
llvm::outs() << "here!!!\n";
});
callFn({}, 82);
callFn({}, 83);
callFn({}, 84);
EXPECT_THAT(getOutput(), HasSubstr("\"method\":\"outgoing-request\""));
EXPECT_EQ(responseCallbackInvoked, 0);

// One of the requests receives a response. The message handler handles this
// response by invoking the callback from above. Subsequent responses with the
// same ID are ignored.
writeInput("{\"jsonrpc\":\"2.0\",\"id\":83,\"result\":{\"foo\":6}}\n"
"// -----\n"
"{\"jsonrpc\":\"2.0\",\"id\":83,\"result\":{\"bar\":8}}\n");
runTransport();
EXPECT_EQ(responseCallbackInvoked, 1);
}
} // namespace
Loading