Skip to content

Commit 79df4ea

Browse files
committed
feat: add support for custom server notifications
#556 introduced support for custom client notifications, so this PR makes the complementary change, adding support for custom server notifications. MCP clients, particularly ones that offer "experimental" capabilities, may wish to handle custom server notifications that are not part of the standard MCP specification. This change introduces a new `CustomServerNotification` type that allows a client to process such custom notifications. - introduces `CustomServerNotification` to carry arbitrary methods/params while still preserving meta/extensions; wires it into the `ServerNotification` union and `serde` so `params` can be decoded with `params_as` - allows client handlers to receive custom notifications via a new `on_custom_notification` hook - adds integration coverage that sends a custom server notification end-to-end and asserts the client sees the method and payload Test: ```shell cargo test -p rmcp --features client test_custom_server_notification_reaches_client ```
1 parent f20ed20 commit 79df4ea

File tree

7 files changed

+266
-7
lines changed

7 files changed

+266
-7
lines changed

crates/rmcp/src/handler/client.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ impl<H: ClientHandler> Service<RoleClient> for H {
5656
ServerNotification::PromptListChangedNotification(_notification_no_param) => {
5757
self.on_prompt_list_changed(context).await
5858
}
59+
ServerNotification::CustomServerNotification(notification) => {
60+
self.on_custom_notification(notification, context).await
61+
}
5962
};
6063
Ok(())
6164
}
@@ -166,6 +169,14 @@ pub trait ClientHandler: Sized + Send + Sync + 'static {
166169
) -> impl Future<Output = ()> + Send + '_ {
167170
std::future::ready(())
168171
}
172+
fn on_custom_notification(
173+
&self,
174+
notification: CustomServerNotification,
175+
context: NotificationContext<RoleClient>,
176+
) -> impl Future<Output = ()> + Send + '_ {
177+
let _ = (notification, context);
178+
std::future::ready(())
179+
}
169180

170181
fn get_info(&self) -> ClientInfo {
171182
ClientInfo::default()

crates/rmcp/src/model.rs

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -661,6 +661,40 @@ impl CustomClientNotification {
661661
}
662662
}
663663

664+
/// A catch-all notification the server can use to send custom messages to a client.
665+
///
666+
/// This preserves the raw `method` name and `params` payload so handlers can
667+
/// deserialize them into domain-specific types.
668+
#[derive(Debug, Clone)]
669+
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
670+
pub struct CustomServerNotification {
671+
pub method: String,
672+
pub params: Option<Value>,
673+
/// extensions will carry anything possible in the context, including [`Meta`]
674+
///
675+
/// this is similar with the Extensions in `http` crate
676+
#[cfg_attr(feature = "schemars", schemars(skip))]
677+
pub extensions: Extensions,
678+
}
679+
680+
impl CustomServerNotification {
681+
pub fn new(method: impl Into<String>, params: Option<Value>) -> Self {
682+
Self {
683+
method: method.into(),
684+
params,
685+
extensions: Extensions::default(),
686+
}
687+
}
688+
689+
/// Deserialize `params` into a strongly-typed structure.
690+
pub fn params_as<T: DeserializeOwned>(&self) -> Result<Option<T>, serde_json::Error> {
691+
self.params
692+
.as_ref()
693+
.map(|params| serde_json::from_value(params.clone()))
694+
.transpose()
695+
}
696+
}
697+
664698
const_string!(InitializeResultMethod = "initialize");
665699
/// # Initialization
666700
/// This request is sent from the client to the server when it first connects, asking it to begin initialization.
@@ -1817,7 +1851,8 @@ ts_union!(
18171851
| ResourceUpdatedNotification
18181852
| ResourceListChangedNotification
18191853
| ToolListChangedNotification
1820-
| PromptListChangedNotification;
1854+
| PromptListChangedNotification
1855+
| CustomServerNotification;
18211856
);
18221857

