Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions pkg/vmcp/cli/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -352,8 +352,9 @@ func Serve(ctx context.Context, cfg ServeConfig) error {
}
sessionFactory := createSessionFactory(outgoingRegistry, agg)

// When the optimizer is enabled, its meta-tools must pass through the authz
// response filter so they appear in tools/list.
// When the optimizer is enabled, its meta-tools are pass-through tools.
// Authz uses this for optimizer-aware authorization/filtering, and rate
// limiting uses it to resolve call_tool to the inner backend tool name.
var passThroughTools map[string]struct{}
if optCfg != nil {
passThroughTools = map[string]struct{}{
Expand Down Expand Up @@ -382,10 +383,11 @@ func Serve(ctx context.Context, cfg ServeConfig) error {

namespace := vmcpNamespace()
rateLimitMiddleware, rateLimitCleanup, err := ratelimitfactory.NewMiddleware(ctx, ratelimitfactory.Config{
Namespace: namespace,
ServerName: vmcpCfg.Name,
RateLimiting: vmcpCfg.RateLimiting,
SessionStorage: vmcpCfg.SessionStorage,
Namespace: namespace,
ServerName: vmcpCfg.Name,
RateLimiting: vmcpCfg.RateLimiting,
SessionStorage: vmcpCfg.SessionStorage,
PassThroughTools: passThroughTools,
})
if err != nil {
return fmt.Errorf("failed to create rate limit middleware: %w", err)
Expand Down
67 changes: 62 additions & 5 deletions pkg/vmcp/ratelimit/factory/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,20 @@ import (
"fmt"
"net/http"

mcpparser "github.com/stacklok/toolhive/pkg/mcp"
"github.com/stacklok/toolhive/pkg/ratelimit"
ratelimittypes "github.com/stacklok/toolhive/pkg/ratelimit/types"
vmcpconfig "github.com/stacklok/toolhive/pkg/vmcp/config"
"github.com/stacklok/toolhive/pkg/vmcp/session/optimizerdec"
)

// Config contains the vMCP rate-limit middleware inputs.
type Config struct {
Namespace string
ServerName string
RateLimiting *ratelimittypes.RateLimitConfig
SessionStorage *vmcpconfig.SessionStorageConfig
Namespace string
ServerName string
RateLimiting *ratelimittypes.RateLimitConfig
SessionStorage *vmcpconfig.SessionStorageConfig
PassThroughTools map[string]struct{}
}

// NewMiddleware creates Redis-backed rate-limit middleware for vMCP.
Expand Down Expand Up @@ -51,5 +54,59 @@ func NewMiddleware(
cleanup := func(context.Context) error {
return middleware.Close()
}
return middleware.Handler(), cleanup, nil
return withOptimizerToolNameResolution(middleware.Handler(), cfg.PassThroughTools), cleanup, nil
}

func withOptimizerToolNameResolution(
rateLimitMiddleware func(http.Handler) http.Handler,
passThroughTools map[string]struct{},
) func(http.Handler) http.Handler {
if len(passThroughTools) == 0 {
return rateLimitMiddleware
}

return func(next http.Handler) http.Handler {
normalHandler := rateLimitMiddleware(next)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
parsed := mcpparser.GetParsedMCPRequest(r.Context())
resolved := optimizerRateLimitRequest(parsed, passThroughTools)
if resolved == parsed {
normalHandler.ServeHTTP(w, r)
return
}

// Rate limiting needs the inner backend tool name for optimizer
// call_tool requests, but downstream middleware must still see the
// original call_tool request. Override the parsed request only while
// invoking the shared rate-limit middleware, then restore the original
// request before continuing the vMCP handler chain.
ctx := context.WithValue(r.Context(), mcpparser.MCPRequestContextKey, resolved)
rateLimitRequest := r.WithContext(ctx)
restoreOriginalRequest := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
next.ServeHTTP(w, r)
})
rateLimitMiddleware(restoreOriginalRequest).ServeHTTP(w, rateLimitRequest)
})
}
}

