Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions crates/rmcp/src/handler/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ impl<H: ClientHandler> Service<RoleClient> for H {
.create_elicitation(request.params, context)
.await
.map(ClientResult::CreateElicitationResult),
ServerRequest::CustomRequest(request) => self
.on_custom_request(request, context)
.await
.map(ClientResult::CustomResult),
}
}

Expand Down Expand Up @@ -123,6 +127,20 @@ pub trait ClientHandler: Sized + Send + Sync + 'static {
}))
}

fn on_custom_request(
&self,
request: CustomRequest,
context: RequestContext<RoleClient>,
) -> impl Future<Output = Result<CustomResult, McpError>> + Send + '_ {
let CustomRequest { method, .. } = request;
let _ = context;
std::future::ready(Err(McpError::new(
ErrorCode::METHOD_NOT_FOUND,
method,
None,
)))
}

fn on_cancelled(
&self,
params: CancelledNotificationParam,
Expand Down
17 changes: 17 additions & 0 deletions crates/rmcp/src/handler/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ impl<H: ServerHandler> Service<RoleServer> for H {
.list_tools(request.params, context)
.await
.map(ServerResult::ListToolsResult),
ClientRequest::CustomRequest(request) => self
.on_custom_request(request, context)
.await
.map(ServerResult::CustomResult),
}
}

Expand Down Expand Up @@ -200,6 +204,19 @@ pub trait ServerHandler: Sized + Send + Sync + 'static {
) -> impl Future<Output = Result<ListToolsResult, McpError>> + Send + '_ {
std::future::ready(Ok(ListToolsResult::default()))
}
fn on_custom_request(
&self,
request: CustomRequest,
context: RequestContext<RoleServer>,
) -> impl Future<Output = Result<CustomResult, McpError>> + Send + '_ {
let CustomRequest { method, .. } = request;
let _ = context;
std::future::ready(Err(McpError::new(
ErrorCode::METHOD_NOT_FOUND,
method,
None,
)))
}

fn on_cancelled(
&self,
Expand Down
103 changes: 99 additions & 4 deletions crates/rmcp/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ macro_rules! object {
///
/// without returning any specific data.
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Copy, Eq)]
#[serde(deny_unknown_fields)]
#[cfg_attr(feature = "server", derive(schemars::JsonSchema))]
pub struct EmptyObject {}

Expand Down Expand Up @@ -606,6 +607,23 @@ impl From<EmptyResult> for () {
fn from(_value: EmptyResult) {}
}

/// A catch-all response either side can use for custom requests.
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
#[serde(transparent)]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
pub struct CustomResult(pub Value);

impl CustomResult {
pub fn new(result: Value) -> Self {
Self(result)
}

/// Deserialize the result into a strongly-typed structure.
pub fn result_as<T: DeserializeOwned>(&self) -> Result<T, serde_json::Error> {
serde_json::from_value(self.0.clone())
}
}

#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
#[serde(rename_all = "camelCase")]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
Expand Down Expand Up @@ -661,6 +679,40 @@ impl CustomNotification {
}
}

/// A catch-all request either side can use to send custom messages to its peer.
///
/// This preserves the raw `method` name and `params` payload so handlers can
/// deserialize them into domain-specific types.
#[derive(Debug, Clone)]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
pub struct CustomRequest {
pub method: String,
pub params: Option<Value>,
/// extensions will carry anything possible in the context, including [`Meta`]
///
/// this is similar with the Extensions in `http` crate
#[cfg_attr(feature = "schemars", schemars(skip))]
pub extensions: Extensions,
}

impl CustomRequest {
pub fn new(method: impl Into<String>, params: Option<Value>) -> Self {
Self {
method: method.into(),
params,
extensions: Extensions::default(),
}
}

/// Deserialize `params` into a strongly-typed structure.
pub fn params_as<T: DeserializeOwned>(&self) -> Result<Option<T>, serde_json::Error> {
self.params
.as_ref()
.map(|params| serde_json::from_value(params.clone()))
.transpose()
}
}

const_string!(InitializeResultMethod = "initialize");
/// # Initialization
/// This request is sent from the client to the server when it first connects, asking it to begin initialization.
Expand Down Expand Up @@ -1757,11 +1809,12 @@ ts_union!(
| SubscribeRequest
| UnsubscribeRequest
| CallToolRequest
| ListToolsRequest;
| ListToolsRequest
| CustomRequest;
);

