diff --git a/internal/handler/chunk.go b/internal/handler/chunk.go index e49c44281bc..a87ffde2718 100644 --- a/internal/handler/chunk.go +++ b/internal/handler/chunk.go @@ -19,12 +19,13 @@ package handler import ( "encoding/json" "net/http" + "strconv" "strings" - "ragflow/internal/common" "github.com/gin-gonic/gin" "go.uber.org/zap" + "ragflow/internal/common" "ragflow/internal/service" ) @@ -35,6 +36,10 @@ type chunkService interface { List(req *service.ListChunksRequest, userID string) (*service.ListChunksResponse, error) UpdateChunk(req *service.UpdateChunkRequest, userID string) error RemoveChunks(req *service.RemoveChunksRequest, userID string) (int64, error) + ListChunksREST(datasetID, documentID, userID string, page, pageSize int, keywords string, available *bool) (*service.ListChunksResponse, error) + AddChunk(datasetID, documentID, userID string, req *service.AddChunkRequest) (map[string]interface{}, error) + UpdateChunkREST(datasetID, documentID, chunkID, userID string, req *service.UpdateChunkRESTRequest) error + SwitchChunks(datasetID, documentID, userID string, chunkIDs []string, available bool) error } // ChunkHandler chunk handler @@ -387,6 +392,197 @@ func (h *ChunkHandler) UpdateChunk(c *gin.Context) { }) } +// ListChunksREST lists chunks for a document inside a dataset. +// @Summary List Chunks +// @Description List chunks for a document (dataset_id and document_id from path). +// @Tags chunks +// @Produce json +// @Param dataset_id path string true "Dataset ID" +// @Param document_id path string true "Document ID" +// @Param page query int false "Page number (default 1)" +// @Param page_size query int false "Items per page (default 30)" +// @Param keywords query string false "Keyword filter" +// @Param available query bool false "Filter by available status" +// @Success 200 {object} map[string]interface{} +// @Router /api/v1/datasets/{dataset_id}/documents/{document_id}/chunks [get] +func (h *ChunkHandler) ListChunksREST(c *gin.Context) { + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) + return + } + + datasetID := c.Param("dataset_id") + documentID := c.Param("document_id") + if datasetID == "" || documentID == "" { + c.JSON(http.StatusOK, gin.H{"code": common.CodeDataError, "data": false, "message": "dataset_id and document_id are required"}) + return + } + + page := 1 + if v := c.Query("page"); v != "" { + if p, err := strconv.Atoi(v); err == nil && p > 0 { + page = p + } + } + pageSize := 30 + if v := c.Query("page_size"); v != "" { + if ps, err := strconv.Atoi(v); err == nil && ps > 0 { + if ps > 100 { + ps = 100 + } + pageSize = ps + } + } + keywords := c.Query("keywords") + + var available *bool + if v := c.Query("available"); v != "" { + b := v == "true" || v == "1" + available = &b + } + + resp, err := h.chunkService.ListChunksREST(datasetID, documentID, user.ID, page, pageSize, keywords, available) + if err != nil { + c.JSON(http.StatusOK, gin.H{"code": common.CodeDataError, "data": false, "message": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{"code": common.CodeSuccess, "data": resp, "message": "success"}) +} + +// AddChunk adds a manually created chunk to a document. +// @Summary Add Chunk +// @Description Create a new chunk for a document with content, keywords, and questions. +// @Tags chunks +// @Accept json +// @Produce json +// @Param dataset_id path string true "Dataset ID" +// @Param document_id path string true "Document ID" +// @Param request body service.AddChunkRequest true "chunk content" +// @Success 200 {object} map[string]interface{} +// @Router /api/v1/datasets/{dataset_id}/documents/{document_id}/chunks [post] +func (h *ChunkHandler) AddChunk(c *gin.Context) { + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) + return + } + + datasetID := c.Param("dataset_id") + documentID := c.Param("document_id") + if datasetID == "" || documentID == "" { + c.JSON(http.StatusOK, gin.H{"code": common.CodeDataError, "data": false, "message": "dataset_id and document_id are required"}) + return + } + + var req service.AddChunkRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusOK, gin.H{"code": common.CodeDataError, "data": false, "message": err.Error()}) + return + } + + data, err := h.chunkService.AddChunk(datasetID, documentID, user.ID, &req) + if err != nil { + c.JSON(http.StatusOK, gin.H{"code": common.CodeDataError, "data": false, "message": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{"code": common.CodeSuccess, "data": data, "message": "success"}) +} + +// UpdateChunkREST updates a chunk's content, keywords, and availability. +// Re-embeds the chunk when content or questions change. +// @Summary Update Chunk (REST) +// @Description Partially update a chunk by ID, re-embedding on content/question changes. +// @Tags chunks +// @Accept json +// @Produce json +// @Param dataset_id path string true "Dataset ID" +// @Param document_id path string true "Document ID" +// @Param chunk_id path string true "Chunk ID" +// @Param request body service.UpdateChunkRESTRequest true "fields to update" +// @Success 200 {object} map[string]interface{} +// @Router /api/v1/datasets/{dataset_id}/documents/{document_id}/chunks/{chunk_id} [patch] +func (h *ChunkHandler) UpdateChunkREST(c *gin.Context) { + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) + return + } + + datasetID := c.Param("dataset_id") + documentID := c.Param("document_id") + chunkID := c.Param("chunk_id") + if datasetID == "" || documentID == "" || chunkID == "" { + c.JSON(http.StatusOK, gin.H{"code": common.CodeDataError, "data": false, "message": "dataset_id, document_id and chunk_id are required"}) + return + } + + var req service.UpdateChunkRESTRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusOK, gin.H{"code": common.CodeDataError, "data": false, "message": err.Error()}) + return + } + + if err := h.chunkService.UpdateChunkREST(datasetID, documentID, chunkID, user.ID, &req); err != nil { + c.JSON(http.StatusOK, gin.H{"code": common.CodeDataError, "data": false, "message": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{"code": common.CodeSuccess, "data": true, "message": "success"}) +} + +// SwitchChunks bulk-toggles the available status for a list of chunks. +// @Summary Switch Chunks Availability +// @Description Toggle available_int for a set of chunk IDs. +// @Tags chunks +// @Accept json +// @Produce json +// @Param dataset_id path string true "Dataset ID" +// @Param document_id path string true "Document ID" +// @Param request body object true "chunk_ids + available" +// @Success 200 {object} map[string]interface{} +// @Router /api/v1/datasets/{dataset_id}/documents/{document_id}/chunks [patch] +func (h *ChunkHandler) SwitchChunks(c *gin.Context) { + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) + return + } + + datasetID := c.Param("dataset_id") + documentID := c.Param("document_id") + if datasetID == "" || documentID == "" { + c.JSON(http.StatusOK, gin.H{"code": common.CodeDataError, "data": false, "message": "dataset_id and document_id are required"}) + return + } + + var body struct { + ChunkIDs []string `json:"chunk_ids"` + Available *bool `json:"available"` + } + if err := c.ShouldBindJSON(&body); err != nil { + c.JSON(http.StatusOK, gin.H{"code": common.CodeDataError, "data": false, "message": err.Error()}) + return + } + if len(body.ChunkIDs) == 0 { + c.JSON(http.StatusOK, gin.H{"code": common.CodeDataError, "data": false, "message": "`chunk_ids` is required."}) + return + } + if body.Available == nil { + c.JSON(http.StatusOK, gin.H{"code": common.CodeDataError, "data": false, "message": "`available` is required."}) + return + } + + if err := h.chunkService.SwitchChunks(datasetID, documentID, user.ID, body.ChunkIDs, *body.Available); err != nil { + c.JSON(http.StatusOK, gin.H{"code": common.CodeDataError, "data": false, "message": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{"code": common.CodeSuccess, "data": true, "message": "success"}) +} + // RemoveChunks handles chunk removal requests // @Summary Remove Chunks // @Description Remove chunks from a document diff --git a/internal/handler/chunk_test.go b/internal/handler/chunk_test.go index 3b6efe2a824..d63774e3956 100644 --- a/internal/handler/chunk_test.go +++ b/internal/handler/chunk_test.go @@ -42,6 +42,18 @@ func (m *mockChunkSvc) UpdateChunk(*service.UpdateChunkRequest, string) error { func (m *mockChunkSvc) RemoveChunks(*service.RemoveChunksRequest, string) (int64, error) { panic("not implemented") } +func (m *mockChunkSvc) ListChunksREST(datasetID, documentID, userID string, page, pageSize int, keywords string, available *bool) (*service.ListChunksResponse, error) { + panic("not implemented") +} +func (m *mockChunkSvc) AddChunk(datasetID, documentID, userID string, req *service.AddChunkRequest) (map[string]interface{}, error) { + panic("not implemented") +} +func (m *mockChunkSvc) UpdateChunkREST(datasetID, documentID, chunkID, userID string, req *service.UpdateChunkRESTRequest) error { + panic("not implemented") +} +func (m *mockChunkSvc) SwitchChunks(datasetID, documentID, userID string, chunkIDs []string, available bool) error { + panic("not implemented") +} func setupChunkRetrievalTest(userID string) (*gin.Engine, *mockChunkSvc) { mock := &mockChunkSvc{} diff --git a/internal/router/router.go b/internal/router/router.go index 0940b08ea46..a89599eedc1 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -260,10 +260,14 @@ func (r *Router) Setup(engine *gin.Engine) { datasets.DELETE("/:dataset_id/documents", r.documentHandler.DeleteDocuments) // Dataset document chunk + datasets.GET("/:dataset_id/documents/:document_id/chunks", r.chunkHandler.ListChunksREST) + datasets.POST("/:dataset_id/documents/:document_id/chunks", r.chunkHandler.AddChunk) datasets.GET("/:dataset_id/documents/:document_id/chunks/:chunk_id", r.chunkHandler.Get) + datasets.PATCH("/:dataset_id/documents/:document_id/chunks/:chunk_id", r.chunkHandler.UpdateChunkREST) + datasets.PATCH("/:dataset_id/documents/:document_id/chunks", r.chunkHandler.SwitchChunks) + datasets.DELETE("/:dataset_id/documents/:document_id/chunks", r.chunkHandler.RemoveChunks) datasets.POST("/:dataset_id/documents/parse", r.documentHandler.ParseDocuments) datasets.POST("/:dataset_id/documents/stop", r.documentHandler.StopParseDocuments) - datasets.DELETE("/:dataset_id/documents/:document_id/chunks", r.chunkHandler.RemoveChunks) } // Search routes diff --git a/internal/service/chunk.go b/internal/service/chunk.go index e97625e1646..a4c1efc89ab 100644 --- a/internal/service/chunk.go +++ b/internal/service/chunk.go @@ -19,18 +19,21 @@ package service import ( "context" "fmt" - "ragflow/internal/common" - "ragflow/internal/entity" - "ragflow/internal/entity/models" - "ragflow/internal/server" "strconv" "strings" + "time" + "github.com/cespare/xxhash/v2" "go.uber.org/zap" + "gorm.io/gorm" + "ragflow/internal/common" "ragflow/internal/dao" "ragflow/internal/engine" "ragflow/internal/engine/types" + "ragflow/internal/entity" + "ragflow/internal/entity/models" + "ragflow/internal/server" "ragflow/internal/service/nlp" "ragflow/internal/tokenizer" "ragflow/internal/utility" @@ -973,3 +976,580 @@ func (s *ChunkService) RemoveChunks(req *RemoveChunksRequest, userID string) (in return deletedCount, nil } + +// ── REST chunk helpers ──────────────────────────────────────────────────────── + +// resolveDatasetAccess validates user access to a dataset and returns the +// dataset's owning tenantID plus the Knowledgebase record. +// It also optionally validates that a document belongs to the dataset. +func (s *ChunkService) resolveDatasetAccess(userID, datasetID string) (tenantID string, kb *entity.Knowledgebase, err error) { + tenants, err := s.userTenantDAO.GetByUserID(userID) + if err != nil { + return "", nil, fmt.Errorf("failed to get user tenants: %w", err) + } + for _, t := range tenants { + kb, err = s.kbDAO.GetByIDAndTenantID(datasetID, t.TenantID) + if err == nil && kb != nil { + return t.TenantID, kb, nil + } + } + return "", nil, fmt.Errorf("You don't own the dataset %s.", datasetID) +} + +// getEmbeddingModelForKB resolves the embedding model for a knowledge base. +func (s *ChunkService) getEmbeddingModelForKB(kb *entity.Knowledgebase, tenantID string) (*models.EmbeddingModel, error) { + tenantLLMDAO := dao.NewTenantLLMDAO() + modelProviderSvc := NewModelProviderService() + + var embdID string + var err error + if kb.TenantEmbdID != nil && *kb.TenantEmbdID > 0 { + _, embdID, err = dao.LookupTenantLLMByID(tenantLLMDAO, *kb.TenantEmbdID) + } else if kb.EmbdID != "" { + // Mirror Python TenantLLMService.split_model_name_and_factory: the factory + // is the segment after the LAST "@", so model names that themselves contain + // "@" (e.g. "Qwen/Qwen3-Embedding-8B@test@SILICONFLOW") resolve correctly. + name, factory := kb.EmbdID, "" + if idx := strings.LastIndex(kb.EmbdID, "@"); idx >= 0 { + name, factory = kb.EmbdID[:idx], kb.EmbdID[idx+1:] + } + if factory != "" { + _, embdID, err = dao.LookupTenantLLMByFactory(tenantLLMDAO, tenantID, factory, name, entity.ModelTypeEmbedding) + } else { + _, embdID, err = dao.LookupTenantLLMByName(tenantLLMDAO, tenantID, name, entity.ModelTypeEmbedding) + } + } + if err != nil { + return nil, fmt.Errorf("failed to resolve embedding model: %w", err) + } + if embdID == "" { + return nil, fmt.Errorf("no embedding model configured for dataset") + } + return modelProviderSvc.GetEmbeddingModel(tenantID, embdID) +} + +// embedTexts calls the embedding model and returns the raw float64 slices plus +// an approximate token count for the embedded input. +// +// The embedding driver does not surface provider token usage (EmbeddingData has +// no token field), so the count is derived from the input text via the project +// tokenizer rather than from the embedding vector dimensionality (which is a +// fixed model property unrelated to consumption). +func embedTexts(em *models.EmbeddingModel, texts []string) ([][]float64, int, error) { + resp, err := em.ModelDriver.Embed(em.ModelName, texts, em.APIConfig, nil) + if err != nil { + return nil, 0, err + } + vecs := make([][]float64, len(resp)) + for i, d := range resp { + vecs[i] = d.Embedding + } + tokenCount := 0 + for _, t := range texts { + tokenCount += estimateTokenCount(t) + } + return vecs, tokenCount, nil +} + +// estimateTokenCount approximates the number of tokens in text. It tokenizes the +// text with the project tokenizer (which segments CJK and splits terms) and +// counts the resulting terms; on failure it falls back to a rune-based estimate +// of roughly one token per four characters. +func estimateTokenCount(text string) int { + if strings.TrimSpace(text) == "" { + return 0 + } + if toks, err := tokenizer.Tokenize(text); err == nil && toks != "" { + return len(strings.Fields(toks)) + } + return (len([]rune(text)) + 3) / 4 +} + +// derefString safely dereferences a *string, returning "" when nil. +func derefString(p *string) string { + if p == nil { + return "" + } + return *p +} + +// mapDocRun maps a document's run-status code to its label, mirroring Python's +// run_mapping in chunk_api._map_doc. Unknown/nil codes map to an empty string. +func mapDocRun(run *string) string { + if run == nil { + return "" + } + switch strings.TrimSpace(*run) { + case "0": + return "UNSTART" + case "1": + return "RUNNING" + case "2": + return "CANCEL" + case "3": + return "DONE" + case "4": + return "FAIL" + default: + return "" + } +} + +// weightedVec returns 0.1*a + 0.9*b (doc_name weight vs content weight). +func weightedVec(a, b []float64) []float64 { + n := len(a) + if len(b) < n { + n = len(b) + } + out := make([]float64, n) + for i := range out { + out[i] = 0.1*a[i] + 0.9*b[i] + } + return out +} + +// ── ListChunksREST ──────────────────────────────────────────────────────────── + +// ListChunksREST mirrors Python GET /datasets/:dataset_id/documents/:document_id/chunks. +// dataset_id and document_id are path params; validation is ownership-based. +func (s *ChunkService) ListChunksREST(datasetID, documentID, userID string, page, pageSize int, keywords string, available *bool) (*ListChunksResponse, error) { + if s.docEngine == nil { + return nil, fmt.Errorf("doc engine not initialized") + } + + tenantID, _, err := s.resolveDatasetAccess(userID, datasetID) + if err != nil { + return nil, err + } + + // Verify document belongs to dataset. + doc, err := s.documentDAO.GetByID(documentID) + if err != nil || doc == nil { + return nil, fmt.Errorf("You don't own the document %s.", documentID) + } + if doc.KbID != datasetID { + return nil, fmt.Errorf("You don't own the document %s.", documentID) + } + + ctx := context.Background() + indexName := fmt.Sprintf("ragflow_%s", tenantID) + + searchReq := &types.SearchRequest{ + IndexNames: []string{indexName}, + MatchExprs: []interface{}{keywords}, + KbIDs: []string{datasetID}, + Offset: (page - 1) * pageSize, + Limit: pageSize, + Filter: map[string]interface{}{"doc_id": documentID}, + } + if available != nil { + avInt := 0 + if *available { + avInt = 1 + } + searchReq.Filter["available_int"] = avInt + } + + searchResp, err := s.docEngine.Search(ctx, searchReq) + if err != nil { + return nil, fmt.Errorf("search failed: %w", err) + } + + chunks := make([]map[string]interface{}, 0, len(searchResp.Chunks)) + for _, chunk := range searchResp.Chunks { + result := map[string]interface{}{ + "id": chunk["id"], + "content": chunk["content_with_weight"], + "document_id": chunk["doc_id"], + "docnm_kwd": chunk["docnm_kwd"], + "important_keywords": orSlice(chunk["important_kwd"]), + "questions": orSlice(chunk["question_kwd"]), + "tag_kwd": orSlice(chunk["tag_kwd"]), + "dataset_id": datasetID, + "image_id": orStr(chunk["img_id"]), + "available": intToBool(chunk["available_int"]), + "positions": orSlice(chunk["position_int"]), + } + chunks = append(chunks, result) + } + + timeFormat := "2006-01-02T15:04:05" + // Mirror Python chunk_api._map_doc: return the full document with the SDK key + // renames (kb_id→dataset_id, chunk_num→chunk_count, token_num→token_count, + // parser_id→chunk_method) and the run-status label mapping, so the frontend + // receives every field it expects. + docInfo := map[string]interface{}{ + "id": doc.ID, + "thumbnail": doc.Thumbnail, + "dataset_id": doc.KbID, + "chunk_method": doc.ParserID, + "pipeline_id": doc.PipelineID, + "parser_config": doc.ParserConfig, + "source_type": doc.SourceType, + "type": doc.Type, + "created_by": doc.CreatedBy, + "name": doc.Name, + "location": doc.Location, + "size": doc.Size, + "token_count": doc.TokenNum, + "chunk_count": doc.ChunkNum, + "progress": utility.JSONFloat64(doc.Progress), + "progress_msg": doc.ProgressMsg, + "process_begin_at": utility.FormatTimeToString(doc.ProcessBeginAt, timeFormat), + "process_duration": doc.ProcessDuration, + "content_hash": doc.ContentHash, + "meta_fields": doc.MetaFields, + "suffix": doc.Suffix, + "run": mapDocRun(doc.Run), + "status": doc.Status, + "create_time": doc.CreateTime, + "create_date": utility.FormatTimeToString(doc.CreateDate, timeFormat), + "update_time": doc.UpdateTime, + "update_date": utility.FormatTimeToString(doc.UpdateDate, timeFormat), + } + + return &ListChunksResponse{ + Total: searchResp.Total, + Chunks: chunks, + Doc: docInfo, + }, nil +} + +func orSlice(v interface{}) interface{} { + if v == nil { + return []interface{}{} + } + return v +} + +func orStr(v interface{}) string { + if s, ok := v.(string); ok { + return s + } + return "" +} + +func intToBool(v interface{}) bool { + switch t := v.(type) { + case int: + return t != 0 + case int64: + return t != 0 + case float64: + return t != 0 + case string: + return t != "0" && t != "" + } + return true // default available +} + +// ── AddChunk ───────────────────────────────────────────────────────────────── + +// AddChunkRequest mirrors the Python add_chunk body. +type AddChunkRequest struct { + Content string `json:"content"` + ImportantKeywords []string `json:"important_keywords"` + Questions []string `json:"questions"` + TagKwd []string `json:"tag_kwd"` + TagFeas map[string]interface{} `json:"tag_feas"` +} + +// AddChunk mirrors Python POST /datasets/:dataset_id/documents/:document_id/chunks. +func (s *ChunkService) AddChunk(datasetID, documentID, userID string, req *AddChunkRequest) (map[string]interface{}, error) { + if s.docEngine == nil { + return nil, fmt.Errorf("doc engine not initialized") + } + if strings.TrimSpace(req.Content) == "" { + return nil, fmt.Errorf("`content` is required") + } + + tenantID, kb, err := s.resolveDatasetAccess(userID, datasetID) + if err != nil { + return nil, err + } + + doc, err := s.documentDAO.GetByID(documentID) + if err != nil || doc == nil || doc.KbID != datasetID { + return nil, fmt.Errorf("You don't own the document %s.", documentID) + } + docName := derefString(doc.Name) + + // Deterministic chunk ID: xxhash64(content + document_id). + chunkID := fmt.Sprintf("%x", xxhash.Sum64String(req.Content+documentID)) + + // Tokenize content. + contentLtks, _ := tokenizer.Tokenize(req.Content) + contentSmLtks, _ := tokenizer.FineGrainedTokenize(contentLtks) + + // Build questions list (trimmed, non-empty). + questions := make([]string, 0, len(req.Questions)) + for _, q := range req.Questions { + if q = strings.TrimSpace(q); q != "" { + questions = append(questions, q) + } + } + importantKwd := req.ImportantKeywords + if importantKwd == nil { + importantKwd = []string{} + } + + now := time.Now() + d := map[string]interface{}{ + "id": chunkID, + "content_ltks": contentLtks, + "content_sm_ltks": contentSmLtks, + "content_with_weight": req.Content, + "important_kwd": importantKwd, + "important_tks": strings.Join(importantKwd, " "), + "question_kwd": questions, + "question_tks": strings.Join(questions, "\n"), + "create_time": now.Format("2006-01-02 15:04:05"), + "create_timestamp_flt": float64(now.Unix()), + "kb_id": datasetID, + "docnm_kwd": docName, + "doc_id": documentID, + "available_int": 1, + } + if len(req.TagKwd) > 0 { + d["tag_kwd"] = req.TagKwd + } + if req.TagFeas != nil { + d["tag_feas"] = req.TagFeas + } + + // Compute embedding: 0.1 * embed(doc.name) + 0.9 * embed(content). + em, err := s.getEmbeddingModelForKB(kb, tenantID) + if err != nil { + return nil, fmt.Errorf("failed to get embedding model: %w", err) + } + embedInput := req.Content + if len(questions) > 0 { + embedInput = strings.Join(questions, "\n") + } + vecs, tokenCount, err := embedTexts(em, []string{docName, embedInput}) + if err != nil { + return nil, fmt.Errorf("embedding failed: %w", err) + } + if len(vecs) >= 2 { + vec := weightedVec(vecs[0], vecs[1]) + d[fmt.Sprintf("q_%d_vec", len(vec))] = vec + } + + // Insert into document store. + indexName := fmt.Sprintf("ragflow_%s", tenantID) + if _, err := s.docEngine.InsertChunks(ctx_bg(), []map[string]interface{}{d}, indexName, datasetID); err != nil { + return nil, fmt.Errorf("failed to insert chunk: %w", err) + } + + // Increment document chunk_num and token_num. + _ = s.documentDAO.UpdateByID(documentID, map[string]interface{}{ + "chunk_num": gorm.Expr("chunk_num + 1"), + "token_num": gorm.Expr("token_num + ?", tokenCount), + }) + + // Build response matching Python key_mapping. + renamed := map[string]interface{}{ + "id": chunkID, + "content": req.Content, + "document_id": documentID, + "important_keywords": importantKwd, + "questions": questions, + "dataset_id": datasetID, + "create_timestamp": float64(now.Unix()), + "create_time": d["create_time"], + } + if len(req.TagKwd) > 0 { + renamed["tag_kwd"] = req.TagKwd + } + return map[string]interface{}{"chunk": renamed}, nil +} + +// ctx_bg returns a background context (avoids referencing context.Background everywhere). +func ctx_bg() context.Context { return context.Background() } + +// ── UpdateChunkREST ─────────────────────────────────────────────────────────── + +// UpdateChunkRESTRequest mirrors the Python PATCH body. +type UpdateChunkRESTRequest struct { + Content *string `json:"content"` + ImportantKeywords []string `json:"important_keywords"` + Questions []string `json:"questions"` + Available *bool `json:"available"` + Positions []interface{} `json:"positions"` + TagKwd []string `json:"tag_kwd"` + TagFeas map[string]interface{} `json:"tag_feas"` +} + +// UpdateChunkREST mirrors Python PATCH /datasets/:dataset_id/documents/:document_id/chunks/:chunk_id. +// Like the existing UpdateChunk but with re-embedding on content change. +func (s *ChunkService) UpdateChunkREST(datasetID, documentID, chunkID, userID string, req *UpdateChunkRESTRequest) error { + if s.docEngine == nil { + return fmt.Errorf("doc engine not initialized") + } + + tenantID, kb, err := s.resolveDatasetAccess(userID, datasetID) + if err != nil { + return err + } + + doc, err := s.documentDAO.GetByID(documentID) + if err != nil || doc == nil || doc.KbID != datasetID { + return fmt.Errorf("You don't own the document %s.", documentID) + } + docName := derefString(doc.Name) + + ctx := context.Background() + indexName := fmt.Sprintf("ragflow_%s", tenantID) + + // Get existing chunk. + rawChunk, err := s.docEngine.GetChunk(ctx, indexName, chunkID, []string{datasetID}) + if err != nil || rawChunk == nil { + return fmt.Errorf("Can't find this chunk %s", chunkID) + } + existing, ok := rawChunk.(map[string]interface{}) + if !ok { + return fmt.Errorf("Can't find this chunk %s", chunkID) + } + existingDocID := "" + if v, ok := existing["doc_id"].(string); ok { + existingDocID = v + } else if v, ok := existing["document_id"].(string); ok { + existingDocID = v + } + if existingDocID != documentID { + return fmt.Errorf("Can't find this chunk %s", chunkID) + } + + // Determine content. + content := "" + if req.Content != nil { + if strings.TrimSpace(*req.Content) == "" { + return fmt.Errorf("`content` is required") + } + content = *req.Content + } else { + if v, ok := existing["content_with_weight"].(string); ok { + content = v + } else if v, ok := existing["content"].(string); ok { + content = v + } + } + + // Tokenize. + contentLtks, _ := tokenizer.Tokenize(content) + contentSmLtks, _ := tokenizer.FineGrainedTokenize(contentLtks) + + d := map[string]interface{}{ + "id": chunkID, + "content_with_weight": content, + "content_ltks": contentLtks, + "content_sm_ltks": contentSmLtks, + } + + if req.ImportantKeywords != nil { + d["important_kwd"] = req.ImportantKeywords + d["important_tks"] = strings.Join(req.ImportantKeywords, " ") + } + + questions := []string{} + if req.Questions != nil { + for _, q := range req.Questions { + if q = strings.TrimSpace(q); q != "" { + questions = append(questions, q) + } + } + d["question_kwd"] = questions + d["question_tks"] = strings.Join(questions, "\n") + } + + if req.Available != nil { + avInt := 0 + if *req.Available { + avInt = 1 + } + d["available_int"] = avInt + } + if req.Positions != nil { + d["position_int"] = req.Positions + } + if req.TagKwd != nil { + d["tag_kwd"] = req.TagKwd + } + if req.TagFeas != nil { + d["tag_feas"] = req.TagFeas + } + + // Re-embed when content or questions changed. + if req.Content != nil || req.Questions != nil { + em, err := s.getEmbeddingModelForKB(kb, tenantID) + if err != nil { + return fmt.Errorf("failed to get embedding model: %w", err) + } + embedInput := content + if len(questions) > 0 { + embedInput = strings.Join(questions, "\n") + } + vecs, _, err := embedTexts(em, []string{docName, embedInput}) + if err != nil { + return fmt.Errorf("embedding failed: %w", err) + } + if len(vecs) >= 2 { + vec := weightedVec(vecs[0], vecs[1]) + d[fmt.Sprintf("q_%d_vec", len(vec))] = vec + } + } + + condition := map[string]interface{}{"id": chunkID} + return s.docEngine.UpdateChunks(ctx, condition, d, indexName, datasetID) +} + +// ── SwitchChunks ───────────────────────────────────────────────────────────── + +// SwitchChunks mirrors Python PATCH /datasets/:dataset_id/documents/:document_id/chunks +// (without chunk_id) — bulk toggle of available_int. +func (s *ChunkService) SwitchChunks(datasetID, documentID, userID string, chunkIDs []string, available bool) error { + if s.docEngine == nil { + return fmt.Errorf("doc engine not initialized") + } + if len(chunkIDs) == 0 { + return fmt.Errorf("`chunk_ids` is required.") + } + + tenantID, _, err := s.resolveDatasetAccess(userID, datasetID) + if err != nil { + return err + } + + // Mirror Python: verify the document belongs to the dataset before touching + // the index. + doc, err := s.documentDAO.GetByID(documentID) + if err != nil || doc == nil || doc.KbID != datasetID { + return fmt.Errorf("Document not found!") + } + + ctx := context.Background() + indexName := fmt.Sprintf("ragflow_%s", tenantID) + availInt := 0 + if available { + availInt = 1 + } + + // Update each chunk's available_int. Python's docStoreConn.update returns False + // for a non-existent chunk id, surfacing as "Index updating failure" (code 102); + // a blind update would otherwise report a false-positive success. Confirm the + // chunk exists, then update. + for _, chunkID := range chunkIDs { + existing, gerr := s.docEngine.GetChunk(ctx, indexName, chunkID, []string{datasetID}) + if gerr != nil || existing == nil { + return fmt.Errorf("Index updating failure") + } + condition := map[string]interface{}{"id": chunkID} + update := map[string]interface{}{"available_int": availInt} + if err := s.docEngine.UpdateChunks(ctx, condition, update, indexName, datasetID); err != nil { + common.Warn("SwitchChunks: failed to update chunk", zap.String("chunkID", chunkID), zap.Error(err)) + return fmt.Errorf("Index updating failure") + } + } + return nil +} +