Skip to content

Commit e5dff5b

Browse files
committed
[session-resources] add support to streamable_http
1 parent b27df1d commit e5dff5b

File tree

2 files changed

+50
-8
lines changed

2 files changed

+50
-8
lines changed

server/streamable_http.go

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ func WithTLSCert(certFile, keyFile string) StreamableHTTPOption {
129129
type StreamableHTTPServer struct {
130130
server *MCPServer
131131
sessionTools *sessionToolsStore
132+
sessionResources *sessionResourcesStore
132133
sessionRequestIDs sync.Map // sessionId --> last requestID(*atomic.Int64)
133134
activeSessions sync.Map // sessionId --> *streamableHttpSession (for sampling responses)
134135

@@ -155,6 +156,7 @@ func NewStreamableHTTPServer(server *MCPServer, opts ...StreamableHTTPOption) *S
155156
endpointPath: "/mcp",
156157
sessionIdManager: &InsecureStatefulSessionIdManager{},
157158
logger: util.DefaultLogger(),
159+
sessionResources: newSessionResourcesStore(),
158160
}
159161

160162
// Apply all options
@@ -299,7 +301,7 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request
299301
}
300302
}
301303

302-
session := newStreamableHttpSession(sessionID, s.sessionTools, s.sessionLogLevels)
304+
session := newStreamableHttpSession(sessionID, s.sessionTools, s.sessionResources, s.sessionLogLevels)
303305

304306
// Set the client context before handling the message
305307
ctx := s.server.WithContext(r.Context(), session)
@@ -410,7 +412,7 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request)
410412
sessionID = uuid.New().String()
411413
}
412414

413-
session := newStreamableHttpSession(sessionID, s.sessionTools, s.sessionLogLevels)
415+
session := newStreamableHttpSession(sessionID, s.sessionTools, s.sessionResources, s.sessionLogLevels)
414416
if err := s.server.RegisterSession(r.Context(), session); err != nil {
415417
http.Error(w, fmt.Sprintf("Session registration failed: %v", err), http.StatusBadRequest)
416418
return
@@ -716,6 +718,35 @@ func (s *sessionLogLevelsStore) delete(sessionID string) {
716718
delete(s.logs, sessionID)
717719
}
718720

721+
type sessionResourcesStore struct {
722+
mu sync.RWMutex
723+
resources map[string]map[string]ServerResource // sessionID -> resourceURI -> resource
724+
}
725+
726+
func newSessionResourcesStore() *sessionResourcesStore {
727+
return &sessionResourcesStore{
728+
resources: make(map[string]map[string]ServerResource),
729+
}
730+
}
731+
732+
func (s *sessionResourcesStore) get(sessionID string) map[string]ServerResource {
733+
s.mu.RLock()
734+
defer s.mu.RUnlock()
735+
return s.resources[sessionID]
736+
}
737+
738+
func (s *sessionResourcesStore) set(sessionID string, resources map[string]ServerResource) {
739+
s.mu.Lock()
740+
defer s.mu.Unlock()
741+
s.resources[sessionID] = resources
742+
}
743+
744+
func (s *sessionResourcesStore) delete(sessionID string) {
745+
s.mu.Lock()
746+
defer s.mu.Unlock()
747+
delete(s.resources, sessionID)
748+
}
749+
719750
type sessionToolsStore struct {
720751
mu sync.RWMutex
721752
tools map[string]map[string]ServerTool // sessionID -> toolName -> tool
@@ -765,6 +796,7 @@ type streamableHttpSession struct {
765796
sessionID string
766797
notificationChannel chan mcp.JSONRPCNotification // server -> client notifications
767798
tools *sessionToolsStore
799+
resources *sessionResourcesStore
768800
upgradeToSSE atomic.Bool
769801
logLevels *sessionLogLevelsStore
770802

@@ -774,11 +806,12 @@ type streamableHttpSession struct {
774806
requestIDCounter atomic.Int64 // for generating unique request IDs
775807
}
776808

777-
func newStreamableHttpSession(sessionID string, toolStore *sessionToolsStore, levels *sessionLogLevelsStore) *streamableHttpSession {
809+
func newStreamableHttpSession(sessionID string, toolStore *sessionToolsStore, resourcesStore *sessionResourcesStore, levels *sessionLogLevelsStore) *streamableHttpSession {
778810
s := &streamableHttpSession{
779811
sessionID: sessionID,
780812
notificationChannel: make(chan mcp.JSONRPCNotification, 100),
781813
tools: toolStore,
814+
resources: resourcesStore,
782815
logLevels: levels,
783816
samplingRequestChan: make(chan samplingRequestItem, 10),
784817
}
@@ -821,9 +854,18 @@ func (s *streamableHttpSession) SetSessionTools(tools map[string]ServerTool) {
821854
s.tools.set(s.sessionID, tools)
822855
}
823856

857+
func (s *streamableHttpSession) GetSessionResources() map[string]ServerResource {
858+
return s.resources.get(s.sessionID)
859+
}
860+
861+
func (s *streamableHttpSession) SetSessionResources(resources map[string]ServerResource) {
862+
s.resources.set(s.sessionID, resources)
863+
}
864+
824865
var (
825-
_ SessionWithTools = (*streamableHttpSession)(nil)
826-
_ SessionWithLogging = (*streamableHttpSession)(nil)
866+
_ SessionWithTools = (*streamableHttpSession)(nil)
867+
_ SessionWithResources = (*streamableHttpSession)(nil)
868+
_ SessionWithLogging = (*streamableHttpSession)(nil)
827869
)
828870

829871
func (s *streamableHttpSession) UpgradeToSSEWhenReceiveNotification() {

server/streamable_http_sampling_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ func TestStreamableHTTPServer_SamplingBasic(t *testing.T) {
2626

2727
// Test session creation and interface implementation
2828
sessionID := "test-session"
29-
session := newStreamableHttpSession(sessionID, httpServer.sessionTools, httpServer.sessionLogLevels)
29+
session := newStreamableHttpSession(sessionID, httpServer.sessionTools, httpServer.sessionResources, httpServer.sessionLogLevels)
3030

3131
// Verify it implements SessionWithSampling
3232
_, ok := any(session).(SessionWithSampling)
@@ -139,7 +139,7 @@ func TestStreamableHTTPServer_SamplingInterface(t *testing.T) {
139139

140140
// Create a session
141141
sessionID := "test-session"
142-
session := newStreamableHttpSession(sessionID, httpServer.sessionTools, httpServer.sessionLogLevels)
142+
session := newStreamableHttpSession(sessionID, httpServer.sessionTools, httpServer.sessionResources, httpServer.sessionLogLevels)
143143

144144
// Verify it implements SessionWithSampling
145145
_, ok := any(session).(SessionWithSampling)
@@ -178,7 +178,7 @@ func TestStreamableHTTPServer_SamplingInterface(t *testing.T) {
178178
// TestStreamableHTTPServer_SamplingQueueFull tests queue overflow scenarios
179179
func TestStreamableHTTPServer_SamplingQueueFull(t *testing.T) {
180180
sessionID := "test-session"
181-
session := newStreamableHttpSession(sessionID, nil, nil)
181+
session := newStreamableHttpSession(sessionID, nil, nil, nil)
182182

183183
// Fill the sampling request queue
184184
for i := 0; i < cap(session.samplingRequestChan); i++ {

0 commit comments

Comments
 (0)