Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
527 changes: 527 additions & 0 deletions docs/architecture.md

Large diffs are not rendered by default.

5 changes: 4 additions & 1 deletion docs/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ tasks.matching {

knit {
rootDir = project.rootDir
files = files(project.rootDir.resolve("README.md"))
files = files(
project.rootDir.resolve("README.md"),
project.rootDir.resolve("docs/architecture.md"),
)
defaultLineSeparator = "\n"
siteRoot = "" // Disable site root validation
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
package io.modelcontextprotocol.kotlin.sdk.client

import io.kotest.matchers.shouldBe
import io.kotest.matchers.string.shouldContain
import io.modelcontextprotocol.kotlin.sdk.server.Server
import io.modelcontextprotocol.kotlin.sdk.server.ServerOptions
import io.modelcontextprotocol.kotlin.sdk.server.ServerSession
import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport
import io.modelcontextprotocol.kotlin.sdk.shared.InMemoryTransport
import io.modelcontextprotocol.kotlin.sdk.shared.TransportSendOptions
import io.modelcontextprotocol.kotlin.sdk.types.BooleanSchema
import io.modelcontextprotocol.kotlin.sdk.types.CallToolResult
import io.modelcontextprotocol.kotlin.sdk.types.ClientCapabilities
import io.modelcontextprotocol.kotlin.sdk.types.CreateMessageRequest
import io.modelcontextprotocol.kotlin.sdk.types.CreateMessageResult
Expand Down Expand Up @@ -47,6 +49,7 @@ import io.modelcontextprotocol.kotlin.sdk.types.ToolSchema
import io.modelcontextprotocol.kotlin.sdk.types.UntitledMultiSelectEnumSchema
import io.modelcontextprotocol.kotlin.sdk.types.UntitledSingleSelectEnumSchema
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.TimeoutCancellationException
import kotlinx.coroutines.cancel
import kotlinx.coroutines.delay
Expand All @@ -67,6 +70,8 @@ import kotlin.test.assertIs
import kotlin.test.assertNull
import kotlin.test.assertTrue
import kotlin.test.fail
import kotlin.time.Duration.Companion.milliseconds
import kotlin.time.Duration.Companion.seconds

class ClientTest {
@Test
Expand Down Expand Up @@ -387,7 +392,10 @@ class ClientTest {
val ex = assertFailsWith<IllegalStateException> {
client.listPrompts()
}
assertTrue(ex.message?.contains("Server does not support prompts") == true)
ex.message shouldContain "Server does not support prompts"

client.close()
server.close()
}

@Test
Expand Down Expand Up @@ -441,6 +449,10 @@ class ClientTest {
clientWithoutCapability.sendRootsListChanged()
}
assertTrue(ex.message?.startsWith("Client does not support") == true)

client.close()
clientWithoutCapability.close()
server.close()
}

@Test
Expand Down Expand Up @@ -499,6 +511,9 @@ class ClientTest {
serverSession.sendToolListChanged()
}
assertTrue(ex.message?.contains("Server does not support notifying of tool list changes") == true)

client.close()
server.close()
}

@Test
Expand Down Expand Up @@ -543,7 +558,7 @@ class ClientTest {
// Simulate delay
def.complete(Unit)
try {
delay(1000)
delay(1.seconds)
} catch (e: CancellationException) {
defTimeOut.complete(Unit)
throw e
Expand All @@ -565,6 +580,9 @@ class ClientTest {
runCatching { job.cancel("Cancelled by test") }
defCancel.await()
defTimeOut.await()

client.close()
server.close()
}

@Test
Expand Down Expand Up @@ -602,26 +620,20 @@ class ClientTest {

val serverSession = serverSessionResult.await()
serverSession.setRequestHandler<ListResourcesRequest>(Method.Defined.ResourcesList) { _, _ ->
// Simulate a delayed response
// Wait ~100ms unless canceled
try {
withTimeout(100L) {
// Just delay here, if timeout is 0 on the client side, this won't return in time
delay(100)
}
} catch (_: Exception) {
// If aborted, just rethrow or return early
}
delay(100.milliseconds)
ListResourcesResult(resources = emptyList())
}

// Request with 1 msec timeout should fail immediately
val ex = assertFailsWith<Exception> {
withTimeout(1) {
withTimeout(1.milliseconds) {
client.listResources()
}
}
assertTrue(ex is TimeoutCancellationException)

client.close()
server.close()
}

@Test
Expand Down Expand Up @@ -654,88 +666,43 @@ class ClientTest {
}

@Test
fun `JSONRPCRequest with ToolsList method and default params returns list of tools`() = runTest {
val serverOptions = ServerOptions(
capabilities = ServerCapabilities(
tools = ServerCapabilities.Tools(null),
fun `listTools returns list of tools`() = runTest {
val expectedTools = listOf(
Tool(
name = "testTool",
title = "testTool title",
description = "testTool description",
annotations = null,
inputSchema = ToolSchema(),
outputSchema = null,
),
)

val server = Server(
Implementation(name = "test server", version = "1.0"),
serverOptions,
)
ServerOptions(capabilities = ServerCapabilities(tools = ServerCapabilities.Tools(null))),
) {
addTool(expectedTools[0]) { CallToolResult(content = emptyList()) }
}

val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair()

val client = Client(
clientInfo = Implementation(name = "test client", version = "1.0"),
options = ClientOptions(
capabilities = ClientCapabilities(sampling = EmptyJsonObject),
),
)

var receivedMessage: JSONRPCMessage? = null
clientTransport.onMessage { msg ->
receivedMessage = msg
}

val serverSessionResult = CompletableDeferred<ServerSession>()

listOf(
launch {
client.connect(clientTransport)
println("Client connected")
},
launch {
serverSessionResult.complete(server.createSession(serverTransport))
println("Server connected")
},
).joinAll()

val serverSession = serverSessionResult.await()

serverSession.setRequestHandler<InitializeRequest>(Method.Defined.Initialize) { _, _ ->
InitializeResult(
protocolVersion = LATEST_PROTOCOL_VERSION,
capabilities = ServerCapabilities(
resources = ServerCapabilities.Resources(null, null),
tools = ServerCapabilities.Tools(null),
),
serverInfo = Implementation(name = "test", version = "1.0"),
)
}

val serverListToolsResult = ListToolsResult(
tools = listOf(
Tool(
name = "testTool",
title = "testTool title",
description = "testTool description",
annotations = null,
inputSchema = ToolSchema(),
outputSchema = null,
),
),
nextCursor = null,
options = ClientOptions(),
)

serverSession.setRequestHandler<ListToolsRequest>(Method.Defined.ToolsList) { _, _ ->
serverListToolsResult
}
connectClientServer(client, server, clientTransport, serverTransport)

val serverCapabilities = client.serverCapabilities
assertEquals(ServerCapabilities.Tools(null), serverCapabilities?.tools)
val result = client.listTools()

val request = JSONRPCRequest(
method = Method.Defined.ToolsList.value,
)
clientTransport.send(request)
assertEquals(expectedTools.size, result.tools.size)
assertEquals(expectedTools[0].name, result.tools[0].name)
assertEquals(expectedTools[0].title, result.tools[0].title)
assertEquals(expectedTools[0].description, result.tools[0].description)

assertIs<JSONRPCResponse>(receivedMessage)
val receivedAsResponse = receivedMessage as JSONRPCResponse
assertEquals(request.id, receivedAsResponse.id)
assertEquals(request.jsonrpc, receivedAsResponse.jsonrpc)
assertEquals(serverListToolsResult, receivedAsResponse.result)
client.close()
server.close()
}

@Test
Expand Down Expand Up @@ -785,6 +752,9 @@ class ClientTest {
val listRootsResult = serverSession.listRoots()

assertEquals(listRootsResult.roots, clientRoots)

client.close()
server.close()
}

@Test
Expand Down Expand Up @@ -923,6 +893,9 @@ class ClientTest {
rootListChangedNotificationReceived,
"Notification should be sent when sendRootsListChanged is called",
)

client.close()
server.close()
}

@Test
Expand Down Expand Up @@ -972,6 +945,9 @@ class ClientTest {
"Client does not support elicitation (required for elicitation/create)",
exception.message,
)

client.close()
server.close()
}

@Test
Expand Down Expand Up @@ -1039,7 +1015,7 @@ class ClientTest {
)
}

delay(100)
delay(100.milliseconds)

// Only warning and error should be received
assertEquals(2, receivedMessages.size, "Should receive only 2 messages (warning and error)")
Expand All @@ -1052,6 +1028,9 @@ class ClientTest {
"Received message with level ${message.params.level} should have severity >= $minLevel",
)
}

client.close()
server.close()
}

@Test
Expand Down Expand Up @@ -1118,6 +1097,9 @@ class ClientTest {

assertEquals(elicitationResultAction, result.action)
assertEquals(elicitationResultContent, result.content)

client.close()
server.close()
}

@Test
Expand Down Expand Up @@ -1290,4 +1272,18 @@ class ClientTest {

client to serverSessionResult.await()
}

private suspend fun CoroutineScope.connectClientServer(
client: Client,
server: Server,
clientTransport: InMemoryTransport,
serverTransport: InMemoryTransport,
): ServerSession {
val result = CompletableDeferred<ServerSession>()
listOf(
launch { client.connect(clientTransport) },
launch { result.complete(server.createSession(serverTransport)) },
).joinAll()
return result.await()
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package io.modelcontextprotocol.kotlin.sdk.server

import io.kotest.matchers.collections.shouldHaveSize
import io.kotest.matchers.shouldBe
import io.modelcontextprotocol.kotlin.sdk.shared.InMemoryTransport
import io.modelcontextprotocol.kotlin.sdk.types.ClientCapabilities
import io.modelcontextprotocol.kotlin.sdk.types.Implementation
Expand All @@ -23,7 +25,6 @@ import java.util.concurrent.CopyOnWriteArrayList
import kotlin.test.assertEquals
import kotlin.test.assertNotNull
import kotlin.test.assertNull
import kotlin.test.assertTrue

class ServerSessionInitializeTest {

Expand Down Expand Up @@ -94,13 +95,16 @@ class ServerSessionInitializeTest {

secondResponseDone.await()

assertEquals(2, responses.size)
assertTrue(responses[0] is JSONRPCResponse, "First response should be success")
assertTrue(responses[1] is JSONRPCError, "Second response should be error")
assertEquals(RPCError.ErrorCode.INVALID_REQUEST, (responses[1] as JSONRPCError).error.code)
responses.size shouldBe 2
// With concurrent dispatch, responses may arrive in any order
val successResponses = responses.filterIsInstance<JSONRPCResponse>()
val errorResponses = responses.filterIsInstance<JSONRPCError>()
successResponses.shouldHaveSize(1)
errorResponses.shouldHaveSize(1)
errorResponses[0].error.code shouldBe RPCError.ErrorCode.INVALID_REQUEST

// Capabilities still reflect the first client, not overwritten
assertEquals("first-client", session.clientVersion?.name)
session.clientVersion?.name shouldBe "first-client"
}

@Test
Expand Down
Loading
Loading