diff --git a/crates/rmcp-macros/src/tool.rs b/crates/rmcp-macros/src/tool.rs index ab343f54..abfb3c36 100644 --- a/crates/rmcp-macros/src/tool.rs +++ b/crates/rmcp-macros/src/tool.rs @@ -220,7 +220,7 @@ pub fn tool(attr: TokenStream, input: TokenStream) -> syn::Result { if let Some(params_ty) = params_ty { // if found, use the Parameters schema syn::parse2::(quote! { - rmcp::handler::server::common::cached_schema_for_type::<#params_ty>() + rmcp::handler::server::common::schema_for_type::<#params_ty>() })? } else { // if not found, use a default empty JSON schema object diff --git a/crates/rmcp/src/handler/server/common.rs b/crates/rmcp/src/handler/server/common.rs index 2d6deaf4..4f344a86 100644 --- a/crates/rmcp/src/handler/server/common.rs +++ b/crates/rmcp/src/handler/server/common.rs @@ -8,26 +8,8 @@ use crate::{ RoleServer, model::JsonObject, schemars::generate::SchemaSettings, service::RequestContext, }; -/// A shortcut for generating a JSON schema for a type. -pub fn schema_for_type() -> JsonObject { - // explicitly to align json schema version to official specifications. - // refer to https://github.com/modelcontextprotocol/modelcontextprotocol/pull/655 for details. - let mut settings = SchemaSettings::draft2020_12(); - settings.transforms = vec![Box::new(schemars::transform::AddNullable::default())]; - let generator = settings.into_generator(); - let schema = generator.into_root_schema_for::(); - let object = serde_json::to_value(schema).expect("failed to serialize schema"); - match object { - serde_json::Value::Object(object) => object, - _ => panic!( - "Schema serialization produced non-object value: expected JSON object but got {:?}", - object - ), - } -} - -/// Call [`schema_for_type`] with a cache -pub fn cached_schema_for_type() -> Arc { +/// Generates a JSON schema for a type +pub fn schema_for_type() -> Arc { thread_local! { static CACHE_FOR_TYPE: std::sync::RwLock>> = Default::default(); }; @@ -39,12 +21,26 @@ pub fn cached_schema_for_type() -> Arc(); - let schema = Arc::new(schema); + // explicitly to align json schema version to official specifications. + // refer to https://github.com/modelcontextprotocol/modelcontextprotocol/pull/655 for details. + let mut settings = SchemaSettings::draft2020_12(); + settings.transforms = vec![Box::new(schemars::transform::AddNullable::default())]; + let generator = settings.into_generator(); + let schema = generator.into_root_schema_for::(); + let object = serde_json::to_value(schema).expect("failed to serialize schema"); + let object = match object { + serde_json::Value::Object(object) => object, + _ => panic!( + "Schema serialization produced non-object value: expected JSON object but got {:?}", + object + ), + }; + let schema = Arc::new(object); cache .write() .expect("schema cache lock poisoned") .insert(TypeId::of::(), schema.clone()); + schema } }) @@ -69,7 +65,7 @@ pub fn schema_for_output() -> Result(); let result = match schema.get("type") { - Some(serde_json::Value::String(t)) if t == "object" => Ok(Arc::new(schema)), + Some(serde_json::Value::String(t)) if t == "object" => Ok(schema.clone()), Some(serde_json::Value::String(t)) => Err(format!( "MCP specification requires tool outputSchema to have root type 'object', but found '{}'.", t @@ -196,6 +192,71 @@ mod tests { value: i32, } + #[derive(serde::Serialize, serde::Deserialize, JsonSchema)] + struct AnotherTestObject { + value: i32, + } + + #[test] + fn test_schema_for_type_handles_primitive() { + let schema = schema_for_type::(); + + assert_eq!(schema.get("type"), Some(&serde_json::json!("integer"))); + } + + #[test] + fn test_schema_for_type_handles_array() { + let schema = schema_for_type::>(); + + assert_eq!(schema.get("type"), Some(&serde_json::json!("array"))); + let items = schema.get("items").and_then(|v| v.as_object()); + assert_eq!( + items.unwrap().get("type"), + Some(&serde_json::json!("integer")) + ); + } + + #[test] + fn test_schema_for_type_handles_struct() { + let schema = schema_for_type::(); + + assert_eq!(schema.get("type"), Some(&serde_json::json!("object"))); + let properties = schema.get("properties").and_then(|v| v.as_object()); + assert!(properties.unwrap().contains_key("value")); + } + + #[test] + fn test_schema_for_type_caches_primitive_types() { + let schema1 = schema_for_type::(); + let schema2 = schema_for_type::(); + + assert!(Arc::ptr_eq(&schema1, &schema2)); + } + + #[test] + fn test_schema_for_type_caches_struct_types() { + let schema1 = schema_for_type::(); + let schema2 = schema_for_type::(); + + assert!(Arc::ptr_eq(&schema1, &schema2)); + } + + #[test] + fn test_schema_for_type_different_types_different_schemas() { + let schema1 = schema_for_type::(); + let schema2 = schema_for_type::(); + + assert!(!Arc::ptr_eq(&schema1, &schema2)); + } + + #[test] + fn test_schema_for_type_arc_can_be_shared() { + let schema = schema_for_type::(); + let cloned = schema.clone(); + + assert!(Arc::ptr_eq(&schema, &cloned)); + } + #[test] fn test_schema_for_output_rejects_primitive() { let result = schema_for_output::(); diff --git a/crates/rmcp/src/handler/server/prompt.rs b/crates/rmcp/src/handler/server/prompt.rs index 5c262e2b..826bee0d 100644 --- a/crates/rmcp/src/handler/server/prompt.rs +++ b/crates/rmcp/src/handler/server/prompt.rs @@ -325,7 +325,7 @@ impl_prompt_handler_for!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15); /// as PromptArgument entries with name, description, and required status pub fn cached_arguments_from_schema() -> Option> { - let schema = super::common::cached_schema_for_type::(); + let schema = super::common::schema_for_type::(); let schema_value = serde_json::Value::Object((*schema).clone()); let properties = schema_value.get("properties").and_then(|p| p.as_object()); diff --git a/crates/rmcp/src/handler/server/router/tool.rs b/crates/rmcp/src/handler/server/router/tool.rs index 610c9326..03653124 100644 --- a/crates/rmcp/src/handler/server/router/tool.rs +++ b/crates/rmcp/src/handler/server/router/tool.rs @@ -154,8 +154,8 @@ where self.attr.description = Some(description.into()); self } - pub fn parameters(mut self) -> Self { - self.attr.input_schema = schema_for_type::().into(); + pub fn parameters(mut self) -> Self { + self.attr.input_schema = schema_for_type::(); self } pub fn parameters_value(mut self, schema: serde_json::Value) -> Self { diff --git a/crates/rmcp/src/handler/server/tool.rs b/crates/rmcp/src/handler/server/tool.rs index fa24a1c2..21e7e1a2 100644 --- a/crates/rmcp/src/handler/server/tool.rs +++ b/crates/rmcp/src/handler/server/tool.rs @@ -9,7 +9,7 @@ use serde::de::DeserializeOwned; use super::common::{AsRequestContext, FromContextPart}; pub use super::{ - common::{Extension, RequestId, cached_schema_for_type, schema_for_output, schema_for_type}, + common::{Extension, RequestId, schema_for_output, schema_for_type}, router::tool::{ToolRoute, ToolRouter}, }; use crate::{ diff --git a/crates/rmcp/src/model/tool.rs b/crates/rmcp/src/model/tool.rs index 4bab7efd..814d10ae 100644 --- a/crates/rmcp/src/model/tool.rs +++ b/crates/rmcp/src/model/tool.rs @@ -178,7 +178,7 @@ impl Tool { /// Set the input schema using a type that implements JsonSchema pub fn with_input_schema(mut self) -> Self { - self.input_schema = crate::handler::server::tool::cached_schema_for_type::(); + self.input_schema = crate::handler::server::tool::schema_for_type::(); self } diff --git a/crates/rmcp/tests/test_json_schema_detection.rs b/crates/rmcp/tests/test_json_schema_detection.rs index 89dd8586..af587319 100644 --- a/crates/rmcp/tests/test_json_schema_detection.rs +++ b/crates/rmcp/tests/test_json_schema_detection.rs @@ -55,7 +55,7 @@ impl TestServer { } /// Tool with explicit output_schema attribute - should have output schema - #[tool(name = "explicit-schema", output_schema = rmcp::handler::server::tool::cached_schema_for_type::())] + #[tool(name = "explicit-schema", output_schema = rmcp::handler::server::tool::schema_for_type::())] pub async fn explicit_schema(&self) -> Result { Ok("test".to_string()) }