Skip to content

Commit b90e674

Browse files
authored
Add support for a base path (#30)
Closes #23
1 parent 19d84bf commit b90e674

File tree

8 files changed

+253
-33
lines changed

8 files changed

+253
-33
lines changed

integration/base_path_test.go

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
package integration
2+
3+
import (
4+
"net/http"
5+
"testing"
6+
7+
"github.com/stretchr/testify/assert"
8+
"github.com/stretchr/testify/require"
9+
)
10+
11+
func TestBasePathRouting(t *testing.T) {
12+
waitForDB(t)
13+
14+
startMCPFront(t, "config/config.base-path-test.json")
15+
waitForMCPFront(t)
16+
17+
initialContainers := getMCPContainers()
18+
t.Cleanup(func() {
19+
cleanupContainers(t, initialContainers)
20+
})
21+
22+
t.Run("health at root", func(t *testing.T) {
23+
resp, err := http.Get("http://localhost:8080/health")
24+
require.NoError(t, err)
25+
defer resp.Body.Close()
26+
27+
assert.Equal(t, http.StatusOK, resp.StatusCode)
28+
})
29+
30+
t.Run("health not under base path", func(t *testing.T) {
31+
resp, err := http.Get("http://localhost:8080/mcp-api/health")
32+
require.NoError(t, err)
33+
defer resp.Body.Close()
34+
35+
assert.Equal(t, http.StatusNotFound, resp.StatusCode)
36+
})
37+
38+
t.Run("MCP server at base path", func(t *testing.T) {
39+
req, err := http.NewRequest("GET", "http://localhost:8080/mcp-api/postgres/sse", nil)
40+
require.NoError(t, err)
41+
req.Header.Set("Authorization", "Bearer test-token")
42+
req.Header.Set("Accept", "text/event-stream")
43+
44+
client := &http.Client{}
45+
resp, err := client.Do(req)
46+
require.NoError(t, err)
47+
defer resp.Body.Close()
48+
49+
assert.Equal(t, http.StatusOK, resp.StatusCode)
50+
assert.Equal(t, "text/event-stream", resp.Header.Get("Content-Type"))
51+
})
52+
53+
t.Run("MCP server not at root", func(t *testing.T) {
54+
req, err := http.NewRequest("GET", "http://localhost:8080/postgres/sse", nil)
55+
require.NoError(t, err)
56+
req.Header.Set("Authorization", "Bearer test-token")
57+
req.Header.Set("Accept", "text/event-stream")
58+
59+
client := &http.Client{}
60+
resp, err := client.Do(req)
61+
require.NoError(t, err)
62+
defer resp.Body.Close()
63+
64+
assert.Equal(t, http.StatusNotFound, resp.StatusCode)
65+
})
66+
67+
t.Run("full MCP connection", func(t *testing.T) {
68+
client := NewMCPSSEClient("http://localhost:8080/mcp-api")
69+
client.SetAuthToken("test-token")
70+
71+
err := client.ConnectToServer("postgres")
72+
require.NoError(t, err)
73+
defer client.Close()
74+
75+
result, err := client.SendMCPRequest("tools/list", map[string]any{})
76+
require.NoError(t, err)
77+
assert.NotNil(t, result)
78+
})
79+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
{
2+
"version": "v0.0.1-DEV_EDITION_EXPECT_CHANGES",
3+
"proxy": {
4+
"baseURL": "http://localhost:8080/mcp-api",
5+
"addr": ":8080",
6+
"name": "mcp-front-base-path-test"
7+
},
8+
"mcpServers": {
9+
"postgres": {
10+
"transportType": "stdio",
11+
"command": "docker",
12+
"args": [
13+
"run",
14+
"-i",
15+
"--network",
16+
"host",
17+
"mcp/postgres",
18+
"postgresql://testuser:testpass@localhost:15432/testdb"
19+
],
20+
"serviceAuths": [
21+
{
22+
"type": "bearer",
23+
"tokens": ["test-token"]
24+
}
25+
]
26+
}
27+
}
28+
}

internal/config/load.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package config
33
import (
44
"encoding/json"
55
"fmt"
6+
"net/url"
67
"os"
78
"strings"
89

@@ -40,6 +41,11 @@ func Load(path string) (Config, error) {
4041
return Config{}, fmt.Errorf("parsing config: %w", err)
4142
}
4243

44+
// Extract base path from baseURL
45+
if err := extractBasePath(&config); err != nil {
46+
return Config{}, fmt.Errorf("extracting base path: %w", err)
47+
}
48+
4349
if err := ValidateConfig(&config); err != nil {
4450
return Config{}, fmt.Errorf("config validation failed: %w", err)
4551
}
@@ -208,3 +214,31 @@ func validateMCPServer(name string, server *MCPClientConfig) error {
208214

209215
return nil
210216
}
217+
218+
func extractBasePath(config *Config) error {
219+
u, err := url.Parse(config.Proxy.BaseURL)
220+
if err != nil {
221+
return fmt.Errorf("invalid baseURL: %w", err)
222+
}
223+
224+
basePath := u.Path
225+
if basePath == "" {
226+
basePath = "/"
227+
}
228+
229+
if !strings.HasPrefix(basePath, "/") {
230+
basePath = "/" + basePath
231+
}
232+
if len(basePath) > 1 && strings.HasSuffix(basePath, "/") {
233+
basePath = strings.TrimSuffix(basePath, "/")
234+
}
235+
236+
config.Proxy.BasePath = basePath
237+
238+
log.LogInfoWithFields("config", "Extracted base path from baseURL", map[string]any{
239+
"baseURL": config.Proxy.BaseURL,
240+
"basePath": basePath,
241+
})
242+
243+
return nil
244+
}

internal/config/load_test.go

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,3 +215,67 @@ func TestValidateConfig_SessionConfig(t *testing.T) {
215215
})
216216
}
217217
}
218+
219+
func TestExtractBasePath(t *testing.T) {
220+
tests := []struct {
221+
name string
222+
baseURL string
223+
expectedPath string
224+
expectError bool
225+
}{
226+
{
227+
name: "root_path",
228+
baseURL: "http://localhost:8080",
229+
expectedPath: "/",
230+
},
231+
{
232+
name: "simple_path",
233+
baseURL: "http://localhost:8080/api",
234+
expectedPath: "/api",
235+
},
236+
{
237+
name: "nested_path",
238+
baseURL: "http://localhost:8080/api/v1",
239+
expectedPath: "/api/v1",
240+
},
241+
{
242+
name: "trailing_slash_removed",
243+
baseURL: "http://localhost:8080/api/",
244+
expectedPath: "/api",
245+
},
246+
{
247+
name: "root_with_trailing_slash",
248+
baseURL: "http://localhost:8080/",
249+
expectedPath: "/",
250+
},
251+
{
252+
name: "path_with_multiple_segments",
253+
baseURL: "https://mcp.company.com/mcp-api/v1",
254+
expectedPath: "/mcp-api/v1",
255+
},
256+
{
257+
name: "invalid_url",
258+
baseURL: "://invalid",
259+
expectError: true,
260+
},
261+
}
262+
263+
for _, tt := range tests {
264+
t.Run(tt.name, func(t *testing.T) {
265+
cfg := Config{
266+
Proxy: ProxyConfig{
267+
BaseURL: tt.baseURL,
268+
},
269+
}
270+
271+
err := extractBasePath(&cfg)
272+
273+
if tt.expectError {
274+
assert.Error(t, err)
275+
} else {
276+
assert.NoError(t, err)
277+
assert.Equal(t, tt.expectedPath, cfg.Proxy.BasePath)
278+
}
279+
})
280+
}
281+
}