func optimizerRateLimitRequest(
parsed *mcpparser.ParsedMCPRequest,
passThroughTools map[string]struct{},
) *mcpparser.ParsedMCPRequest {
if parsed == nil || parsed.Method != "tools/call" {
return parsed
}
if _, ok := passThroughTools[parsed.ResourceID]; !ok {
return parsed
}

innerToolName, ok := parsed.Arguments[optimizerdec.CallToolArgToolName].(string)
if !ok || innerToolName == "" {
return parsed
}

resolved := *parsed
resolved.ResourceID = innerToolName
return &resolved
}
209 changes: 201 additions & 8 deletions pkg/vmcp/ratelimit/factory/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/stacklok/toolhive/pkg/ratelimit"
ratelimittypes "github.com/stacklok/toolhive/pkg/ratelimit/types"
vmcpconfig "github.com/stacklok/toolhive/pkg/vmcp/config"
"github.com/stacklok/toolhive/pkg/vmcp/session/optimizerdec"
)

func TestNewMiddlewareDisabledWithoutConfig(t *testing.T) {
Expand Down Expand Up @@ -158,14 +159,174 @@ func TestRateLimitMiddlewareUsesPostAggregationToolNames(t *testing.T) {
assert.Equal(t, http.StatusTooManyRequests, secondMatchingTool.Code)
}

func TestRateLimitToolNameOptimizerExtractsInnerToolName(t *testing.T) {
t.Parallel()

parsed := parsedToolCall("call_tool", map[string]any{
optimizerdec.CallToolArgToolName: "backend_a_echo",
}, 1)

resolved := optimizerRateLimitRequest(parsed, map[string]struct{}{
optimizerdec.CallToolName: {},
})

require.NotSame(t, parsed, resolved)
assert.Equal(t, "backend_a_echo", resolved.ResourceID)
assert.Equal(t, "call_tool", parsed.ResourceID, "original parsed request should not be mutated")
}

func TestRateLimitToolNameFallsBackForInvalidInnerToolName(t *testing.T) {
t.Parallel()

testCases := []struct {
name string
arguments map[string]any
}{
{
name: "missing tool_name",
arguments: map[string]any{},
},
{
name: "empty tool_name",
arguments: map[string]any{
optimizerdec.CallToolArgToolName: "",
},
},
{
name: "non-string tool_name",
arguments: map[string]any{
optimizerdec.CallToolArgToolName: 123,
},
},
{
name: "nil arguments",
arguments: nil,
},
}

passThroughTools := map[string]struct{}{
optimizerdec.CallToolName: {},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

parsed := parsedToolCall("call_tool", tc.arguments, 1)

resolved := optimizerRateLimitRequest(parsed, passThroughTools)

assert.Same(t, parsed, resolved)
assert.Equal(t, "call_tool", resolved.ResourceID)
})
}
}

func TestRateLimitToolNameFallsBackForNilParsedRequest(t *testing.T) {
t.Parallel()

resolved := optimizerRateLimitRequest(nil, map[string]struct{}{
optimizerdec.CallToolName: {},
})

assert.Nil(t, resolved)
}

func TestRateLimitToolNameFallsBackForNonPassThroughTool(t *testing.T) {
t.Parallel()

parsed := parsedToolCall("backend_a_echo", map[string]any{
optimizerdec.CallToolArgToolName: "backend_b_echo",
}, 1)

resolved := optimizerRateLimitRequest(parsed, map[string]struct{}{
optimizerdec.CallToolName: {},
})

assert.Same(t, parsed, resolved)
assert.Equal(t, "backend_a_echo", resolved.ResourceID)
}

func TestRateLimitMiddlewareOptimizerUsesInnerToolName(t *testing.T) {
t.Parallel()

handler := newTestRateLimitHandlerWithPassThroughTools(t, &ratelimittypes.RateLimitConfig{
Tools: []ratelimittypes.ToolRateLimitConfig{
{
Name: "backend_a_echo",
Shared: &ratelimittypes.RateLimitBucket{
MaxTokens: 1,
RefillPeriod: metav1.Duration{Duration: time.Minute},
},
},
},
}, map[string]struct{}{
optimizerdec.CallToolName: {},
})

first := serveToolCallWithArguments(t, handler, optimizerdec.CallToolName, "", map[string]any{
optimizerdec.CallToolArgToolName: "backend_a_echo",
})
assert.Equal(t, http.StatusOK, first.Code)

otherTool := serveToolCallWithArguments(t, handler, optimizerdec.CallToolName, "", map[string]any{
optimizerdec.CallToolArgToolName: "backend_b_echo",
})
assert.Equal(t, http.StatusOK, otherTool.Code)

secondMatchingTool := serveToolCallWithArguments(t, handler, optimizerdec.CallToolName, "", map[string]any{
optimizerdec.CallToolArgToolName: "backend_a_echo",
})
assert.Equal(t, http.StatusTooManyRequests, secondMatchingTool.Code)
assertRateLimitedBody(t, secondMatchingTool)
}