impl ClientRequest {
pub fn method(&self) -> &'static str {
pub fn method(&self) -> &str {
match &self {
ClientRequest::PingRequest(r) => r.method.as_str(),
ClientRequest::InitializeRequest(r) => r.method.as_str(),
Expand All @@ -1776,6 +1829,7 @@ impl ClientRequest {
ClientRequest::UnsubscribeRequest(r) => r.method.as_str(),
ClientRequest::CallToolRequest(r) => r.method.as_str(),
ClientRequest::ListToolsRequest(r) => r.method.as_str(),
ClientRequest::CustomRequest(r) => r.method.as_str(),
}
}
}
Expand All @@ -1790,7 +1844,12 @@ ts_union!(
);

ts_union!(
export type ClientResult = box CreateMessageResult | ListRootsResult | CreateElicitationResult | EmptyResult;
export type ClientResult =
box CreateMessageResult
| ListRootsResult
| CreateElicitationResult
| EmptyResult
| CustomResult;
);

impl ClientResult {
Expand All @@ -1806,7 +1865,8 @@ ts_union!(
| PingRequest
| CreateMessageRequest
| ListRootsRequest
| CreateElicitationRequest;
| CreateElicitationRequest
| CustomRequest;
);

ts_union!(
Expand Down Expand Up @@ -1834,6 +1894,7 @@ ts_union!(
| ListToolsResult
| CreateElicitationResult
| EmptyResult
| CustomResult
;
);

Expand Down Expand Up @@ -1960,6 +2021,40 @@ mod tests {
assert_eq!(json, raw);
}

#[test]
fn test_custom_request_roundtrip() {
let raw = json!( {
"jsonrpc": JsonRpcVersion2_0,
"id": 42,
"method": "requests/custom",
"params": {"foo": "bar"},
});

let message: ClientJsonRpcMessage =
serde_json::from_value(raw.clone()).expect("invalid request");
match &message {
ClientJsonRpcMessage::Request(JsonRpcRequest { id, request, .. }) => {
assert_eq!(id, &RequestId::Number(42));
match request {
ClientRequest::CustomRequest(custom) => {
let expected_request = json!({
"method": "requests/custom",
"params": {"foo": "bar"},
});
let actual_request =
serde_json::to_value(custom).expect("serialize custom request");
assert_eq!(actual_request, expected_request);
}
other => panic!("Expected custom request, got: {other:?}"),
}
}
other => panic!("Expected request, got: {other:?}"),
}

let json = serde_json::to_value(message).expect("valid json");
assert_eq!(json, raw);
}

#[test]
fn test_request_conversion() {
let raw = json!( {
Expand Down
26 changes: 24 additions & 2 deletions crates/rmcp/src/model/meta.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ use serde::{Deserialize, Serialize};
use serde_json::Value;

use super::{
ClientNotification, ClientRequest, CustomNotification, Extensions, JsonObject, JsonRpcMessage,
NumberOrString, ProgressToken, ServerNotification, ServerRequest,
ClientNotification, ClientRequest, CustomNotification, CustomRequest, Extensions, JsonObject,
JsonRpcMessage, NumberOrString, ProgressToken, ServerNotification, ServerRequest,
};

pub trait GetMeta {
Expand Down Expand Up @@ -38,6 +38,26 @@ impl GetMeta for CustomNotification {
}
}

impl GetExtensions for CustomRequest {
fn extensions(&self) -> &Extensions {
&self.extensions
}
fn extensions_mut(&mut self) -> &mut Extensions {
&mut self.extensions
}
}

impl GetMeta for CustomRequest {
fn get_meta_mut(&mut self) -> &mut Meta {
self.extensions_mut().get_or_insert_default()
}
fn get_meta(&self) -> &Meta {
self.extensions()
.get::<Meta>()
.unwrap_or(Meta::static_empty())
}
}

macro_rules! variant_extension {
(
$Enum: ident {
Expand Down Expand Up @@ -86,6 +106,7 @@ variant_extension! {
UnsubscribeRequest
CallToolRequest
ListToolsRequest
CustomRequest
}
}

Expand All @@ -95,6 +116,7 @@ variant_extension! {
CreateMessageRequest
ListRootsRequest
CreateElicitationRequest
CustomRequest
}
}

Expand Down
57 changes: 55 additions & 2 deletions crates/rmcp/src/model/serde_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use std::borrow::Cow;
use serde::{Deserialize, Serialize};

use super::{
CustomNotification, Extensions, Meta, Notification, NotificationNoParam, Request,
RequestNoParam, RequestOptionalParam,
CustomNotification, CustomRequest, Extensions, Meta, Notification, NotificationNoParam,
Request, RequestNoParam, RequestOptionalParam,
};
#[derive(Serialize, Deserialize)]
struct WithMeta<'a, P> {
Expand Down Expand Up @@ -249,6 +249,59 @@ where
}
}

impl Serialize for CustomRequest {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let extensions = &self.extensions;
let _meta = extensions.get::<Meta>().map(Cow::Borrowed);
let params = self.params.as_ref();

let params = if _meta.is_some() || params.is_some() {
Some(WithMeta {
_meta,
_rest: &self.params,
})
} else {
None
};

ProxyOptionalParam::serialize(
&ProxyOptionalParam {
method: &self.method,
params,
},
serializer,
)
}
}

impl<'de> Deserialize<'de> for CustomRequest {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let body =
ProxyOptionalParam::<'_, _, Option<serde_json::Value>>::deserialize(deserializer)?;
let mut params = None;
let mut _meta = None;
if let Some(body_params) = body.params {
params = body_params._rest;
_meta = body_params._meta.map(|m| m.into_owned());
}
let mut extensions = Extensions::new();
if let Some(meta) = _meta {
extensions.insert(meta);
}
Ok(CustomRequest {
extensions,
method: body.method,
params,
})
}
}

impl Serialize for CustomNotification {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
Expand Down
Loading