diff --git a/.dockerignore b/.dockerignore index ea1e38e..8124151 100644 --- a/.dockerignore +++ b/.dockerignore @@ -5,7 +5,6 @@ target/maven-* target/ROOT target/test-classes/ target/war -target/duplicate-finder-result.xml target/jacoco.exec target/*.original .idea diff --git a/README.md b/README.md index a5d58a1..711450e 100644 --- a/README.md +++ b/README.md @@ -588,3 +588,7 @@ docker run --rm -it -p 8080:8080 \ ``` You can configure the agent using environment variables or Java system properties, see for details. + +## Enable MCP + +MCP capabilities can be enabled by setting the `spring.ai.mcp.server.enabled` to `true`. This will enable the MCP server and expose the MCP endpoints. The MCP endpoint is currently hardcoded to `/mcp/message` and can be tried out by running e.g. `npx @modelcontextprotocol/inspector` and connect to http://localhost:8080/mcp/message using Streamable HTTP. Spring AI MCP Server Auto Configuration is currently not supported. diff --git a/pom.xml b/pom.xml index 57ed7b8..3373897 100644 --- a/pom.xml +++ b/pom.xml @@ -383,6 +383,32 @@ 5.0.1 + + org.springframework.ai + spring-ai-mcp + 1.0.2 + + + + + + io.modelcontextprotocol.sdk + mcp + 0.12.1 + + + + org.springframework + spring-test + + org.junit.jupiter diff --git a/src/main/java/ca/uhn/fhir/jpa/starter/AppProperties.java b/src/main/java/ca/uhn/fhir/jpa/starter/AppProperties.java index f1913e1..f10f693 100644 --- a/src/main/java/ca/uhn/fhir/jpa/starter/AppProperties.java +++ b/src/main/java/ca/uhn/fhir/jpa/starter/AppProperties.java @@ -21,9 +21,9 @@ import java.util.Set; import static org.apache.commons.lang3.ObjectUtils.defaultIfNull; +@EnableConfigurationProperties @ConfigurationProperties(prefix = "hapi.fhir") @Configuration -@EnableConfigurationProperties public class AppProperties { private final Set auto_version_reference_at_paths = new HashSet<>(); diff --git a/src/main/java/ca/uhn/fhir/jpa/starter/mcp/CallToolResultFactory.java b/src/main/java/ca/uhn/fhir/jpa/starter/mcp/CallToolResultFactory.java new file mode 100644 index 0000000..72b744b --- /dev/null +++ b/src/main/java/ca/uhn/fhir/jpa/starter/mcp/CallToolResultFactory.java @@ -0,0 +1,40 @@ +package ca.uhn.fhir.jpa.starter.mcp; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.McpSchema; +import org.springframework.stereotype.Component; + +import java.util.Map; + +@Component +public class CallToolResultFactory { + + public static McpSchema.CallToolResult success( + String resourceType, Interaction interaction, String response, int status) { + Map payload = Map.of( + "resourceType", resourceType, + "interaction", interaction, + "response", response, + "status", status); + + ObjectMapper objectMapper = new ObjectMapper(); + String jacksonData; + try { + jacksonData = objectMapper.writeValueAsString(payload); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + + return McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent(jacksonData)) + .build(); + } + + public static McpSchema.CallToolResult failure(String message) { + return McpSchema.CallToolResult.builder() + .isError(true) + .addTextContent(message) + .build(); + } +} diff --git a/src/main/java/ca/uhn/fhir/jpa/starter/mcp/Interaction.java b/src/main/java/ca/uhn/fhir/jpa/starter/mcp/Interaction.java new file mode 100644 index 0000000..e377d33 --- /dev/null +++ b/src/main/java/ca/uhn/fhir/jpa/starter/mcp/Interaction.java @@ -0,0 +1,34 @@ +package ca.uhn.fhir.jpa.starter.mcp; + +import ca.uhn.fhir.rest.api.RequestTypeEnum; + +public enum Interaction { + CALL_CDS_HOOK("call-cds-hook"), + SEARCH("search"), + READ("read"), + CREATE("create"), + UPDATE("update"), + DELETE("delete"), + PATCH("patch"), + TRANSACTION("transaction"); + + private final String name; + + Interaction(String name) { + this.name = name; + } + + public String getName() { + return name; + } + + public RequestTypeEnum asRequestType() { + return switch (this) { + case SEARCH, READ -> RequestTypeEnum.GET; + case CREATE, TRANSACTION, CALL_CDS_HOOK -> RequestTypeEnum.POST; + case UPDATE -> RequestTypeEnum.PUT; + case DELETE -> RequestTypeEnum.DELETE; + case PATCH -> RequestTypeEnum.PATCH; + }; + } +} diff --git a/src/main/java/ca/uhn/fhir/jpa/starter/mcp/McpServerConfig.java b/src/main/java/ca/uhn/fhir/jpa/starter/mcp/McpServerConfig.java new file mode 100644 index 0000000..9dec687 --- /dev/null +++ b/src/main/java/ca/uhn/fhir/jpa/starter/mcp/McpServerConfig.java @@ -0,0 +1,81 @@ +package ca.uhn.fhir.jpa.starter.mcp; + +import ca.uhn.fhir.context.FhirContext; +import ca.uhn.fhir.rest.server.McpBridge; +import ca.uhn.fhir.rest.server.McpCdsBridge; +import ca.uhn.fhir.rest.server.McpFhirBridge; +import ca.uhn.fhir.rest.server.RestfulServer; +import ca.uhn.hapi.fhir.cdshooks.api.ICdsServiceRegistry; +import ca.uhn.hapi.fhir.cdshooks.module.CdsHooksObjectMapperFactory; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.McpSyncServer; +import io.modelcontextprotocol.server.transport.HttpServletStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.boot.web.servlet.ServletRegistrationBean; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; + +import java.util.List; + +// https://mcp-cn.ssshooter.com/sdk/java/mcp-server#sse-servlet +// https://www.baeldung.com/spring-ai-model-context-protocol-mcp +// https://github.com/spring-projects/spring-ai-examples/blob/main/model-context-protocol/weather/manual-webflux-server/src/main/java/org/springframework/ai/mcp/sample/server/McpServerConfig.java +// https://github.com/spring-projects/spring-ai-examples/tree/main/model-context-protocol/weather/starter-stdio-server/src/main/java/org/springframework/ai/mcp/sample/server +// https://github.com/spring-projects/spring-ai-examples/blob/main/model-context-protocol/sampling/mcp-weather-webmvc-server/src/main/java/org/springframework/ai/mcp/sample/server/WeatherService.java +// https://docs.spring.io/spring-ai/reference/api/mcp/mcp-server-boot-starter-docs.html +@Configuration +@ConditionalOnProperty( + prefix = "spring.ai.mcp.server", + name = {"enabled"}, + havingValue = "true") +public class McpServerConfig { + + private static final String SSE_ENDPOINT = "/sse"; + private static final String SSE_MESSAGE_ENDPOINT = "/mcp/message"; + + @Bean + public McpSyncServer syncServer( + List mcpBridges, McpStreamableServerTransportProvider transportProvider) { + return McpServer.sync(transportProvider) + .tools(mcpBridges.stream() + .flatMap(bridge -> bridge.generateTools().stream()) + .toList()) + .build(); + } + + @Bean + public McpFhirBridge mcpFhirBridge(RestfulServer restfulServer) { + return new McpFhirBridge(restfulServer); + } + + @Bean + @ConditionalOnProperty( + prefix = "hapi.fhir.cr", + name = {"enabled"}, + havingValue = "true") + public McpCdsBridge mcpCdsBridge(FhirContext fhirContext, ICdsServiceRegistry cdsServiceRegistry) { + + return new McpCdsBridge( + fhirContext, cdsServiceRegistry, new CdsHooksObjectMapperFactory(fhirContext).newMapper()); + } + + @Bean + public HttpServletStreamableServerTransportProvider servletSseServerTransportProvider( + /*McpServerProperties properties*/ ) { + + return HttpServletStreamableServerTransportProvider.builder() + .disallowDelete(false) + .mcpEndpoint(SSE_MESSAGE_ENDPOINT) + .objectMapper(new ObjectMapper()) + // .contextExtractor((serverRequest, context) -> context) + .build(); + } + + @Bean + public ServletRegistrationBean customServletBean( + HttpServletStreamableServerTransportProvider transportProvider /*, McpServerProperties properties*/) { + return new ServletRegistrationBean<>(transportProvider, SSE_MESSAGE_ENDPOINT, SSE_ENDPOINT); + } +} diff --git a/src/main/java/ca/uhn/fhir/jpa/starter/mcp/RequestBuilder.java b/src/main/java/ca/uhn/fhir/jpa/starter/mcp/RequestBuilder.java new file mode 100644 index 0000000..985e711 --- /dev/null +++ b/src/main/java/ca/uhn/fhir/jpa/starter/mcp/RequestBuilder.java @@ -0,0 +1,120 @@ +package ca.uhn.fhir.jpa.starter.mcp; + +import ca.uhn.fhir.context.FhirContext; +import com.google.gson.Gson; +import org.hl7.fhir.instance.model.api.IBaseResource; +import org.springframework.mock.web.MockHttpServletRequest; + +import java.nio.charset.StandardCharsets; +import java.util.Map; + +public class RequestBuilder { + + private final FhirContext fhirContext; + private final String resourceType; + private final Interaction interaction; + private final Map config; + /** + * Constructs a RequestBuilder for a specific FHIR interaction. + * + * @param fhirContext the FHIR context + * @param contextMap a map containing configuration parameters, including 'resourceType' + * @param interaction the type of interaction (e.g., SEARCH, READ, CREATE, etc.) + */ + public RequestBuilder(FhirContext fhirContext, Map contextMap, Interaction interaction) { + this.config = contextMap; + if (interaction == Interaction.TRANSACTION) this.resourceType = ""; + else if (contextMap.get("resourceType") instanceof String rt && !rt.isBlank()) this.resourceType = rt; + else throw new IllegalArgumentException("Missing or invalid 'resourceType' in contextMap"); + + this.interaction = interaction; + this.fhirContext = fhirContext; + } + + public MockHttpServletRequest buildRequest() { + String basePath = "/" + resourceType; + String method; + MockHttpServletRequest req; + + switch (interaction) { + case SEARCH -> { + method = "GET"; + req = new MockHttpServletRequest(method, basePath); + Map sp = null; + if (config.get("query") instanceof Map q) { + sp = q; + } else if (config.get("searchParams") instanceof Map s) { + sp = s; + } + if (sp != null) { + sp.forEach((k, v) -> req.addParameter(k.toString(), v.toString())); + } + } + case READ -> { + method = "GET"; + String id = requireString(); + req = new MockHttpServletRequest(method, basePath + "/" + id); + } + case CREATE, TRANSACTION -> { + method = "POST"; + req = new MockHttpServletRequest(method, basePath); + applyResourceBody(req); + } + case UPDATE -> { + method = "PUT"; + String id = requireString(); + req = new MockHttpServletRequest(method, basePath + "/" + id); + applyResourceBody(req); + } + case DELETE -> { + method = "DELETE"; + String id = requireString(); + req = new MockHttpServletRequest(method, basePath + "/" + id); + } + case PATCH -> { + method = "PATCH"; + String id = requireString(); + req = new MockHttpServletRequest(method, basePath + "/" + id); + applyPatchBody(req); + } + default -> throw new IllegalArgumentException("Unsupported interaction: " + interaction); + } + + req.setContentType("application/fhir+json"); + req.addHeader("Accept", "application/fhir+json"); + return req; + } + + private void applyResourceBody(MockHttpServletRequest req) { + Object resourceObj = config.get("resource"); + String json; + if (resourceObj instanceof Map) json = new Gson().toJson(resourceObj, Map.class); + else if (resourceObj instanceof String) json = resourceObj.toString(); + else throw new IllegalArgumentException("Unsupported resource body type: " + resourceObj.getClass()); + req.setContent(json.getBytes(StandardCharsets.UTF_8)); + } + + private void applyPatchBody(MockHttpServletRequest req) { + Object patchBody = config.get("resource"); + if (patchBody == null) { + throw new IllegalArgumentException("Missing 'resource' for patch interaction"); + } + String content; + if (patchBody instanceof String s) { + content = s; + } else if (patchBody instanceof IBaseResource r) { + content = fhirContext.newJsonParser().encodeResourceToString(r); + } else { + throw new IllegalArgumentException("Unsupported patch body type: " + patchBody.getClass()); + } + req.setContent(content.getBytes(StandardCharsets.UTF_8)); + } + + private String requireString() { + Object val = config.get("id"); + if (!(val instanceof String s) || s.isBlank()) { + throw new IllegalArgumentException("Missing or invalid '" + "id" + "'"); + } + return s; + } +} diff --git a/src/main/java/ca/uhn/fhir/jpa/starter/mcp/ToolFactory.java b/src/main/java/ca/uhn/fhir/jpa/starter/mcp/ToolFactory.java new file mode 100644 index 0000000..32cd005 --- /dev/null +++ b/src/main/java/ca/uhn/fhir/jpa/starter/mcp/ToolFactory.java @@ -0,0 +1,337 @@ +package ca.uhn.fhir.jpa.starter.mcp; + +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.Tool; + +public class ToolFactory { + + private static final String READ_FHIR_RESOURCE_SCHEMA = + """ + { + "type": "object", + "properties": { + "resourceType": { + "type": "string", + "description": "type of the resource to read" + }, + "id": { + "type": "string", + "description": "id of the resource to read" + } + } + + } + """; + + private static final String CREATE_FHIR_RESOURCE_SCHEMA = + """ + { + "type": "object", + "properties": { + "resourceType": { + "type": "string", + "description": "Type of the resource to create" + }, + "resource": { + "type": "object", + "description": "Resource content in JSON format" + }, + "headers": { + "type": "object", + "description": "Headers for create request.\\nAvailable headers: If-None-Exist header for conditional create where the value is search param string.\\nFor example: {\\"If-None-Exist\\": \\"active=false\\"}" + } + }, + "required": ["resourceType", "resource"] + } + """; + + private static final String UPDATE_FHIR_RESOURCE_SCHEMA = + """ + { + "type": "object", + "properties": { + "resourceType": { + "type": "string", + "description": "Type of the resource to update" + }, + "id": { + "type": "string", + "description": "ID of the resource to update" + }, + "resource": { + "type": "object", + "description": "Updated resource content in JSON format" + } + }, + "required": ["resourceType", "id", "resource"] + } + """; + + private static final String CONDITIONAL_UPDATE_FHIR_RESOURCE_SCHEMA = + """ + { + "type": "object", + "properties": { + "resourceType": { + "type": "string", + "description": "Type of the resource to update" + }, + "resource": { + "type": "object", + "description": "Updated resource content in JSON format" + }, + "query": { + "type": "string", + "description": "Query string with search params separate by \\",\\". For example: \\"_id=pt-1,name=ivan\\". Uses for conditional update." + }, + "headers": { + "type": "object", + "description": "Headers for create request.\\nAvailable headers: If-None-Match header for conditional update where the value is ETag.\\nFor example: {\\"If-None-Match\\": \\"12345\\"}" + } + }, + "required": ["resourceType", "resource"] + } + """; + + private static final String CONDITIONAL_PATCH_FHIR_RESOURCE_SCHEMA = + """ + { + "type": "object", + "properties": { + "resourceType": { + "type": "string", + "description": "Type of the resource to patch" + }, + "resource": { + "type": "object", + "description": "Resource content to patch in JSON format" + }, + "query": { + "type": "string", + "description": "Query string with search params separate by \\",\\". For example: \\"_id=pt-1,name=ivan\\". Uses for conditional patch." + }, + "headers": { + "type": "object", + "description": "Headers for create request.\\nAvailable headers: If-None-Match header for conditional patch where the value is ETag.\\nFor example: {\\"If-None-Match\\": \\"12345\\"}" + } + }, + "required": ["resourceType", "resource"] + } + """; + + private static final String PATCH_FHIR_RESOURCE_SCHEMA = + """ + { + "type": "object", + "properties": { + "resourceType": { + "type": "string", + "description": "Type of the resource to patch" + }, + "id": { + "type": "string", + "description": "ID of the resource to patch" + }, + "resource": { + "type": "object", + "description": "Resource content to patch in JSON format" + } + }, + "required": ["resourceType", "id", "resource"] + } + """; + + private static final String DELETE_FHIR_RESOURCE_SCHEMA = + """ + { + "type": "object", + "properties": { + "resourceType": { + "type": "string", + "description": "Type of the resource to delete" + }, + "id": { + "type": "string", + "description": "ID of the resource to delete" + } + }, + "required": ["resourceType", "id"] + } + """; + + private static final String SEARCH_FHIR_RESOURCES_SCHEMA = + """ + { + "type": "object", + "properties": { + "resourceType": { + "type": "string", + "description": "Type of the resource to search" + }, + "query": { + "type": "string", + "description": "Query string with search params separate by \\",\\". For example: \\"_id=pt-1,name=ivan\\"" + } + }, + "required": ["resourceType", "query"] + } + """; + + private static final String CREATE_FHIR_TRANSACTION_SCHEMA = + """ + { + "type": "object", + "properties": { + "resourceType": { + "type": "string", + "description": "A Bundle resource type with type 'transaction' containing multiple FHIR resources" + }, + "resource": { + "type": "object", + "description": "A FHIR Bundle Resource content in JSON format" + } + }, + "required": ["resourceType", "resource"] + } + """; + + // TODO Add a tool for the CDS Hooks discovery endpoint + // Alternatively, should each service be a separate tool? + + // TODO Add other fields from https://cds-hooks.hl7.org/STU2/#http-request-1 + // TODO Context here is for the patient-view hook, https://cds-hooks.hl7.org/hooks/STU1/patient-view.html#context + private static final String CALL_CDS_HOOK_SCHEMA_2_0_1 = + """ + { + "type": "object", + "properties": { + "service": { + "type": "string", + "description": "The CDS Service to call." + }, + "hook": { + "type": "string", + "description": "The hook that triggered this CDS Service call." + }, + "hookInstance": { + "type": "string", + "description": "A universally unique identifier (UUID) for this particular hook call." + }, + "hookContext": { + "type": "object", + "description": "Hook-specific contextual data that the CDS service will need.", + "properties": { + "userId": { + "type": "string", + "description": "The id of the current user. Must be in the format [ResourceType]/[id]." + }, + "patientId": { + "type": "string", + "description": "The FHIR Patient.id of the current patient in context" + }, + "encounterId": { + "type": "string", + "description": "The FHIR Encounter.id of the current encounter in context." + } + } + }, + "prefetch": { + "type": "object", + "description": "Additional data to prefetch for the CDS service call." + } + }, + "required": ["service", "hook", "hookInstance", "hookContext"] + } + """; + + public static Tool readFhirResource() throws JsonProcessingException { + return new Tool.Builder() + .name("read-fhir-resource") + .description("Read an individual FHIR resource") + .inputSchema(mapper.readValue(READ_FHIR_RESOURCE_SCHEMA, McpSchema.JsonSchema.class)) + .build(); + } + + public static Tool createFhirResource() throws JsonProcessingException { + return new Tool.Builder() + .name("create-fhir-resource") + .description("Create a new FHIR resource") + .inputSchema(mapper.readValue(CREATE_FHIR_RESOURCE_SCHEMA, McpSchema.JsonSchema.class)) + .build(); + } + + public static Tool updateFhirResource() throws JsonProcessingException { + return new Tool.Builder() + .name("update-fhir-resource") + .description("Update an existing FHIR resource") + .inputSchema(mapper.readValue(UPDATE_FHIR_RESOURCE_SCHEMA, McpSchema.JsonSchema.class)) + .build(); + } + + public static Tool conditionalUpdateFhirResource() throws JsonProcessingException { + return new Tool.Builder() + .name("conditional-update-fhir-resource") + .description("Conditional update an existing FHIR resource") + .inputSchema(mapper.readValue(CONDITIONAL_UPDATE_FHIR_RESOURCE_SCHEMA, McpSchema.JsonSchema.class)) + .build(); + } + + public static Tool conditionalPatchFhirResource() throws JsonProcessingException { + return new Tool.Builder() + .name("conditional-patch-fhir-resource") + .description("Conditional patch an existing FHIR resource") + .inputSchema(mapper.readValue(CONDITIONAL_PATCH_FHIR_RESOURCE_SCHEMA, McpSchema.JsonSchema.class)) + .build(); + } + + public static Tool patchFhirResource() throws JsonProcessingException { + return new Tool.Builder() + .name("patch-fhir-resource") + .description("Patch an existing FHIR resource") + .inputSchema(mapper.readValue(PATCH_FHIR_RESOURCE_SCHEMA, McpSchema.JsonSchema.class)) + .build(); + } + + public static Tool deleteFhirResource() throws JsonProcessingException { + return new Tool.Builder() + .name("delete-fhir-resource") + .description("Delete an existing FHIR resource") + .inputSchema(mapper.readValue(DELETE_FHIR_RESOURCE_SCHEMA, McpSchema.JsonSchema.class)) + .build(); + } + + public static Tool searchFhirResources() throws JsonProcessingException { + return new Tool.Builder() + .name("search-fhir-resources") + .description("Search an existing FHIR resources") + .inputSchema(mapper.readValue(SEARCH_FHIR_RESOURCES_SCHEMA, McpSchema.JsonSchema.class)) + .build(); + } + + public static Tool createFhirTransaction() throws JsonProcessingException { + return new Tool.Builder() + .name("create-fhir-transaction") + .description("Create a FHIR transaction") + .inputSchema(mapper.readValue(CREATE_FHIR_TRANSACTION_SCHEMA, McpSchema.JsonSchema.class)) + .build(); + } + + public static Tool callCdsHook() throws JsonProcessingException { + return new Tool.Builder() + .name("call-cds-hook") + .description("Call a CDS Hook") + .inputSchema(mapper.readValue(CALL_CDS_HOOK_SCHEMA_2_0_1, McpSchema.JsonSchema.class)) + .build(); + } + + public static final ObjectMapper mapper = new ObjectMapper() + .enable(JsonParser.Feature.ALLOW_COMMENTS) + .enable(JsonParser.Feature.ALLOW_SINGLE_QUOTES) + .enable(JsonParser.Feature.ALLOW_UNQUOTED_FIELD_NAMES) + .enable(JsonParser.Feature.INCLUDE_SOURCE_IN_LOCATION) + .configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); +} diff --git a/src/main/java/ca/uhn/fhir/rest/server/McpBridge.java b/src/main/java/ca/uhn/fhir/rest/server/McpBridge.java new file mode 100644 index 0000000..b568ec2 --- /dev/null +++ b/src/main/java/ca/uhn/fhir/rest/server/McpBridge.java @@ -0,0 +1,9 @@ +package ca.uhn.fhir.rest.server; + +import io.modelcontextprotocol.server.McpServerFeatures; + +import java.util.List; + +public interface McpBridge { + List generateTools(); +} diff --git a/src/main/java/ca/uhn/fhir/rest/server/McpCdsBridge.java b/src/main/java/ca/uhn/fhir/rest/server/McpCdsBridge.java new file mode 100644 index 0000000..1fe1de7 --- /dev/null +++ b/src/main/java/ca/uhn/fhir/rest/server/McpCdsBridge.java @@ -0,0 +1,117 @@ +package ca.uhn.fhir.rest.server; + +import ca.uhn.fhir.context.FhirContext; +import ca.uhn.fhir.jpa.starter.cdshooks.CdsHooksRequest; +import ca.uhn.fhir.jpa.starter.mcp.Interaction; +import ca.uhn.fhir.jpa.starter.mcp.ToolFactory; +import ca.uhn.fhir.rest.api.server.cdshooks.CdsServiceRequestContextJson; +import ca.uhn.hapi.fhir.cdshooks.api.ICdsServiceRegistry; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.gson.Gson; +import io.modelcontextprotocol.server.McpServerFeatures; +import io.modelcontextprotocol.spec.McpSchema; +import org.jetbrains.annotations.NotNull; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.stereotype.Component; + +import java.util.List; +import java.util.Map; + +@Component +public class McpCdsBridge implements McpBridge { + + private static final Logger logger = LoggerFactory.getLogger(McpCdsBridge.class); + + private final ICdsServiceRegistry cdsServiceRegistry; + private final ObjectMapper objectMapper; + private final FhirContext fhirContext; + + public McpCdsBridge(FhirContext fhirContext, ICdsServiceRegistry cdsServiceRegistry, ObjectMapper objectMapper) { + this.fhirContext = fhirContext; + this.cdsServiceRegistry = cdsServiceRegistry; + this.objectMapper = objectMapper; + } + + public List generateTools() { + + try { + return List.of(new McpServerFeatures.SyncToolSpecification.Builder() + .tool(ToolFactory.callCdsHook()) + .callHandler((exchange, request) -> getToolResult(request, Interaction.CALL_CDS_HOOK)) + .build()); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + private McpSchema.CallToolResult getToolResult(McpSchema.CallToolRequest contextMap, Interaction interaction) { + + if (interaction != Interaction.CALL_CDS_HOOK) + throw new UnsupportedOperationException("Unsupported interaction: " + interaction); + + var cdsInvocation = constructCdsHooksRequest(contextMap); + var serviceResponseJson = cdsServiceRegistry.callService( + contextMap.arguments().get("service").toString(), cdsInvocation); + + final String content; + try { + content = objectMapper.writeValueAsString(serviceResponseJson); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + return McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent(content)) + .build(); + } + + private @NotNull CdsHooksRequest constructCdsHooksRequest(McpSchema.CallToolRequest callToolRequest) { + + // TODO Build up CDS Hooks request JSON from contextMap + var contextMap = callToolRequest.arguments(); + var request = new CdsHooksRequest(); + request.setHook(contextMap.get("hook").toString()); + request.setHookInstance(contextMap.get("hookInstance").toString()); + + // Context + var context = new CdsServiceRequestContextJson(); + Object hookContextObj = contextMap.get("hookContext"); + if (hookContextObj instanceof Map hookContext) { + if (hookContext.containsKey("userId")) { + context.put("userId", String.valueOf(hookContext.get("userId"))); + } + if (hookContext.containsKey("patientId")) { + context.put("patientId", String.valueOf(hookContext.get("patientId"))); + } + if (hookContext.containsKey("encounterId")) { + context.put("encounterId", String.valueOf(hookContext.get("encounterId"))); + } + } + request.setContext(context); + + // Prefetch + if (contextMap.containsKey("prefetch")) { + var prefetch = contextMap.get("prefetch"); + if (prefetch instanceof Map) { + @SuppressWarnings("unchecked") + var prefetchMap = (Map) prefetch; + for (Map.Entry entry : prefetchMap.entrySet()) { + var key = entry.getKey(); + var value = entry.getValue(); + + // Object is a String -> Object map + // Use a standard JSON library to convert it + var resource = fhirContext.newJsonParser().parseResource(new Gson().toJson(value)); + request.addPrefetch(key, resource); + } + } else { + logger.warn( + "Prefetch object is not a Map: {}", + prefetch == null ? "null" : prefetch.getClass().getName()); + } + } + + return request; + } +} diff --git a/src/main/java/ca/uhn/fhir/rest/server/McpFhirBridge.java b/src/main/java/ca/uhn/fhir/rest/server/McpFhirBridge.java new file mode 100644 index 0000000..13c5c3a --- /dev/null +++ b/src/main/java/ca/uhn/fhir/rest/server/McpFhirBridge.java @@ -0,0 +1,101 @@ +package ca.uhn.fhir.rest.server; + +import ca.uhn.fhir.context.FhirContext; +import ca.uhn.fhir.jpa.starter.mcp.CallToolResultFactory; +import ca.uhn.fhir.jpa.starter.mcp.Interaction; +import ca.uhn.fhir.jpa.starter.mcp.RequestBuilder; +import ca.uhn.fhir.jpa.starter.mcp.ToolFactory; +import com.fasterxml.jackson.core.JsonProcessingException; +import io.modelcontextprotocol.server.McpServerFeatures; +import io.modelcontextprotocol.spec.McpSchema; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.stereotype.Component; + +import java.util.List; + +@Component +public class McpFhirBridge implements McpBridge { + + private static final Logger logger = LoggerFactory.getLogger(McpFhirBridge.class); + + private final RestfulServer restfulServer; + private final FhirContext fhirContext; + + public McpFhirBridge(RestfulServer restfulServer) { + this.restfulServer = restfulServer; + this.fhirContext = restfulServer.getFhirContext(); + } + + public List generateTools() { + + try { + return List.of( + new McpServerFeatures.SyncToolSpecification.Builder() + .tool(ToolFactory.createFhirResource()) + .callHandler((exchange, request) -> getToolResult(request, Interaction.CREATE)) + .build(), + new McpServerFeatures.SyncToolSpecification.Builder() + .tool(ToolFactory.readFhirResource()) + .callHandler((exchange, request) -> getToolResult(request, Interaction.READ)) + .build(), + new McpServerFeatures.SyncToolSpecification.Builder() + .tool(ToolFactory.updateFhirResource()) + .callHandler((exchange, request) -> getToolResult(request, Interaction.UPDATE)) + .build(), + new McpServerFeatures.SyncToolSpecification.Builder() + .tool(ToolFactory.deleteFhirResource()) + .callHandler((exchange, request) -> getToolResult(request, Interaction.DELETE)) + .build(), + new McpServerFeatures.SyncToolSpecification.Builder() + .tool(ToolFactory.conditionalPatchFhirResource()) + .callHandler((exchange, request) -> getToolResult(request, Interaction.PATCH)) + .build(), + new McpServerFeatures.SyncToolSpecification.Builder() + .tool(ToolFactory.searchFhirResources()) + .callHandler((exchange, request) -> getToolResult(request, Interaction.SEARCH)) + .build(), + new McpServerFeatures.SyncToolSpecification.Builder() + .tool(ToolFactory.conditionalUpdateFhirResource()) + .callHandler((exchange, request) -> getToolResult(request, Interaction.UPDATE)) + .build(), + new McpServerFeatures.SyncToolSpecification.Builder() + .tool(ToolFactory.patchFhirResource()) + .callHandler((exchange, request) -> getToolResult(request, Interaction.PATCH)) + .build(), + new McpServerFeatures.SyncToolSpecification.Builder() + .tool(ToolFactory.createFhirTransaction()) + .callHandler((exchange, request) -> getToolResult(request, Interaction.TRANSACTION)) + .build()); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + private McpSchema.CallToolResult getToolResult(McpSchema.CallToolRequest contextMap, Interaction interaction) { + + var response = new MockHttpServletResponse(); + var request = new RequestBuilder(fhirContext, contextMap.arguments(), interaction).buildRequest(); + + try { + restfulServer.handleRequest(interaction.asRequestType(), request, response); + var status = response.getStatus(); + var body = response.getContentAsString(); + + if (status >= 200 && status < 300) { + if (body.isBlank()) { + return CallToolResultFactory.failure("Empty successful response for " + interaction); + } + + return CallToolResultFactory.success( + contextMap.arguments().get("resourceType").toString(), interaction, body, status); + } else { + return CallToolResultFactory.failure(String.format("FHIR server error %d: %s", status, body)); + } + } catch (Exception e) { + logger.error(e.getMessage(), e); + return CallToolResultFactory.failure("Unexpected error: " + e.getMessage()); + } + } +} diff --git a/src/main/resources/application.yaml b/src/main/resources/application.yaml index b769a62..662608a 100644 --- a/src/main/resources/application.yaml +++ b/src/main/resources/application.yaml @@ -36,7 +36,65 @@ management: export: enabled: true spring: + ai: + # Run e.g. `npx @modelcontextprotocol/inspector` and connect to http://localhost:8080/mcp/message using Streamable HTTP + +# Add the following to the MCP server settings file in e.g. cursor or claude (Desktop applications) for local debugging: +# cursor: +# { +# "mcpServers": { +# "hapi": { +# "url": "http://localhost:8080/mcp/message" +# } +# } +# } +# or claude: +# { +# "mcpServers": { +# "hapi": { +# "command": "npx", +# "args": [ +# "mcp-remote@latest", +# "http://localhost:8080/mcp/message" +# ] +# } +# } +# } + + mcp: + server: + # Will be enabled once spring-ai-starter-mcp-server is added as dependency +# name: FHIR MCP Server +# version: 1.0.0 +# type: SYNC +# instructions: "This server provides access to a FHIR RESTful API. You can use it to query FHIR resources, perform operations, and retrieve data in a structured format." +# sse-message-endpoint: /mcp/message +# capabilities: +# tool: true +# resource: true +# prompt: true +# completion: true +# stdio: false + enabled: true + + #endpoint: /mcp + + #schema: + # fhir-enabled: true + # fhir: + # base-url: http://localhost:8080/fhir + + #query: + # prompt: + # template: | + # You are a FHIR assistant. Translate the following question into a valid FHIR RESTful API query: + # "{{query}}" + # Use the provided FHIR schema: + # {{schema}} + #base-url: /api/v1 + main: + allow-bean-definition-overriding: false allow-circular-references: true flyway: enabled: false diff --git a/src/test/java/ca/uhn/fhir/jpa/starter/McpTests.java b/src/test/java/ca/uhn/fhir/jpa/starter/McpTests.java new file mode 100644 index 0000000..19a8a5f --- /dev/null +++ b/src/test/java/ca/uhn/fhir/jpa/starter/McpTests.java @@ -0,0 +1,78 @@ +package ca.uhn.fhir.jpa.starter; + +import ca.uhn.fhir.context.FhirContext; +import ca.uhn.fhir.jpa.searchparam.config.NicknameServiceConfig; +import ca.uhn.fhir.jpa.starter.mcp.ToolFactory; +import ca.uhn.fhir.util.BundleUtil; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.google.gson.Gson; +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.spec.McpSchema; +import org.hl7.fhir.r4.model.Bundle; +import org.junit.jupiter.api.Test; +import org.opencds.cqf.fhir.cr.hapi.config.RepositoryConfig; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.boot.test.web.server.LocalServerPort; + +import java.time.Duration; +import java.util.LinkedHashMap; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; + + +@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT, classes = {Application.class, NicknameServiceConfig.class, RepositoryConfig.class}, properties = {"spring.datasource.url=jdbc:h2:mem:dbr4", "hapi.fhir.fhir_version=r4", "hibernate.search.enabled=true", "spring.ai.mcp.server.enabled=true",}) +public class McpTests { + + @LocalServerPort + private int port; + + @Test + public void mcpTests() throws JsonProcessingException { + + var fhirContext = FhirContext.forR4(); + + var transport = HttpClientStreamableHttpTransport.builder("http://localhost:" + port).endpoint("/mcp/message").build(); + var client = McpClient.sync(transport).requestTimeout(Duration.ofSeconds(10)).capabilities(McpSchema.ClientCapabilities.builder().roots(true) // Enable roots capability + .sampling().build()).build(); + var initializationResult = client.initialize(); + + var tools = client.listTools().tools(); + assertThat(tools).isNotEmpty(); + + var searchToolName = ToolFactory.searchFhirResources().name(); + var createToolName = ToolFactory.createFhirResource().name(); + + assertThat(tools.stream().filter(tool -> tool.name().equals(searchToolName)).findFirst().get()).isNotNull(); + assertThat(tools.stream().filter(tool -> tool.name().equals(createToolName)).findFirst().get()).isNotNull(); + + + var createMcpRequest = new McpSchema.CallToolRequest.Builder().arguments(Map.of("operation", "create", "resourceType", "Patient", "resource", """ + { + "resourceType": "Patient", + "id": "example", + "identifier": [ + { + "system": "urn:something", + "value": "uncleScrooge" + } + ] + }""")).name(createToolName).build(); + assertThat(client.callTool(createMcpRequest).isError()).isFalse(); + + var searchMcpRequest = new McpSchema.CallToolRequest.Builder().arguments(Map.of("operation", "search", "resourceType", "Patient", "query", "identifier=urn:something|uncleScrooge")).name(searchToolName).build(); + + var searchResult = client.callTool(searchMcpRequest); + assertThat(searchResult.isError()).isFalse(); + assertThat(searchResult.content().size()).isEqualTo(1); + + var content = ((McpSchema.TextContent) searchResult.content().get(0)); + var embeddedResponseBundle = new Gson().fromJson(content.text(), LinkedHashMap.class).get("response"); + var responseBundle = fhirContext.newJsonParser().parseResource(Bundle.class, embeddedResponseBundle.toString()); + var entries = BundleUtil.toListOfEntries(fhirContext, responseBundle); + assertThat(entries.size()).isEqualTo(1); + + client.closeGracefully(); + } +} diff --git a/src/test/resources/mcp/hello-patient-request.json b/src/test/resources/mcp/hello-patient-request.json new file mode 100644 index 0000000..2a26967 --- /dev/null +++ b/src/test/resources/mcp/hello-patient-request.json @@ -0,0 +1,18 @@ +{ + "hook": "patient-view", + "hookInstance": "8d5a3a2e-6d8b-4f7c-bb2d-2f1b8cf1d7a1", + "context": { + "userId": "Practitioner/123", + "patientId": "123", + "encounterId": "456" + }, + "prefetch": { + "item1": { + "resourceType": "Patient", + "gender": "male", + "birthDate": "1989-10-23", + "id": "123", + "active": true + } + } +} diff --git a/src/test/resources/mcp/mcp-hookContext-object.json b/src/test/resources/mcp/mcp-hookContext-object.json new file mode 100644 index 0000000..b4648e6 --- /dev/null +++ b/src/test/resources/mcp/mcp-hookContext-object.json @@ -0,0 +1,5 @@ +{ + "userId": "Practitioner/123", + "patientId": "123", + "encounterId": "456" +} diff --git a/src/test/resources/mcp/mpc-prefetch-object.json b/src/test/resources/mcp/mpc-prefetch-object.json new file mode 100644 index 0000000..4b30088 --- /dev/null +++ b/src/test/resources/mcp/mpc-prefetch-object.json @@ -0,0 +1,9 @@ +{ + "item1": { + "resourceType": "Patient", + "gender": "male", + "birthDate": "1989-10-23", + "id": "123", + "active": true + } +} diff --git a/src/test/resources/mcp/plandefinition-hello-patient.xml b/src/test/resources/mcp/plandefinition-hello-patient.xml new file mode 100644 index 0000000..a9c4aee --- /dev/null +++ b/src/test/resources/mcp/plandefinition-hello-patient.xml @@ -0,0 +1,42 @@ + + + + + + + + + + + + + + + + <type> + <coding> + <system value="http://terminology.hl7.org/CodeSystem/plan-definition-type" /> + <code value="eca-rule" /> + <display value="ECA Rule" /> + </coding> + </type> + <status value="draft" /> + <experimental value="true" /> + <date value="2024-09-28" /> + <description value="Demo PlanDefinition for Hello Patient" /> + <action> + <title value="Hello, Patient!" /> + <description value="Please state the nature of the medical emergency." /> + <trigger> + <type value="named-event" /> + <name value="patient-view" /> + </trigger> + <condition> + <kind value="applicability" /> + <expression> + <language value="text/cql" /> + <expression value="true" /> + </expression> + </condition> + </action> +</PlanDefinition>