Skip to content

Commit

Permalink
Cleanup API
Browse files Browse the repository at this point in the history
  • Loading branch information
ignatov committed Dec 23, 2024
1 parent be6e473 commit a36c050
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ public open class Server(
* Called when the server connection is closing.
* Invokes [onCloseCallback] if set.
*/
override fun onclose() {
override fun onClose() {
logger.info { "Server connection closing" }
onCloseCallback?.invoke()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public open class ProtocolOptions(
* Note that this DOES NOT affect checking of _local_ side capabilities, as it is
* considered a logic error to mis-specify those.
*
* Currently this defaults to false, for backwards compatibility with SDK versions
* Currently, this defaults to false, for backwards compatibility with SDK versions
* that did not advertise capabilities correctly. In future, this will default to true.
*/
public var enforceStrictCapabilities: Boolean = false,
Expand Down Expand Up @@ -114,14 +114,14 @@ public abstract class Protocol(
*
* This is invoked when close() is called as well.
*/
public open fun onclose() {}
public open fun onClose() {}

/**
* Callback for when an error occurs.
*
* Note that errors are not necessarily fatal they are used for reporting any kind of exceptional condition out of band.
*/
public open fun onerror(error: Throwable) {}
public open fun onError(error: Throwable) {}

/**
* A handler to invoke for any request types that do not have their own handler installed.
Expand All @@ -136,7 +136,7 @@ public abstract class Protocol(

init {
setNotificationHandler<ProgressNotification>(Method.Defined.NotificationsProgress) { notification ->
this.onProgress(notification)
onProgress(notification)
COMPLETED
}

Expand All @@ -153,11 +153,11 @@ public abstract class Protocol(
public open suspend fun connect(transport: Transport) {
this.transport = transport
transport.onClose = {
this.onClose()
doClose()
}

transport.onError = {
this.onError(it)
onError(it)
}

transport.onMessage = { message ->
Expand All @@ -172,22 +172,18 @@ public abstract class Protocol(
return transport.start()
}

private fun onClose() {
private fun doClose() {
responseHandlers.clear()
progressHandlers.clear()
transport = null
onclose()
onClose()

val error = McpError(ErrorCode.Defined.ConnectionClosed.code, "Connection closed")
for (handler in responseHandlers.values) {
handler(null, error)
}
}

private fun onError(error: Throwable) {
onerror(error)
}

private suspend fun onNotification(notification: JSONRPCNotification) {
LOGGER.trace { "Received notification: ${notification.method}" }
val function = notificationHandlers[notification.method]
Expand All @@ -208,7 +204,7 @@ public abstract class Protocol(

private suspend fun onRequest(request: JSONRPCRequest) {
LOGGER.trace { "Received request: ${request.method} (id: ${request.id})" }
val handler = requestHandlers[request.method] ?: this.fallbackRequestHandler
val handler = requestHandlers[request.method] ?: fallbackRequestHandler

if (handler === null) {
LOGGER.trace { "No handler found for request: ${request.method}" }
Expand Down Expand Up @@ -260,13 +256,13 @@ public abstract class Protocol(
val total = notification.total
val progressToken = notification.progressToken

val handler = this.progressHandlers[progressToken]
val handler = progressHandlers[progressToken]
if (handler == null) {
val error = Error(
"Received a progress notification for an unknown token: ${McpJson.encodeToString(notification)}",
)
LOGGER.error { error.message }
this.onError(error)
onError(error)
return
}

Expand All @@ -275,14 +271,14 @@ public abstract class Protocol(

private fun onResponse(response: JSONRPCResponse?, error: JSONRPCError?) {
val messageId = response?.id
val handler = this.responseHandlers[messageId]
val handler = responseHandlers[messageId]
if (handler == null) {
this.onError(Error("Received a response for an unknown message ID: ${McpJson.encodeToString(response)}"))
onError(Error("Received a response for an unknown message ID: ${McpJson.encodeToString(response)}"))
return
}

this.responseHandlers.remove(messageId)
this.progressHandlers.remove(messageId)
responseHandlers.remove(messageId)
progressHandlers.remove(messageId)
if (response != null) {
handler(response, null)
} else {
Expand Down Expand Up @@ -469,7 +465,7 @@ public abstract class Protocol(
* Note that this will replace any previous notification handler for the same method.
*/
public fun <T : Notification> setNotificationHandler(method: Method, handler: (notification: T) -> Deferred<Unit>) {
this.notificationHandlers[method.value] = {
notificationHandlers[method.value] = {
@Suppress("UNCHECKED_CAST")
handler(it.fromJSON() as T)
}
Expand All @@ -479,6 +475,6 @@ public abstract class Protocol(
* Removes the notification handler for the given method.
*/
public fun removeNotificationHandler(method: Method) {
this.notificationHandlers.remove(method.value)
notificationHandlers.remove(method.value)
}
}

0 comments on commit a36c050

Please sign in to comment.