internal/config/types.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ type OAuthAuthConfig struct {
221221
// ProxyConfig represents the proxy configuration with resolved values
222222
type ProxyConfig struct {
223223
BaseURL string `json:"baseURL"`
224+
BasePath string `json:"-"` // Extracted from BaseURL, not in JSON
224225
Addr string `json:"addr"`
225226
Name string `json:"name"`
226227
Auth *OAuthAuthConfig `json:"auth,omitempty"` // Only OAuth is supported

internal/mcpfront.go

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -294,9 +294,17 @@ func buildHTTPHandler(
294294
userTokenService *server.UserTokenService,
295295
baseURL string,
296296
info mcp.Implementation,
297-
) (*http.ServeMux, error) {
297+
) (http.Handler, error) {
298298
// Create mux and register all routes with dependency injection
299299
mux := http.NewServeMux()
300+
basePath := cfg.Proxy.BasePath
301+
302+
route := func(path string) string {
303+
if basePath == "/" {
304+
return path
305+
}
306+
return basePath + path
307+
}
300308

301309
// Build common middleware
302310
corsMiddleware := server.NewCORSMiddleware(authConfig.AllowedOrigins)
@@ -307,7 +315,6 @@ func buildHTTPHandler(
307315
mcpRecover := server.NewRecoverMiddleware("mcp")
308316
oauthRecover := server.NewRecoverMiddleware("oauth")
309317

310-
// Register health endpoint
311318
mux.Handle("/health", server.NewHealthHandler())
312319

313320
// Create browser state token for SSO middleware (used by both OAuth and admin routes)
@@ -337,13 +344,13 @@ func buildHTTPHandler(
337344
)
338345

339346
// Register OAuth endpoints
340-
mux.Handle("/.well-known/oauth-authorization-server", server.ChainMiddleware(http.HandlerFunc(authHandlers.WellKnownHandler), oauthMiddleware...))
341-
mux.Handle("/.well-known/oauth-protected-resource", server.ChainMiddleware(http.HandlerFunc(authHandlers.ProtectedResourceMetadataHandler), oauthMiddleware...))
342-
mux.Handle("/authorize", server.ChainMiddleware(http.HandlerFunc(authHandlers.AuthorizeHandler), oauthMiddleware...))
343-
mux.Handle("/oauth/callback", server.ChainMiddleware(http.HandlerFunc(authHandlers.GoogleCallbackHandler), oauthMiddleware...))
344-
mux.Handle("/token", server.ChainMiddleware(http.HandlerFunc(authHandlers.TokenHandler), oauthMiddleware...))
345-
mux.Handle("/register", server.ChainMiddleware(http.HandlerFunc(authHandlers.RegisterHandler), oauthMiddleware...))
346-
mux.Handle("/clients/{client_id}", server.ChainMiddleware(http.HandlerFunc(authHandlers.ClientMetadataHandler), oauthMiddleware...))
347+
mux.Handle(route("/.well-known/oauth-authorization-server"), server.ChainMiddleware(http.HandlerFunc(authHandlers.WellKnownHandler), oauthMiddleware...))
348+
mux.Handle(route("/.well-known/oauth-protected-resource"), server.ChainMiddleware(http.HandlerFunc(authHandlers.ProtectedResourceMetadataHandler), oauthMiddleware...))
349+
mux.Handle(route("/authorize"), server.ChainMiddleware(http.HandlerFunc(authHandlers.AuthorizeHandler), oauthMiddleware...))
350+
mux.Handle(route("/oauth/callback"), server.ChainMiddleware(http.HandlerFunc(authHandlers.GoogleCallbackHandler), oauthMiddleware...))
351+
mux.Handle(route("/token"), server.ChainMiddleware(http.HandlerFunc(authHandlers.TokenHandler), oauthMiddleware...))
352+
mux.Handle(route("/register"), server.ChainMiddleware(http.HandlerFunc(authHandlers.RegisterHandler), oauthMiddleware...))
353+
mux.Handle(route("/clients/{client_id}"), server.ChainMiddleware(http.HandlerFunc(authHandlers.ClientMetadataHandler), oauthMiddleware...))
347354

348355
// Register protected token endpoints
349356
tokenMiddleware := []server.MiddlewareFunc{
@@ -357,19 +364,19 @@ func buildHTTPHandler(
357364
tokenHandlers := server.NewTokenHandlers(storage, cfg.MCPServers, true, serviceOAuthClient)
358365

359366
// Token management UI endpoints
360-
mux.Handle("/my/tokens", server.ChainMiddleware(http.HandlerFunc(tokenHandlers.ListTokensHandler), tokenMiddleware...))
361-
mux.Handle("/my/tokens/set", server.ChainMiddleware(http.HandlerFunc(tokenHandlers.SetTokenHandler), tokenMiddleware...))
362-
mux.Handle("/my/tokens/delete", server.ChainMiddleware(http.HandlerFunc(tokenHandlers.DeleteTokenHandler), tokenMiddleware...))
367+
mux.Handle(route("/my/tokens"), server.ChainMiddleware(http.HandlerFunc(tokenHandlers.ListTokensHandler), tokenMiddleware...))
368+
mux.Handle(route("/my/tokens/set"), server.ChainMiddleware(http.HandlerFunc(tokenHandlers.SetTokenHandler), tokenMiddleware...))
369+
mux.Handle(route("/my/tokens/delete"), server.ChainMiddleware(http.HandlerFunc(tokenHandlers.DeleteTokenHandler), tokenMiddleware...))
363370

364371
// OAuth interstitial page and completion endpoint
365-
mux.Handle("/oauth/services", server.ChainMiddleware(http.HandlerFunc(authHandlers.ServiceSelectionHandler), tokenMiddleware...))
366-
mux.Handle("/oauth/complete", server.ChainMiddleware(http.HandlerFunc(authHandlers.CompleteOAuthHandler), tokenMiddleware...))
372+
mux.Handle(route("/oauth/services"), server.ChainMiddleware(http.HandlerFunc(authHandlers.ServiceSelectionHandler), tokenMiddleware...))
373+
mux.Handle(route("/oauth/complete"), server.ChainMiddleware(http.HandlerFunc(authHandlers.CompleteOAuthHandler), tokenMiddleware...))
367374

368375
// Register service OAuth endpoints
369376
serviceAuthHandlers := server.NewServiceAuthHandlers(serviceOAuthClient, cfg.MCPServers, storage)
370-
mux.HandleFunc("/oauth/callback/", serviceAuthHandlers.CallbackHandler)
371-
mux.Handle("/oauth/connect", server.ChainMiddleware(http.HandlerFunc(serviceAuthHandlers.ConnectHandler), tokenMiddleware...))
372-
mux.Handle("/oauth/disconnect", server.ChainMiddleware(http.HandlerFunc(serviceAuthHandlers.DisconnectHandler), tokenMiddleware...))
377+
mux.HandleFunc(route("/oauth/callback/{service}"), serviceAuthHandlers.CallbackHandler)
378+
mux.Handle(route("/oauth/connect"), server.ChainMiddleware(http.HandlerFunc(serviceAuthHandlers.ConnectHandler), tokenMiddleware...))
379+
mux.Handle(route("/oauth/disconnect"), server.ChainMiddleware(http.HandlerFunc(serviceAuthHandlers.DisconnectHandler), tokenMiddleware...))
373380
}
374381

375382
// Setup MCP server endpoints
@@ -411,8 +418,8 @@ func buildHTTPHandler(
411418
baseURL,
412419
info,
413420
sessionManager,
414-
sseServers[serverName], // Pass the shared SSE server (nil for non-stdio)
415-
mcpServer, // Pass the shared MCP server (nil for non-stdio)
421+
sseServers[serverName],
422+
mcpServer,
416423
userTokenService.GetUserToken,
417424
)
418425
}
@@ -436,8 +443,7 @@ func buildHTTPHandler(
436443
// Recovery middleware should be last (outermost)
437444
mcpMiddlewares = append(mcpMiddlewares, mcpRecover)
438445

439-
// Register handler - SSE server needs to handle all paths under the server name
440-
mux.Handle("/"+serverName+"/", server.ChainMiddleware(handler, mcpMiddlewares...))
446+
mux.Handle(route("/"+serverName+"/"), server.ChainMiddleware(handler, mcpMiddlewares...))
441447
}
442448

443449
// Setup admin routes if admin is enabled
@@ -474,10 +480,10 @@ func buildHTTPHandler(
474480
adminMiddleware = append(adminMiddleware, mcpRecover)
475481

476482
// Register admin routes
477-
mux.Handle("/admin", server.ChainMiddleware(http.HandlerFunc(adminHandlers.DashboardHandler), adminMiddleware...))
478-
mux.Handle("/admin/users", server.ChainMiddleware(http.HandlerFunc(adminHandlers.UserActionHandler), adminMiddleware...))
479-
mux.Handle("/admin/sessions", server.ChainMiddleware(http.HandlerFunc(adminHandlers.SessionActionHandler), adminMiddleware...))
480-
mux.Handle("/admin/logging", server.ChainMiddleware(http.HandlerFunc(adminHandlers.LoggingActionHandler), adminMiddleware...))
483+
mux.Handle(route("/admin"), server.ChainMiddleware(http.HandlerFunc(adminHandlers.DashboardHandler), adminMiddleware...))
484+
mux.Handle(route("/admin/users"), server.ChainMiddleware(http.HandlerFunc(adminHandlers.UserActionHandler), adminMiddleware...))
485+
mux.Handle(route("/admin/sessions"), server.ChainMiddleware(http.HandlerFunc(adminHandlers.SessionActionHandler), adminMiddleware...))
486+
mux.Handle(route("/admin/logging"), server.ChainMiddleware(http.HandlerFunc(adminHandlers.LoggingActionHandler), adminMiddleware...))
481487
}
482488

483489
log.LogInfoWithFields("server", "MCP proxy server initialized", nil)

internal/oauth/resource_indicators.go

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,21 @@ func BuildResourceURI(issuer string, serviceName string) (string, error) {
165165
// ValidateAudienceForService("/linear/sse", []string{"https://mcp.company.com/postgres"}, "https://mcp.company.com")
166166
// Returns: error (invalid - linear not in audience)
167167
func ValidateAudienceForService(requestPath string, tokenAudience []string, issuer string) error {
168-
path := strings.TrimPrefix(requestPath, "/")
168+
u, err := url.Parse(issuer)
169+
if err != nil {
170+
return fmt.Errorf("invalid issuer URL: %w", err)
171+
}
172+
173+
basePath := u.Path
174+
if basePath == "" {
175+
basePath = "/"
176+
}
177+
178+
path := requestPath
179+
if basePath != "/" {
180+
path = strings.TrimPrefix(path, basePath)
181+
}
182+
path = strings.TrimPrefix(path, "/")
169183
parts := strings.Split(path, "/")
170184

171185
if len(parts) == 0 || parts[0] == "" {

internal/server/service_auth_handlers.go

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -93,13 +93,7 @@ func (h *ServiceAuthHandlers) CallbackHandler(w http.ResponseWriter, r *http.Req
9393
return
9494
}
9595

96-
// Extract service name from path: /oauth/callback/{service}
97-
pathParts := strings.Split(strings.TrimPrefix(r.URL.Path, "/"), "/")
98-
if len(pathParts) < 3 {
99-
jsonwriter.WriteBadRequest(w, "Invalid callback path")
100-
return
101-
}
102-
serviceName := pathParts[2]
96+
serviceName := r.PathValue("service")
10397
if serviceName == "" {
10498
jsonwriter.WriteBadRequest(w, "Service name is required")
10599
return

0 commit comments

Comments
 (0)