diff --git a/pkg/vmcp/cli/serve.go b/pkg/vmcp/cli/serve.go index c26061c974..4f8d8c04f1 100644 --- a/pkg/vmcp/cli/serve.go +++ b/pkg/vmcp/cli/serve.go @@ -352,8 +352,9 @@ func Serve(ctx context.Context, cfg ServeConfig) error { } sessionFactory := createSessionFactory(outgoingRegistry, agg) - // When the optimizer is enabled, its meta-tools must pass through the authz - // response filter so they appear in tools/list. + // When the optimizer is enabled, its meta-tools are pass-through tools. + // Authz uses this for optimizer-aware authorization/filtering, and rate + // limiting uses it to resolve call_tool to the inner backend tool name. var passThroughTools map[string]struct{} if optCfg != nil { passThroughTools = map[string]struct{}{ @@ -382,10 +383,11 @@ func Serve(ctx context.Context, cfg ServeConfig) error { namespace := vmcpNamespace() rateLimitMiddleware, rateLimitCleanup, err := ratelimitfactory.NewMiddleware(ctx, ratelimitfactory.Config{ - Namespace: namespace, - ServerName: vmcpCfg.Name, - RateLimiting: vmcpCfg.RateLimiting, - SessionStorage: vmcpCfg.SessionStorage, + Namespace: namespace, + ServerName: vmcpCfg.Name, + RateLimiting: vmcpCfg.RateLimiting, + SessionStorage: vmcpCfg.SessionStorage, + PassThroughTools: passThroughTools, }) if err != nil { return fmt.Errorf("failed to create rate limit middleware: %w", err) diff --git a/pkg/vmcp/ratelimit/factory/middleware.go b/pkg/vmcp/ratelimit/factory/middleware.go index 3460fa8e3d..56d7ebc43e 100644 --- a/pkg/vmcp/ratelimit/factory/middleware.go +++ b/pkg/vmcp/ratelimit/factory/middleware.go @@ -9,17 +9,20 @@ import ( "fmt" "net/http" + mcpparser "github.com/stacklok/toolhive/pkg/mcp" "github.com/stacklok/toolhive/pkg/ratelimit" ratelimittypes "github.com/stacklok/toolhive/pkg/ratelimit/types" vmcpconfig "github.com/stacklok/toolhive/pkg/vmcp/config" + "github.com/stacklok/toolhive/pkg/vmcp/session/optimizerdec" ) // Config contains the vMCP rate-limit middleware inputs. type Config struct { - Namespace string - ServerName string - RateLimiting *ratelimittypes.RateLimitConfig - SessionStorage *vmcpconfig.SessionStorageConfig + Namespace string + ServerName string + RateLimiting *ratelimittypes.RateLimitConfig + SessionStorage *vmcpconfig.SessionStorageConfig + PassThroughTools map[string]struct{} } // NewMiddleware creates Redis-backed rate-limit middleware for vMCP. @@ -51,5 +54,59 @@ func NewMiddleware( cleanup := func(context.Context) error { return middleware.Close() } - return middleware.Handler(), cleanup, nil + return withOptimizerToolNameResolution(middleware.Handler(), cfg.PassThroughTools), cleanup, nil +} + +func withOptimizerToolNameResolution( + rateLimitMiddleware func(http.Handler) http.Handler, + passThroughTools map[string]struct{}, +) func(http.Handler) http.Handler { + if len(passThroughTools) == 0 { + return rateLimitMiddleware + } + + return func(next http.Handler) http.Handler { + normalHandler := rateLimitMiddleware(next) + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + parsed := mcpparser.GetParsedMCPRequest(r.Context()) + resolved := optimizerRateLimitRequest(parsed, passThroughTools) + if resolved == parsed { + normalHandler.ServeHTTP(w, r) + return + } + + // Rate limiting needs the inner backend tool name for optimizer + // call_tool requests, but downstream middleware must still see the + // original call_tool request. Override the parsed request only while + // invoking the shared rate-limit middleware, then restore the original + // request before continuing the vMCP handler chain. + ctx := context.WithValue(r.Context(), mcpparser.MCPRequestContextKey, resolved) + rateLimitRequest := r.WithContext(ctx) + restoreOriginalRequest := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + next.ServeHTTP(w, r) + }) + rateLimitMiddleware(restoreOriginalRequest).ServeHTTP(w, rateLimitRequest) + }) + } +} + +func optimizerRateLimitRequest( + parsed *mcpparser.ParsedMCPRequest, + passThroughTools map[string]struct{}, +) *mcpparser.ParsedMCPRequest { + if parsed == nil || parsed.Method != "tools/call" { + return parsed + } + if _, ok := passThroughTools[parsed.ResourceID]; !ok { + return parsed + } + + innerToolName, ok := parsed.Arguments[optimizerdec.CallToolArgToolName].(string) + if !ok || innerToolName == "" { + return parsed + } + + resolved := *parsed + resolved.ResourceID = innerToolName + return &resolved } diff --git a/pkg/vmcp/ratelimit/factory/middleware_test.go b/pkg/vmcp/ratelimit/factory/middleware_test.go index 7175593e0c..be269c8033 100644 --- a/pkg/vmcp/ratelimit/factory/middleware_test.go +++ b/pkg/vmcp/ratelimit/factory/middleware_test.go @@ -21,6 +21,7 @@ import ( "github.com/stacklok/toolhive/pkg/ratelimit" ratelimittypes "github.com/stacklok/toolhive/pkg/ratelimit/types" vmcpconfig "github.com/stacklok/toolhive/pkg/vmcp/config" + "github.com/stacklok/toolhive/pkg/vmcp/session/optimizerdec" ) func TestNewMiddlewareDisabledWithoutConfig(t *testing.T) { @@ -158,14 +159,174 @@ func TestRateLimitMiddlewareUsesPostAggregationToolNames(t *testing.T) { assert.Equal(t, http.StatusTooManyRequests, secondMatchingTool.Code) } +func TestRateLimitToolNameOptimizerExtractsInnerToolName(t *testing.T) { + t.Parallel() + + parsed := parsedToolCall("call_tool", map[string]any{ + optimizerdec.CallToolArgToolName: "backend_a_echo", + }, 1) + + resolved := optimizerRateLimitRequest(parsed, map[string]struct{}{ + optimizerdec.CallToolName: {}, + }) + + require.NotSame(t, parsed, resolved) + assert.Equal(t, "backend_a_echo", resolved.ResourceID) + assert.Equal(t, "call_tool", parsed.ResourceID, "original parsed request should not be mutated") +} + +func TestRateLimitToolNameFallsBackForInvalidInnerToolName(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + arguments map[string]any + }{ + { + name: "missing tool_name", + arguments: map[string]any{}, + }, + { + name: "empty tool_name", + arguments: map[string]any{ + optimizerdec.CallToolArgToolName: "", + }, + }, + { + name: "non-string tool_name", + arguments: map[string]any{ + optimizerdec.CallToolArgToolName: 123, + }, + }, + { + name: "nil arguments", + arguments: nil, + }, + } + + passThroughTools := map[string]struct{}{ + optimizerdec.CallToolName: {}, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + parsed := parsedToolCall("call_tool", tc.arguments, 1) + + resolved := optimizerRateLimitRequest(parsed, passThroughTools) + + assert.Same(t, parsed, resolved) + assert.Equal(t, "call_tool", resolved.ResourceID) + }) + } +} + +func TestRateLimitToolNameFallsBackForNilParsedRequest(t *testing.T) { + t.Parallel() + + resolved := optimizerRateLimitRequest(nil, map[string]struct{}{ + optimizerdec.CallToolName: {}, + }) + + assert.Nil(t, resolved) +} + +func TestRateLimitToolNameFallsBackForNonPassThroughTool(t *testing.T) { + t.Parallel() + + parsed := parsedToolCall("backend_a_echo", map[string]any{ + optimizerdec.CallToolArgToolName: "backend_b_echo", + }, 1) + + resolved := optimizerRateLimitRequest(parsed, map[string]struct{}{ + optimizerdec.CallToolName: {}, + }) + + assert.Same(t, parsed, resolved) + assert.Equal(t, "backend_a_echo", resolved.ResourceID) +} + +func TestRateLimitMiddlewareOptimizerUsesInnerToolName(t *testing.T) { + t.Parallel() + + handler := newTestRateLimitHandlerWithPassThroughTools(t, &ratelimittypes.RateLimitConfig{ + Tools: []ratelimittypes.ToolRateLimitConfig{ + { + Name: "backend_a_echo", + Shared: &ratelimittypes.RateLimitBucket{ + MaxTokens: 1, + RefillPeriod: metav1.Duration{Duration: time.Minute}, + }, + }, + }, + }, map[string]struct{}{ + optimizerdec.CallToolName: {}, + }) + + first := serveToolCallWithArguments(t, handler, optimizerdec.CallToolName, "", map[string]any{ + optimizerdec.CallToolArgToolName: "backend_a_echo", + }) + assert.Equal(t, http.StatusOK, first.Code) + + otherTool := serveToolCallWithArguments(t, handler, optimizerdec.CallToolName, "", map[string]any{ + optimizerdec.CallToolArgToolName: "backend_b_echo", + }) + assert.Equal(t, http.StatusOK, otherTool.Code) + + secondMatchingTool := serveToolCallWithArguments(t, handler, optimizerdec.CallToolName, "", map[string]any{ + optimizerdec.CallToolArgToolName: "backend_a_echo", + }) + assert.Equal(t, http.StatusTooManyRequests, secondMatchingTool.Code) + assertRateLimitedBody(t, secondMatchingTool) +} + +func TestRateLimitMiddlewareOptimizerFallsBackForInvalidInnerToolName(t *testing.T) { + t.Parallel() + + handler := newTestRateLimitHandlerWithPassThroughTools(t, &ratelimittypes.RateLimitConfig{ + Tools: []ratelimittypes.ToolRateLimitConfig{ + { + Name: optimizerdec.CallToolName, + Shared: &ratelimittypes.RateLimitBucket{ + MaxTokens: 1, + RefillPeriod: metav1.Duration{Duration: time.Minute}, + }, + }, + }, + }, map[string]struct{}{ + optimizerdec.CallToolName: {}, + }) + + first := serveToolCallWithArguments(t, handler, optimizerdec.CallToolName, "", map[string]any{ + optimizerdec.CallToolArgToolName: "", + }) + assert.Equal(t, http.StatusOK, first.Code) + + second := serveToolCallWithArguments(t, handler, optimizerdec.CallToolName, "", map[string]any{ + optimizerdec.CallToolArgToolName: "", + }) + assert.Equal(t, http.StatusTooManyRequests, second.Code) +} + func newTestRateLimitHandler(t *testing.T, cfg *ratelimittypes.RateLimitConfig) http.Handler { t.Helper() + return newTestRateLimitHandlerWithPassThroughTools(t, cfg, nil) +} + +func newTestRateLimitHandlerWithPassThroughTools( + t *testing.T, + cfg *ratelimittypes.RateLimitConfig, + passThroughTools map[string]struct{}, +) http.Handler { + t.Helper() + mr := miniredis.RunT(t) middleware, cleanup, err := NewMiddleware(t.Context(), Config{ - Namespace: "default", - ServerName: "vmcp", - RateLimiting: cfg, + Namespace: "default", + ServerName: "vmcp", + RateLimiting: cfg, + PassThroughTools: passThroughTools, SessionStorage: &vmcpconfig.SessionStorageConfig{ Provider: "redis", Address: mr.Addr(), @@ -186,8 +347,20 @@ func newTestRateLimitHandler(t *testing.T, cfg *ratelimittypes.RateLimitConfig) func serveToolCall(t *testing.T, handler http.Handler, toolName, userID string) *httptest.ResponseRecorder { t.Helper() + return serveToolCallWithArguments(t, handler, toolName, userID, nil) +} + +func serveToolCallWithArguments( + t *testing.T, + handler http.Handler, + toolName string, + userID string, + arguments map[string]any, +) *httptest.ResponseRecorder { + t.Helper() + req := httptest.NewRequest(http.MethodPost, "/mcp", nil) - req = withParsedMCPRequest(req, "tools/call", toolName, 1) + req = withParsedMCPRequest(req, "tools/call", toolName, arguments, 1) if userID != "" { req = withIdentity(req, userID) } @@ -196,15 +369,35 @@ func serveToolCall(t *testing.T, handler http.Handler, toolName, userID string) return w } -func withParsedMCPRequest(r *http.Request, method, resourceID string, id any) *http.Request { - parsed := &mcpparser.ParsedMCPRequest{ +func withParsedMCPRequest( + r *http.Request, + method string, + resourceID string, + arguments map[string]any, + id any, +) *http.Request { + parsed := parsedMCPRequest(method, resourceID, arguments, id) + ctx := context.WithValue(r.Context(), mcpparser.MCPRequestContextKey, parsed) + return r.WithContext(ctx) +} + +func parsedToolCall(resourceID string, arguments map[string]any, id any) *mcpparser.ParsedMCPRequest { + return parsedMCPRequest("tools/call", resourceID, arguments, id) +} + +func parsedMCPRequest( + method string, + resourceID string, + arguments map[string]any, + id any, +) *mcpparser.ParsedMCPRequest { + return &mcpparser.ParsedMCPRequest{ Method: method, ResourceID: resourceID, + Arguments: arguments, ID: id, IsRequest: true, } - ctx := context.WithValue(r.Context(), mcpparser.MCPRequestContextKey, parsed) - return r.WithContext(ctx) } func withIdentity(r *http.Request, subject string) *http.Request { diff --git a/test/e2e/thv-operator/virtualmcp/virtualmcp_rate_limiting_test.go b/test/e2e/thv-operator/virtualmcp/virtualmcp_rate_limiting_test.go index 3a4f89f85b..9476c956c1 100644 --- a/test/e2e/thv-operator/virtualmcp/virtualmcp_rate_limiting_test.go +++ b/test/e2e/thv-operator/virtualmcp/virtualmcp_rate_limiting_test.go @@ -199,6 +199,177 @@ var _ = ginkgo.Describe("VirtualMCPServer Rate Limiting", ginkgo.Ordered, func() }) }) +var _ = ginkgo.Describe("VirtualMCPServer Rate Limiting with Optimizer", ginkgo.Ordered, func() { + const ( + timeout = 5 * time.Minute + pollInterval = 2 * time.Second + optimizerEchoTool = "optimizer_rl_echo" + optimizerEchoPrompt = "ratelimit optimizer test" + ) + + var ( + mcpGroupName string + backendName string + fakeEmbeddingName string + vmcpName string + redisName string + vmcpLocalPort int + vmcpPortForwardCleanup func() + ) + + ginkgo.BeforeAll(func() { + ts := time.Now().UnixNano() + mcpGroupName = fmt.Sprintf("e2e-rl-opt-group-%d", ts) + backendName = fmt.Sprintf("e2e-rl-opt-backend-%d", ts) + fakeEmbeddingName = fmt.Sprintf("e2e-rl-opt-embedding-%d", ts) + vmcpName = fmt.Sprintf("e2e-rl-opt-vmcp-%d", ts) + redisName = fmt.Sprintf("e2e-rl-opt-redis-%d", ts) + + ginkgo.By("Deploying Redis") + deployRedis(redisName) + + ginkgo.By("Deploying fake embedding server") + embeddingURL := DeployFakeEmbeddingServer(ctx, k8sClient, + fakeEmbeddingName, defaultNamespace, timeout, pollInterval) + + ginkgo.By("Creating MCPGroup") + CreateMCPGroupAndWait(ctx, k8sClient, mcpGroupName, defaultNamespace, + "E2E vMCP optimizer rate limiting group", timeout, pollInterval) + + ginkgo.By("Creating backend MCPServer") + gomega.Expect(k8sClient.Create(ctx, &mcpv1beta1.MCPServer{ + ObjectMeta: metav1.ObjectMeta{Name: backendName, Namespace: defaultNamespace}, + Spec: mcpv1beta1.MCPServerSpec{ + GroupRef: &mcpv1beta1.MCPGroupRef{Name: mcpGroupName}, + Image: images.YardstickServerImage, + Transport: "streamable-http", + ProxyPort: 8080, + MCPPort: 8080, + }, + })).To(gomega.Succeed()) + + ginkgo.By("Waiting for backend MCPServer to be ready") + gomega.Eventually(func() error { + server := &mcpv1beta1.MCPServer{} + if err := k8sClient.Get(ctx, types.NamespacedName{ + Name: backendName, + Namespace: defaultNamespace, + }, server); err != nil { + return err + } + if server.Status.Phase != mcpv1beta1.MCPServerPhaseReady { + return fmt.Errorf("backend not ready yet, phase: %s", server.Status.Phase) + } + return nil + }, timeout, pollInterval).Should(gomega.Succeed()) + + redisAddr := fmt.Sprintf("%s.%s.svc.cluster.local:6379", redisName, defaultNamespace) + ginkgo.By("Creating VirtualMCPServer with optimizer and per-tool rate limiting") + gomega.Expect(k8sClient.Create(ctx, &mcpv1beta1.VirtualMCPServer{ + ObjectMeta: metav1.ObjectMeta{Name: vmcpName, Namespace: defaultNamespace}, + Spec: mcpv1beta1.VirtualMCPServerSpec{ + GroupRef: &mcpv1beta1.MCPGroupRef{Name: mcpGroupName}, + Config: vmcpconfig.Config{ + Group: mcpGroupName, + Optimizer: &vmcpconfig.OptimizerConfig{ + EmbeddingService: embeddingURL, + }, + Aggregation: &vmcpconfig.AggregationConfig{ + ConflictResolution: "prefix", + Tools: []*vmcpconfig.WorkloadToolConfig{ + { + Workload: backendName, + Overrides: map[string]*vmcpconfig.ToolOverride{ + "echo": { + Name: optimizerEchoTool, + Description: "Echo tool for optimizer rate-limit E2E", + }, + }, + }, + }, + }, + RateLimiting: &mcpv1beta1.RateLimitConfig{ + Tools: []mcpv1beta1.ToolRateLimitConfig{ + { + Name: optimizerEchoTool, + Shared: &mcpv1beta1.RateLimitBucket{ + MaxTokens: 1, + RefillPeriod: metav1.Duration{Duration: time.Minute}, + }, + }, + }, + }, + }, + IncomingAuth: &mcpv1beta1.IncomingAuthConfig{ + Type: "anonymous", + }, + OutgoingAuth: &mcpv1beta1.OutgoingAuthConfig{ + Source: "discovered", + }, + SessionStorage: &mcpv1beta1.SessionStorageConfig{ + Provider: mcpv1beta1.SessionStorageProviderRedis, + Address: redisAddr, + }, + }, + })).To(gomega.Succeed()) + + ginkgo.By("Waiting for VirtualMCPServer to be ready") + WaitForVirtualMCPServerReady(ctx, k8sClient, vmcpName, defaultNamespace, timeout, pollInterval) + + ginkgo.By("Port-forwarding VirtualMCPServer service") + var err error + vmcpLocalPort, vmcpPortForwardCleanup, err = startRateLimitServicePortForward(VMCPServiceName(vmcpName), 4483) + gomega.Expect(err).ToNot(gomega.HaveOccurred()) + }) + + ginkgo.AfterAll(func() { + if vmcpPortForwardCleanup != nil { + vmcpPortForwardCleanup() + } + _ = k8sClient.Delete(ctx, &mcpv1beta1.VirtualMCPServer{ + ObjectMeta: metav1.ObjectMeta{Name: vmcpName, Namespace: defaultNamespace}, + }) + _ = k8sClient.Delete(ctx, &mcpv1beta1.MCPServer{ + ObjectMeta: metav1.ObjectMeta{Name: backendName, Namespace: defaultNamespace}, + }) + _ = k8sClient.Delete(ctx, &mcpv1beta1.MCPGroup{ + ObjectMeta: metav1.ObjectMeta{Name: mcpGroupName, Namespace: defaultNamespace}, + }) + CleanupFakeEmbeddingServer(ctx, k8sClient, fakeEmbeddingName, defaultNamespace) + cleanupRedis(redisName) + }) + + ginkgo.It("rate-limits call_tool by the inner backend tool name", func() { + mcpClient := newAnonymousRateLimitMCPClient(vmcpLocalPort) + defer mcpClient.Close() + + tools, err := mcpClient.ListTools(ctx, mcp.ListToolsRequest{}) + gomega.Expect(err).ToNot(gomega.HaveOccurred()) + gomega.Expect(toolNames(tools.Tools)).To(gomega.ConsistOf("find_tool", "call_tool")) + + _, err = callToolViaOptimizer(&InitializedMCPClient{ + Client: mcpClient, + Ctx: ctx, + }, optimizerEchoTool, map[string]any{ + "input": optimizerEchoPrompt, + }) + gomega.Expect(err).ToNot(gomega.HaveOccurred()) + + _, err = callToolViaOptimizer(&InitializedMCPClient{ + Client: mcpClient, + Ctx: ctx, + }, optimizerEchoTool, map[string]any{ + "input": optimizerEchoPrompt, + }) + gomega.Expect(err).To(gomega.HaveOccurred()) + gomega.Expect(err.Error()).To(gomega.Or( + gomega.ContainSubstring("429"), + gomega.ContainSubstring("-32029"), + gomega.ContainSubstring("Rate limit exceeded"), + )) + }) +}) + func fetchRateLimitOIDCToken(oidcPort int, subject string) string { url := fmt.Sprintf("http://localhost:%d/token?subject=%s", oidcPort, subject) resp, err := http.Post(url, "application/x-www-form-urlencoded", nil) //nolint:noctx @@ -223,6 +394,11 @@ func newRateLimitMCPClient(vmcpPort int, token string) *mcpclient.Client { return InitializeMCPClientWithRetries(serverURL, 2*time.Minute, transport.WithHTTPBasicClient(httpClient)) } +func newAnonymousRateLimitMCPClient(vmcpPort int) *mcpclient.Client { + serverURL := fmt.Sprintf("http://localhost:%d/mcp", vmcpPort) + return InitializeMCPClientWithRetries(serverURL, 2*time.Minute) +} + func startRateLimitServicePortForward(serviceName string, servicePort int32) (int, func(), error) { listener, err := net.Listen("tcp", ":0") if err != nil {