Skip to content

Commit 1ad6332

Browse files
committed
[mlir-lsp] Support outgoing requests
Add support for outgoing requests to `lsp::MessageHandler`. Much like `MessageHandler::outgoingNotification`, this allows for the message handler to send outgoing messages via its JSON transport, but in this case, those messages are requests, not notifications. Requests receive responses (also referred to as "replies" in `MLIRLspServerSupportLib`). These were previously unsupported, and `lsp::MessageHandler` would log an error each time it processed a JSON message that appeared to be a response (something with an "id" field, but no "method" field). However, the `outgoingRequest` method now handles response callbacks: an outgoing request with a given ID is set up such that a callback function is invoked when a response with that ID is received.
1 parent b77416e commit 1ad6332

File tree

3 files changed

+87
-15
lines changed

3 files changed

+87
-15
lines changed

mlir/include/mlir/Tools/lsp-server-support/Transport.h

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#ifndef MLIR_TOOLS_LSPSERVERSUPPORT_TRANSPORT_H
1616
#define MLIR_TOOLS_LSPSERVERSUPPORT_TRANSPORT_H
1717

18+
#include "mlir/Support/DebugStringHelper.h"
1819
#include "mlir/Support/LLVM.h"
1920
#include "mlir/Support/LogicalResult.h"
2021
#include "mlir/Tools/lsp-server-support/Logging.h"
@@ -100,6 +101,18 @@ using Callback = llvm::unique_function<void(llvm::Expected<T>)>;
100101
template <typename T>
101102
using OutgoingMessage = llvm::unique_function<void(const T &)>;
102103

104+
/// An OutgoingRequest<T> is a function used for outgoing requests to send to
105+
/// the client.
106+
template <typename T>
107+
using OutgoingRequest =
108+
llvm::unique_function<void(const T &, llvm::json::Value id)>;
109+
110+
/// An `OutgoingRequestCallback` is invoked when an outgoing request to the
111+
/// client receives a response in turn. It is passed the original request's ID,
112+
/// as well as the result JSON.
113+
using OutgoingRequestCallback =
114+
std::function<void(llvm::json::Value, llvm::Expected<llvm::json::Value>)>;
115+
103116
/// A handler used to process the incoming transport messages.
104117
class MessageHandler {
105118
public:
@@ -171,6 +184,26 @@ class MessageHandler {
171184
};
172185
}
173186

187+
/// Create an OutgoingRequest function that, when called, sends a request with
188+
/// the given method via the transport. Should the outgoing request be
189+
/// met with a response, the response callback is invoked to handle that
190+
/// response.
191+
template <typename T>
192+
OutgoingRequest<T> outgoingRequest(llvm::StringLiteral method,
193+
OutgoingRequestCallback callback) {
194+
return [&, method](const T &params, llvm::json::Value id) {
195+
{
196+
std::lock_guard<std::mutex> lock(responseHandlersMutex);
197+
responseHandlers.insert(
198+
{debugString(id), std::make_pair(method.str(), callback)});
199+
}
200+
201+
std::lock_guard<std::mutex> transportLock(transportOutputMutex);
202+
Logger::info("--> {0}({1})", method, id);
203+
transport.call(method, llvm::json::Value(params), id);
204+
};
205+
}
206+
174207
private:
175208
template <typename HandlerT>
176209
using HandlerMap = llvm::StringMap<llvm::unique_function<HandlerT>>;
@@ -179,6 +212,14 @@ class MessageHandler {
179212
HandlerMap<void(llvm::json::Value, Callback<llvm::json::Value>)>
180213
methodHandlers;
181214

215+
/// A pair of (1) the original request's method name, and (2) the callback
216+
/// function to be invoked for responses.
217+
using ResponseHandlerTy = std::pair<std::string, OutgoingRequestCallback>;
218+
/// A mapping from request/response ID to response handler.
219+
llvm::StringMap<ResponseHandlerTy> responseHandlers;
220+
/// Mutex to guard insertion into the response handler map.
221+
std::mutex responseHandlersMutex;
222+
182223
JSONTransport &transport;
183224

184225
/// Mutex to guard sending output messages to the transport.

mlir/lib/Tools/lsp-server-support/Transport.cpp

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -117,21 +117,17 @@ bool MessageHandler::onCall(llvm::StringRef method, llvm::json::Value params,
117117

118118
bool MessageHandler::onReply(llvm::json::Value id,
119119
llvm::Expected<llvm::json::Value> result) {
120-
// TODO: Add support for reply callbacks when support for outgoing messages is
121-
// added. For now, we just log an error on any replies received.
122-
Callback<llvm::json::Value> replyHandler =
123-
[&id](llvm::Expected<llvm::json::Value> result) {
124-
Logger::error(
125-
"received a reply with ID {0}, but there was no such call", id);
126-
if (!result)
127-
llvm::consumeError(result.takeError());
128-
};
129-
130-
// Log and run the reply handler.
131-
if (result)
132-
replyHandler(std::move(result));
133-
else
134-
replyHandler(result.takeError());
120+
auto it = responseHandlers.find(debugString(id));
121+
if (it != responseHandlers.end()) {
122+
Logger::info("--> reply:{0}({1})", it->second.first, id);
123+
it->second.second(std::move(id), std::move(result));
124+
} else {
125+
Logger::error(
126+
"received a reply with ID {0}, but there was no such outgoing request",
127+
id);
128+
if (!result)
129+
llvm::consumeError(result.takeError());
130+
}
135131
return true;
136132
}
137133

mlir/unittests/Tools/lsp-server-support/Transport.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,4 +125,39 @@ TEST_F(TransportInputTest, OutgoingNotification) {
125125
notifyFn(CompletionList{});
126126
EXPECT_THAT(getOutput(), HasSubstr("\"method\":\"outgoing-notification\""));
127127
}
128+
129+
TEST_F(TransportInputTest, ResponseHandlerNotFound) {
130+
// Unhandled responses are only reported via error logging. As a result, this
131+
// test can't make any expectations -- but it prints the output anyway, by way
132+
// of demonstration.
133+
Logger::setLogLevel(Logger::Level::Error);
134+
writeInput("{\"jsonrpc\":\"2.0\",\"id\":81,\"result\":null}\n");
135+
runTransport();
136+
}
137+
138+
TEST_F(TransportInputTest, OutgoingRequest) {
139+
// Make some outgoing requests.
140+
bool responseCallbackInvoked = false;
141+
auto callFn = getMessageHandler().outgoingRequest<CompletionList>(
142+
"outgoing-request",
143+
[&responseCallbackInvoked](llvm::json::Value id,
144+
llvm::Expected<llvm::json::Value> value) {
145+
// Make expectations on the expected response.
146+
EXPECT_EQ(id, 83);
147+
ASSERT_TRUE((bool)value);
148+
EXPECT_EQ(debugString(*value), "{\"foo\":6}");
149+
responseCallbackInvoked = true;
150+
});
151+
callFn({}, 82);
152+
callFn({}, 83);
153+
callFn({}, 84);
154+
EXPECT_THAT(getOutput(), HasSubstr("\"method\":\"outgoing-request\""));
155+
EXPECT_FALSE(responseCallbackInvoked);
156+
157+
// One of the requests receives a response. The message handler handles this
158+
// response by invoking the callback from above.
159+
writeInput("{\"jsonrpc\":\"2.0\",\"id\":83,\"result\":{\"foo\":6}}\n");
160+
runTransport();
161+
EXPECT_TRUE(responseCallbackInvoked);
162+
}
128163
} // namespace

0 commit comments

Comments
 (0)