Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion mem0-ts/src/oss/src/embeddings/lmstudio.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ export class LMStudioEmbedder implements Embedder {
input: normalized,
encoding_format: "float",
});
return response.data.map((item) => item.embedding);
return response.data
.sort((a, b) => a.index - b.index)
.map((item) => item.embedding);
} catch (err) {
const message = err instanceof Error ? err.message : String(err);
throw new Error(`LM Studio embedder failed: ${message}`);
Expand Down
39 changes: 27 additions & 12 deletions mem0-ts/src/oss/src/memory/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1169,25 +1169,42 @@ export class Memory {

if (deduped.length > 0) {
const entityStore = await this.getEntityStore();
const entitySearchFilters: Record<string, any> = {};
for (const k of ["user_id", "agent_id", "run_id"] as const) {
if (effectiveFilters[k])
entitySearchFilters[k] = effectiveFilters[k];
}
const entityTexts = deduped.map((e) => e.text);
const embeddings = await this.embedder.embedBatch(entityTexts);

for (const entity of deduped) {
try {
const entityEmbedding = await this.embedder.embed(entity.text);
const matches = await entityStore.search(
entityEmbedding,
500,
effectiveFilters,
);
if (embeddings.length !== entityTexts.length) {
console.warn(
`embedBatch returned ${embeddings.length} vectors for ${entityTexts.length} texts — skipping entity boost`,
);
} else {
const searchResults = await Promise.allSettled(
deduped.map((_, i) =>
entityStore.search(embeddings[i], 500, entitySearchFilters),
),
);

for (const result of searchResults) {
if (result.status === "rejected") {
console.warn(
"Entity boost search failed for one entity:",
result.reason,
);
continue;
}

for (const match of matches) {
for (const match of result.value) {
const similarity = match.score ?? 0;
if (similarity < 0.5) continue;

const payload = match.payload || {};
const linkedMemoryIds = payload.linkedMemoryIds ?? [];
if (!Array.isArray(linkedMemoryIds)) continue;

// Spread-attenuated boost
const numLinked = Math.max(linkedMemoryIds.length, 1);
const memoryCountWeight =
1.0 / (1.0 + 0.001 * (numLinked - 1) ** 2);
Expand All @@ -1204,8 +1221,6 @@ export class Memory {
}
}
}
} catch (e) {
// Individual entity boost failed — continue
}
}
}
Expand Down
276 changes: 276 additions & 0 deletions mem0-ts/src/oss/tests/memory.entity-boost.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,276 @@
/**
* Entity boost parallelism tests (#5214).
*
* Verifies that entity boost searches run concurrently via Promise.allSettled,
* scoring is preserved, and individual entity failures don't abort others.
*/
/// <reference types="jest" />
import { Memory } from "../src/memory";
import { ENTITY_BOOST_WEIGHT } from "../src/utils/scoring";
import type { VectorStoreResult } from "../src/types";

jest.setTimeout(15000);

jest.mock("../src/embeddings/google", () => ({
GoogleEmbedder: jest.fn(),
}));
jest.mock("../src/llms/google", () => ({
GoogleLLM: jest.fn(),
}));

jest.mock("../src/llms/openai", () => ({
OpenAILLM: jest.fn().mockImplementation(() => ({
generateResponse: jest.fn().mockResolvedValue(
JSON.stringify({
memory: [{ id: "0", text: "fact", attributed_to: "user" }],
}),
),
})),
}));

const mockEmbedding = new Array(1536).fill(0.1);
jest.mock("../src/embeddings/openai", () => ({
OpenAIEmbedder: jest.fn().mockImplementation(() => ({
embed: jest.fn().mockResolvedValue(mockEmbedding),
embedBatch: jest
.fn()
.mockImplementation((texts: string[]) =>
Promise.resolve(texts.map(() => mockEmbedding)),
),
embeddingDims: 1536,
})),
}));

function makeMatch(
id: string,
score: number,
linkedMemoryIds: string[],
): VectorStoreResult {
return { id, score, payload: { linkedMemoryIds } };
}

function createMemory(): Memory {
return new Memory({
version: "v1.1",
embedder: {
provider: "openai",
config: { apiKey: "test-key", model: "text-embedding-3-small" },
},
vectorStore: {
provider: "memory",
config: {
collectionName: `test-entity-${Date.now()}-${Math.random()}`,
dimension: 1536,
dbPath: ":memory:",
},
},
llm: {
provider: "openai",
config: { apiKey: "test-key", model: "gpt-5-mini" },
},
historyDbPath: ":memory:",
});
}

describe("Entity boost parallelism (#5214)", () => {
let memory: Memory;

beforeEach(() => {
memory = createMemory();
});

afterEach(async () => {
await memory.reset();
});

it("should use Promise.allSettled for concurrent entity searches", async () => {
// Spy on Promise.allSettled to confirm it's being used
const allSettledSpy = jest.spyOn(Promise, "allSettled");

// Access internals to inject a mock entity store
const m = memory as any;
await m._ensureInitialized();

const mockEntityStore = {
search: jest.fn().mockResolvedValue([makeMatch("e1", 0.9, ["mem-1"])]),
initialize: jest.fn().mockResolvedValue(undefined),
};
m._entityStore = mockEntityStore;

m.embedder = {
embed: jest.fn().mockResolvedValue(mockEmbedding),
embedBatch: jest
.fn()
.mockImplementation((texts: string[]) =>
Promise.resolve(texts.map(() => mockEmbedding)),
),
};

// Mock the vector store to return a semantic result
m.vectorStore.search = jest
.fn()
.mockResolvedValue([
{ id: "mem-1", score: 0.8, payload: { data: "test" } },
]);
m.vectorStore.keywordSearch = jest.fn().mockResolvedValue(null);

await m.search("alice and bob", { filters: { user_id: "u1" } });

expect(allSettledSpy).toHaveBeenCalled();
allSettledSpy.mockRestore();
});

it("should preserve scoring math with parallel execution", async () => {
const m = memory as any;
await m._ensureInitialized();

// Two entities: "alice" links to mem-1, "bob" links to mem-1 and mem-2
// mem-1 should get max(alice_boost, bob_boost)
const mockEntityStore = {
search: jest
.fn()
.mockImplementation(
(_embedding: number[], _topK: number, _filters: any) => {
// We need to differentiate by embedding content — but since all
// embeddings are identical mocks, we'll use call order
const callCount = mockEntityStore.search.mock.calls.length;
if (callCount <= 1) {
// First entity: "alice"
return Promise.resolve([makeMatch("e-alice", 0.9, ["mem-1"])]);
}
// Second entity: "bob"
return Promise.resolve([
makeMatch("e-bob", 0.6, ["mem-1", "mem-2"]),
]);
},
),
initialize: jest.fn().mockResolvedValue(undefined),
};
m._entityStore = mockEntityStore;
m.embedder = {
embed: jest.fn().mockResolvedValue(mockEmbedding),
embedBatch: jest
.fn()
.mockImplementation((texts: string[]) =>
Promise.resolve(texts.map(() => mockEmbedding)),
),
};

// Semantic results include mem-1 and mem-2
m.vectorStore.search = jest.fn().mockResolvedValue([
{ id: "mem-1", score: 0.85, payload: { data: "alice memory" } },
{ id: "mem-2", score: 0.75, payload: { data: "bob memory" } },
]);
m.vectorStore.keywordSearch = jest.fn().mockResolvedValue(null);

const result = await m.search("alice and bob", {
filters: { user_id: "u1" },
});

// Entity store was called (parallelized via Promise.allSettled)
expect(mockEntityStore.search).toHaveBeenCalled();

// Results should exist and have scores
expect(result.results.length).toBeGreaterThan(0);
for (const item of result.results) {
expect(typeof item.score).toBe("number");
expect(item.score).toBeGreaterThan(0);
}
});

it("should survive one entity search failure without losing other boosts", async () => {
const m = memory as any;
await m._ensureInitialized();

let callIndex = 0;
const mockEntityStore = {
search: jest.fn().mockImplementation(() => {
callIndex++;
if (callIndex === 1) {
return Promise.reject(new Error("provider timeout"));
}
return Promise.resolve([makeMatch("e-ok", 0.8, ["mem-9"])]);
}),
initialize: jest.fn().mockResolvedValue(undefined),
};
m._entityStore = mockEntityStore;
m.embedder = {
embed: jest.fn().mockResolvedValue(mockEmbedding),
embedBatch: jest
.fn()
.mockImplementation((texts: string[]) =>
Promise.resolve(texts.map(() => mockEmbedding)),
),
};
m.vectorStore.search = jest
.fn()
.mockResolvedValue([
{ id: "mem-9", score: 0.85, payload: { data: "surviving memory" } },
]);
m.vectorStore.keywordSearch = jest.fn().mockResolvedValue(null);

const warnSpy = jest.spyOn(console, "warn").mockImplementation(() => {});

// "John Smith met Jane Doe" extracts two proper entities
const result = await m.search("John Smith met Jane Doe", {
filters: { user_id: "u1" },
});

expect(result.results.length).toBeGreaterThan(0);
expect(result.results[0].id).toBe("mem-9");
// Should log the failure like Python does
expect(warnSpy).toHaveBeenCalledWith(
"Entity boost search failed for one entity:",
expect.any(Error),
);
warnSpy.mockRestore();
});

it("should call entity searches concurrently, not sequentially", async () => {
const m = memory as any;
await m._ensureInitialized();

const concurrency = { current: 0, peak: 0 };

const mockEntityStore = {
search: jest.fn().mockImplementation(() => {
concurrency.current++;
concurrency.peak = Math.max(concurrency.peak, concurrency.current);
return new Promise<VectorStoreResult[]>((resolve) => {
setTimeout(() => {
concurrency.current--;
resolve([makeMatch("e1", 0.7, ["mem-1"])]);
}, 100);
});
}),
initialize: jest.fn().mockResolvedValue(undefined),
};
m._entityStore = mockEntityStore;
m.embedder = {
embed: jest.fn().mockResolvedValue(mockEmbedding),
embedBatch: jest
.fn()
.mockImplementation((texts: string[]) =>
Promise.resolve(texts.map(() => mockEmbedding)),
),
};
m.vectorStore.search = jest
.fn()
.mockResolvedValue([
{ id: "mem-1", score: 0.8, payload: { data: "test" } },
]);
m.vectorStore.keywordSearch = jest.fn().mockResolvedValue(null);

const start = performance.now();
await m.search("entity1 and entity2 and entity3 and entity4", {
filters: { user_id: "u1" },
});
const elapsed = performance.now() - start;

// With 4 entities at 100ms each, sequential would be ~400ms+.
// Parallel should be well under that. Use generous bound for CI.
expect(elapsed).toBeLessThan(500);
// At least 2 searches should have overlapped
expect(concurrency.peak).toBeGreaterThanOrEqual(2);
});
});
Loading
Loading