diff --git a/cmd/server_main.go b/cmd/server_main.go index d13e781c0c7..c7a49b507a5 100644 --- a/cmd/server_main.go +++ b/cmd/server_main.go @@ -218,8 +218,20 @@ func startServer(config *server.Config) { &handler.SearchbotRealLLM{Svc: modelProviderService}, ) + // Dify retrieval handler + docDAO := dao.NewDocumentDAO() + retrievalService := nlp.NewRetrievalService(docEngine, docDAO) + difyRetrievalHandler := handler.NewDifyRetrievalHandler( + knowledgebaseService, + modelProviderService, + metadataService, + retrievalService, + docDAO, + docEngine, + ) + // Initialize router - r := router.NewRouter(authHandler, userHandler, tenantHandler, documentHandler, datasetsHandler, systemHandler, knowledgebaseHandler, chunkHandler, llmHandler, chatHandler, chatSessionHandler, connectorHandler, searchHandler, fileHandler, memoryHandler, mcpHandler, skillSearchHandler, providerHandler, agentHandler, relatedQuestionsHandler) + r := router.NewRouter(authHandler, userHandler, tenantHandler, documentHandler, datasetsHandler, systemHandler, knowledgebaseHandler, chunkHandler, llmHandler, chatHandler, chatSessionHandler, connectorHandler, searchHandler, fileHandler, memoryHandler, mcpHandler, skillSearchHandler, providerHandler, agentHandler, relatedQuestionsHandler, difyRetrievalHandler) // Create Gin engine ginEngine := gin.New() diff --git a/internal/handler/dify_retrieval_handler.go b/internal/handler/dify_retrieval_handler.go new file mode 100644 index 00000000000..51cb929d1dd --- /dev/null +++ b/internal/handler/dify_retrieval_handler.go @@ -0,0 +1,373 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package handler + +import ( + "context" + "errors" + "fmt" + "net/http" + "strconv" + "strings" + + "ragflow/internal/common" + "gorm.io/gorm" + "go.uber.org/zap" + "ragflow/internal/engine" + "ragflow/internal/entity" + modelModule "ragflow/internal/entity/models" + "ragflow/internal/service" + "ragflow/internal/service/kg" + "ragflow/internal/service/nlp" + + "github.com/gin-gonic/gin" +) + +// --- Interfaces (for testability) --- + +// KBServiceIface abstracts KnowledgebaseService for the Dify handler. +type KBServiceIface interface { + GetByID(kbID string) (*entity.Knowledgebase, error) + Accessible(kbID, userID string) bool +} + +// ModelServiceIface abstracts ModelProviderService for the Dify handler. +type ModelServiceIface interface { + GetEmbeddingModel(tenantID, embdID string) (*modelModule.EmbeddingModel, error) + GetChatModel(tenantID, compositeModelName string) (*modelModule.ChatModel, error) +} + +// MetadataServiceIface abstracts MetadataService for the Dify handler. +type MetadataServiceIface interface { + GetFlattedMetaByKBs(kbIDs []string) (common.MetaData, error) + LabelQuestion(question string, kbs []*entity.Knowledgebase) map[string]float64 +} + +// RetrievalServiceIface abstracts RetrievalService for the Dify handler. +type RetrievalServiceIface interface { + Retrieval(ctx context.Context, req *nlp.RetrievalRequest) (*nlp.RetrievalResult, error) +} + +// DocumentDAOIface abstracts DocumentDAO for the Dify handler. +type DocumentDAOIface interface { + GetByIDs(ids []string) ([]*entity.Document, error) +} + +// --- Request / Response types --- + +// difyRetrievalRequest is the JSON body / query params for the Dify retrieval endpoint. +type difyRetrievalRequest struct { + KnowledgeID string `json:"knowledge_id" form:"knowledge_id"` + Query string `json:"query" form:"query"` + UseKG bool `json:"use_kg" form:"use_kg"` + RetrievalSetting *difyRetrievalSetting `json:"retrieval_setting"` + MetadataCondition *difyMetadataCondition `json:"metadata_condition"` +} + +type difyRetrievalSetting struct { + TopK *int `json:"top_k" form:"top_k"` + ScoreThreshold *float64 `json:"score_threshold" form:"score_threshold"` +} + +// difyCondition is a Dify-format metadata filter condition. +// Dify uses "name"/"comparison_operator" instead of MetaFilterCondition's "key"/"op". +type difyCondition struct { + Name string `json:"name"` + ComparisonOperator string `json:"comparison_operator"` + Value interface{} `json:"value"` +} + +type difyMetadataCondition struct { + Conditions []difyCondition `json:"conditions"` + Logic string `json:"logic"` +} + +// toMetaFilterConditions converts Dify-format conditions to internal MetaFilterConditions. +func (c difyMetadataCondition) toMetaFilterConditions() []service.MetaFilterCondition { + if len(c.Conditions) == 0 { + return nil + } + result := make([]service.MetaFilterCondition, len(c.Conditions)) + for i, dc := range c.Conditions { + v := "" + if dc.Value != nil { + v = fmt.Sprint(dc.Value) + } + result[i] = service.MetaFilterCondition{ + Key: dc.Name, + Op: dc.ComparisonOperator, + Value: v, + } + } + return result +} + +// difyRecord is one item in the response records array. +type difyRecord struct { + Content string `json:"content"` + Score float64 `json:"score"` + Title string `json:"title"` + Metadata map[string]interface{} `json:"metadata"` +} + +// --- Handler --- + +// DifyRetrievalHandler handles Dify-compatible retrieval requests. +type DifyRetrievalHandler struct { + kbSvc KBServiceIface + modelSvc ModelServiceIface + metadataSvc MetadataServiceIface + retrievalSvc RetrievalServiceIface + docDAO DocumentDAOIface + docEngine engine.DocEngine +} + +// NewDifyRetrievalHandler creates a new DifyRetrievalHandler. +// The KG pipeline is created inline when use_kg=true to avoid injecting +// a pipeline that depends on per-request model configuration. +func NewDifyRetrievalHandler( + kbSvc KBServiceIface, + modelSvc ModelServiceIface, + metadataSvc MetadataServiceIface, + retrievalSvc RetrievalServiceIface, + docDAO DocumentDAOIface, + docEngine engine.DocEngine, +) *DifyRetrievalHandler { + return &DifyRetrievalHandler{ + kbSvc: kbSvc, + modelSvc: modelSvc, + metadataSvc: metadataSvc, + retrievalSvc: retrievalSvc, + docDAO: docDAO, + docEngine: docEngine, + } +} + +// Retrieval handles POST/GET /api/v1/dify/retrieval. +// Matches Python: api/apps/restful_apis/dify_retrieval_api.py::retrieval() +func (h *DifyRetrievalHandler) Retrieval(c *gin.Context) { + user, errCode, errMsg := GetUser(c) + if errCode != common.CodeSuccess { + c.JSON(http.StatusUnauthorized, gin.H{"code": errCode, "message": errMsg}) + return + } + + var req difyRetrievalRequest + if c.Request.Method == http.MethodGet { + if err := c.ShouldBindQuery(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"code": common.CodeArgumentError, "message": "invalid query parameters"}) + return + } + // Manually extract top_k and score_threshold from query (flat params, not nested) + if v := c.Query("top_k"); v != "" { + if parsed, err := strconv.Atoi(v); err == nil { + if req.RetrievalSetting == nil { + req.RetrievalSetting = &difyRetrievalSetting{} + } + req.RetrievalSetting.TopK = &parsed + } + } + if v := c.Query("score_threshold"); v != "" { + if parsed, err := strconv.ParseFloat(v, 64); err == nil { + if req.RetrievalSetting == nil { + req.RetrievalSetting = &difyRetrievalSetting{} + } + req.RetrievalSetting.ScoreThreshold = &parsed + } + } + } else { + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"code": common.CodeArgumentError, "message": "invalid request body"}) + return + } + } + + if req.KnowledgeID == "" || req.Query == "" { + c.JSON(http.StatusBadRequest, gin.H{"code": common.CodeArgumentError, "message": "knowledge_id and query are required"}) + return + } + + kb, err := h.kbSvc.GetByID(req.KnowledgeID) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + c.JSON(http.StatusNotFound, gin.H{"code": common.CodeNotFound, "message": "Knowledgebase not found!"}) + } else { + c.JSON(http.StatusInternalServerError, gin.H{"code": common.CodeServerError, "message": "failed to query knowledgebase"}) + } + return + } + + if !h.kbSvc.Accessible(req.KnowledgeID, user.ID) { + c.JSON(http.StatusUnauthorized, gin.H{"code": common.CodeAuthenticationError, "message": "No authorization."}) + return + } + + // Parse retrieval options (nil means service uses defaults) + var topK *int + if req.RetrievalSetting != nil && req.RetrievalSetting.TopK != nil { + topK = req.RetrievalSetting.TopK + } + var scoreThreshold *float64 + if req.RetrievalSetting != nil && req.RetrievalSetting.ScoreThreshold != nil { + scoreThreshold = req.RetrievalSetting.ScoreThreshold + } + pageSize := 1024 + if topK != nil { + pageSize = *topK + } + + // Get embedding model + embModel, err := h.modelSvc.GetEmbeddingModel(kb.TenantID, kb.EmbdID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"code": common.CodeServerError, "message": fmt.Sprintf("failed to get embedding model: %v", err)}) + return + } + + // Metadata filter + metas, metaErr := h.metadataSvc.GetFlattedMetaByKBs([]string{req.KnowledgeID}) + docIDs := make([]string, 0) + if metaErr == nil && req.MetadataCondition != nil { + logic := req.MetadataCondition.Logic + if logic == "" { + logic = "and" + } + filteredIDs := service.ApplyMetaFilter(metas, req.MetadataCondition.toMetaFilterConditions(), logic) + docIDs = append(docIDs, filteredIDs...) + } + if len(docIDs) == 0 && req.MetadataCondition != nil { + docIDs = []string{"-999"} + } + + // Label question for rank features + kbs := []*entity.Knowledgebase{kb} + rankFeature := h.metadataSvc.LabelQuestion(req.Query, kbs) + + // Chunk retrieval + sr := &nlp.RetrievalRequest{ + Question: req.Query, + TenantIDs: []string{kb.TenantID}, + KbIDs: []string{req.KnowledgeID}, + DocIDs: docIDs, + Page: 1, + PageSize: pageSize, + Top: topK, + SimilarityThreshold: scoreThreshold, + EmbeddingModel: embModel, + } + if rankFeature != nil { + sr.RankFeature = &rankFeature + } + + result, err := h.retrievalSvc.Retrieval(c.Request.Context(), sr) + if err != nil { + if strings.Contains(err.Error(), "not_found") { + c.JSON(http.StatusNotFound, gin.H{"code": common.CodeNotFound, "message": "No chunk found! Check the chunk status please!"}) + return + } + c.JSON(http.StatusInternalServerError, gin.H{"code": common.CodeServerError, "message": err.Error()}) + return + } + + // Enrich with child chunks + chunks := nlp.RetrievalByChildren(result.Chunks, []string{kb.TenantID}, h.docEngine, c.Request.Context()) + + // KG retrieval (optional) + if req.UseKG { + chatModel, kgErr := h.modelSvc.GetChatModel(kb.TenantID, "") + if kgErr != nil { + common.Warn("KG retrieval: failed to get chat model", zap.String("kbID", req.KnowledgeID), zap.Error(kgErr)) + } else if chatModel != nil { + kgPipeline := kg.NewPipeline( + h.docEngine, + []string{req.KnowledgeID}, + []string{kb.TenantID}, + req.Query, + ) + kgPipeline.SetChatModel(chatModel) + kgPipeline.SetEmbModel(embModel) + if kgResult, kgErr := kgPipeline.Retrieval(c.Request.Context()); kgErr == nil { + if content, ok := kgResult["content_with_weight"].(string); ok && content != "" { + chunks = append([]map[string]interface{}{kgResult}, chunks...) + } + } + } + } + + // Collect doc IDs and fetch documents + docIDSet := make(map[string]struct{}) + for _, ch := range chunks { + if docID, ok := ch["doc_id"].(string); ok && docID != "" { + docIDSet[docID] = struct{}{} + } + } + allDocIDs := make([]string, 0, len(docIDSet)) + for id := range docIDSet { + allDocIDs = append(allDocIDs, id) + } + + docMap := make(map[string]*entity.Document) + if len(allDocIDs) > 0 { + docs, err := h.docDAO.GetByIDs(allDocIDs) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"code": common.CodeServerError, "message": fmt.Sprintf("failed to load documents: %v", err)}) + return + } + for _, d := range docs { + docMap[d.ID] = d + } + } + + // Build response + records := make([]difyRecord, 0, len(chunks)) + for _, ch := range chunks { + docID, _ := ch["doc_id"].(string) + doc := docMap[docID] + if doc == nil { + continue + } + + // Remove vector to reduce response size + delete(ch, "vector") + + meta := make(map[string]interface{}) + if doc.MetaFields != nil { + for k, v := range *doc.MetaFields { + meta[k] = v + } + } + meta["doc_id"] = docID + meta["document_id"] = docID + + score, _ := ch["similarity"].(float64) + title, _ := ch["docnm_kwd"].(string) + content, _ := ch["content_with_weight"].(string) + + records = append(records, difyRecord{ + Content: content, + Score: score, + Title: title, + Metadata: meta, + }) + } + + c.JSON(http.StatusOK, gin.H{"records": records}) +} + +// HealthCheck returns a simple health check response. +func (h *DifyRetrievalHandler) HealthCheck(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"code": 0, "data": true}) +} diff --git a/internal/handler/dify_retrieval_handler_test.go b/internal/handler/dify_retrieval_handler_test.go new file mode 100644 index 00000000000..912bb693ba9 --- /dev/null +++ b/internal/handler/dify_retrieval_handler_test.go @@ -0,0 +1,401 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package handler + +import ( + "context" + "encoding/json" + "errors" + "gorm.io/gorm" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "ragflow/internal/common" + "ragflow/internal/engine" + "ragflow/internal/entity" + modelModule "ragflow/internal/entity/models" + "ragflow/internal/service/nlp" + "ragflow/internal/engine/types" + + "github.com/gin-gonic/gin" +) + +// --- Mock implementations --- + +type mockKBService struct { + KBServiceIface + getByIDFn func(kbID string) (*entity.Knowledgebase, error) + accessibleFn func(kbID, userID string) bool +} + +func (m *mockKBService) GetByID(kbID string) (*entity.Knowledgebase, error) { + if m.getByIDFn != nil { + return m.getByIDFn(kbID) + } + return &entity.Knowledgebase{ + ID: kbID, TenantID: "tenant1", EmbdID: "text-embedding", + }, nil +} + +func (m *mockKBService) Accessible(kbID, userID string) bool { + if m.accessibleFn != nil { + return m.accessibleFn(kbID, userID) + } + return true +} + +type mockModelService struct { + ModelServiceIface + getEmbeddingFn func(tenantID, embdID string) (*modelModule.EmbeddingModel, error) + getChatModelFn func(tenantID, llmID string) (*modelModule.ChatModel, error) +} + +func (m *mockModelService) GetEmbeddingModel(tenantID, embdID string) (*modelModule.EmbeddingModel, error) { + if m.getEmbeddingFn != nil { + return m.getEmbeddingFn(tenantID, embdID) + } + return &modelModule.EmbeddingModel{}, nil +} + +func (m *mockModelService) GetChatModel(tenantID, llmID string) (*modelModule.ChatModel, error) { + if m.getChatModelFn != nil { + return m.getChatModelFn(tenantID, llmID) + } + return &modelModule.ChatModel{}, nil +} + +type mockMetadataService struct { + MetadataServiceIface + getFlattedMetaFn func(kbIDs []string) (common.MetaData, error) + labelQuestionFn func(question string, kbs []*entity.Knowledgebase) map[string]float64 +} + +func (m *mockMetadataService) GetFlattedMetaByKBs(kbIDs []string) (common.MetaData, error) { + if m.getFlattedMetaFn != nil { + return m.getFlattedMetaFn(kbIDs) + } + return common.MetaData{}, nil +} + +func (m *mockMetadataService) LabelQuestion(question string, kbs []*entity.Knowledgebase) map[string]float64 { + if m.labelQuestionFn != nil { + return m.labelQuestionFn(question, kbs) + } + return nil +} + +type mockRetrievalService struct { + RetrievalServiceIface + retrievalFn func(ctx context.Context, req *nlp.RetrievalRequest) (*nlp.RetrievalResult, error) +} + +func (m *mockRetrievalService) Retrieval(ctx context.Context, req *nlp.RetrievalRequest) (*nlp.RetrievalResult, error) { + if m.retrievalFn != nil { + return m.retrievalFn(ctx, req) + } + return &nlp.RetrievalResult{ + Chunks: []map[string]interface{}{ + {"doc_id": "doc1", "docnm_kwd": "Test Doc", "content_with_weight": "test content", "similarity": 0.85}, + }, + }, nil +} + +type mockDocDAO struct { + DocumentDAOIface + getByIDsFn func(ids []string) ([]*entity.Document, error) +} + +func (m *mockDocDAO) GetByIDs(ids []string) ([]*entity.Document, error) { + if m.getByIDsFn != nil { + return m.getByIDsFn(ids) + } + return []*entity.Document{ + {ID: "doc1", Name: strPtr("Test Doc"), MetaFields: &entity.JSONMap{"author": "Zhang San"}}, + }, nil +} + +// mockDocEngine stubs the DocEngine interface (embed = panic on unimplemented). +type mockDocEngine struct { + engine.DocEngine +} + +func (m *mockDocEngine) Close() error { return nil } +func (m *mockDocEngine) Ping(ctx context.Context) error { return nil } +func (m *mockDocEngine) GetType() string { return "mock" } + func (m *mockDocEngine) Search(ctx context.Context, req *types.SearchRequest) (*types.SearchResult, error) { + return &types.SearchResult{}, nil + } +func (m *mockDocEngine) GetChunk(ctx context.Context, _, _ string, _ []string) (interface{}, error) { + return map[string]interface{}{}, nil +} + +// --- Helper --- + +func setupDifyTest(userID string) (*DifyRetrievalHandler, *gin.Engine) { + h := &DifyRetrievalHandler{ + kbSvc: &mockKBService{}, + modelSvc: &mockModelService{}, + metadataSvc: &mockMetadataService{}, + retrievalSvc: &mockRetrievalService{}, + docDAO: &mockDocDAO{}, + docEngine: &mockDocEngine{}, + } + + gin.SetMode(gin.TestMode) + r := gin.New() + r.Use(func(c *gin.Context) { + c.Set("user", &entity.User{ID: userID}) + }) + r.POST("/api/v1/dify/retrieval", h.Retrieval) + r.GET("/api/v1/dify/retrieval", h.Retrieval) + r.GET("/api/v1/dify/retrieval/health", h.HealthCheck) + return h, r +} + +func setupDifyTestNoAuth() (*DifyRetrievalHandler, *gin.Engine) { + h := &DifyRetrievalHandler{ + kbSvc: &mockKBService{}, + modelSvc: &mockModelService{}, + metadataSvc: &mockMetadataService{}, + retrievalSvc: &mockRetrievalService{}, + docDAO: &mockDocDAO{}, + docEngine: &mockDocEngine{}, + } + gin.SetMode(gin.TestMode) + r := gin.New() + r.POST("/api/v1/dify/retrieval", h.Retrieval) + return h, r +} + +// --- Tests --- + +func TestDifyRetrieval_HealthCheck(t *testing.T) { + _, r := setupDifyTest("user1") + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/api/v1/dify/retrieval/health", nil) + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected 200, got %d", w.Code) + } + var resp map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatal(err) + } + if resp["data"] != true { + t.Errorf("expected data=true, got %v", resp["data"]) + } +} + +func TestDifyRetrieval_Basic(t *testing.T) { + _, r := setupDifyTest("user1") + body := `{"knowledge_id": "kb1", "query": "test question"}` + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/api/v1/dify/retrieval", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + var resp map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatal(err) + } + records, ok := resp["records"].([]interface{}) + if !ok || len(records) == 0 { + t.Errorf("expected non-empty records, got %v", resp["records"]) + } +} + +func TestDifyRetrieval_GET(t *testing.T) { + _, r := setupDifyTest("user1") + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/api/v1/dify/retrieval?knowledge_id=kb1&query=test", nil) + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestDifyRetrieval_MissingArgs(t *testing.T) { + _, r := setupDifyTest("user1") + tests := []struct { + name string + body string + }{ + {"no knowledge_id", `{"query": "test"}`}, + {"no query", `{"knowledge_id": "kb1"}`}, + {"empty body", `{}`}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/api/v1/dify/retrieval", strings.NewReader(tc.body)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400, got %d", w.Code) + } + }) + } +} + +func TestDifyRetrieval_KBNotFound(t *testing.T) { + h, r := setupDifyTest("user1") + h.kbSvc = &mockKBService{ + getByIDFn: func(kbID string) (*entity.Knowledgebase, error) { + return nil, gorm.ErrRecordNotFound + }, + } + w := httptest.NewRecorder() + body := `{"knowledge_id": "nonexistent", "query": "test"}` + req, _ := http.NewRequest("POST", "/api/v1/dify/retrieval", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + if w.Code != http.StatusNotFound { + t.Errorf("expected 404, got %d", w.Code) + } +} + +func TestDifyRetrieval_NoAuth(t *testing.T) { + _, r := setupDifyTestNoAuth() + w := httptest.NewRecorder() + body := `{"knowledge_id": "kb1", "query": "test"}` + req, _ := http.NewRequest("POST", "/api/v1/dify/retrieval", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + if w.Code != http.StatusUnauthorized { + t.Errorf("expected 401, got %d", w.Code) + } +} + +func TestDifyRetrieval_Unauthorized(t *testing.T) { + h, r := setupDifyTest("user1") + h.kbSvc = &mockKBService{ + accessibleFn: func(kbID, userID string) bool { return false }, + } + w := httptest.NewRecorder() + body := `{"knowledge_id": "kb1", "query": "test"}` + req, _ := http.NewRequest("POST", "/api/v1/dify/retrieval", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + if w.Code != http.StatusUnauthorized { + t.Errorf("expected 401, got %d", w.Code) + } +} + +func TestDifyRetrieval_WithMetadataFilter(t *testing.T) { + h, r := setupDifyTest("user1") + h.metadataSvc = &mockMetadataService{ + getFlattedMetaFn: func(kbIDs []string) (common.MetaData, error) { + return common.MetaData{}, nil + }, + } + body := `{"knowledge_id":"kb1","query":"test","metadata_condition":{"conditions":[{"name":"author","comparison_operator":"eq","value":"Zhang San"}],"logic":"and"}}` + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/api/v1/dify/retrieval", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestDifyRetrieval_InvalidJSON(t *testing.T) { + _, r := setupDifyTest("user1") + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/api/v1/dify/retrieval", strings.NewReader("{invalid json")) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400, got %d", w.Code) + } +} + +func TestDifyRetrieval_UseKG(t *testing.T) { + h, r := setupDifyTest("user1") + h.metadataSvc = &mockMetadataService{ + labelQuestionFn: func(question string, kbs []*entity.Knowledgebase) map[string]float64 { + return map[string]float64{"tag_1": 0.8} + }, + } + body := `{"knowledge_id":"kb1","query":"test","use_kg":true}` + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/api/v1/dify/retrieval", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String()) + } +} + +func strPtr(s string) *string { return &s } + +func TestDifyRetrieval_KBDBError(t *testing.T) { + h, r := setupDifyTest("user1") + h.kbSvc = &mockKBService{ + getByIDFn: func(kbID string) (*entity.Knowledgebase, error) { + return nil, errors.New("connection refused") + }, + } + w := httptest.NewRecorder() + body := `{"knowledge_id": "kb1", "query": "test"}` + req, _ := http.NewRequest("POST", "/api/v1/dify/retrieval", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + if w.Code != http.StatusInternalServerError { + t.Errorf("expected 500 for DB error, got %d", w.Code) + } +} + +func TestDifyRetrieval_DocLoadError(t *testing.T) { + h, r := setupDifyTest("user1") + h.docDAO = &mockDocDAO{ + getByIDsFn: func(ids []string) ([]*entity.Document, error) { + return nil, errors.New("db unavailable") + }, + } + w := httptest.NewRecorder() + body := `{"knowledge_id": "kb1", "query": "test"}` + req, _ := http.NewRequest("POST", "/api/v1/dify/retrieval", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + if w.Code != http.StatusInternalServerError { + t.Errorf("expected 500 for doc load error, got %d", w.Code) + } +} + +func TestDifyRetrieval_RetrievalNotFound(t *testing.T) { + h, r := setupDifyTest("user1") + h.retrievalSvc = &mockRetrievalService{ + retrievalFn: func(ctx context.Context, req *nlp.RetrievalRequest) (*nlp.RetrievalResult, error) { + return nil, errors.New("no chunk found: not_found") + }, + } + w := httptest.NewRecorder() + body := `{"knowledge_id": "kb1", "query": "test"}` + req, _ := http.NewRequest("POST", "/api/v1/dify/retrieval", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + if w.Code != http.StatusNotFound { + t.Errorf("expected 404 for not_found, got %d", w.Code) + } +} diff --git a/internal/router/router.go b/internal/router/router.go index 765a8a2aa79..77d7ebdf2be 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -43,6 +43,7 @@ type Router struct { providerHandler *handler.ProviderHandler agentHandler *handler.AgentHandler relatedQuestionsHandler *handler.SearchbotHandler + difyRetrievalHandler *handler.DifyRetrievalHandler } // NewRouter create router @@ -67,6 +68,7 @@ func NewRouter( providerHandler *handler.ProviderHandler, agentHandler *handler.AgentHandler, relatedQuestionsHandler *handler.SearchbotHandler, + difyRetrievalHandler *handler.DifyRetrievalHandler, ) *Router { return &Router{ authHandler: authHandler, @@ -89,6 +91,7 @@ func NewRouter( providerHandler: providerHandler, agentHandler: agentHandler, relatedQuestionsHandler: relatedQuestionsHandler, + difyRetrievalHandler: difyRetrievalHandler, } } @@ -517,6 +520,14 @@ func (r *Router) Setup(engine *gin.Engine) { } + // Dify retrieval routes + dify := authorized.Group("/api/v1/dify") + { + dify.POST("/retrieval", r.difyRetrievalHandler.Retrieval) + dify.GET("/retrieval", r.difyRetrievalHandler.Retrieval) + } + apiNoAuth.GET("/dify/retrieval/health", r.difyRetrievalHandler.HealthCheck) + // Handle undefined routes engine.NoRoute(handler.HandleNoRoute) } diff --git a/internal/service/kg_pipeline.go b/internal/service/kg/pipeline.go similarity index 79% rename from internal/service/kg_pipeline.go rename to internal/service/kg/pipeline.go index fd2a01835f2..3dc72679077 100644 --- a/internal/service/kg_pipeline.go +++ b/internal/service/kg/pipeline.go @@ -14,7 +14,7 @@ // limitations under the License. // -package service +package kg import ( "context" @@ -30,9 +30,9 @@ import ( "go.uber.org/zap" ) -// KGSearchPipeline encapsulates the knowledge graph retrieval pipeline. +// Pipeline encapsulates the knowledge graph retrieval pipeline. // Matches Python: rag/graphrag/search.py::KGSearch -type KGSearchPipeline struct { +type Pipeline struct { docEngine engine.DocEngine chatModel *modelModule.ChatModel embModel *modelModule.EmbeddingModel @@ -50,51 +50,51 @@ type KGSearchPipeline struct { maxToken int } -// KGSearchOption configures a KGSearchPipeline. -type KGSearchOption func(*KGSearchPipeline) +// Option configures a Pipeline. +type Option func(*Pipeline) -// WithKGSimThreshold sets the similarity threshold for entity and relation search. +// WithSimThreshold sets the similarity threshold for entity and relation search. // Default: 0.3 (matches Python ent_sim_threshold, rel_sim_threshold). -func WithKGSimThreshold(v float64) KGSearchOption { - return func(p *KGSearchPipeline) { p.entSimThreshold = v; p.relSimThreshold = v } +func WithSimThreshold(v float64) Option { + return func(p *Pipeline) { p.entSimThreshold = v; p.relSimThreshold = v } } -// WithKGDenseTopK sets the TopK for dense vector search. +// WithDenseTopK sets the TopK for dense vector search. // Default: 1024 (matches Python get_vector topk). -func WithKGDenseTopK(v int) KGSearchOption { - return func(p *KGSearchPipeline) { p.denseTopK = v } +func WithDenseTopK(v int) Option { + return func(p *Pipeline) { p.denseTopK = v } } -// NewKGSearchPipeline creates a KG search pipeline with the given dependencies. +// NewPipeline creates a KG search pipeline with the given dependencies. // // docEngine: search engine backend // kbIDs: knowledge base IDs to search // tenantIDs: tenant IDs (converted to index names internally) // question: user query string -// opts: optional configuration (WithKGSimThreshold, WithKGDenseTopK) +// opts: optional configuration (WithSimThreshold, WithDenseTopK) // // chatModel and embModel should be set via WithChatModel/WithEmbModel setters // or passed directly after construction. -func NewKGSearchPipeline( +func NewPipeline( docEngine engine.DocEngine, kbIDs []string, tenantIDs []string, question string, - opts ...KGSearchOption, -) *KGSearchPipeline { + opts ...Option, +) *Pipeline { idxnms := make([]string, len(tenantIDs)) for i, tid := range tenantIDs { idxnms[i] = indexName(tid) } - p := &KGSearchPipeline{ + p := &Pipeline{ docEngine: docEngine, kbIDs: kbIDs, idxnms: idxnms, question: question, - entSimThreshold: defaultKGSimThreshold, - relSimThreshold: defaultKGSimThreshold, - denseTopK: defaultKGDenseTopK, + entSimThreshold: defaultSimThreshold, + relSimThreshold: defaultSimThreshold, + denseTopK: defaultDenseTopK, entTopN: 6, relTopN: 6, commTopN: 1, @@ -106,12 +106,22 @@ func NewKGSearchPipeline( return p } +// SetChatModel sets the chat model for LLM-based query rewrite. +func (p *Pipeline) SetChatModel(chatModel *modelModule.ChatModel) { + p.chatModel = chatModel +} + +// SetEmbModel sets the embedding model for dense/hybrid search. +func (p *Pipeline) SetEmbModel(embModel *modelModule.EmbeddingModel) { + p.embModel = embModel +} + // Retrieval runs the full KG retrieval pipeline and returns a synthetic chunk. -func (p *KGSearchPipeline) Retrieval(ctx context.Context) (map[string]interface{}, error) { +func (p *Pipeline) Retrieval(ctx context.Context) (map[string]interface{}, error) { // 1. Query rewrite via LLM, or fall back to raw question ty2entsJSON := "" if p.chatModel != nil { - typeSamples, err := searchKGTypeSamples(ctx, p.docEngine, p.idxnms, p.kbIDs) + typeSamples, err := searchTypeSamples(ctx, p.docEngine, p.idxnms, p.kbIDs) if err != nil { common.Warn("KG type samples search failed", zap.String("kbIDs", fmt.Sprint(p.kbIDs))) } @@ -157,11 +167,11 @@ func (p *KGSearchPipeline) Retrieval(ctx context.Context) (map[string]interface{ scoredRels := SortAndTrimRelations(relsFromText, p.relTopN) // 7. Build KG content with token budget - entsRelsContent := BuildKGContent(scoredEnts, scoredRels, p.maxToken) + entsRelsContent := BuildContent(scoredEnts, scoredRels, p.maxToken) used := NumTokensFromString(entsRelsContent) remaining := p.maxToken - used // 8. Search community reports with remaining token budget - communityContent := searchKGCommunityContent(ctx, p.docEngine, p.idxnms, p.kbIDs, scoredEnts, p.commTopN, &remaining) + communityContent := searchCommunityContent(ctx, p.docEngine, p.idxnms, p.kbIDs, scoredEnts, p.commTopN, &remaining) // 9. Build synthetic chunk return map[string]interface{}{ @@ -182,7 +192,7 @@ func (p *KGSearchPipeline) Retrieval(ctx context.Context) (map[string]interface{ } // searchEntities searches KG entities by keyword text and optional dense vector. -func (p *KGSearchPipeline) searchEntities(ctx context.Context, entities []string) (map[string]*KGEntity, error) { +func (p *Pipeline) searchEntities(ctx context.Context, entities []string) (map[string]*KGEntity, error) { entsReq := &types.SearchRequest{ IndexNames: p.idxnms, KbIDs: p.kbIDs, @@ -207,14 +217,14 @@ func (p *KGSearchPipeline) searchEntities(ctx context.Context, entities []string if name == "" { continue } - e := kgEntityFromChunk(name, chunk) + e := entityFromChunk(name, chunk) result[name] = &e } return result, nil } // searchEntityTypes searches KG entities by type keywords. -func (p *KGSearchPipeline) searchEntityTypes(ctx context.Context, typeKeywords []string) map[string]struct{} { +func (p *Pipeline) searchEntityTypes(ctx context.Context, typeKeywords []string) map[string]struct{} { typesReq := &types.SearchRequest{ IndexNames: p.idxnms, KbIDs: p.kbIDs, @@ -244,7 +254,7 @@ func (p *KGSearchPipeline) searchEntityTypes(ctx context.Context, typeKeywords [ } // searchRelations searches KG relations by entity text and optional dense vector. -func (p *KGSearchPipeline) searchRelations(ctx context.Context, entities []string) map[Edge]*KGRelation { +func (p *Pipeline) searchRelations(ctx context.Context, entities []string) map[Edge]*KGRelation { relsReq := &types.SearchRequest{ IndexNames: p.idxnms, KbIDs: p.kbIDs, @@ -265,7 +275,7 @@ func (p *KGSearchPipeline) searchRelations(ctx context.Context, entities []strin common.Warn("KG relations search failed", zap.String("kbIDs", fmt.Sprint(p.kbIDs))) } else { for _, chunk := range FilterChunksByScore(relsResult.Chunks, p.relSimThreshold) { - edge, rel := kgRelationFromChunk(chunk) + edge, rel := relationFromChunk(chunk) if edge.From == "" || edge.To == "" { continue } diff --git a/internal/service/kg_retrieval.go b/internal/service/kg/retrieval.go similarity index 74% rename from internal/service/kg_retrieval.go rename to internal/service/kg/retrieval.go index 000f21a7831..3fcf357948c 100644 --- a/internal/service/kg_retrieval.go +++ b/internal/service/kg/retrieval.go @@ -1,20 +1,4 @@ -// -// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -package service +package kg import ( "context" @@ -27,67 +11,9 @@ import ( modelModule "ragflow/internal/entity/models" ) -// indexName builds the search index name from a tenant ID. -// Matches Python: rag/nlp/search.py::index_name() -func indexName(tenantID string) string { - return "ragflow_" + tenantID -} - -// Python alignment defaults — match rag/graphrag/search.py retrieval() params -const ( - defaultKGSimThreshold = 0.3 // Python: ent_sim_threshold, rel_sim_threshold - defaultKGDenseTopK = 1024 // Python: get_vector() topk -) - -// kgEntityFromChunk parses a single entity chunk into a KGEntity. -func kgEntityFromChunk(name string, chunk map[string]interface{}) KGEntity { - e := KGEntity{} - if v, ok := chunk["_score"].(float64); ok { - e.Similarity = v - } else if v, ok := chunk["score"].(float64); ok { - e.Similarity = v - } - if v, ok := chunk["rank_flt"].(float64); ok { - e.PageRank = v - } - e.Description, _ = chunk["content_with_weight"].(string) - if raw, ok := chunk["n_hop_with_weight"].(string); ok && raw != "" { - var nhopData []struct { - Path []string `json:"path"` - Weights []float64 `json:"weights"` - } - if err := json.Unmarshal([]byte(raw), &nhopData); err == nil { - for _, item := range nhopData { - e.NhopEnts = append(e.NhopEnts, NhopEntity{ - Path: item.Path, - Weights: item.Weights, - }) - } - } - } - return e -} - -// kgRelationFromChunk parses a single relation chunk into a KGRelation. -func kgRelationFromChunk(chunk map[string]interface{}) (Edge, KGRelation) { - r := KGRelation{} - r.Description, _ = chunk["content_with_weight"].(string) - if v, ok := chunk["weight_int"].(float64); ok { - r.PageRank = float64(v) - } else if v, ok := chunk["weight_int"].(int); ok { - r.PageRank = float64(v) - } - from, _ := chunk["from_entity_kwd"].(string) - to, _ := chunk["to_entity_kwd"].(string) - return Edge{From: from, To: to}, r -} - -// KGSearchRetrieval performs a full knowledge graph retrieval and returns -// a synthetic chunk to be inserted into search results. -// Corresponds to Python: rag/graphrag/search.py::KGSearch.retrieval() -// -// This is a convenience wrapper around KGSearchPipeline. -func KGSearchRetrieval( +// Retrieval performs a full knowledge graph retrieval and returns +// a synthetic chunk. Convenience wrapper around Pipeline. +func Retrieval( ctx context.Context, docEngine engine.DocEngine, chatModel *modelModule.ChatModel, @@ -96,16 +22,16 @@ func KGSearchRetrieval( tenantIDs []string, question string, ) (map[string]interface{}, error) { - p := &KGSearchPipeline{ + p := &Pipeline{ docEngine: docEngine, chatModel: chatModel, embModel: embModel, kbIDs: kbIDs, idxnms: makeIndexNames(tenantIDs), question: question, - entSimThreshold: defaultKGSimThreshold, - relSimThreshold: defaultKGSimThreshold, - denseTopK: defaultKGDenseTopK, + entSimThreshold: defaultSimThreshold, + relSimThreshold: defaultSimThreshold, + denseTopK: defaultDenseTopK, entTopN: 6, relTopN: 6, commTopN: 1, @@ -123,8 +49,13 @@ func makeIndexNames(tenantIDs []string) []string { return idxnms } -// searchKGTypeSamples searches for ty2ents data. -func searchKGTypeSamples(ctx context.Context, docEngine engine.DocEngine, idxnms []string, kbIDs []string) (map[string][]string, error) { +// indexName builds the search index name from a tenant ID. +func indexName(tenantID string) string { + return "ragflow_" + tenantID +} + +// searchTypeSamples searches for ty2ents data. +func searchTypeSamples(ctx context.Context, docEngine engine.DocEngine, idxnms []string, kbIDs []string) (map[string][]string, error) { req := &types.SearchRequest{ IndexNames: idxnms, KbIDs: kbIDs, @@ -153,8 +84,8 @@ func searchKGTypeSamples(ctx context.Context, docEngine engine.DocEngine, idxnms return typeMap, nil } -// searchKGCommunityContent searches for community reports and formats them. -func searchKGCommunityContent(ctx context.Context, docEngine engine.DocEngine, idxnms []string, kbIDs []string, scoredEnts []ScoredEntity, topN int, maxToken *int) string { +// searchCommunityContent searches for community reports and formats them. +func searchCommunityContent(ctx context.Context, docEngine engine.DocEngine, idxnms []string, kbIDs []string, scoredEnts []ScoredEntity, topN int, maxToken *int) string { if maxToken == nil || len(scoredEnts) == 0 || *maxToken <= 0 { return "" } @@ -189,7 +120,6 @@ func searchKGCommunityContent(ctx context.Context, docEngine engine.DocEngine, i if title == "" && raw == "" { continue } - // Parse JSON for nested report/evidences fields (Python: json.loads) report := raw evidence := "" var parsed map[string]interface{} @@ -212,8 +142,73 @@ func searchKGCommunityContent(ctx context.Context, docEngine engine.DocEngine, i return bld } +// entityFromChunk parses a single entity chunk into a KGEntity. +func entityFromChunk(name string, chunk map[string]interface{}) KGEntity { + e := KGEntity{} + if v, ok := chunk["_score"].(float64); ok { + e.Similarity = v + } else if v, ok := chunk["score"].(float64); ok { + e.Similarity = v + } + if v, ok := chunk["rank_flt"].(float64); ok { + e.PageRank = v + } + e.Description, _ = chunk["content_with_weight"].(string) + if raw, ok := chunk["n_hop_with_weight"].(string); ok && raw != "" { + var nhopData []struct { + Path []string `json:"path"` + Weights []float64 `json:"weights"` + } + if err := json.Unmarshal([]byte(raw), &nhopData); err == nil { + for _, item := range nhopData { + e.NhopEnts = append(e.NhopEnts, NhopEntity{ + Path: item.Path, + Weights: item.Weights, + }) + } + } + } + return e +} + +// relationFromChunk parses a single relation chunk into a KGRelation. +func relationFromChunk(chunk map[string]interface{}) (Edge, KGRelation) { + r := KGRelation{} + r.Description, _ = chunk["content_with_weight"].(string) + if v, ok := chunk["_score"].(float64); ok { + r.Sim = v + } else if v, ok := chunk["score"].(float64); ok { + r.Sim = v + } + if v, ok := chunk["weight_int"].(float64); ok { + r.PageRank = float64(v) + } else if v, ok := chunk["weight_int"].(int); ok { + r.PageRank = float64(v) + } + from, _ := chunk["from_entity_kwd"].(string) + to, _ := chunk["to_entity_kwd"].(string) + return Edge{From: from, To: to}, r +} + +// buildSearchExprs constructs MatchExprs for KG entity/relation search. +// When embModel is nil, returns text-only match expression. +// When embModel is non-nil, embeds the question and returns hybrid +// (text + dense + fusion) expressions for vector+keyword search. +func buildSearchExprs(embModel *modelModule.EmbeddingModel, matchText *types.MatchTextExpr, simThreshold float64, denseTopK int) []interface{} { + if embModel == nil || embModel.ModelDriver == nil { + return []interface{}{matchText} + } + embeddingConfig := &modelModule.EmbeddingConfig{Dimension: 0} + embeddings, err := embModel.ModelDriver.Embed(embModel.ModelName, []string{matchText.MatchingText}, embModel.APIConfig, embeddingConfig) + if err != nil || len(embeddings) == 0 { + return []interface{}{matchText} + } + denseExpr := buildMatchDenseExpr(embeddings[0].Embedding, denseTopK, simThreshold) + fusionExpr := buildFusionExpr(defaultTextWeight, defaultVectorWeight, matchText.TopN) + return []interface{}{matchText, denseExpr, fusionExpr} +} + // buildMatchDenseExpr constructs a MatchDenseExpr from an embedding vector. -// This is a pure function — no I/O, no external dependencies. func buildMatchDenseExpr(vector []float64, topN int, similarity float64) *types.MatchDenseExpr { vectorColumnName := fmt.Sprintf("q_%d_vec", len(vector)) return &types.MatchDenseExpr{ @@ -227,7 +222,6 @@ func buildMatchDenseExpr(vector []float64, topN int, similarity float64) *types. } // buildFusionExpr constructs a FusionExpr for weighted-sum hybrid search. -// This is a pure function — no I/O, no external dependencies. func buildFusionExpr(textWeight, vectorWeight float64, topN int) *types.FusionExpr { return &types.FusionExpr{ Method: "weighted_sum", @@ -238,26 +232,7 @@ func buildFusionExpr(textWeight, vectorWeight float64, topN int) *types.FusionEx } } -// buildSearchExprs constructs MatchExprs for KG entity/relation search. -// When embModel is nil, returns text-only match expression. -// When embModel is non-nil, embeds the question and returns hybrid -// (text + dense + fusion) expressions for vector+keyword search. -func buildSearchExprs(embModel *modelModule.EmbeddingModel, matchText *types.MatchTextExpr, simThreshold float64, denseTopK int) []interface{} { - if embModel == nil || embModel.ModelDriver == nil { - return []interface{}{matchText} - } - embeddingConfig := &modelModule.EmbeddingConfig{Dimension: 0} - embeddings, err := embModel.ModelDriver.Embed(embModel.ModelName, []string{matchText.MatchingText}, embModel.APIConfig, embeddingConfig) - if err != nil || len(embeddings) == 0 { - return []interface{}{matchText} - } - denseExpr := buildMatchDenseExpr(embeddings[0].Embedding, denseTopK, simThreshold) - fusionExpr := buildFusionExpr(0.5, 0.5, matchText.TopN) - return []interface{}{matchText, denseExpr, fusionExpr} -} - // queryRewrite attempts LLM-based query rewrite, falling back to raw question. -// ty2entsJSON is the JSON-encoded type→entities mapping for prompt context. func queryRewrite(chatModel *modelModule.ChatModel, question string, ty2entsJSON string) (typeKeywords, entities []string) { if question == "" { return nil, nil @@ -276,6 +251,14 @@ func queryRewrite(chatModel *modelModule.ChatModel, question string, ty2entsJSON } } } - // Fallback: use raw question as single entity return nil, []string{question} } + +// Python alignment defaults +const ( + defaultSimThreshold = 0.3 + defaultDenseTopK = 1024 + // defaultTextWeight / defaultVectorWeight are fusion weights for hybrid search (equal by default). + defaultTextWeight = 0.5 + defaultVectorWeight = 0.5 +) diff --git a/internal/service/kg_retrieval_test.go b/internal/service/kg/retrieval_test.go similarity index 85% rename from internal/service/kg_retrieval_test.go rename to internal/service/kg/retrieval_test.go index 7e3989abd76..ca8bf36919f 100644 --- a/internal/service/kg_retrieval_test.go +++ b/internal/service/kg/retrieval_test.go @@ -14,7 +14,7 @@ // limitations under the License. // -package service +package kg import ( "context" @@ -56,16 +56,16 @@ func (m *mockRetrievalEngine) Search(ctx context.Context, req *types.SearchReque return &types.SearchResult{}, nil } -// --- kgEntityFromChunk --- +// --- entityFromChunk --- -func TestKgEntityFromChunk_Basic(t *testing.T) { +func TestEntityFromChunk_Basic(t *testing.T) { chunk := map[string]interface{}{ "_score": 0.85, "rank_flt": 0.9, "content_with_weight": "Founder of SpaceX", "n_hop_with_weight": `[{"path":["A","B"],"weights":[0.8]}]`, } - e := kgEntityFromChunk("Elon Musk", chunk) + e := entityFromChunk("Elon Musk", chunk) if e.Similarity != 0.85 { t.Errorf("expected Sim=0.85, got %f", e.Similarity) } @@ -80,32 +80,32 @@ func TestKgEntityFromChunk_Basic(t *testing.T) { } } -func TestKgEntityFromChunk_ScoreFallback(t *testing.T) { +func TestEntityFromChunk_ScoreFallback(t *testing.T) { chunk := map[string]interface{}{"score": 0.75} - e := kgEntityFromChunk("Test", chunk) + e := entityFromChunk("Test", chunk) if e.Similarity != 0.75 { t.Errorf("expected Sim=0.75 from score field, got %f", e.Similarity) } } -func TestKgEntityFromChunk_MissingFields(t *testing.T) { +func TestEntityFromChunk_MissingFields(t *testing.T) { chunk := map[string]interface{}{} - e := kgEntityFromChunk("Empty", chunk) + e := entityFromChunk("Empty", chunk) if e.Similarity != 0 || e.PageRank != 0 || len(e.NhopEnts) != 0 { t.Errorf("expected zero defaults, got %+v", e) } } -// --- kgRelationFromChunk --- +// --- relationFromChunk --- -func TestKgRelationFromChunk_Basic(t *testing.T) { +func TestRelationFromChunk_Basic(t *testing.T) { chunk := map[string]interface{}{ "from_entity_kwd": "Elon Musk", "to_entity_kwd": "SpaceX", "weight_int": float64(5), "content_with_weight": "Founder", } - edge, rel := kgRelationFromChunk(chunk) + edge, rel := relationFromChunk(chunk) if edge.From != "Elon Musk" || edge.To != "SpaceX" { t.Errorf("expected Elon Musk→SpaceX, got %v", edge) } @@ -114,17 +114,17 @@ func TestKgRelationFromChunk_Basic(t *testing.T) { } } -func TestKgRelationFromChunk_MissingFrom(t *testing.T) { +func TestRelationFromChunk_MissingFrom(t *testing.T) { chunk := map[string]interface{}{"to_entity_kwd": "B"} - edge, _ := kgRelationFromChunk(chunk) + edge, _ := relationFromChunk(chunk) if edge.From != "" { t.Error("expected empty from") } } -// --- searchKGTypeSamples --- +// --- searchTypeSamples --- -func TestSearchKGTypeSamples_Success(t *testing.T) { +func TestSearchTypeSamples_Success(t *testing.T) { data, _ := json.Marshal(map[string][]string{"PERSON": {"Elon Musk"}}) mock := &mockRetrievalEngine{ results: map[string]*types.SearchResult{ @@ -133,7 +133,7 @@ func TestSearchKGTypeSamples_Success(t *testing.T) { }}, }, } - result, err := searchKGTypeSamples(context.Background(), mock, []string{"ragflow_tenant1"}, []string{"kb1"}) + result, err := searchTypeSamples(context.Background(), mock, []string{"ragflow_tenant1"}, []string{"kb1"}) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -142,9 +142,9 @@ func TestSearchKGTypeSamples_Success(t *testing.T) { } } -func TestSearchKGTypeSamples_Empty(t *testing.T) { +func TestSearchTypeSamples_Empty(t *testing.T) { mock := &mockRetrievalEngine{} - result, err := searchKGTypeSamples(context.Background(), mock, []string{"ragflow_tenant1"}, []string{"kb1"}) + result, err := searchTypeSamples(context.Background(), mock, []string{"ragflow_tenant1"}, []string{"kb1"}) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -153,9 +153,9 @@ func TestSearchKGTypeSamples_Empty(t *testing.T) { } } -// --- KGSearchRetrieval --- +// --- Retrieval --- -func TestKGSearchRetrieval_Basic(t *testing.T) { +func TestRetrieval_Basic(t *testing.T) { mock := &mockRetrievalEngine{ results: map[string]*types.SearchResult{ "entity": {Chunks: []map[string]interface{}{ @@ -172,9 +172,9 @@ func TestKGSearchRetrieval_Basic(t *testing.T) { }}, }, } - result, err := KGSearchRetrieval(context.Background(), mock, nil, nil, []string{"kb1"}, []string{"tenant1"}, "Elon Musk") + result, err := Retrieval(context.Background(), mock, nil, nil, []string{"kb1"}, []string{"tenant1"}, "Elon Musk") if err != nil { - t.Fatalf("KGSearchRetrieval failed: %v", err) + t.Fatalf("Retrieval failed: %v", err) } if result == nil { t.Fatal("expected non-nil result") @@ -194,11 +194,11 @@ func TestKGSearchRetrieval_Basic(t *testing.T) { } } -func TestKGSearchRetrieval_NoEntities(t *testing.T) { +func TestRetrieval_NoEntities(t *testing.T) { mock := &mockRetrievalEngine{} - result, err := KGSearchRetrieval(context.Background(), mock, nil, nil, []string{"kb1"}, []string{"tenant1"}, "test") + result, err := Retrieval(context.Background(), mock, nil, nil, []string{"kb1"}, []string{"tenant1"}, "test") if err != nil { - t.Fatalf("KGSearchRetrieval failed: %v", err) + t.Fatalf("Retrieval failed: %v", err) } if result == nil { t.Fatal("expected non-nil result") @@ -211,7 +211,7 @@ func TestKGSearchRetrieval_NoEntities(t *testing.T) { // TestEntitySearch_MultiEntities verifies that all entities are used in search query. -func TestKGSearchRetrieval_WithChatModel(t *testing.T) { +func TestRetrieval_WithChatModel(t *testing.T) { mock := &mockRetrievalEngine{ results: map[string]*types.SearchResult{ "entity": {Chunks: []map[string]interface{}{ @@ -225,9 +225,9 @@ func TestKGSearchRetrieval_WithChatModel(t *testing.T) { // chatModel with nil ModelName so queryRewrite falls back to raw question, // but the ty2entsJSON construction path is still exercised. chatModel := &modelModule.ChatModel{ModelName: nil, APIConfig: nil} - result, err := KGSearchRetrieval(context.Background(), mock, chatModel, nil, []string{"kb1"}, []string{"tenant1"}, "Elon Musk") + result, err := Retrieval(context.Background(), mock, chatModel, nil, []string{"kb1"}, []string{"tenant1"}, "Elon Musk") if err != nil { - t.Fatalf("KGSearchRetrieval failed: %v", err) + t.Fatalf("Retrieval failed: %v", err) } if result == nil { t.Fatal("expected non-nil result") @@ -415,7 +415,7 @@ func TestBuildSearchExprs_WithEmbModel(t *testing.T) { MatchingText: "Elon Musk SpaceX", TopN: 50, } - exprs := buildSearchExprs(embModel, matchText, defaultKGSimThreshold, defaultKGDenseTopK) + exprs := buildSearchExprs(embModel, matchText, defaultSimThreshold, defaultDenseTopK) // Verify Embed was called with matchText.MatchingText, not raw question if len(driver.capturedTexts) != 1 || driver.capturedTexts[0] != "Elon Musk SpaceX" { t.Errorf("expected Embed to receive %q, got %v", "Elon Musk SpaceX", driver.capturedTexts) @@ -439,11 +439,11 @@ func TestBuildSearchExprs_WithEmbModel(t *testing.T) { if md.VectorColumnName != "q_3_vec" { t.Errorf("expected q_3_vec, got %q", md.VectorColumnName) } - if md.TopN != defaultKGDenseTopK { - t.Errorf("expected TopN=%d (Python alignment), got %d", defaultKGDenseTopK, md.TopN) + if md.TopN != defaultDenseTopK { + t.Errorf("expected TopN=%d (Python alignment), got %d", defaultDenseTopK, md.TopN) } - if md.ExtraOptions["similarity"] != defaultKGSimThreshold { - t.Errorf("expected similarity=%v (Python alignment), got %v", defaultKGSimThreshold, md.ExtraOptions["similarity"]) + if md.ExtraOptions["similarity"] != defaultSimThreshold { + t.Errorf("expected similarity=%v (Python alignment), got %v", defaultSimThreshold, md.ExtraOptions["similarity"]) } // Index 2: FusionExpr fu, ok := exprs[2].(*types.FusionExpr) @@ -463,7 +463,7 @@ func TestBuildSearchExprs_EmbModelFallback(t *testing.T) { MatchingText: "fallback test", TopN: 10, } - exprs := buildSearchExprs(embModel, matchText, defaultKGSimThreshold, defaultKGDenseTopK) + exprs := buildSearchExprs(embModel, matchText, defaultSimThreshold, defaultDenseTopK) // Should fall back to text-only when Embed fails if len(exprs) != 1 { t.Fatalf("expected 1 expr (text-only fallback), got %d", len(exprs)) @@ -476,11 +476,11 @@ func TestBuildSearchExprs_EmbModelFallback(t *testing.T) { // --- Python alignment defaults --- func TestDefaultValuesMatchPython(t *testing.T) { - if defaultKGSimThreshold != 0.3 { - t.Errorf("expected 0.3 (Python ent_sim_threshold), got %f", defaultKGSimThreshold) + if defaultSimThreshold != 0.3 { + t.Errorf("expected 0.3 (Python ent_sim_threshold), got %f", defaultSimThreshold) } - if defaultKGDenseTopK != 1024 { - t.Errorf("expected 1024 (Python get_vector topk), got %d", defaultKGDenseTopK) + if defaultDenseTopK != 1024 { + t.Errorf("expected 1024 (Python get_vector topk), got %d", defaultDenseTopK) } } @@ -506,11 +506,11 @@ func TestIndexName_Empty(t *testing.T) { } } -// --- searchKGCommunityContent --- +// --- searchCommunityContent --- func TestSearchKGCommunityContent_EmptyEntities(t *testing.T) { mock := &mockRetrievalEngine{} - result := searchKGCommunityContent(context.Background(), mock, []string{"ragflow_t1"}, []string{"kb1"}, nil, 1, intPtr(100)) + result := searchCommunityContent(context.Background(), mock, []string{"ragflow_t1"}, []string{"kb1"}, nil, 1, intPtr(100)) if result != "" { t.Errorf("expected empty, got %q", result) } @@ -527,7 +527,7 @@ func TestSearchKGCommunityContent_WithContent(t *testing.T) { }}, }, } - result := searchKGCommunityContent(context.Background(), mock, []string{"ragflow_t1"}, []string{"kb1"}, []ScoredEntity{{Entity: "E1"}}, 1, intPtr(500)) + result := searchCommunityContent(context.Background(), mock, []string{"ragflow_t1"}, []string{"kb1"}, []ScoredEntity{{Entity: "E1"}}, 1, intPtr(500)) if result == "" { t.Fatal("expected non-empty result") } @@ -547,7 +547,7 @@ func TestSearchKGCommunityContent_WithContent(t *testing.T) { func TestSearchKGCommunityContent_NilMaxToken(t *testing.T) { mock := &mockRetrievalEngine{} - result := searchKGCommunityContent(context.Background(), mock, []string{"ragflow_t1"}, []string{"kb1"}, []ScoredEntity{{Entity: "E1"}}, 1, nil) + result := searchCommunityContent(context.Background(), mock, []string{"ragflow_t1"}, []string{"kb1"}, []ScoredEntity{{Entity: "E1"}}, 1, nil) if result != "" { t.Errorf("expected empty when maxToken is nil, got %q", result) } @@ -555,7 +555,7 @@ func TestSearchKGCommunityContent_NilMaxToken(t *testing.T) { func TestSearchKGCommunityContent_ZeroMaxToken(t *testing.T) { mock := &mockRetrievalEngine{} - result := searchKGCommunityContent(context.Background(), mock, []string{"ragflow_t1"}, []string{"kb1"}, []ScoredEntity{{Entity: "E1"}}, 1, intPtr(0)) + result := searchCommunityContent(context.Background(), mock, []string{"ragflow_t1"}, []string{"kb1"}, []ScoredEntity{{Entity: "E1"}}, 1, intPtr(0)) if result != "" { t.Errorf("expected empty when maxToken=0, got %q", result) } diff --git a/internal/service/kg_scoring_funcs.go b/internal/service/kg/scoring.go similarity index 98% rename from internal/service/kg_scoring_funcs.go rename to internal/service/kg/scoring.go index f739e62fe68..ff2fb334a61 100644 --- a/internal/service/kg_scoring_funcs.go +++ b/internal/service/kg/scoring.go @@ -14,7 +14,7 @@ // limitations under the License. // -package service +package kg import ( "bytes" @@ -227,9 +227,9 @@ func FormatRelationsToCSV(relations []ScoredRelation, maxToken int) (csv string, return b.String(), maxToken } -// BuildKGContent assembles the final knowledge graph content string. +// BuildContent assembles the final knowledge graph content string. // Python equivalent: lines 267-291 -func BuildKGContent( +func BuildContent( entities []ScoredEntity, relations []ScoredRelation, maxToken int, diff --git a/internal/service/kg/search.go b/internal/service/kg/search.go new file mode 100644 index 00000000000..29fc057e548 --- /dev/null +++ b/internal/service/kg/search.go @@ -0,0 +1,306 @@ +package kg + +import ( + "context" + "fmt" + + "ragflow/internal/engine" + "encoding/json" + "ragflow/internal/engine/types" + modelModule "ragflow/internal/entity/models" +) + +// NhopEntityNames extracts unique entity names from an n_hop_with_weight JSON string. +func NhopEntityNames(nHopJSON string) []string { + if nHopJSON == "" { + return nil + } + var nhopData []struct { + Path []string `json:"path"` + Weights []float64 `json:"weights"` + } + if err := json.Unmarshal([]byte(nHopJSON), &nhopData); err != nil { + return nil + } + seen := make(map[string]struct{}) + for _, item := range nhopData { + for _, name := range item.Path { + seen[name] = struct{}{} + } + } + result := make([]string, 0, len(seen)) + for name := range seen { + result = append(result, name) + } + return result +} + +// SearchEntities searches for KG entities matching a question. +func SearchEntities(ctx context.Context, docEngine engine.DocEngine, kbIDs []string, question string, embModel *modelModule.EmbeddingModel, topN int) ([]KGEntity, error) { + dense, err := buildDenseExpr(embModel, question, topN) + if err != nil { + return nil, err + } + searchReq := buildEntitySearchRequest(kbIDs, question, dense, topN) + result, err := docEngine.Search(ctx, searchReq) + if err != nil { + return nil, fmt.Errorf("KG entity search failed: %w", err) + } + return ParseEntityChunks(result.Chunks), nil +} + +// SearchEntitiesByTypes searches for KG entities by type keywords. +func SearchEntitiesByTypes(ctx context.Context, docEngine engine.DocEngine, kbIDs []string, typeKeywords []string, topN int) ([]KGEntity, error) { + searchReq := buildEntityTypeSearchRequest(kbIDs, typeKeywords, topN) + result, err := docEngine.Search(ctx, searchReq) + if err != nil { + return nil, fmt.Errorf("KG entity type search failed: %w", err) + } + return ParseEntityChunks(result.Chunks), nil +} + +// SearchRelations searches for KG relations matching a question. +func SearchRelations(ctx context.Context, docEngine engine.DocEngine, kbIDs []string, question string, embModel *modelModule.EmbeddingModel, topN int) ([]KGRelation, error) { + dense, err := buildDenseExpr(embModel, question, topN) + if err != nil { + return nil, err + } + searchReq := buildRelationSearchRequest(kbIDs, question, dense, topN) + result, err := docEngine.Search(ctx, searchReq) + if err != nil { + return nil, fmt.Errorf("KG relation search failed: %w", err) + } + return ParseRelationChunks(result.Chunks), nil +} + +// SearchCommunityReports searches for community reports related to given entities. +func SearchCommunityReports(ctx context.Context, docEngine engine.DocEngine, kbIDs []string, entityNames []string, topN int) ([]KGCommunityReport, error) { + searchReq := buildCommunitySearchRequest(kbIDs, entityNames, topN) + result, err := docEngine.Search(ctx, searchReq) + if err != nil { + return nil, fmt.Errorf("KG community search failed: %w", err) + } + return ParseCommunityReportChunks(result.Chunks), nil +} + +// SearchTypeSamples retrieves the typeu2192entities mapping from ES. +func SearchTypeSamples(ctx context.Context, docEngine engine.DocEngine, kbIDs []string) (map[string][]string, error) { + searchReq := buildTypeSamplesSearchRequest(kbIDs) + result, err := docEngine.Search(ctx, searchReq) + if err != nil { + return nil, err + } + return ParseTypeSamplesChunks(result.Chunks), nil +} + +// buildDenseExpr computes the query vector and returns a MatchDenseExpr. +func buildDenseExpr(embModel *modelModule.EmbeddingModel, question string, topN int) (*types.MatchDenseExpr, error) { + if embModel == nil || question == "" { + return nil, nil + } + embCfg := &modelModule.EmbeddingConfig{Dimension: 0} + embeddings, err := embModel.ModelDriver.Embed(embModel.ModelName, []string{question}, embModel.APIConfig, embCfg) + if err != nil { + return nil, fmt.Errorf("KG entity embed failed: %w", err) + } + if len(embeddings) == 0 || len(embeddings[0].Embedding) == 0 { + return nil, nil + } + vector := embeddings[0].Embedding + return &types.MatchDenseExpr{ + VectorColumnName: fmt.Sprintf("q_%d_vec", len(vector)), + EmbeddingData: vector, + EmbeddingDataType: "float", + DistanceType: "cosine", + TopN: topN, + ExtraOptions: map[string]interface{}{"similarity": 0.3}, + }, nil +} + +// buildHybridExpr returns MatchExprs for hybrid search (dense + text + fusion). +func buildHybridExpr(dense *types.MatchDenseExpr, text *types.MatchTextExpr, topN int) []interface{} { + if dense == nil { + return []interface{}{text} + } + fusion := buildFusionExpr(defaultTextWeight, defaultVectorWeight, topN) + return []interface{}{dense, text, fusion} +} + +// buildEntitySearchRequest constructs a SearchRequest for KG entities. +func buildEntitySearchRequest(kbIDs []string, question string, dense *types.MatchDenseExpr, topN int) *types.SearchRequest { + req := &types.SearchRequest{ + KbIDs: kbIDs, + SelectFields: []string{"entity_kwd", "entity_type_kwd", "rank_flt", "content_with_weight", "n_hop_with_weight", "_score"}, + Limit: topN, + Filter: map[string]interface{}{"knowledge_graph_kwd": "entity"}, + } + if question != "" { + textExpr := &types.MatchTextExpr{ + Fields: []string{"entity_kwd^10", "content_ltks^2"}, + MatchingText: question, + TopN: topN, + } + req.MatchExprs = buildHybridExpr(dense, textExpr, topN) + } + return req +} + +// buildEntityTypeSearchRequest constructs a SearchRequest for KG entities by type. +func buildEntityTypeSearchRequest(kbIDs []string, typeKeywords []string, topN int) *types.SearchRequest { + req := &types.SearchRequest{ + KbIDs: kbIDs, + SelectFields: []string{"entity_kwd", "entity_type_kwd"}, + Limit: topN, + Filter: map[string]interface{}{"knowledge_graph_kwd": "entity"}, + } + if len(typeKeywords) > 0 { + filters := make([]interface{}, len(typeKeywords)) + for i, t := range typeKeywords { + filters[i] = t + } + req.Filter["entity_type_kwd"] = filters + } + return req +} + +// buildRelationSearchRequest constructs a SearchRequest for KG relations. +func buildRelationSearchRequest(kbIDs []string, question string, dense *types.MatchDenseExpr, topN int) *types.SearchRequest { + req := &types.SearchRequest{ + KbIDs: kbIDs, + SelectFields: []string{"from_entity_kwd", "to_entity_kwd", "weight_int", "content_with_weight", "_score"}, + Limit: topN, + Filter: map[string]interface{}{"knowledge_graph_kwd": "relation"}, + } + if question != "" { + textExpr := &types.MatchTextExpr{ + Fields: []string{"content_ltks", "from_entity_kwd", "to_entity_kwd"}, + MatchingText: question, + TopN: topN, + } + req.MatchExprs = buildHybridExpr(dense, textExpr, topN) + } + return req +} + +// buildCommunitySearchRequest constructs a SearchRequest for KG community reports. +func buildCommunitySearchRequest(kbIDs []string, entityNames []string, topN int) *types.SearchRequest { + req := &types.SearchRequest{ + KbIDs: kbIDs, + SelectFields: []string{"docnm_kwd", "content_with_weight", "weight_flt", "entities_kwd"}, + Limit: topN, + Filter: map[string]interface{}{"knowledge_graph_kwd": "community_report"}, + OrderBy: (&types.OrderByExpr{}).Desc("weight_flt"), + } + if len(entityNames) > 0 { + filters := make([]interface{}, len(entityNames)) + for i, name := range entityNames { + filters[i] = name + } + req.Filter["entities_kwd"] = filters + } + return req +} + +// buildTypeSamplesSearchRequest constructs a SearchRequest for type samples. +func buildTypeSamplesSearchRequest(kbIDs []string) *types.SearchRequest { + return &types.SearchRequest{ + KbIDs: kbIDs, + SelectFields: []string{"content_with_weight"}, + Limit: 10000, + Filter: map[string]interface{}{"knowledge_graph_kwd": "ty2ents"}, + } +} + +// ParseEntityChunks converts raw search result chunks into KGEntity slices. +func ParseEntityChunks(chunks []map[string]interface{}) []KGEntity { + var entities []KGEntity + for _, chunk := range chunks { + name, _ := chunk["entity_kwd"].(string) + if name == "" { + // Try extracting from list + if list, ok := chunk["entity_kwd"].([]interface{}); ok && len(list) > 0 { + name, _ = list[0].(string) + } + } + if name == "" { + continue + } + typ, _ := chunk["entity_type_kwd"].(string) + e := KGEntity{Name: name, Type: typ} + if v, ok := chunk["rank_flt"].(float64); ok { + e.PageRank = v + } + if v, ok := chunk["_score"].(float64); ok { + e.Similarity = v + } else if v, ok := chunk["score"].(float64); ok { + e.Similarity = v + } + e.Description, _ = chunk["content_with_weight"].(string) + entities = append(entities, e) + } + return entities +} + +// ParseRelationChunks converts raw search result chunks into KGRelation slices. +func ParseRelationChunks(chunks []map[string]interface{}) []KGRelation { + var relations []KGRelation + for _, chunk := range chunks { + from, _ := chunk["from_entity_kwd"].(string) + to, _ := chunk["to_entity_kwd"].(string) + if from == "" || to == "" { + continue + } + r := KGRelation{From: from, To: to} + if v, ok := chunk["_score"].(float64); ok { + r.Sim = v + } else if v, ok := chunk["score"].(float64); ok { + r.Sim = v + } + if v, ok := chunk["weight_int"].(float64); ok { + r.PageRank = v + } else if v, ok := chunk["weight_int"].(int); ok { + r.PageRank = float64(v) + } + r.Description, _ = chunk["content_with_weight"].(string) + relations = append(relations, r) + } + return relations +} + +// ParseCommunityReportChunks converts raw search result chunks into KGCommunityReport slices. +func ParseCommunityReportChunks(chunks []map[string]interface{}) []KGCommunityReport { + var reports []KGCommunityReport + for _, chunk := range chunks { + title, _ := chunk["docnm_kwd"].(string) + content, _ := chunk["content_with_weight"].(string) + if title == "" && content == "" { + continue + } + r := KGCommunityReport{Title: title, Content: content} + if v, ok := chunk["weight_flt"].(float64); ok { + r.Weight = v + } + r.Entities, _ = chunk["entities_kwd"].(string) + reports = append(reports, r) + } + return reports +} + +// ParseTypeSamplesChunks converts raw search result chunks into a typeu2192entities map. +func ParseTypeSamplesChunks(chunks []map[string]interface{}) map[string][]string { + typeMap := make(map[string][]string) + for _, chunk := range chunks { + content, ok := chunk["content_with_weight"].(string) + if !ok || content == "" { + continue + } + var parsed map[string][]string + if err := json.Unmarshal([]byte(content), &parsed); err != nil { + continue + } + for typ, entities := range parsed { + typeMap[typ] = append(typeMap[typ], entities...) + } + } + return typeMap +} diff --git a/internal/service/kg_search_test.go b/internal/service/kg/search_test.go similarity index 81% rename from internal/service/kg_search_test.go rename to internal/service/kg/search_test.go index 79801b0df35..fd5045cfd01 100644 --- a/internal/service/kg_search_test.go +++ b/internal/service/kg/search_test.go @@ -14,7 +14,7 @@ // limitations under the License. // -package service +package kg import ( "context" @@ -159,13 +159,13 @@ func TestBuildTypeSamplesSearchRequest(t *testing.T) { } } -// --- ParseKGEntityChunks --- +// --- ParseEntityChunks --- -func TestParseKGEntityChunks_Basic(t *testing.T) { +func TestParseEntityChunks_Basic(t *testing.T) { chunks := []map[string]interface{}{ {"entity_kwd": "Elon Musk", "entity_type_kwd": "PERSON", "rank_flt": 0.9, "_score": 0.85, "content_with_weight": "Founder of SpaceX"}, } - entities := ParseKGEntityChunks(chunks) + entities := ParseEntityChunks(chunks) if len(entities) != 1 { t.Fatalf("expected 1, got %d", len(entities)) } @@ -174,98 +174,98 @@ func TestParseKGEntityChunks_Basic(t *testing.T) { } } -func TestParseKGEntityChunks_List(t *testing.T) { +func TestParseEntityChunks_List(t *testing.T) { chunks := []map[string]interface{}{ {"entity_kwd": []interface{}{"Elon Musk", "elon_musk"}}, } - entities := ParseKGEntityChunks(chunks) + entities := ParseEntityChunks(chunks) if len(entities) != 1 || entities[0].Name != "Elon Musk" { t.Errorf("expected first list element, got %q", entities[0].Name) } } -func TestParseKGEntityChunks_EmptyName(t *testing.T) { +func TestParseEntityChunks_EmptyName(t *testing.T) { chunks := []map[string]interface{}{{"entity_type_kwd": "PERSON"}} - if len(ParseKGEntityChunks(chunks)) != 0 { + if len(ParseEntityChunks(chunks)) != 0 { t.Error("expected 0 for missing name") } } -func TestParseKGEntityChunks_ScoreFallback(t *testing.T) { +func TestParseEntityChunks_ScoreFallback(t *testing.T) { chunks := []map[string]interface{}{{"entity_kwd": "Test", "score": 0.75}} - if ParseKGEntityChunks(chunks)[0].Similarity != 0.75 { + if ParseEntityChunks(chunks)[0].Similarity != 0.75 { t.Error("expected 0.75 from score field") } } -func TestParseKGEntityChunks_NilInput(t *testing.T) { - if len(ParseKGEntityChunks(nil)) != 0 { +func TestParseEntityChunks_NilInput(t *testing.T) { + if len(ParseEntityChunks(nil)) != 0 { t.Error("expected 0 for nil input") } } -// --- ParseKGRelationChunks --- +// --- ParseRelationChunks --- -func TestParseKGRelationChunks_Basic(t *testing.T) { +func TestParseRelationChunks_Basic(t *testing.T) { chunks := []map[string]interface{}{ {"from_entity_kwd": "Elon Musk", "to_entity_kwd": "SpaceX", "weight_int": float64(5), "content_with_weight": "Founder"}, } - relations := ParseKGRelationChunks(chunks) - if len(relations) != 1 || relations[0].From != "Elon Musk" || relations[0].Weight != 5 { + relations := ParseRelationChunks(chunks) + if len(relations) != 1 || relations[0].From != "Elon Musk" || relations[0].PageRank != 5 { t.Errorf("unexpected: %+v", relations[0]) } } -func TestParseKGRelationChunks_IntWeight(t *testing.T) { +func TestParseRelationChunks_IntWeight(t *testing.T) { chunks := []map[string]interface{}{{"from_entity_kwd": "A", "to_entity_kwd": "B", "weight_int": 3}} - if ParseKGRelationChunks(chunks)[0].Weight != 3 { + if ParseRelationChunks(chunks)[0].PageRank != 3 { t.Error("expected weight 3") } } -func TestParseKGRelationChunks_EmptyFrom(t *testing.T) { - if len(ParseKGRelationChunks([]map[string]interface{}{{"to_entity_kwd": "B"}})) != 0 { +func TestParseRelationChunks_EmptyFrom(t *testing.T) { + if len(ParseRelationChunks([]map[string]interface{}{{"to_entity_kwd": "B"}})) != 0 { t.Error("expected 0 for missing from") } } -func TestParseKGRelationChunks_NilInput(t *testing.T) { - if len(ParseKGRelationChunks(nil)) != 0 { +func TestParseRelationChunks_NilInput(t *testing.T) { + if len(ParseRelationChunks(nil)) != 0 { t.Error("expected 0 for nil") } } -// --- ParseKGCommunityReportChunks --- +// --- ParseCommunityReportChunks --- -func TestParseKGCommunityReportChunks_Basic(t *testing.T) { +func TestParseCommunityReportChunks_Basic(t *testing.T) { chunks := []map[string]interface{}{ {"docnm_kwd": "Report 1", "content_with_weight": "content", "weight_flt": 0.95, "entities_kwd": "A, B"}, } - reports := ParseKGCommunityReportChunks(chunks) + reports := ParseCommunityReportChunks(chunks) if len(reports) != 1 || reports[0].Title != "Report 1" || reports[0].Weight != 0.95 { t.Errorf("unexpected: %+v", reports[0]) } } -func TestParseKGCommunityReportChunks_EmptyTitle(t *testing.T) { - if len(ParseKGCommunityReportChunks([]map[string]interface{}{{"weight_flt": 0.5}})) != 0 { +func TestParseCommunityReportChunks_EmptyTitle(t *testing.T) { + if len(ParseCommunityReportChunks([]map[string]interface{}{{"weight_flt": 0.5}})) != 0 { t.Error("expected 0 for empty title and content") } } -func TestParseKGCommunityReportChunks_NilInput(t *testing.T) { - if len(ParseKGCommunityReportChunks(nil)) != 0 { +func TestParseCommunityReportChunks_NilInput(t *testing.T) { + if len(ParseCommunityReportChunks(nil)) != 0 { t.Error("expected 0 for nil") } } -// --- ParseKGTypeSamplesChunks --- +// --- ParseTypeSamplesChunks --- -func TestParseKGTypeSamplesChunks_ValidJSON(t *testing.T) { +func TestParseTypeSamplesChunks_ValidJSON(t *testing.T) { chunks := []map[string]interface{}{ {"content_with_weight": `{"PERSON": ["Elon Musk", "Einstein"], "ORGANIZATION": ["SpaceX"]}`}, } - result := ParseKGTypeSamplesChunks(chunks) + result := ParseTypeSamplesChunks(chunks) if len(result) != 2 { t.Fatalf("expected 2 types, got %d: %v", len(result), result) } @@ -277,18 +277,18 @@ func TestParseKGTypeSamplesChunks_ValidJSON(t *testing.T) { } } -func TestParseKGTypeSamplesChunks_InvalidJSON(t *testing.T) { +func TestParseTypeSamplesChunks_InvalidJSON(t *testing.T) { chunks := []map[string]interface{}{ {"content_with_weight": "not json"}, } - result := ParseKGTypeSamplesChunks(chunks) + result := ParseTypeSamplesChunks(chunks) if len(result) != 0 { t.Error("expected empty for invalid JSON") } } -func TestParseKGTypeSamplesChunks_Empty(t *testing.T) { - result := ParseKGTypeSamplesChunks(nil) +func TestParseTypeSamplesChunks_Empty(t *testing.T) { + result := ParseTypeSamplesChunks(nil) if len(result) != 0 { t.Error("expected empty for nil") } @@ -356,9 +356,9 @@ func TestBuildKGDenseExpr_WithModel(t *testing.T) { }, APIConfig: &modelModule.APIConfig{}, } - dense, err := buildKGDenseExpr(embModel, "test question", 10) + dense, err := buildDenseExpr(embModel, "test question", 10) if err != nil { - t.Fatalf("buildKGDenseExpr failed: %v", err) + t.Fatalf("buildDenseExpr failed: %v", err) } if dense == nil { t.Fatal("expected non-nil MatchDenseExpr") @@ -372,14 +372,14 @@ func TestBuildKGDenseExpr_WithModel(t *testing.T) { } func TestBuildKGDenseExpr_NilModel(t *testing.T) { - dense, err := buildKGDenseExpr(nil, "test", 10) + dense, err := buildDenseExpr(nil, "test", 10) if dense != nil || err != nil { t.Errorf("expected nil,nil for nil model, got dense=%v err=%v", dense, err) } } func TestBuildKGDenseExpr_EmptyQuestion(t *testing.T) { - dense, err := buildKGDenseExpr(&modelModule.EmbeddingModel{}, "", 10) + dense, err := buildDenseExpr(&modelModule.EmbeddingModel{}, "", 10) if dense != nil || err != nil { t.Errorf("expected nil,nil for empty question, got dense=%v err=%v", dense, err) } @@ -387,7 +387,7 @@ func TestBuildKGDenseExpr_EmptyQuestion(t *testing.T) { // --- Search integration with mock --- -func TestSearchKGEntities_WithMock(t *testing.T) { +func TestSearchEntities_WithMock(t *testing.T) { mock := &mockKGEngine{ searchFunc: func(ctx context.Context, req *types.SearchRequest) (*types.SearchResult, error) { if req.Filter["knowledge_graph_kwd"] != "entity" { @@ -400,16 +400,16 @@ func TestSearchKGEntities_WithMock(t *testing.T) { }, nil }, } - entities, err := SearchKGEntities(context.Background(), mock, []string{"kb1"}, "Elon", nil, 10) + entities, err := SearchEntities(context.Background(), mock, []string{"kb1"}, "Elon", nil, 10) if err != nil { - t.Fatalf("SearchKGEntities failed: %v", err) + t.Fatalf("SearchEntities failed: %v", err) } if len(entities) != 1 || entities[0].Name != "Elon Musk" { t.Errorf("expected [Elon Musk], got %v", entities) } } -func TestSearchKGEntitiesByTypes_WithMock(t *testing.T) { +func TestSearchEntitiesByTypes_WithMock(t *testing.T) { mock := &mockKGEngine{ searchFunc: func(ctx context.Context, req *types.SearchRequest) (*types.SearchResult, error) { return &types.SearchResult{ @@ -419,20 +419,20 @@ func TestSearchKGEntitiesByTypes_WithMock(t *testing.T) { }, nil }, } - entities, err := SearchKGEntitiesByTypes(context.Background(), mock, []string{"kb1"}, []string{"ORGANIZATION"}, 10) + entities, err := SearchEntitiesByTypes(context.Background(), mock, []string{"kb1"}, []string{"ORGANIZATION"}, 10) if err != nil { - t.Fatalf("SearchKGEntitiesByTypes failed: %v", err) + t.Fatalf("SearchEntitiesByTypes failed: %v", err) } if len(entities) != 1 || entities[0].Type != "ORGANIZATION" { t.Errorf("expected ORGANIZATION, got %v", entities) } } -func TestSearchKGTypeSamples_WithMock(t *testing.T) { +func TestSearchTypeSamples_WithMock(t *testing.T) { mock := &mockKGEngine{} - samples, err := SearchKGTypeSamples(context.Background(), mock, []string{"kb1"}) + samples, err := SearchTypeSamples(context.Background(), mock, []string{"kb1"}) if err != nil { - t.Fatalf("SearchKGTypeSamples failed: %v", err) + t.Fatalf("SearchTypeSamples failed: %v", err) } if samples == nil { samples = map[string][]string{} diff --git a/internal/service/kg/testutil_test.go b/internal/service/kg/testutil_test.go new file mode 100644 index 00000000000..015e1d41983 --- /dev/null +++ b/internal/service/kg/testutil_test.go @@ -0,0 +1,3 @@ +package kg + +func strPtr(s string) *string { return &s } diff --git a/internal/service/kg/types.go b/internal/service/kg/types.go new file mode 100644 index 00000000000..2c502c3e090 --- /dev/null +++ b/internal/service/kg/types.go @@ -0,0 +1,60 @@ +package kg + +// KGEntity represents a knowledge graph entity. +type KGEntity struct { + Name string // entity_kwd + Type string // entity_type_kwd + PageRank float64 // rank_flt + Similarity float64 // _score + Description string // content_with_weight + NhopEnts []NhopEntity // n_hop_with_weight (parsed JSON) +} + +// NhopEntity represents an N-hop neighbor path. +type NhopEntity struct { + Path []string // entity names along the path + Weights []float64 // pagerank weights per hop +} + +// KGRelation represents a relation between two entities. +type KGRelation struct { + From string // from_entity_kwd + To string // to_entity_kwd + Description string // content_with_weight + Sim float64 // score accumulated during pipeline scoring + PageRank float64 // rank_flt or weight_int as float64 +} + +// Edge represents a directed (from_entity, to_entity) pair. +type Edge struct { + From, To string +} + +// EdgeScore represents the accumulated score for an edge from N-hop analysis. +type EdgeScore struct { + Sim float64 + PageRank float64 +} + +// ScoredEntity is a scored entity ready for output. +type ScoredEntity struct { + Entity string + Score float64 + Description string +} + +// ScoredRelation is a scored relation ready for output. +type ScoredRelation struct { + From string + To string + Score float64 + Description string +} + +// KGCommunityReport represents a community report. +type KGCommunityReport struct { + Title string // docnm_kwd + Content string // content_with_weight + Weight float64 // weight_flt + Entities string // entities_kwd +} diff --git a/internal/service/kg_scoring_funcs_test.go b/internal/service/kg_scoring_funcs_test.go deleted file mode 100644 index 62ed85d6b0e..00000000000 --- a/internal/service/kg_scoring_funcs_test.go +++ /dev/null @@ -1,368 +0,0 @@ -// -// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -package service - -import ( - "strings" - "testing" -) - -// --- AnalyzeNHopPaths --- - -func TestAnalyzeNHopPaths_Basic(t *testing.T) { - ents := map[string]*KGEntity{ - "A": { -Similarity: 0.9, - NhopEnts: []NhopEntity{ - {Path: []string{"A", "B", "C"}, Weights: []float64{0.8, 0.5}}, - }, - }, - } - result := AnalyzeNHopPaths(ents) - // A→B: 0.9 / (2+0) = 0.45 - // B→C: 0.9 / (2+1) = 0.3 - if len(result) != 2 { - t.Fatalf("expected 2 edges, got %d", len(result)) - } - if result[Edge{"A", "B"}].Sim != 0.45 { - t.Errorf("expected A→B sim=0.45, got %f", result[Edge{"A", "B"}].Sim) - } - if result[Edge{"B", "C"}].Sim != 0.3 { - t.Errorf("expected B→C sim=0.3, got %f", result[Edge{"B", "C"}].Sim) - } -} - -func TestAnalyzeNHopPaths_MultipleContributors(t *testing.T) { - ents := map[string]*KGEntity{ - "A": { -Similarity: 0.8, - NhopEnts: []NhopEntity{ - {Path: []string{"A", "B"}, Weights: []float64{0.7}}, - }, - }, - "X": { -Similarity: 0.6, - NhopEnts: []NhopEntity{ - {Path: []string{"X", "B"}, Weights: []float64{0.5}}, - }, - }, - } - result := AnalyzeNHopPaths(ents) - // A→B: 0.8 / 2 = 0.4 - // X→B: 0.6 / 2 = 0.3 - if result[Edge{"A", "B"}].Sim != 0.4 { - t.Errorf("expected A→B sim=0.4, got %f", result[Edge{"A", "B"}].Sim) - } - if result[Edge{"X", "B"}].Sim != 0.3 { - t.Errorf("expected X→B sim=0.3, got %f", result[Edge{"X", "B"}].Sim) - } -} - -func TestAnalyzeNHopPaths_Empty(t *testing.T) { - result := AnalyzeNHopPaths(nil) - if len(result) != 0 { - t.Errorf("expected empty, got %d", len(result)) - } -} - -// --- DoubleHitBoost --- - -func TestDoubleHitBoost(t *testing.T) { - ents := map[string]*KGEntity{ - "A": {Similarity: 0.5}, - "B": {Similarity: 0.3}, - } - types := map[string]struct{}{"A": {}} - DoubleHitBoost(ents, types) - if ents["A"].Similarity != 1.0 { - t.Errorf("expected A sim=1.0 after boost, got %f", ents["A"].Similarity) - } - if ents["B"].Similarity != 0.3 { - t.Errorf("expected B sim unchanged at 0.3, got %f", ents["B"].Similarity) - } -} - -func TestDoubleHitBoost_Empty(t *testing.T) { - ents := map[string]*KGEntity{"A": {Similarity: 0.5}} - DoubleHitBoost(ents, map[string]struct{}{}) - if ents["A"].Similarity != 0.5 { - t.Errorf("expected unchanged, got %f", ents["A"].Similarity) - } -} - -// --- FuseRelationScores --- - -func TestFuseRelationScores_NhopContribution(t *testing.T) { - rels := map[Edge]*KGRelation{ - {"A", "B"}: {Sim: 0.5, PageRank: 0.8}, - } - types := map[string]struct{}{} - nhop := map[Edge]EdgeScore{ - {"A", "B"}: {Sim: 0.3}, - } - FuseRelationScores(rels, types, nhop) - // sim = 0.5 * (0.3 + 1) = 0.65 - if rels[Edge{"A", "B"}].Sim != 0.65 { - t.Errorf("expected 0.65, got %f", rels[Edge{"A", "B"}].Sim) - } -} - -func TestFuseRelationScores_TypeBoost(t *testing.T) { - rels := map[Edge]*KGRelation{ - {"A", "B"}: {Sim: 0.5}, - } - types := map[string]struct{}{"A": {}, "B": {}} - nhop := map[Edge]EdgeScore{} - FuseRelationScores(rels, types, nhop) - // Both endpoints in types: s=2, sim = 0.5 * (2+1) = 1.5 - if rels[Edge{"A", "B"}].Sim != 1.5 { - t.Errorf("expected 1.5, got %f", rels[Edge{"A", "B"}].Sim) - } -} - -func TestFuseRelationScores_NhopNewEdge(t *testing.T) { - rels := map[Edge]*KGRelation{} - types := map[string]struct{}{} - nhop := map[Edge]EdgeScore{ - {"A", "B"}: {Sim: 0.4, PageRank: 0.7}, - } - FuseRelationScores(rels, types, nhop) - if _, ok := rels[Edge{"A", "B"}]; !ok { - t.Fatal("expected new edge from N-hop") - } - if rels[Edge{"A", "B"}].Sim != 0.4 { - t.Errorf("expected sim=0.4, got %f", rels[Edge{"A", "B"}].Sim) - } -} - -// --- SortAndTrim --- - -func TestSortAndTrimEntities(t *testing.T) { - ents := map[string]*KGEntity{ - "A": {Similarity: 0.5, PageRank: 0.9}, - "B": {Similarity: 0.8, PageRank: 0.3}, - "C": {Similarity: 0.9, PageRank: 0.1}, - } - result := SortAndTrimEntities(ents, 2) - if len(result) != 2 { - t.Fatalf("expected 2, got %d", len(result)) - } - // A: 0.45, B: 0.24, C: 0.09 → top 2 should be A, B - if result[0].Entity != "A" { - t.Errorf("expected A first (0.45), got %s (%f)", result[0].Entity, result[0].Score) - } -} - -func TestSortAndTrimEntities_DefaultTopN(t *testing.T) { - ents := map[string]*KGEntity{ - "A": {Similarity: 0.5, PageRank: 0.9}, - "B": {Similarity: 0.8, PageRank: 0.3}, - } - result := SortAndTrimEntities(ents, 0) - if len(result) != 2 { - t.Errorf("expected default topN to include all, got %d", len(result)) - } -} - -func TestSortAndTrimRelations(t *testing.T) { - rels := map[Edge]*KGRelation{ - {"A", "B"}: {Sim: 0.9, PageRank: 0.1}, - {"C", "D"}: {Sim: 0.3, PageRank: 0.8}, - } - result := SortAndTrimRelations(rels, 1) - if len(result) != 1 { - t.Fatalf("expected 1, got %d", len(result)) - } - // A→B: 0.09, C→D: 0.24 → C→D should be first - if result[0].From != "C" { - t.Errorf("expected C first (0.24), got %s (%f)", result[0].From, result[0].Score) - } -} - -// --- Format and Build --- - -func TestBuildKGContent_Basic(t *testing.T) { - entities := []ScoredEntity{ - {Entity: "A", Score: 0.45, Description: `{"description": "Entity A desc"}`}, - } - relations := []ScoredRelation{ - {From: "A", To: "B", Score: 0.3, Description: `{"description": "rel A-B"}`}, - } - result := BuildKGContent(entities, relations, 10000) - if !strings.Contains(result, "Entity A desc") { - t.Errorf("expected entity description in output, got: %s", result) - } - if !strings.Contains(result, "rel A-B") { - t.Errorf("expected relation description in output, got: %s", result) - } -} - -func TestBuildKGContent_TokenBudget(t *testing.T) { - longDesc := strings.Repeat("This is a very long description. ", 50) - entities := []ScoredEntity{ - {Entity: "LongEntityName", Score: 1.0, Description: longDesc}, - } - relations := []ScoredRelation{ - {From: "X", To: "Y", Score: 1.0, Description: "relation desc"}, - } - result := BuildKGContent(entities, relations, 50) - // Token budget is very small, should truncate and not include relations - if strings.Contains(result, "relation desc") { - t.Log("Note: relations included despite small budget (depending on token count)") - } -} - -func TestFormatEntitiesToCSV_HeaderExceedsBudget(t *testing.T) { - entities := []ScoredEntity{ - {Entity: "A", Score: 1.0, Description: "d"}, - } - result, remaining := FormatEntitiesToCSV(entities, 3) - tokens := NumTokensFromString(result) - // Header lines (---- Entities ----\n, Entity,Score,Description\n) are written - // before the token budget check. They consume ~11 tokens but are not deducted - // from maxToken. This is a known limitation shared with Python. - if tokens > 3 { - t.Logf("output %d tokens exceeds budget of %d (header not counted, remaining=%d)", tokens, 3, remaining) - } -} - -func TestFilterChunksByScore_AllPass(t *testing.T) { - chunks := []map[string]interface{}{ - {"entity_kwd": "A", "_score": 0.5}, - {"entity_kwd": "B", "_score": 0.8}, - } - result := FilterChunksByScore(chunks, 0.3) - if len(result) != 2 { - t.Errorf("expected all 2 chunks to pass, got %d", len(result)) - } -} - -func TestFilterChunksByScore_SomeFiltered(t *testing.T) { - chunks := []map[string]interface{}{ - {"entity_kwd": "A", "_score": 0.2}, - {"entity_kwd": "B", "_score": 0.9}, - } - result := FilterChunksByScore(chunks, 0.3) - if len(result) != 1 || result[0]["entity_kwd"] != "B" { - t.Errorf("expected only B to pass, got %v", result) - } -} - -func TestFilterChunksByScore_MissingScore(t *testing.T) { - chunks := []map[string]interface{}{ - {"entity_kwd": "A"}, // no _score → treated as 0 - {"entity_kwd": "B", "score": 0.5}, - } - result := FilterChunksByScore(chunks, 0.3) - if len(result) != 1 || result[0]["entity_kwd"] != "B" { - t.Errorf("expected only B (using 'score' field), got %v", result) - } -} - -func TestFilterChunksByScore_NilInput(t *testing.T) { - result := FilterChunksByScore(nil, 0.3) - if result != nil { - t.Errorf("expected nil, got %v", result) - } -} - -func TestFilterChunksByScore_ZeroThreshold(t *testing.T) { - chunks := []map[string]interface{}{ - {"entity_kwd": "A", "_score": 0.0}, - } - result := FilterChunksByScore(chunks, 0) - if len(result) != 1 { - t.Errorf("expected all pass when threshold=0, got %d", len(result)) - } -} - -func TestExtractDescription_JSON(t *testing.T) { - result := extractDescription(`{"description": "Entity A description", "other": "value"}`) - if result != "Entity A description" { - t.Errorf("expected 'Entity A description', got %q", result) - } -} - -func TestExtractDescription_Plain(t *testing.T) { - result := extractDescription("plain description") - if result != "plain description" { - t.Errorf("expected 'plain description', got %q", result) - } -} - -func TestExtractDescription_EscapedQuote(t *testing.T) { - result := extractDescription(`{"description": "has \"quote\" inside"}`) - if result != `has "quote" inside` { - t.Errorf("expected full description with quote, got %q", result) - } -} - -func TestExtractDescription_NonStringValue(t *testing.T) { - result := extractDescription(`{"description": null, "other": "val"}`) - if result != `{"description": null, "other": "val"}` { - t.Errorf("expected raw JSON when description is null, got %q", result) - } -} - -func TestExtractDescription_EmptyString(t *testing.T) { - result := extractDescription("") - if result != "" { - t.Errorf("expected empty, got %q", result) - } -} - -func TestFormatCSVLine_Normal(t *testing.T) { - result := formatCSVLine("Elon Musk", "0.85", "CEO of SpaceX") - // Normal values should not be quoted - if result != "Elon Musk,0.85,CEO of SpaceX\n" { - t.Errorf("expected unquoted CSV, got %q", result) - } -} - -func TestFormatCSVLine_CommaInField(t *testing.T) { - result := formatCSVLine("Musk, Elon", "0.85", "CEO, SpaceX") - // Values with commas should be quoted - expected := `"Musk, Elon",0.85,"CEO, SpaceX"` + "\n" - if result != expected { - t.Errorf("expected %q, got %q", expected, result) - } -} - -func TestFormatCSVLine_QuoteInField(t *testing.T) { - result := formatCSVLine("Elon Musk", "0.85", `CEO of "SpaceX"`) - // Values with quotes should have quotes escaped - expected := `Elon Musk,0.85,"CEO of ""SpaceX"""` + "\n" - if result != expected { - t.Errorf("expected %q, got %q", expected, result) - } -} - -func TestFormatCSVLine_EmptyField(t *testing.T) { - result := formatCSVLine("", "", "") - if result != ",,\n" { - t.Errorf("expected empty fields, got %q", result) - } -} - -func TestNumTokensFromString(t *testing.T) { - s := "This is a test string with multiple words" - tokens := NumTokensFromString(s) - if tokens <= 0 { - t.Errorf("expected positive token count, got %d", tokens) - } -} - diff --git a/internal/service/kg_search.go b/internal/service/kg_search.go deleted file mode 100644 index 3f93c98e6f0..00000000000 --- a/internal/service/kg_search.go +++ /dev/null @@ -1,397 +0,0 @@ -// -// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -package service - -import ( - "context" - "encoding/json" - "fmt" - - "ragflow/internal/engine" - "ragflow/internal/engine/types" - modelModule "ragflow/internal/entity/models" -) - -// NhopEntity represents an N-hop neighbor path. -type NhopEntity struct { - Path []string // entity names along the path - Weights []float64 // pagerank weights per hop -} - -// KGEntity represents a knowledge graph entity. -type KGEntity struct { - Name string // entity_kwd - Type string // entity_type_kwd - PageRank float64 // rank_flt - Similarity float64 // _score - Description string // content_with_weight - NhopEnts []NhopEntity // n_hop_with_weight (parsed JSON) -} - -// Edge represents a directed (from_entity, to_entity) pair. -type Edge struct { - From, To string -} - -// EdgeScore represents the accumulated score for an edge from N-hop analysis. -type EdgeScore struct { - Sim float64 - PageRank float64 -} - -// ScoredEntity is a scored entity ready for output. -type ScoredEntity struct { - Entity string - Score float64 - Description string -} - -// ScoredRelation is a scored relation ready for output. -type ScoredRelation struct { - From string - To string - Score float64 - Description string -} - -// KGRelation represents a relation between two entities. -type KGRelation struct { - From string // from_entity_kwd - To string // to_entity_kwd - Weight int // weight_int - Description string // content_with_weight - Sim float64 // score accumulated during pipeline scoring - PageRank float64 // rank_flt or weight_int as float64 -} - -// KGCommunityReport represents a community report. -type KGCommunityReport struct { - Title string // docnm_kwd - Content string // content_with_weight - Weight float64 // weight_flt - Entities string // entities_kwd -} - -// buildKGDenseExpr computes the query vector and returns a MatchDenseExpr -// for KG hybrid search. Returns nil if embModel or question is empty. -func buildKGDenseExpr(embModel *modelModule.EmbeddingModel, question string, topN int) (*types.MatchDenseExpr, error) { - if embModel == nil || question == "" { - return nil, nil - } - embCfg := &modelModule.EmbeddingConfig{Dimension: 0} - embeddings, err := embModel.ModelDriver.Embed(embModel.ModelName, []string{question}, embModel.APIConfig, embCfg) - if err != nil { - return nil, fmt.Errorf("KG entity embed failed: %w", err) - } - if len(embeddings) == 0 || len(embeddings[0].Embedding) == 0 { - return nil, nil - } - vector := embeddings[0].Embedding - return &types.MatchDenseExpr{ - VectorColumnName: fmt.Sprintf("q_%d_vec", len(vector)), - EmbeddingData: vector, - EmbeddingDataType: "float", - DistanceType: "cosine", - TopN: topN, - ExtraOptions: map[string]interface{}{"similarity": 0.3}, - }, nil -} - -// buildHybridExpr returns MatchExprs for hybrid search (dense + text + fusion). -func buildHybridExpr(dense *types.MatchDenseExpr, text *types.MatchTextExpr, topN int) []interface{} { - return []interface{}{ - dense, - text, - &types.FusionExpr{ - Method: "weighted_sum", - TopN: topN, - FusionParams: map[string]interface{}{"weights": "0.05,0.95"}, - }, - } -} - -// buildEntitySearchRequest constructs a SearchRequest for KG entities. -// dense may be nil for text-only search. -func buildEntitySearchRequest(kbIDs []string, question string, dense *types.MatchDenseExpr, topN int) *types.SearchRequest { - req := &types.SearchRequest{ - KbIDs: kbIDs, - SelectFields: []string{"entity_kwd", "entity_type_kwd", "rank_flt", "content_with_weight"}, - Limit: topN, - Filter: map[string]interface{}{"knowledge_graph_kwd": "entity"}, - } - if question == "" { - return req - } - textExpr := &types.MatchTextExpr{ - Fields: []string{"entity_kwd^10", "content_ltks^2"}, - MatchingText: question, - TopN: topN, - } - if dense != nil { - req.MatchExprs = buildHybridExpr(dense, textExpr, topN) - req.RankFeature = map[string]float64{"pagerank_fea": 10.0} - } else { - req.MatchExprs = []interface{}{textExpr} - } - return req -} - -// buildEntityTypeSearchRequest constructs a SearchRequest for KG entities by type. -func buildEntityTypeSearchRequest(kbIDs []string, typeKeywords []string, topN int) *types.SearchRequest { - req := &types.SearchRequest{ - KbIDs: kbIDs, - SelectFields: []string{"entity_kwd", "entity_type_kwd", "rank_flt", "content_with_weight"}, - Limit: topN, - Filter: map[string]interface{}{ - "knowledge_graph_kwd": "entity", - }, - } - if len(typeKeywords) > 0 { - filters := make([]interface{}, len(typeKeywords)) - for i, t := range typeKeywords { - filters[i] = t - } - req.Filter["entity_type_kwd"] = filters - } - return req -} - -// buildRelationSearchRequest constructs a SearchRequest for KG relations. -// dense may be nil for text-only search. -func buildRelationSearchRequest(kbIDs []string, question string, dense *types.MatchDenseExpr, topN int) *types.SearchRequest { - req := &types.SearchRequest{ - KbIDs: kbIDs, - SelectFields: []string{"from_entity_kwd", "to_entity_kwd", "weight_int", "content_with_weight"}, - Limit: topN, - Filter: map[string]interface{}{"knowledge_graph_kwd": "relation"}, - } - if question != "" { - textExpr := &types.MatchTextExpr{ - Fields: []string{"content_ltks"}, - MatchingText: question, - TopN: topN, - } - if dense != nil { - req.MatchExprs = buildHybridExpr(dense, textExpr, topN) - } else { - req.MatchExprs = []interface{}{textExpr} - } - } - return req -} - -// buildCommunitySearchRequest constructs a SearchRequest for KG community reports. -// Matches community reports whose entities_kwd contains any of the given entity names. -func buildCommunitySearchRequest(kbIDs []string, entityNames []string, topN int) *types.SearchRequest { - req := &types.SearchRequest{ - KbIDs: kbIDs, - SelectFields: []string{"docnm_kwd", "content_with_weight", "weight_flt", "entities_kwd"}, - Limit: topN, - Filter: map[string]interface{}{ - "knowledge_graph_kwd": "community_report", - }, - OrderBy: (&types.OrderByExpr{}).Desc("weight_flt"), - } - if len(entityNames) > 0 { - filters := make([]interface{}, len(entityNames)) - for i, name := range entityNames { - filters[i] = name - } - req.Filter["entities_kwd"] = filters - } - return req -} - -// buildTypeSamplesSearchRequest constructs a SearchRequest for ty2ents data. -func buildTypeSamplesSearchRequest(kbIDs []string) *types.SearchRequest { - return &types.SearchRequest{ - KbIDs: kbIDs, - SelectFields: []string{"content_with_weight"}, - Limit: 10000, - Filter: map[string]interface{}{"knowledge_graph_kwd": "ty2ents"}, - } -} - -// ParseKGEntityChunks converts raw search result chunks into KGEntity slices. -func ParseKGEntityChunks(chunks []map[string]interface{}) []KGEntity { - var entities []KGEntity - for _, chunk := range chunks { - e := KGEntity{} - if v, ok := chunk["entity_kwd"].(string); ok { - e.Name = v - } else if list, ok := chunk["entity_kwd"].([]interface{}); ok && len(list) > 0 { - e.Name, _ = list[0].(string) - } - if e.Name == "" { - continue - } - e.Type, _ = chunk["entity_type_kwd"].(string) - e.Description, _ = chunk["content_with_weight"].(string) - if v, ok := chunk["rank_flt"].(float64); ok { - e.PageRank = v - } - if v, ok := chunk["_score"].(float64); ok { - e.Similarity = v - } else if v, ok := chunk["score"].(float64); ok { - e.Similarity = v - } - entities = append(entities, e) - } - return entities -} - -// ParseKGRelationChunks converts raw search result chunks into KGRelation slices. -func ParseKGRelationChunks(chunks []map[string]interface{}) []KGRelation { - var relations []KGRelation - for _, chunk := range chunks { - r := KGRelation{} - r.From, _ = chunk["from_entity_kwd"].(string) - r.To, _ = chunk["to_entity_kwd"].(string) - r.Description, _ = chunk["content_with_weight"].(string) - if v, ok := chunk["weight_int"].(float64); ok { - r.Weight = int(v) - } else if v, ok := chunk["weight_int"].(int); ok { - r.Weight = v - } - if r.From == "" || r.To == "" { - continue - } - relations = append(relations, r) - } - return relations -} - -// ParseKGCommunityReportChunks converts raw search result chunks into KGCommunityReport slices. -func ParseKGCommunityReportChunks(chunks []map[string]interface{}) []KGCommunityReport { - var reports []KGCommunityReport - for _, chunk := range chunks { - r := KGCommunityReport{} - r.Title, _ = chunk["docnm_kwd"].(string) - r.Content, _ = chunk["content_with_weight"].(string) - r.Entities, _ = chunk["entities_kwd"].(string) - if v, ok := chunk["weight_flt"].(float64); ok { - r.Weight = v - } - if r.Title == "" && r.Content == "" { - continue - } - reports = append(reports, r) - } - return reports -} - -// ParseKGTypeSamplesChunks converts raw search result chunks into a type→entities map. -func ParseKGTypeSamplesChunks(chunks []map[string]interface{}) map[string][]string { - result := make(map[string][]string) - for _, chunk := range chunks { - content, ok := chunk["content_with_weight"].(string) - if !ok || content == "" { - continue - } - var typeMap map[string][]string - if err := json.Unmarshal([]byte(content), &typeMap); err != nil { - continue - } - for typ, entities := range typeMap { - result[typ] = append(result[typ], entities...) - } - } - return result -} - -// NhopEntityNames extracts unique entity names from n_hop_with_weight JSON string. -// The JSON format is: [{"path": ["A", "B", "C"], "weights": [0.8, 0.5]}, ...] -// Returns entity names in order of first appearance, with duplicates removed. -func NhopEntityNames(nHopJSON string) []string { - type nhopItem struct { - Path []string `json:"path"` - Weights []float64 `json:"weights"` - } - var data []nhopItem - if err := json.Unmarshal([]byte(nHopJSON), &data); err != nil { - return nil - } - seen := make(map[string]struct{}) - var names []string - for _, item := range data { - for _, name := range item.Path { - if _, ok := seen[name]; !ok { - seen[name] = struct{}{} - names = append(names, name) - } - } - } - return names -} - -// SearchKGEntities searches for KG entities matching a question. -func SearchKGEntities(ctx context.Context, docEngine engine.DocEngine, kbIDs []string, question string, embModel *modelModule.EmbeddingModel, topN int) ([]KGEntity, error) { - dense, err := buildKGDenseExpr(embModel, question, topN) - if err != nil { - return nil, err - } - req := buildEntitySearchRequest(kbIDs, question, dense, topN) - result, err := docEngine.Search(ctx, req) - if err != nil { - return nil, fmt.Errorf("KG entity search failed: %w", err) - } - return ParseKGEntityChunks(result.Chunks), nil -} - -// SearchKGEntitiesByTypes searches for KG entities by type keywords. -func SearchKGEntitiesByTypes(ctx context.Context, docEngine engine.DocEngine, kbIDs []string, typeKeywords []string, topN int) ([]KGEntity, error) { - req := buildEntityTypeSearchRequest(kbIDs, typeKeywords, topN) - result, err := docEngine.Search(ctx, req) - if err != nil { - return nil, fmt.Errorf("KG entity type search failed: %w", err) - } - return ParseKGEntityChunks(result.Chunks), nil -} - -// SearchKGRelations searches for KG relations matching a question. -func SearchKGRelations(ctx context.Context, docEngine engine.DocEngine, kbIDs []string, question string, embModel *modelModule.EmbeddingModel, topN int) ([]KGRelation, error) { - dense, err := buildKGDenseExpr(embModel, question, topN) - if err != nil { - return nil, err - } - req := buildRelationSearchRequest(kbIDs, question, dense, topN) - result, err := docEngine.Search(ctx, req) - if err != nil { - return nil, fmt.Errorf("KG relation search failed: %w", err) - } - return ParseKGRelationChunks(result.Chunks), nil -} - -// SearchKGCommunityReports searches for community reports related to given entities. -func SearchKGCommunityReports(ctx context.Context, docEngine engine.DocEngine, kbIDs []string, entityNames []string, topN int) ([]KGCommunityReport, error) { - req := buildCommunitySearchRequest(kbIDs, entityNames, topN) - result, err := docEngine.Search(ctx, req) - if err != nil { - return nil, fmt.Errorf("KG community search failed: %w", err) - } - return ParseKGCommunityReportChunks(result.Chunks), nil -} - -// SearchKGTypeSamples retrieves the type→entities mapping from ES. -func SearchKGTypeSamples(ctx context.Context, docEngine engine.DocEngine, kbIDs []string) (map[string][]string, error) { - req := buildTypeSamplesSearchRequest(kbIDs) - result, err := docEngine.Search(ctx, req) - if err != nil { - return nil, fmt.Errorf("KG type samples search failed: %w", err) - } - return ParseKGTypeSamplesChunks(result.Chunks), nil -}