Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/rmcp-macros/src/tool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ pub fn tool(attr: TokenStream, input: TokenStream) -> syn::Result<TokenStream> {
if let Some(params_ty) = params_ty {
// if found, use the Parameters schema
syn::parse2::<Expr>(quote! {
rmcp::handler::server::common::cached_schema_for_type::<#params_ty>()
rmcp::handler::server::common::schema_for_type::<#params_ty>()
})?
} else {
// if not found, use a default empty JSON schema object
Expand Down
107 changes: 84 additions & 23 deletions crates/rmcp/src/handler/server/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,8 @@ use crate::{
RoleServer, model::JsonObject, schemars::generate::SchemaSettings, service::RequestContext,
};

/// A shortcut for generating a JSON schema for a type.
pub fn schema_for_type<T: JsonSchema>() -> JsonObject {
// explicitly to align json schema version to official specifications.
// refer to https://github.com/modelcontextprotocol/modelcontextprotocol/pull/655 for details.
let mut settings = SchemaSettings::draft2020_12();
settings.transforms = vec![Box::new(schemars::transform::AddNullable::default())];
let generator = settings.into_generator();
let schema = generator.into_root_schema_for::<T>();
let object = serde_json::to_value(schema).expect("failed to serialize schema");
match object {
serde_json::Value::Object(object) => object,
_ => panic!(
"Schema serialization produced non-object value: expected JSON object but got {:?}",
object
),
}
}

/// Call [`schema_for_type`] with a cache
pub fn cached_schema_for_type<T: JsonSchema + std::any::Any>() -> Arc<JsonObject> {
/// Generates a JSON schema for a type
pub fn schema_for_type<T: JsonSchema + std::any::Any>() -> Arc<JsonObject> {
thread_local! {
static CACHE_FOR_TYPE: std::sync::RwLock<HashMap<TypeId, Arc<JsonObject>>> = Default::default();
};
Expand All @@ -39,12 +21,26 @@ pub fn cached_schema_for_type<T: JsonSchema + std::any::Any>() -> Arc<JsonObject
{
x.clone()
} else {
let schema = schema_for_type::<T>();
let schema = Arc::new(schema);
// explicitly to align json schema version to official specifications.
// refer to https://github.com/modelcontextprotocol/modelcontextprotocol/pull/655 for details.
let mut settings = SchemaSettings::draft2020_12();
settings.transforms = vec![Box::new(schemars::transform::AddNullable::default())];
let generator = settings.into_generator();
let schema = generator.into_root_schema_for::<T>();
let object = serde_json::to_value(schema).expect("failed to serialize schema");
let object = match object {
serde_json::Value::Object(object) => object,
_ => panic!(
"Schema serialization produced non-object value: expected JSON object but got {:?}",
object
),
};
let schema = Arc::new(object);
cache
.write()
.expect("schema cache lock poisoned")
.insert(TypeId::of::<T>(), schema.clone());

schema
}
})
Expand All @@ -69,7 +65,7 @@ pub fn schema_for_output<T: JsonSchema + std::any::Any>() -> Result<Arc<JsonObje
// Generate and validate schema
let schema = schema_for_type::<T>();
let result = match schema.get("type") {
Some(serde_json::Value::String(t)) if t == "object" => Ok(Arc::new(schema)),
Some(serde_json::Value::String(t)) if t == "object" => Ok(schema.clone()),
Some(serde_json::Value::String(t)) => Err(format!(
"MCP specification requires tool outputSchema to have root type 'object', but found '{}'.",
t
Expand Down Expand Up @@ -196,6 +192,71 @@ mod tests {
value: i32,
}

#[derive(serde::Serialize, serde::Deserialize, JsonSchema)]
struct AnotherTestObject {
value: i32,
}

#[test]
fn test_schema_for_type_handles_primitive() {
let schema = schema_for_type::<i32>();

assert_eq!(schema.get("type"), Some(&serde_json::json!("integer")));
}

#[test]
fn test_schema_for_type_handles_array() {
let schema = schema_for_type::<Vec<i32>>();

assert_eq!(schema.get("type"), Some(&serde_json::json!("array")));
let items = schema.get("items").and_then(|v| v.as_object());
assert_eq!(
items.unwrap().get("type"),
Some(&serde_json::json!("integer"))
);
}

#[test]
fn test_schema_for_type_handles_struct() {
let schema = schema_for_type::<TestObject>();

assert_eq!(schema.get("type"), Some(&serde_json::json!("object")));
let properties = schema.get("properties").and_then(|v| v.as_object());
assert!(properties.unwrap().contains_key("value"));
}

#[test]
fn test_schema_for_type_caches_primitive_types() {
let schema1 = schema_for_type::<i32>();
let schema2 = schema_for_type::<i32>();

assert!(Arc::ptr_eq(&schema1, &schema2));
}

#[test]
fn test_schema_for_type_caches_struct_types() {
let schema1 = schema_for_type::<TestObject>();
let schema2 = schema_for_type::<TestObject>();

assert!(Arc::ptr_eq(&schema1, &schema2));
}

#[test]
fn test_schema_for_type_different_types_different_schemas() {
let schema1 = schema_for_type::<TestObject>();
let schema2 = schema_for_type::<AnotherTestObject>();

assert!(!Arc::ptr_eq(&schema1, &schema2));
}

#[test]
fn test_schema_for_type_arc_can_be_shared() {
let schema = schema_for_type::<TestObject>();
let cloned = schema.clone();

assert!(Arc::ptr_eq(&schema, &cloned));
}

#[test]
fn test_schema_for_output_rejects_primitive() {
let result = schema_for_output::<i32>();
Expand Down
2 changes: 1 addition & 1 deletion crates/rmcp/src/handler/server/prompt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ impl_prompt_handler_for!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15);
/// as PromptArgument entries with name, description, and required status
pub fn cached_arguments_from_schema<T: schemars::JsonSchema + std::any::Any>()
-> Option<Vec<crate::model::PromptArgument>> {
let schema = super::common::cached_schema_for_type::<T>();
let schema = super::common::schema_for_type::<T>();
let schema_value = serde_json::Value::Object((*schema).clone());

let properties = schema_value.get("properties").and_then(|p| p.as_object());
Expand Down
4 changes: 2 additions & 2 deletions crates/rmcp/src/handler/server/router/tool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,8 @@ where
self.attr.description = Some(description.into());
self
}
pub fn parameters<T: JsonSchema>(mut self) -> Self {
self.attr.input_schema = schema_for_type::<T>().into();
pub fn parameters<T: JsonSchema + 'static>(mut self) -> Self {
self.attr.input_schema = schema_for_type::<T>();
self
}
pub fn parameters_value(mut self, schema: serde_json::Value) -> Self {
Expand Down
2 changes: 1 addition & 1 deletion crates/rmcp/src/handler/server/tool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use serde::de::DeserializeOwned;

use super::common::{AsRequestContext, FromContextPart};
pub use super::{
common::{Extension, RequestId, cached_schema_for_type, schema_for_output, schema_for_type},
common::{Extension, RequestId, schema_for_output, schema_for_type},
router::tool::{ToolRoute, ToolRouter},
};
use crate::{
Expand Down
2 changes: 1 addition & 1 deletion crates/rmcp/src/model/tool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ impl Tool {

/// Set the input schema using a type that implements JsonSchema
pub fn with_input_schema<T: JsonSchema + 'static>(mut self) -> Self {
self.input_schema = crate::handler::server::tool::cached_schema_for_type::<T>();
self.input_schema = crate::handler::server::tool::schema_for_type::<T>();
self
}

Expand Down
2 changes: 1 addition & 1 deletion crates/rmcp/tests/test_json_schema_detection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ impl TestServer {
}

/// Tool with explicit output_schema attribute - should have output schema
#[tool(name = "explicit-schema", output_schema = rmcp::handler::server::tool::cached_schema_for_type::<TestData>())]
#[tool(name = "explicit-schema", output_schema = rmcp::handler::server::tool::schema_for_type::<TestData>())]
pub async fn explicit_schema(&self) -> Result<String, String> {
Ok("test".to_string())
}
Expand Down