diff --git a/mem0-ts/src/oss/src/embeddings/lmstudio.ts b/mem0-ts/src/oss/src/embeddings/lmstudio.ts index ff33733c48..509a0a0dae 100644 --- a/mem0-ts/src/oss/src/embeddings/lmstudio.ts +++ b/mem0-ts/src/oss/src/embeddings/lmstudio.ts @@ -44,7 +44,9 @@ export class LMStudioEmbedder implements Embedder { input: normalized, encoding_format: "float", }); - return response.data.map((item) => item.embedding); + return response.data + .sort((a, b) => a.index - b.index) + .map((item) => item.embedding); } catch (err) { const message = err instanceof Error ? err.message : String(err); throw new Error(`LM Studio embedder failed: ${message}`); diff --git a/mem0-ts/src/oss/src/memory/index.ts b/mem0-ts/src/oss/src/memory/index.ts index e175e71ff4..ab36593a87 100644 --- a/mem0-ts/src/oss/src/memory/index.ts +++ b/mem0-ts/src/oss/src/memory/index.ts @@ -1169,17 +1169,35 @@ export class Memory { if (deduped.length > 0) { const entityStore = await this.getEntityStore(); + const entitySearchFilters: Record = {}; + for (const k of ["user_id", "agent_id", "run_id"] as const) { + if (effectiveFilters[k]) + entitySearchFilters[k] = effectiveFilters[k]; + } + const entityTexts = deduped.map((e) => e.text); + const embeddings = await this.embedder.embedBatch(entityTexts); - for (const entity of deduped) { - try { - const entityEmbedding = await this.embedder.embed(entity.text); - const matches = await entityStore.search( - entityEmbedding, - 500, - effectiveFilters, - ); + if (embeddings.length !== entityTexts.length) { + console.warn( + `embedBatch returned ${embeddings.length} vectors for ${entityTexts.length} texts — skipping entity boost`, + ); + } else { + const searchResults = await Promise.allSettled( + deduped.map((_, i) => + entityStore.search(embeddings[i], 500, entitySearchFilters), + ), + ); + + for (const result of searchResults) { + if (result.status === "rejected") { + console.warn( + "Entity boost search failed for one entity:", + result.reason, + ); + continue; + } - for (const match of matches) { + for (const match of result.value) { const similarity = match.score ?? 0; if (similarity < 0.5) continue; @@ -1187,7 +1205,6 @@ export class Memory { const linkedMemoryIds = payload.linkedMemoryIds ?? []; if (!Array.isArray(linkedMemoryIds)) continue; - // Spread-attenuated boost const numLinked = Math.max(linkedMemoryIds.length, 1); const memoryCountWeight = 1.0 / (1.0 + 0.001 * (numLinked - 1) ** 2); @@ -1204,8 +1221,6 @@ export class Memory { } } } - } catch (e) { - // Individual entity boost failed — continue } } } diff --git a/mem0-ts/src/oss/tests/memory.entity-boost.test.ts b/mem0-ts/src/oss/tests/memory.entity-boost.test.ts new file mode 100644 index 0000000000..ce10c276a3 --- /dev/null +++ b/mem0-ts/src/oss/tests/memory.entity-boost.test.ts @@ -0,0 +1,276 @@ +/** + * Entity boost parallelism tests (#5214). + * + * Verifies that entity boost searches run concurrently via Promise.allSettled, + * scoring is preserved, and individual entity failures don't abort others. + */ +/// +import { Memory } from "../src/memory"; +import { ENTITY_BOOST_WEIGHT } from "../src/utils/scoring"; +import type { VectorStoreResult } from "../src/types"; + +jest.setTimeout(15000); + +jest.mock("../src/embeddings/google", () => ({ + GoogleEmbedder: jest.fn(), +})); +jest.mock("../src/llms/google", () => ({ + GoogleLLM: jest.fn(), +})); + +jest.mock("../src/llms/openai", () => ({ + OpenAILLM: jest.fn().mockImplementation(() => ({ + generateResponse: jest.fn().mockResolvedValue( + JSON.stringify({ + memory: [{ id: "0", text: "fact", attributed_to: "user" }], + }), + ), + })), +})); + +const mockEmbedding = new Array(1536).fill(0.1); +jest.mock("../src/embeddings/openai", () => ({ + OpenAIEmbedder: jest.fn().mockImplementation(() => ({ + embed: jest.fn().mockResolvedValue(mockEmbedding), + embedBatch: jest + .fn() + .mockImplementation((texts: string[]) => + Promise.resolve(texts.map(() => mockEmbedding)), + ), + embeddingDims: 1536, + })), +})); + +function makeMatch( + id: string, + score: number, + linkedMemoryIds: string[], +): VectorStoreResult { + return { id, score, payload: { linkedMemoryIds } }; +} + +function createMemory(): Memory { + return new Memory({ + version: "v1.1", + embedder: { + provider: "openai", + config: { apiKey: "test-key", model: "text-embedding-3-small" }, + }, + vectorStore: { + provider: "memory", + config: { + collectionName: `test-entity-${Date.now()}-${Math.random()}`, + dimension: 1536, + dbPath: ":memory:", + }, + }, + llm: { + provider: "openai", + config: { apiKey: "test-key", model: "gpt-5-mini" }, + }, + historyDbPath: ":memory:", + }); +} + +describe("Entity boost parallelism (#5214)", () => { + let memory: Memory; + + beforeEach(() => { + memory = createMemory(); + }); + + afterEach(async () => { + await memory.reset(); + }); + + it("should use Promise.allSettled for concurrent entity searches", async () => { + // Spy on Promise.allSettled to confirm it's being used + const allSettledSpy = jest.spyOn(Promise, "allSettled"); + + // Access internals to inject a mock entity store + const m = memory as any; + await m._ensureInitialized(); + + const mockEntityStore = { + search: jest.fn().mockResolvedValue([makeMatch("e1", 0.9, ["mem-1"])]), + initialize: jest.fn().mockResolvedValue(undefined), + }; + m._entityStore = mockEntityStore; + + m.embedder = { + embed: jest.fn().mockResolvedValue(mockEmbedding), + embedBatch: jest + .fn() + .mockImplementation((texts: string[]) => + Promise.resolve(texts.map(() => mockEmbedding)), + ), + }; + + // Mock the vector store to return a semantic result + m.vectorStore.search = jest + .fn() + .mockResolvedValue([ + { id: "mem-1", score: 0.8, payload: { data: "test" } }, + ]); + m.vectorStore.keywordSearch = jest.fn().mockResolvedValue(null); + + await m.search("alice and bob", { filters: { user_id: "u1" } }); + + expect(allSettledSpy).toHaveBeenCalled(); + allSettledSpy.mockRestore(); + }); + + it("should preserve scoring math with parallel execution", async () => { + const m = memory as any; + await m._ensureInitialized(); + + // Two entities: "alice" links to mem-1, "bob" links to mem-1 and mem-2 + // mem-1 should get max(alice_boost, bob_boost) + const mockEntityStore = { + search: jest + .fn() + .mockImplementation( + (_embedding: number[], _topK: number, _filters: any) => { + // We need to differentiate by embedding content — but since all + // embeddings are identical mocks, we'll use call order + const callCount = mockEntityStore.search.mock.calls.length; + if (callCount <= 1) { + // First entity: "alice" + return Promise.resolve([makeMatch("e-alice", 0.9, ["mem-1"])]); + } + // Second entity: "bob" + return Promise.resolve([ + makeMatch("e-bob", 0.6, ["mem-1", "mem-2"]), + ]); + }, + ), + initialize: jest.fn().mockResolvedValue(undefined), + }; + m._entityStore = mockEntityStore; + m.embedder = { + embed: jest.fn().mockResolvedValue(mockEmbedding), + embedBatch: jest + .fn() + .mockImplementation((texts: string[]) => + Promise.resolve(texts.map(() => mockEmbedding)), + ), + }; + + // Semantic results include mem-1 and mem-2 + m.vectorStore.search = jest.fn().mockResolvedValue([ + { id: "mem-1", score: 0.85, payload: { data: "alice memory" } }, + { id: "mem-2", score: 0.75, payload: { data: "bob memory" } }, + ]); + m.vectorStore.keywordSearch = jest.fn().mockResolvedValue(null); + + const result = await m.search("alice and bob", { + filters: { user_id: "u1" }, + }); + + // Entity store was called (parallelized via Promise.allSettled) + expect(mockEntityStore.search).toHaveBeenCalled(); + + // Results should exist and have scores + expect(result.results.length).toBeGreaterThan(0); + for (const item of result.results) { + expect(typeof item.score).toBe("number"); + expect(item.score).toBeGreaterThan(0); + } + }); + + it("should survive one entity search failure without losing other boosts", async () => { + const m = memory as any; + await m._ensureInitialized(); + + let callIndex = 0; + const mockEntityStore = { + search: jest.fn().mockImplementation(() => { + callIndex++; + if (callIndex === 1) { + return Promise.reject(new Error("provider timeout")); + } + return Promise.resolve([makeMatch("e-ok", 0.8, ["mem-9"])]); + }), + initialize: jest.fn().mockResolvedValue(undefined), + }; + m._entityStore = mockEntityStore; + m.embedder = { + embed: jest.fn().mockResolvedValue(mockEmbedding), + embedBatch: jest + .fn() + .mockImplementation((texts: string[]) => + Promise.resolve(texts.map(() => mockEmbedding)), + ), + }; + m.vectorStore.search = jest + .fn() + .mockResolvedValue([ + { id: "mem-9", score: 0.85, payload: { data: "surviving memory" } }, + ]); + m.vectorStore.keywordSearch = jest.fn().mockResolvedValue(null); + + const warnSpy = jest.spyOn(console, "warn").mockImplementation(() => {}); + + // "John Smith met Jane Doe" extracts two proper entities + const result = await m.search("John Smith met Jane Doe", { + filters: { user_id: "u1" }, + }); + + expect(result.results.length).toBeGreaterThan(0); + expect(result.results[0].id).toBe("mem-9"); + // Should log the failure like Python does + expect(warnSpy).toHaveBeenCalledWith( + "Entity boost search failed for one entity:", + expect.any(Error), + ); + warnSpy.mockRestore(); + }); + + it("should call entity searches concurrently, not sequentially", async () => { + const m = memory as any; + await m._ensureInitialized(); + + const concurrency = { current: 0, peak: 0 }; + + const mockEntityStore = { + search: jest.fn().mockImplementation(() => { + concurrency.current++; + concurrency.peak = Math.max(concurrency.peak, concurrency.current); + return new Promise((resolve) => { + setTimeout(() => { + concurrency.current--; + resolve([makeMatch("e1", 0.7, ["mem-1"])]); + }, 100); + }); + }), + initialize: jest.fn().mockResolvedValue(undefined), + }; + m._entityStore = mockEntityStore; + m.embedder = { + embed: jest.fn().mockResolvedValue(mockEmbedding), + embedBatch: jest + .fn() + .mockImplementation((texts: string[]) => + Promise.resolve(texts.map(() => mockEmbedding)), + ), + }; + m.vectorStore.search = jest + .fn() + .mockResolvedValue([ + { id: "mem-1", score: 0.8, payload: { data: "test" } }, + ]); + m.vectorStore.keywordSearch = jest.fn().mockResolvedValue(null); + + const start = performance.now(); + await m.search("entity1 and entity2 and entity3 and entity4", { + filters: { user_id: "u1" }, + }); + const elapsed = performance.now() - start; + + // With 4 entities at 100ms each, sequential would be ~400ms+. + // Parallel should be well under that. Use generous bound for CI. + expect(elapsed).toBeLessThan(500); + // At least 2 searches should have overlapped + expect(concurrency.peak).toBeGreaterThanOrEqual(2); + }); +}); diff --git a/mem0/memory/main.py b/mem0/memory/main.py index ec9dd202e4..5a999cd9a3 100644 --- a/mem0/memory/main.py +++ b/mem0/memory/main.py @@ -1,4 +1,5 @@ import asyncio +import concurrent.futures import gc import hashlib import json @@ -1464,34 +1465,55 @@ def _compute_entity_boosts(self, query_entities, filters): memory_boosts = {} try: - for _, entity_text in deduped: - entity_embedding = self.embedding_model.embed(entity_text, "search") - matches = self.entity_store.search( - query=entity_text, - vectors=entity_embedding, - top_k=500, - filters=search_filters, + entity_texts = [text for _, text in deduped] + embeddings = self.embedding_model.embed_batch(entity_texts, "search") + + if len(embeddings) != len(entity_texts): + logger.warning( + "embed_batch returned %d vectors for %d texts — skipping entity boost", + len(embeddings), + len(entity_texts), ) + return memory_boosts - for match in matches: - similarity = match.score if hasattr(match, 'score') else 0.0 - if similarity < 0.5: - continue + entity_store = self.entity_store - payload = match.payload if hasattr(match, 'payload') else {} - linked_memory_ids = payload.get("linked_memory_ids", []) - if not isinstance(linked_memory_ids, list): + def _search_entity(entity_text, embedding): + return entity_store.search( + query=entity_text, vectors=embedding, top_k=500, filters=search_filters + ) + + with concurrent.futures.ThreadPoolExecutor(max_workers=4) as pool: + futures = { + pool.submit(_search_entity, text, emb): text + for text, emb in zip(entity_texts, embeddings) + } + + for future in concurrent.futures.as_completed(futures): + try: + matches = future.result() + except Exception as e: + logger.warning("Entity boost search failed for one entity: %s", e) continue - # Spread-attenuated boost: entities linking to many memories get attenuated - num_linked = max(len(linked_memory_ids), 1) - memory_count_weight = 1.0 / (1.0 + 0.001 * ((num_linked - 1) ** 2)) - boost = similarity * ENTITY_BOOST_WEIGHT * memory_count_weight + for match in matches: + similarity = match.score if hasattr(match, 'score') else 0.0 + if similarity < 0.5: + continue - for memory_id in linked_memory_ids: - if memory_id: - memory_key = str(memory_id) - memory_boosts[memory_key] = max(memory_boosts.get(memory_key, 0.0), boost) + payload = match.payload if hasattr(match, 'payload') else {} + linked_memory_ids = payload.get("linked_memory_ids", []) + if not isinstance(linked_memory_ids, list): + continue + + num_linked = max(len(linked_memory_ids), 1) + memory_count_weight = 1.0 / (1.0 + 0.001 * ((num_linked - 1) ** 2)) + boost = similarity * ENTITY_BOOST_WEIGHT * memory_count_weight + + for memory_id in linked_memory_ids: + if memory_id: + memory_key = str(memory_id) + memory_boosts[memory_key] = max(memory_boosts.get(memory_key, 0.0), boost) except Exception as e: logger.warning(f"Entity boost computation failed: {e}") @@ -2872,15 +2894,38 @@ async def _compute_entity_boosts_async(self, query_entities, filters): memory_boosts = {} try: - for _, entity_text in deduped: - entity_embedding = await asyncio.to_thread(self.embedding_model.embed, entity_text, "search") - matches = await asyncio.to_thread( - self.entity_store.search, - query=entity_text, - vectors=entity_embedding, - top_k=500, - filters=search_filters, + entity_texts = [text for _, text in deduped] + embeddings = await asyncio.to_thread(self.embedding_model.embed_batch, entity_texts, "search") + + if len(embeddings) != len(entity_texts): + logger.warning( + "embed_batch returned %d vectors for %d texts — skipping entity boost", + len(embeddings), + len(entity_texts), ) + return memory_boosts + + sem = asyncio.Semaphore(4) + + async def _search_entity(entity_text, embedding): + async with sem: + return await asyncio.to_thread( + self.entity_store.search, + query=entity_text, + vectors=embedding, + top_k=500, + filters=search_filters, + ) + + results = await asyncio.gather( + *(_search_entity(text, emb) for text, emb in zip(entity_texts, embeddings)), + return_exceptions=True, + ) + + for matches in results: + if isinstance(matches, BaseException): + logger.warning("Entity boost search failed for one entity: %s", matches) + continue for match in matches: similarity = match.score if hasattr(match, 'score') else 0.0 diff --git a/tests/memory/test_main.py b/tests/memory/test_main.py index 3631162a33..677e7fe4c1 100644 --- a/tests/memory/test_main.py +++ b/tests/memory/test_main.py @@ -1,5 +1,7 @@ import logging +import time from datetime import datetime, timezone +from types import SimpleNamespace from unittest.mock import MagicMock, Mock import pytest @@ -663,3 +665,195 @@ async def test_async_update_preserves_actor_id_when_different_actor_updates(mock assert stored["actor_id"] == "Alice" +def _make_match(score, linked_memory_ids): + return SimpleNamespace(score=score, payload={"linked_memory_ids": linked_memory_ids}) + + +class TestEntityBoostParallelism: + """Tests for parallelized entity boost searches (#5214).""" + + @pytest.fixture + def mock_memory(self, mocker): + _setup_mocks(mocker) + return Memory() + + @pytest.fixture + def mock_async_memory(self, mocker): + _setup_mocks(mocker) + return AsyncMemory() + + def test_sync_boosts_preserve_scoring(self, mock_memory): + from mem0.utils.scoring import ENTITY_BOOST_WEIGHT + + mock_memory.embedding_model = Mock() + mock_memory.embedding_model.embed_batch = Mock(return_value=[[0.1, 0.2, 0.3], [0.1, 0.2, 0.3]]) + + results_by_query = { + "alice": [_make_match(0.9, ["mem-1"])], + "bob": [_make_match(0.6, ["mem-1", "mem-2"])], + } + + def fake_search(query, vectors, top_k, filters): + return results_by_query[query] + + mock_memory._entity_store = Mock() + mock_memory._entity_store.search = Mock(side_effect=fake_search) + + boosts = mock_memory._compute_entity_boosts( + [("person", "alice"), ("person", "bob")], + {"user_id": "u1"}, + ) + + boost_alice = 0.9 * ENTITY_BOOST_WEIGHT * (1.0 / (1.0 + 0.001 * (0**2))) + boost_bob = 0.6 * ENTITY_BOOST_WEIGHT * (1.0 / (1.0 + 0.001 * (1**2))) + assert boosts["mem-1"] == pytest.approx(max(boost_alice, boost_bob)) + assert boosts["mem-2"] == pytest.approx(boost_bob) + + def test_sync_embed_batch_called_once(self, mock_memory): + mock_memory.embedding_model = Mock() + mock_memory.embedding_model.embed_batch = Mock(return_value=[[0.1], [0.1], [0.1]]) + mock_memory._entity_store = Mock() + mock_memory._entity_store.search = Mock(return_value=[_make_match(0.7, ["mem-1"])]) + + mock_memory._compute_entity_boosts( + [("person", "alice"), ("person", "bob"), ("person", "carol")], + {"user_id": "u1"}, + ) + + mock_memory.embedding_model.embed_batch.assert_called_once_with(["alice", "bob", "carol"], "search") + + @pytest.mark.asyncio + async def test_async_boosts_preserve_scoring(self, mock_async_memory): + from mem0.utils.scoring import ENTITY_BOOST_WEIGHT + + mock_async_memory.embedding_model = Mock() + mock_async_memory.embedding_model.embed_batch = Mock(return_value=[[0.1, 0.2, 0.3], [0.1, 0.2, 0.3]]) + + results_by_query = { + "alice": [_make_match(0.9, ["mem-1"])], + "bob": [_make_match(0.6, ["mem-1", "mem-2"])], + } + + def fake_search(query, vectors, top_k, filters): + return results_by_query[query] + + mock_async_memory._entity_store = Mock() + mock_async_memory._entity_store.search = Mock(side_effect=fake_search) + + boosts = await mock_async_memory._compute_entity_boosts_async( + [("person", "alice"), ("person", "bob")], + {"user_id": "u1"}, + ) + + boost_alice = 0.9 * ENTITY_BOOST_WEIGHT * (1.0 / (1.0 + 0.001 * (0**2))) + boost_bob = 0.6 * ENTITY_BOOST_WEIGHT * (1.0 / (1.0 + 0.001 * (1**2))) + assert boosts["mem-1"] == pytest.approx(max(boost_alice, boost_bob)) + assert boosts["mem-2"] == pytest.approx(boost_bob) + + @pytest.mark.asyncio + async def test_async_embed_batch_called_once(self, mock_async_memory): + mock_async_memory.embedding_model = Mock() + mock_async_memory.embedding_model.embed_batch = Mock(return_value=[[0.1], [0.1], [0.1]]) + mock_async_memory._entity_store = Mock() + mock_async_memory._entity_store.search = Mock(return_value=[_make_match(0.7, ["mem-1"])]) + + await mock_async_memory._compute_entity_boosts_async( + [("person", "alice"), ("person", "bob"), ("person", "carol")], + {"user_id": "u1"}, + ) + + mock_async_memory.embedding_model.embed_batch.assert_called_once_with(["alice", "bob", "carol"], "search") + + def test_sync_one_entity_failure_does_not_abort_others(self, mock_memory, caplog): + mock_memory.embedding_model = Mock() + mock_memory.embedding_model.embed_batch = Mock(return_value=[[0.1, 0.2, 0.3], [0.1, 0.2, 0.3]]) + + def fake_search(query, vectors, top_k, filters): + if query == "boom": + raise RuntimeError("provider timeout") + return [_make_match(0.8, ["mem-9"])] + + mock_memory._entity_store = Mock() + mock_memory._entity_store.search = Mock(side_effect=fake_search) + + with caplog.at_level(logging.WARNING): + boosts = mock_memory._compute_entity_boosts( + [("person", "boom"), ("person", "ok")], + {"user_id": "u1"}, + ) + + assert "mem-9" in boosts + assert any("Entity boost search failed" in r.message for r in caplog.records) + + @pytest.mark.asyncio + async def test_async_one_entity_failure_does_not_abort_others(self, mock_async_memory, caplog): + mock_async_memory.embedding_model = Mock() + mock_async_memory.embedding_model.embed_batch = Mock(return_value=[[0.1, 0.2, 0.3], [0.1, 0.2, 0.3]]) + + def fake_search(query, vectors, top_k, filters): + if query == "boom": + raise RuntimeError("provider timeout") + return [_make_match(0.8, ["mem-9"])] + + mock_async_memory._entity_store = Mock() + mock_async_memory._entity_store.search = Mock(side_effect=fake_search) + + with caplog.at_level(logging.WARNING): + boosts = await mock_async_memory._compute_entity_boosts_async( + [("person", "boom"), ("person", "ok")], + {"user_id": "u1"}, + ) + + assert "mem-9" in boosts + assert any("Entity boost search failed" in r.message for r in caplog.records) + + def test_sync_searches_run_concurrently(self, mock_memory): + mock_memory.embedding_model = Mock() + mock_memory.embedding_model.embed_batch = Mock(return_value=[[0.1]] * 4) + + concurrent_count = {"current": 0, "peak": 0} + + def blocking_search(query, vectors, top_k, filters): + concurrent_count["current"] += 1 + concurrent_count["peak"] = max(concurrent_count["peak"], concurrent_count["current"]) + time.sleep(0.2) + concurrent_count["current"] -= 1 + return [_make_match(0.7, [f"mem-{query}"])] + + mock_memory._entity_store = Mock() + mock_memory._entity_store.search = Mock(side_effect=blocking_search) + + entities = [("person", f"e{i}") for i in range(4)] + start = time.perf_counter() + boosts = mock_memory._compute_entity_boosts(entities, {"user_id": "u1"}) + elapsed = time.perf_counter() - start + + assert elapsed < 0.6, f"searches did not run concurrently (took {elapsed:.2f}s)" + assert concurrent_count["peak"] >= 2, "no overlap observed between entity searches" + assert len(boosts) == 4 + + @pytest.mark.asyncio + async def test_async_searches_run_concurrently(self, mock_async_memory): + mock_async_memory.embedding_model = Mock() + mock_async_memory.embedding_model.embed_batch = Mock(return_value=[[0.1]] * 4) + + concurrent_count = {"current": 0, "peak": 0} + + def blocking_search(query, vectors, top_k, filters): + concurrent_count["current"] += 1 + concurrent_count["peak"] = max(concurrent_count["peak"], concurrent_count["current"]) + time.sleep(0.2) + concurrent_count["current"] -= 1 + return [_make_match(0.7, [f"mem-{query}"])] + + mock_async_memory._entity_store = Mock() + mock_async_memory._entity_store.search = Mock(side_effect=blocking_search) + + entities = [("person", f"e{i}") for i in range(4)] + start = time.perf_counter() + boosts = await mock_async_memory._compute_entity_boosts_async(entities, {"user_id": "u1"}) + elapsed = time.perf_counter() - start + + assert elapsed < 0.6, f"searches did not run concurrently (took {elapsed:.2f}s)" + assert concurrent_count["peak"] >= 2, "no overlap observed between entity searches" + assert len(boosts) == 4