Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 24 additions & 4 deletions provider/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,20 @@ func (p *OpenAI) CreateInterceptor(w http.ResponseWriter, r *http.Request, trace
_, span := tracer.Start(r.Context(), "Intercept.CreateInterceptor")
defer tracing.EndSpanErr(span, &outErr)

cfg := p.cfg

// In centralized mode, http.go strips Authorization (it carried the
// Coder token), so the header is absent and cfg keeps the centralized
// key.
//
// In BYOK mode, http.go only strips the BYOK header and leaves the
// user's LLM credentials intact. OpenAI uses Authorization: Bearer
// for both API keys and OAuth tokens, so we just extract the token
// and overwrite cfg.Key.
if bearer := r.Header.Get("Authorization"); bearer != "" {
cfg.Key = strings.TrimPrefix(bearer, "Bearer ")
}

var interceptor intercept.Interceptor

path := strings.TrimPrefix(r.URL.Path, p.RoutePrefix())
Expand All @@ -105,9 +119,9 @@ func (p *OpenAI) CreateInterceptor(w http.ResponseWriter, r *http.Request, trace
}

if req.Stream {
interceptor = chatcompletions.NewStreamingInterceptor(id, &req, p.cfg, r.Header, p.AuthHeader(), tracer)
interceptor = chatcompletions.NewStreamingInterceptor(id, &req, cfg, r.Header, p.AuthHeader(), tracer)
} else {
interceptor = chatcompletions.NewBlockingInterceptor(id, &req, p.cfg, r.Header, p.AuthHeader(), tracer)
interceptor = chatcompletions.NewBlockingInterceptor(id, &req, cfg, r.Header, p.AuthHeader(), tracer)
}

case routeResponses:
Expand All @@ -120,9 +134,9 @@ func (p *OpenAI) CreateInterceptor(w http.ResponseWriter, r *http.Request, trace
return nil, fmt.Errorf("unmarshal request body: %w", err)
}
if req.Stream {
interceptor = responses.NewStreamingInterceptor(id, &req, payload, p.cfg, string(req.Model), r.Header, p.AuthHeader(), tracer)
interceptor = responses.NewStreamingInterceptor(id, &req, payload, cfg, string(req.Model), r.Header, p.AuthHeader(), tracer)
} else {
interceptor = responses.NewBlockingInterceptor(id, &req, payload, p.cfg, string(req.Model), r.Header, p.AuthHeader(), tracer)
interceptor = responses.NewBlockingInterceptor(id, &req, payload, cfg, string(req.Model), r.Header, p.AuthHeader(), tracer)
}

default:
Expand All @@ -146,6 +160,12 @@ func (p *OpenAI) InjectAuthHeader(headers *http.Header) {
headers = &http.Header{}
}

// BYOK: if the request already carries user-supplied credentials,
// do not overwrite them with the centralized key.
if headers.Get("Authorization") != "" {
return
}

headers.Set(p.AuthHeader(), "Bearer "+p.cfg.Key)
}

