Skip to content

Commit da6f722

Browse files
authored
feat(mcp): Add SessionIdManagerResolver interface for request-based session management (#626)
* Add SessionIdManagerResolver * Rabbit AI Cr Comments * Nitpicks * Nitpicks * more tests
1 parent 5088c93 commit da6f722

File tree

2 files changed

+392
-18
lines changed

2 files changed

+392
-18
lines changed

server/streamable_http.go

Lines changed: 62 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -41,23 +41,41 @@ func WithEndpointPath(endpointPath string) StreamableHTTPOption {
4141
// as a new session. No session id returned to the client.
4242
// The default is false.
4343
//
44-
// Notice: This is a convenience method. It's identical to set WithSessionIdManager option
44+
// Note: This is a convenience method. It's identical to set WithSessionIdManager option
4545
// to StatelessSessionIdManager.
4646
func WithStateLess(stateLess bool) StreamableHTTPOption {
4747
return func(s *StreamableHTTPServer) {
4848
if stateLess {
49-
s.sessionIdManager = &StatelessSessionIdManager{}
49+
s.sessionIdManagerResolver = NewDefaultSessionIdManagerResolver(&StatelessSessionIdManager{})
5050
}
5151
}
5252
}
5353

5454
// WithSessionIdManager sets a custom session id generator for the server.
55-
// By default, the server will use SimpleStatefulSessionIdGenerator, which generates
56-
// session ids with uuid, and it's insecure.
57-
// Notice: it will override the WithStateLess option.
55+
// By default, the server uses InsecureStatefulSessionIdManager (UUID-based; insecure).
56+
// Note: Options are applied in order; the last one wins. If combined with
57+
// WithStateLess or WithSessionIdManagerResolver, whichever is applied last takes effect.
5858
func WithSessionIdManager(manager SessionIdManager) StreamableHTTPOption {
5959
return func(s *StreamableHTTPServer) {
60-
s.sessionIdManager = manager
60+
if manager == nil {
61+
s.sessionIdManagerResolver = NewDefaultSessionIdManagerResolver(&InsecureStatefulSessionIdManager{})
62+
return
63+
}
64+
s.sessionIdManagerResolver = NewDefaultSessionIdManagerResolver(manager)
65+
}
66+
}
67+
68+
// WithSessionIdManagerResolver sets a custom session id manager resolver for the server.
69+
// This allows for request-based session id management strategies.
70+
// Note: Options are applied in order; the last one wins. If combined with
71+
// WithStateLess or WithSessionIdManager, whichever is applied last takes effect.
72+
func WithSessionIdManagerResolver(resolver SessionIdManagerResolver) StreamableHTTPOption {
73+
return func(s *StreamableHTTPServer) {
74+
if resolver == nil {
75+
s.sessionIdManagerResolver = NewDefaultSessionIdManagerResolver(&InsecureStatefulSessionIdManager{})
76+
return
77+
}
78+
s.sessionIdManagerResolver = resolver
6179
}
6280
}
6381

@@ -150,13 +168,13 @@ type StreamableHTTPServer struct {
150168
httpServer *http.Server
151169
mu sync.RWMutex
152170

153-
endpointPath string
154-
contextFunc HTTPContextFunc
155-
sessionIdManager SessionIdManager
156-
listenHeartbeatInterval time.Duration
157-
logger util.Logger
158-
sessionLogLevels *sessionLogLevelsStore
159-
disableStreaming bool
171+
endpointPath string
172+
contextFunc HTTPContextFunc
173+
sessionIdManagerResolver SessionIdManagerResolver
174+
listenHeartbeatInterval time.Duration
175+
logger util.Logger
176+
sessionLogLevels *sessionLogLevelsStore
177+
disableStreaming bool
160178

161179
tlsCertFile string
162180
tlsKeyFile string
@@ -169,7 +187,7 @@ func NewStreamableHTTPServer(server *MCPServer, opts ...StreamableHTTPOption) *S
169187
sessionTools: newSessionToolsStore(),
170188
sessionLogLevels: newSessionLogLevelsStore(),
171189
endpointPath: "/mcp",
172-
sessionIdManager: &InsecureStatefulSessionIdManager{},
190+
sessionIdManagerResolver: NewDefaultSessionIdManagerResolver(&InsecureStatefulSessionIdManager{}),
173191
logger: util.DefaultLogger(),
174192
sessionResources: newSessionResourcesStore(),
175193
sessionResourceTemplates: newSessionResourceTemplatesStore(),
@@ -307,14 +325,15 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request
307325
// The session is ephemeral. Its life is the same as the request. It's only created
308326
// for interaction with the mcp server.
309327
var sessionID string
328+
sessionIdManager := s.sessionIdManagerResolver.ResolveSessionIdManager(r)
310329
if isInitializeRequest {
311330
// generate a new one for initialize request
312-
sessionID = s.sessionIdManager.Generate()
331+
sessionID = sessionIdManager.Generate()
313332
} else {
314333
// Get session ID from header.
315334
// Stateful servers need the client to carry the session ID.
316335
sessionID = r.Header.Get(HeaderKeySessionID)
317-
isTerminated, err := s.sessionIdManager.Validate(sessionID)
336+
isTerminated, err := sessionIdManager.Validate(sessionID)
318337
if err != nil {
319338
http.Error(w, "Invalid session ID", http.StatusBadRequest)
320339
return
@@ -611,7 +630,8 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request)
611630
func (s *StreamableHTTPServer) handleDelete(w http.ResponseWriter, r *http.Request) {
612631
// delete request terminate the session
613632
sessionID := r.Header.Get(HeaderKeySessionID)
614-
notAllowed, err := s.sessionIdManager.Terminate(sessionID)
633+
sessionIdManager := s.sessionIdManagerResolver.ResolveSessionIdManager(r)
634+
notAllowed, err := sessionIdManager.Terminate(sessionID)
615635
if err != nil {
616636
http.Error(w, fmt.Sprintf("Session termination failed: %v", err), http.StatusInternalServerError)
617637
return
@@ -659,7 +679,8 @@ func (s *StreamableHTTPServer) handleSamplingResponse(w http.ResponseWriter, r *
659679
}
660680

661681
// Validate session
662-
isTerminated, err := s.sessionIdManager.Validate(sessionID)
682+
sessionIdManager := s.sessionIdManagerResolver.ResolveSessionIdManager(r)
683+
isTerminated, err := sessionIdManager.Validate(sessionID)
663684
if err != nil {
664685
http.Error(w, "Invalid session ID", http.StatusBadRequest)
665686
return err
@@ -1128,6 +1149,11 @@ var _ SessionWithElicitation = (*streamableHttpSession)(nil)
11281149

11291150
// --- session id manager ---
11301151

1152+
// SessionIdManagerResolver resolves a SessionIdManager based on the HTTP request
1153+
type SessionIdManagerResolver interface {
1154+
ResolveSessionIdManager(r *http.Request) SessionIdManager
1155+
}
1156+
11311157
type SessionIdManager interface {
11321158
Generate() string
11331159
// Validate checks if a session ID is valid and not terminated.
@@ -1140,6 +1166,24 @@ type SessionIdManager interface {
11401166
Terminate(sessionID string) (isNotAllowed bool, err error)
11411167
}
11421168

1169+
// DefaultSessionIdManagerResolver is a simple resolver that returns the same SessionIdManager for all requests
1170+
type DefaultSessionIdManagerResolver struct {
1171+
manager SessionIdManager
1172+
}
1173+
1174+
// NewDefaultSessionIdManagerResolver creates a new DefaultSessionIdManagerResolver with the given SessionIdManager
1175+
func NewDefaultSessionIdManagerResolver(manager SessionIdManager) *DefaultSessionIdManagerResolver {
1176+
if manager == nil {
1177+
manager = &InsecureStatefulSessionIdManager{}
1178+
}
1179+
return &DefaultSessionIdManagerResolver{manager: manager}
1180+
}
1181+
1182+
// ResolveSessionIdManager returns the configured SessionIdManager for all requests
1183+
func (r *DefaultSessionIdManagerResolver) ResolveSessionIdManager(_ *http.Request) SessionIdManager {
1184+
return r.manager
1185+
}
1186+
11431187
// StatelessSessionIdManager does nothing, which means it has no session management, which is stateless.
11441188
type StatelessSessionIdManager struct{}
11451189

0 commit comments

Comments
 (0)