18231858
ts_union!(
@@ -1927,6 +1962,38 @@ mod tests {
19271962
assert_eq!(json, raw);
19281963
}
19291964

1965+
#[test]
1966+
fn test_custom_server_notification_roundtrip() {
1967+
let raw = json!( {
1968+
"jsonrpc": JsonRpcVersion2_0,
1969+
"method": "notifications/custom-server",
1970+
"params": {"hello": "world"},
1971+
});
1972+
1973+
let message: ServerJsonRpcMessage =
1974+
serde_json::from_value(raw.clone()).expect("invalid notification");
1975+
match &message {
1976+
ServerJsonRpcMessage::Notification(JsonRpcNotification {
1977+
notification: ServerNotification::CustomServerNotification(notification),
1978+
..
1979+
}) => {
1980+
assert_eq!(notification.method, "notifications/custom-server");
1981+
assert_eq!(
1982+
notification
1983+
.params
1984+
.as_ref()
1985+
.and_then(|p| p.get("hello"))
1986+
.expect("hello present"),
1987+
"world"
1988+
);
1989+
}
1990+
_ => panic!("Expected custom server notification"),
1991+
}
1992+
1993+
let json = serde_json::to_value(message).expect("valid json");
1994+
assert_eq!(json, raw);
1995+
}
1996+
19301997
#[test]
19311998
fn test_request_conversion() {
19321999
let raw = json!( {

crates/rmcp/src/model/meta.rs

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@ use serde::{Deserialize, Serialize};
44
use serde_json::Value;
55

66
use super::{
7-
ClientNotification, ClientRequest, CustomClientNotification, Extensions, JsonObject,
8-
JsonRpcMessage, NumberOrString, ProgressToken, ServerNotification, ServerRequest,
7+
ClientNotification, ClientRequest, CustomClientNotification, CustomServerNotification,
8+
Extensions, JsonObject, JsonRpcMessage, NumberOrString, ProgressToken, ServerNotification,
9+
ServerRequest,
910
};
1011

1112
pub trait GetMeta {
@@ -38,6 +39,26 @@ impl GetMeta for CustomClientNotification {
3839
}
3940
}
4041

42+
impl GetExtensions for CustomServerNotification {
43+
fn extensions(&self) -> &Extensions {
44+
&self.extensions
45+
}
46+
fn extensions_mut(&mut self) -> &mut Extensions {
47+
&mut self.extensions
48+
}
49+
}
50+
51+
impl GetMeta for CustomServerNotification {
52+
fn get_meta_mut(&mut self) -> &mut Meta {
53+
self.extensions_mut().get_or_insert_default()
54+
}
55+
fn get_meta(&self) -> &Meta {
56+
self.extensions()
57+
.get::<Meta>()
58+
.unwrap_or(Meta::static_empty())
59+
}
60+
}
61+
4162
macro_rules! variant_extension {
4263
(
4364
$Enum: ident {
@@ -117,6 +138,7 @@ variant_extension! {
117138
ResourceListChangedNotification
118139
ToolListChangedNotification
119140
PromptListChangedNotification
141+
CustomServerNotification
120142
}
121143
}
122144
#[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq)]

crates/rmcp/src/model/serde_impl.rs

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ use std::borrow::Cow;
33
use serde::{Deserialize, Serialize};
44

55
use super::{
6-
CustomClientNotification, Extensions, Meta, Notification, NotificationNoParam, Request,
7-
RequestNoParam, RequestOptionalParam,
6+
CustomClientNotification, CustomServerNotification, Extensions, Meta, Notification,
7+
NotificationNoParam, Request, RequestNoParam, RequestOptionalParam,
88
};
99
#[derive(Serialize, Deserialize)]
1010
struct WithMeta<'a, P> {
@@ -302,6 +302,59 @@ impl<'de> Deserialize<'de> for CustomClientNotification {
302302
}
303303
}
304304

305+
impl Serialize for CustomServerNotification {
306+
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
307+
where
308+
S: serde::Serializer,
309+
{
310+
let extensions = &self.extensions;
311+
let _meta = extensions.get::<Meta>().map(Cow::Borrowed);
312+
let params = self.params.as_ref();
313+
314+
let params = if _meta.is_some() || params.is_some() {
315+
Some(WithMeta {
316+
_meta,
317+
_rest: &self.params,
318+
})
319+
} else {
320+
None
321+
};
322+
323+
ProxyOptionalParam::serialize(
324+
&ProxyOptionalParam {
325+
method: &self.method,
326+
params,
327+
},
328+
serializer,
329+
)
330+
}
331+
}
332+
333+
impl<'de> Deserialize<'de> for CustomServerNotification {
334+
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
335+
where
336+
D: serde::Deserializer<'de>,
337+
{
338+
let body =
339+
ProxyOptionalParam::<'_, _, Option<serde_json::Value>>::deserialize(deserializer)?;
340+
let mut params = None;
341+
let mut _meta = None;
342+
if let Some(body_params) = body.params {
343+
params = body_params._rest;
344+
_meta = body_params._meta.map(|m| m.into_owned());
345+
}
346+
let mut extensions = Extensions::new();
347+
if let Some(meta) = _meta {
348+
extensions.insert(meta);
349+
}
350+
Ok(CustomServerNotification {
351+
extensions,
352+
method: body.method,
353+
params,
354+
})
355+
}
356+
}
357+
305358
#[cfg(test)]
306359
mod test {
307360
use serde_json::json;

crates/rmcp/tests/test_message_schema/server_json_rpc_message_schema.json

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,19 @@
392392
"content"
393393
]
394394
},
395+
"CustomServerNotification": {
396+
"description": "A catch-all notification the server can use to send custom messages to a client.\n\nThis preserves the raw `method` name and `params` payload so handlers can\ndeserialize them into domain-specific types.",
397+
"type": "object",
398+
"properties": {
399+
"method": {
400+
"type": "string"
401+
},
402+
"params": true
403+
},
404+
"required": [
405+
"method"
406+
]
407+
},
395408
"CancelledNotificationMethod": {
396409
"type": "string",
397410
"format": "const",
@@ -977,6 +990,9 @@
977990
},
978991
{
979992
"$ref": "#/definitions/NotificationNoParam3"
993+
},
994+
{
995+
"$ref": "#/definitions/CustomServerNotification"
980996
}
981997
],
982998
"required": [

crates/rmcp/tests/test_message_schema/server_json_rpc_message_schema_current.json

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,19 @@
392392
"content"
393393
]
394394
},
395+
"CustomServerNotification": {
396+
"description": "A catch-all notification the server can use to send custom messages to a client.\n\nThis preserves the raw `method` name and `params` payload so handlers can\ndeserialize them into domain-specific types.",
397+
"type": "object",
398+
"properties": {
399+
"method": {
400+
"type": "string"
401+
},
402+
"params": true
403+
},
404+
"required": [
405+
"method"
406+
]
407+
},
395408
"CancelledNotificationMethod": {
396409
"type": "string",
397410
"format": "const",
@@ -977,6 +990,9 @@
977990
},
978991
{
979992
"$ref": "#/definitions/NotificationNoParam3"
993+
},
994+
{
995+
"$ref": "#/definitions/CustomServerNotification"
980996
}
981997
],
982998
"required": [

crates/rmcp/tests/test_notification.rs

Lines changed: 76 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@ use std::sync::Arc;
33
use rmcp::{
44
ClientHandler, ServerHandler, ServiceExt,
55
model::{
6-
ClientNotification, CustomClientNotification, ResourceUpdatedNotificationParam,
7-
ServerCapabilities, ServerInfo, SubscribeRequestParam,
6+
ClientNotification, CustomClientNotification, CustomServerNotification,
7+
ResourceUpdatedNotificationParam, ServerCapabilities, ServerInfo, ServerNotification,
8+
SubscribeRequestParam,
89
},
910
};
1011
use serde_json::json;
@@ -165,3 +166,76 @@ async fn test_custom_client_notification_reaches_server() -> anyhow::Result<()>
165166
client.cancel().await?;
166167
Ok(())
167168
}
169+
170+
struct CustomServerNotifier;
171+
172+
impl ServerHandler for CustomServerNotifier {
173+
async fn on_initialized(&self, context: rmcp::service::NotificationContext<rmcp::RoleServer>) {
174+
let peer = context.peer.clone();
175+
tokio::spawn(async move {
176+
peer.send_notification(ServerNotification::CustomServerNotification(
177+
CustomServerNotification::new(
178+
"notifications/custom-test",
179+
Some(json!({ "hello": "world" })),
180+
),
181+
))
182+
.await
183+
.expect("send custom notification");
184+
});
185+
}
186+
}
187+
188+
struct CustomClient {
189+
receive_signal: Arc<Notify>,
190+
payload: Arc<Mutex<Option<CustomNotificationPayload>>>,
191+
}
192+
193+
impl ClientHandler for CustomClient {
194+
async fn on_custom_notification(
195+
&self,
196+
notification: CustomServerNotification,
197+
_context: rmcp::service::NotificationContext<rmcp::RoleClient>,
198+
) {
199+
let CustomServerNotification { method, params, .. } = notification;
200+
let mut payload = self.payload.lock().await;
201+
*payload = Some((method, params));
202+
self.receive_signal.notify_one();
203+
}
204+
}
205+
206+
#[tokio::test]
207+
async fn test_custom_server_notification_reaches_client() -> anyhow::Result<()> {
208+
let _ = tracing_subscriber::registry()
209+
.with(
210+
tracing_subscriber::EnvFilter::try_from_default_env()
211+
.unwrap_or_else(|_| "debug".to_string().into()),
212+
)
213+
.with(tracing_subscriber::fmt::layer())
214+
.try_init();
215+
216+
let (server_transport, client_transport) = tokio::io::duplex(4096);
217+
tokio::spawn(async move {
218+
let server = CustomServerNotifier {}.serve(server_transport).await?;
219+
server.waiting().await?;
220+
anyhow::Ok(())
221+
});
222+
223+
let receive_signal = Arc::new(Notify::new());
224+
let payload = Arc::new(Mutex::new(None));
225+
226+
let client = CustomClient {
227+
receive_signal: receive_signal.clone(),
228+
payload: payload.clone(),
229+
}
230+
.serve(client_transport)
231+
.await?;
232+
233+
tokio::time::timeout(std::time::Duration::from_secs(5), receive_signal.notified()).await?;
234+
235+
let (method, params) = payload.lock().await.clone().expect("payload set");
236+
assert_eq!("notifications/custom-test", method);
237+
assert_eq!(Some(json!({ "hello": "world" })), params);
238+
239+
client.cancel().await?;
240+
Ok(())
241+
}

0 commit comments

Comments
 (0)