Expand Down
121 changes: 88 additions & 33 deletions provider/openai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,66 +162,121 @@ func generateResponsesPayload(payloadSize int, inputCount int, stream bool) []by
func TestOpenAI_CreateInterceptor(t *testing.T) {
t.Parallel()

tests := []struct {
routes := []struct {
name string
route string
requestBody string
responseBody string
}{
{
name: "ChatCompletions_ClientHeaders",
name: "ChatCompletions",
route: routeChatCompletions,
requestBody: `{"model": "gpt-4", "messages": [{"role": "user", "content": "hello"}], "stream": false}`,
responseBody: chatCompletionResponse,
},
{
name: "Responses_ClientHeaders",
name: "Responses",
route: routeResponses,
requestBody: `{"model": "gpt-5", "input": "hello", "stream": false}`,
responseBody: responsesAPIResponse,
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
byokCases := []struct {
name string
setHeaders map[string]string
wantAuthorization string
}{
{
name: "Centralized_UsesCentralizedKey",
setHeaders: map[string]string{},
wantAuthorization: "Bearer test-key",
},
{
name: "BYOK_BearerToken",
setHeaders: map[string]string{"Authorization": "Bearer user-oauth-token"},
wantAuthorization: "Bearer user-oauth-token",
},
}

var receivedHeaders http.Header
for _, route := range routes {
for _, bc := range byokCases {
t.Run(route.name+"_"+bc.name, func(t *testing.T) {
t.Parallel()

var receivedHeaders http.Header

mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedHeaders = r.Header.Clone()
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, err := w.Write([]byte(route.responseBody))
require.NoError(t, err)
}))
t.Cleanup(mockUpstream.Close)

provider := NewOpenAI(config.OpenAI{
BaseURL: mockUpstream.URL,
Key: "test-key",
})

req := httptest.NewRequest(http.MethodPost, provider.RoutePrefix()+route.route, bytes.NewBufferString(route.requestBody))
for k, v := range bc.setHeaders {
req.Header.Set(k, v)
}
w := httptest.NewRecorder()

mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedHeaders = r.Header.Clone()
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, err := w.Write([]byte(tc.responseBody))
interceptor, err := provider.CreateInterceptor(w, req, testTracer)
require.NoError(t, err)
}))
t.Cleanup(mockUpstream.Close)
require.NotNil(t, interceptor)

logger := slog.Make()
interceptor.Setup(logger, &testutil.MockRecorder{}, nil)

provider := NewOpenAI(config.OpenAI{
BaseURL: mockUpstream.URL,
Key: "test-key",
processReq := httptest.NewRequest(http.MethodPost, provider.RoutePrefix()+route.route, nil)
err = interceptor.ProcessRequest(w, processReq)
require.NoError(t, err)

assert.Equal(t, bc.wantAuthorization, receivedHeaders.Get("Authorization"))
})
}
}
}

func TestOpenAI_InjectAuthHeader_BYOK(t *testing.T) {
t.Parallel()

req := httptest.NewRequest(http.MethodPost, provider.RoutePrefix()+tc.route, bytes.NewBufferString(tc.requestBody))
// Simulate a client sending its own auth credential, which must be replaced
// by aibridge with the configured provider key.
req.Header.Set("Authorization", "Bearer fake-client-bearer")
w := httptest.NewRecorder()
provider := NewOpenAI(config.OpenAI{Key: "centralized-key"})

interceptor, err := provider.CreateInterceptor(w, req, testTracer)
require.NoError(t, err)
require.NotNil(t, interceptor)
tests := []struct {
name string
presetHeaders map[string]string
wantAuthorization string
}{
{
name: "no pre-existing auth injects centralized key",
presetHeaders: map[string]string{},
wantAuthorization: "Bearer centralized-key",
},
{
name: "pre-existing Authorization is not overwritten",
presetHeaders: map[string]string{"Authorization": "Bearer user-oauth-token"},
wantAuthorization: "Bearer user-oauth-token",
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

logger := slog.Make()
interceptor.Setup(logger, &testutil.MockRecorder{}, nil)
headers := http.Header{}
for k, v := range tc.presetHeaders {
headers.Set(k, v)
}

processReq := httptest.NewRequest(http.MethodPost, provider.RoutePrefix()+tc.route, nil)
err = interceptor.ProcessRequest(w, processReq)
require.NoError(t, err)
provider.InjectAuthHeader(&headers)

// Verify aibridge's configured key was used and the client's auth credential was not forwarded.
assert.Equal(t, "Bearer test-key", receivedHeaders.Get("Authorization"), "upstream must receive configured provider key")
assert.Empty(t, receivedHeaders.Get("X-Api-Key"), "X-Api-Key must not be set upstream")
assert.Equal(t, tc.wantAuthorization, headers.Get("Authorization"))
})
}
}
Expand Down
Loading