Skip to content
77 changes: 76 additions & 1 deletion client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ type Client struct {
serverCapabilities mcp.ServerCapabilities
protocolVersion string
samplingHandler SamplingHandler
elicitationHandler ElicitationHandler
}

type ClientOption func(*Client)
Expand All @@ -44,6 +45,14 @@ func WithSamplingHandler(handler SamplingHandler) ClientOption {
}
}

// WithElicitationHandler sets the elicitation handler for the client.
// When set, the client will declare elicitation capability during initialization.
func WithElicitationHandler(handler ElicitationHandler) ClientOption {
return func(c *Client) {
c.elicitationHandler = handler
}
}

// WithSession assumes a MCP Session has already been initialized
func WithSession() ClientOption {
return func(c *Client) {
Expand Down Expand Up @@ -174,6 +183,10 @@ func (c *Client) Initialize(
if c.samplingHandler != nil {
capabilities.Sampling = &struct{}{}
}
// Add elicitation capability if handler is configured
if c.elicitationHandler != nil {
capabilities.Elicitation = &struct{}{}
}

// Ensure we send a params object with all required fields
params := struct {
Expand Down Expand Up @@ -458,11 +471,15 @@ func (c *Client) Complete(
}

// handleIncomingRequest processes incoming requests from the server.
// This is the main entry point for server-to-client requests like sampling.
// This is the main entry point for server-to-client requests like sampling and elicitation.
func (c *Client) handleIncomingRequest(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) {
switch request.Method {
case string(mcp.MethodSamplingCreateMessage):
return c.handleSamplingRequestTransport(ctx, request)
case string(mcp.MethodElicitationCreate):
return c.handleElicitationRequestTransport(ctx, request)
case string(mcp.MethodPing):
return c.handlePingRequestTransport(ctx, request)
default:
return nil, fmt.Errorf("unsupported request method: %s", request.Method)
}
Expand Down Expand Up @@ -515,6 +532,64 @@ func (c *Client) handleSamplingRequestTransport(ctx context.Context, request tra

return response, nil
}

// handleElicitationRequestTransport handles elicitation requests at the transport level.
func (c *Client) handleElicitationRequestTransport(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) {
if c.elicitationHandler == nil {
return nil, fmt.Errorf("no elicitation handler configured")
}

// Parse the request parameters
var params mcp.ElicitationParams
if request.Params != nil {
paramsBytes, err := json.Marshal(request.Params)
if err != nil {
return nil, fmt.Errorf("failed to marshal params: %w", err)
}
if err := json.Unmarshal(paramsBytes, &params); err != nil {
return nil, fmt.Errorf("failed to unmarshal params: %w", err)
}
}

// Create the MCP request
mcpRequest := mcp.ElicitationRequest{
Request: mcp.Request{
Method: string(mcp.MethodElicitationCreate),
},
Params: params,
}

// Call the elicitation handler
result, err := c.elicitationHandler.Elicit(ctx, mcpRequest)
if err != nil {
return nil, err
}

// Marshal the result
resultBytes, err := json.Marshal(result)
if err != nil {
return nil, fmt.Errorf("failed to marshal result: %w", err)
}

// Create the transport response
response := &transport.JSONRPCResponse{
JSONRPC: mcp.JSONRPC_VERSION,
ID: request.ID,
Result: json.RawMessage(resultBytes),
}

return response, nil
}

func (c *Client) handlePingRequestTransport(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) {
b, _ := json.Marshal(&mcp.EmptyResult{})
return &transport.JSONRPCResponse{
JSONRPC: mcp.JSONRPC_VERSION,
ID: request.ID,
Result: b,
}, nil
}

func listByPage[T any](
ctx context.Context,
client *Client,
Expand Down
19 changes: 19 additions & 0 deletions client/elicitation.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package client

import (
"context"

"github.com/mark3labs/mcp-go/mcp"
)

// ElicitationHandler defines the interface for handling elicitation requests from servers.
// Clients can implement this interface to request additional information from users.
type ElicitationHandler interface {
// Elicit handles an elicitation request from the server and returns the user's response.
// The implementation should:
// 1. Present the request message to the user
// 2. Validate input against the requested schema
// 3. Allow the user to accept, decline, or cancel
// 4. Return the appropriate response
Elicit(ctx context.Context, request mcp.ElicitationRequest) (*mcp.ElicitationResult, error)
}
225 changes: 225 additions & 0 deletions client/elicitation_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
package client

import (
"context"
"encoding/json"
"fmt"
"testing"

"github.com/mark3labs/mcp-go/client/transport"
"github.com/mark3labs/mcp-go/mcp"
)

// mockElicitationHandler implements ElicitationHandler for testing
type mockElicitationHandler struct {
result *mcp.ElicitationResult
err error
}

func (m *mockElicitationHandler) Elicit(ctx context.Context, request mcp.ElicitationRequest) (*mcp.ElicitationResult, error) {
if m.err != nil {
return nil, m.err
}
return m.result, nil
}

func TestClient_HandleElicitationRequest(t *testing.T) {
tests := []struct {
name string
handler ElicitationHandler
expectedError string
}{
{
name: "no handler configured",
handler: nil,
expectedError: "no elicitation handler configured",
},
{
name: "successful elicitation - accept",
handler: &mockElicitationHandler{
result: &mcp.ElicitationResult{
ElicitationResponse: mcp.ElicitationResponse{
Action: mcp.ElicitationResponseActionAccept,
Content: map[string]any{
"name": "test-project",
"framework": "react",
},
},
},
},
},
{
name: "successful elicitation - decline",
handler: &mockElicitationHandler{
result: &mcp.ElicitationResult{
ElicitationResponse: mcp.ElicitationResponse{
Action: mcp.ElicitationResponseActionDecline,
},
},
},
},
{
name: "successful elicitation - cancel",
handler: &mockElicitationHandler{
result: &mcp.ElicitationResult{
ElicitationResponse: mcp.ElicitationResponse{
Action: mcp.ElicitationResponseActionCancel,
},
},
},
},
{
name: "handler returns error",
handler: &mockElicitationHandler{
err: fmt.Errorf("user interaction failed"),
},
expectedError: "user interaction failed",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client := &Client{elicitationHandler: tt.handler}

request := transport.JSONRPCRequest{
ID: mcp.NewRequestId(1),
Method: string(mcp.MethodElicitationCreate),
Params: map[string]any{
"message": "Please provide project details",
"requestedSchema": map[string]any{
"type": "object",
"properties": map[string]any{
"name": map[string]any{"type": "string"},
"framework": map[string]any{"type": "string"},
},
},
},
}

result, err := client.handleElicitationRequestTransport(context.Background(), request)

if tt.expectedError != "" {
if err == nil {
t.Errorf("expected error %q, got nil", tt.expectedError)
} else if err.Error() != tt.expectedError {
t.Errorf("expected error %q, got %q", tt.expectedError, err.Error())
}
} else {
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if result == nil {
t.Error("expected result, got nil")
} else {
// Verify the response is properly formatted
var elicitationResult mcp.ElicitationResult
if err := json.Unmarshal(result.Result, &elicitationResult); err != nil {
t.Errorf("failed to unmarshal result: %v", err)
}
}
}
})
}
}

func TestWithElicitationHandler(t *testing.T) {
handler := &mockElicitationHandler{}
client := &Client{}

option := WithElicitationHandler(handler)
option(client)

if client.elicitationHandler != handler {
t.Error("elicitation handler not set correctly")
}
}

func TestClient_Initialize_WithElicitationHandler(t *testing.T) {
mockTransport := &mockElicitationTransport{
sendRequestFunc: func(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) {
// Verify that elicitation capability is included
// The client internally converts the typed params to a map for transport
// So we check if we're getting the initialize request
if request.Method != "initialize" {
t.Fatalf("expected initialize method, got %s", request.Method)
}

// Return successful initialization response
result := mcp.InitializeResult{
ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
ServerInfo: mcp.Implementation{
Name: "test-server",
Version: "1.0.0",
},
Capabilities: mcp.ServerCapabilities{},
}

resultBytes, _ := json.Marshal(result)
return &transport.JSONRPCResponse{
ID: request.ID,
Result: json.RawMessage(resultBytes),
}, nil
},
sendNotificationFunc: func(ctx context.Context, notification mcp.JSONRPCNotification) error {
return nil
},
}

handler := &mockElicitationHandler{}
client := NewClient(mockTransport, WithElicitationHandler(handler))

err := client.Start(context.Background())
if err != nil {
t.Fatalf("failed to start client: %v", err)
}

_, err = client.Initialize(context.Background(), mcp.InitializeRequest{
Params: mcp.InitializeParams{
ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
ClientInfo: mcp.Implementation{
Name: "test-client",
Version: "1.0.0",
},
Capabilities: mcp.ClientCapabilities{},
},
})

if err != nil {
t.Fatalf("failed to initialize: %v", err)
}
}

// mockElicitationTransport implements transport.Interface for testing
type mockElicitationTransport struct {
sendRequestFunc func(context.Context, transport.JSONRPCRequest) (*transport.JSONRPCResponse, error)
sendNotificationFunc func(context.Context, mcp.JSONRPCNotification) error
}

func (m *mockElicitationTransport) Start(ctx context.Context) error {
return nil
}

func (m *mockElicitationTransport) Close() error {
return nil
}

func (m *mockElicitationTransport) SendRequest(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) {
if m.sendRequestFunc != nil {
return m.sendRequestFunc(ctx, request)
}
return nil, nil
}

func (m *mockElicitationTransport) SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error {
if m.sendNotificationFunc != nil {
return m.sendNotificationFunc(ctx, notification)
}
return nil
}

func (m *mockElicitationTransport) SetNotificationHandler(handler func(mcp.JSONRPCNotification)) {
}

func (m *mockElicitationTransport) GetSessionId() string {
return "mock-session"
}
Loading