func TestRateLimitMiddlewareOptimizerFallsBackForInvalidInnerToolName(t *testing.T) {
t.Parallel()

handler := newTestRateLimitHandlerWithPassThroughTools(t, &ratelimittypes.RateLimitConfig{
Tools: []ratelimittypes.ToolRateLimitConfig{
{
Name: optimizerdec.CallToolName,
Shared: &ratelimittypes.RateLimitBucket{
MaxTokens: 1,
RefillPeriod: metav1.Duration{Duration: time.Minute},
},
},
},
}, map[string]struct{}{
optimizerdec.CallToolName: {},
})

first := serveToolCallWithArguments(t, handler, optimizerdec.CallToolName, "", map[string]any{
optimizerdec.CallToolArgToolName: "",
})
assert.Equal(t, http.StatusOK, first.Code)

second := serveToolCallWithArguments(t, handler, optimizerdec.CallToolName, "", map[string]any{
optimizerdec.CallToolArgToolName: "",
})
assert.Equal(t, http.StatusTooManyRequests, second.Code)
}

func newTestRateLimitHandler(t *testing.T, cfg *ratelimittypes.RateLimitConfig) http.Handler {
t.Helper()

return newTestRateLimitHandlerWithPassThroughTools(t, cfg, nil)
}

func newTestRateLimitHandlerWithPassThroughTools(
t *testing.T,
cfg *ratelimittypes.RateLimitConfig,
passThroughTools map[string]struct{},
) http.Handler {
t.Helper()

mr := miniredis.RunT(t)
middleware, cleanup, err := NewMiddleware(t.Context(), Config{
Namespace: "default",
ServerName: "vmcp",
RateLimiting: cfg,
Namespace: "default",
ServerName: "vmcp",
RateLimiting: cfg,
PassThroughTools: passThroughTools,
SessionStorage: &vmcpconfig.SessionStorageConfig{
Provider: "redis",
Address: mr.Addr(),
Expand All @@ -186,8 +347,20 @@ func newTestRateLimitHandler(t *testing.T, cfg *ratelimittypes.RateLimitConfig)
func serveToolCall(t *testing.T, handler http.Handler, toolName, userID string) *httptest.ResponseRecorder {
t.Helper()

return serveToolCallWithArguments(t, handler, toolName, userID, nil)
}

func serveToolCallWithArguments(
t *testing.T,
handler http.Handler,
toolName string,
userID string,
arguments map[string]any,
) *httptest.ResponseRecorder {
t.Helper()

req := httptest.NewRequest(http.MethodPost, "/mcp", nil)
req = withParsedMCPRequest(req, "tools/call", toolName, 1)
req = withParsedMCPRequest(req, "tools/call", toolName, arguments, 1)
if userID != "" {
req = withIdentity(req, userID)
}
Expand All @@ -196,15 +369,35 @@ func serveToolCall(t *testing.T, handler http.Handler, toolName, userID string)
return w
}

func withParsedMCPRequest(r *http.Request, method, resourceID string, id any) *http.Request {
parsed := &mcpparser.ParsedMCPRequest{
func withParsedMCPRequest(
r *http.Request,
method string,
resourceID string,
arguments map[string]any,
id any,
) *http.Request {
parsed := parsedMCPRequest(method, resourceID, arguments, id)
ctx := context.WithValue(r.Context(), mcpparser.MCPRequestContextKey, parsed)
return r.WithContext(ctx)
}

func parsedToolCall(resourceID string, arguments map[string]any, id any) *mcpparser.ParsedMCPRequest {
return parsedMCPRequest("tools/call", resourceID, arguments, id)
}

func parsedMCPRequest(
method string,
resourceID string,
arguments map[string]any,
id any,
) *mcpparser.ParsedMCPRequest {
return &mcpparser.ParsedMCPRequest{
Method: method,
ResourceID: resourceID,
Arguments: arguments,
ID: id,
IsRequest: true,
}
ctx := context.WithValue(r.Context(), mcpparser.MCPRequestContextKey, parsed)
return r.WithContext(ctx)
}

func withIdentity(r *http.Request, subject string) *http.Request {
Expand Down
Loading
Loading