Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 49 additions & 3 deletions src/locus/memory/backends/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,19 @@ def _callable_has_parameter(fn: Any, parameter_name: str) -> bool:
return parameter_name in sig.parameters


def _async_backend_op(backend: Any, name: str) -> Any | None:
"""Return ``backend.<name>`` only when it is a real async method.

Used to detect optional backend capabilities (e.g. the atomic checkpoint
index, #301). Guards against ``MagicMock`` test doubles, whose attribute
access auto-creates truthy children that are not coroutine functions.
"""
import inspect

op = getattr(backend, name, None)
return op if inspect.iscoroutinefunction(op) else None


class StorageBackendAdapter(BaseCheckpointer):
"""
Adapter that wraps simple storage backends to implement BaseCheckpointer.
Expand Down Expand Up @@ -166,8 +179,23 @@ async def _update_checkpoint_index(
"""Update the persistent checkpoint index."""
index_key = f"{thread_id}:_checkpoints"

# If the backend offers an atomic index append, use it: the append is
# serialized in the store itself, so it is safe across processes — the
# per-instance lock below only covers a single process (see #301).
index_add = _async_backend_op(self._backend, "index_add")
if index_add is not None:
await index_add(
index_key,
{
"checkpoint_id": checkpoint_id,
"timestamp": timestamp.isoformat(),
"metadata": metadata or {},
},
)
return

# Serialize the read-modify-write so concurrent saves to the same
# thread cannot clobber each other's index entries.
# thread cannot clobber each other's index entries (in-process only).
async with self._index_lock(thread_id):
# Load existing index
existing = await self._backend.load(index_key)
Expand Down Expand Up @@ -237,8 +265,20 @@ async def list_checkpoints(
if existing is None:
return []

checkpoints = existing.get("checkpoints", [])
return [cp.get("checkpoint_id") for cp in checkpoints[:limit] if cp.get("checkpoint_id")]
# De-duplicate by checkpoint_id, preserving order (newest first). The
# atomic-append index path (index_add) does not de-dup on write, so a
# re-saved checkpoint_id can appear more than once; keep the first.
seen: set[str] = set()
result: list[str] = []
for cp in existing.get("checkpoints", []):
cid = cp.get("checkpoint_id")
if not cid or cid in seen:
continue
seen.add(cid)
result.append(cid)
if len(result) >= limit:
break
return result

async def delete(
self,
Expand Down Expand Up @@ -287,6 +327,12 @@ async def _remove_from_index(
"""Remove checkpoint from index."""
index_key = f"{thread_id}:_checkpoints"

# Prefer the backend's atomic remove (cross-process safe). See #301.
index_remove = _async_backend_op(self._backend, "index_remove")
if index_remove is not None:
await index_remove(index_key, checkpoint_id)
return

async with self._index_lock(thread_id):
existing = await self._backend.load(index_key)
if existing is None:
Expand Down
75 changes: 75 additions & 0 deletions src/locus/memory/backends/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,81 @@ async def save(

return checkpoint_id

async def index_add(self, index_key: str, entry: dict[str, Any]) -> None:
"""Append an entry to a checkpoint-index row atomically.

Cross-process safe replacement for the adapter's load-modify-save of
the ``{thread}:_checkpoints`` blob (see #301). The whole append is a
single statement, so concurrent writers from separate processes are
serialized by InnoDB's row lock instead of racing a read-modify-write
and dropping each other's entries. The row keeps the adapter's
``{"checkpoints": [...]}`` shape, newest first (``$.checkpoints[0]``).
"""
await self._ensure_table()
pool = await self._get_pool()

now = datetime.now(UTC).replace(tzinfo=None)
entry_json = json.dumps(entry)

async with pool.acquire() as conn:
async with await conn.cursor() as cur:
await cur.execute(
f"""
INSERT INTO {self._quoted_table_name}
(thread_id, checkpoint_id, data, created_at, updated_at, metadata)
VALUES (%s, %s, JSON_OBJECT('checkpoints', JSON_ARRAY(CAST(%s AS JSON))),
%s, %s, %s)
ON DUPLICATE KEY UPDATE
data = JSON_ARRAY_INSERT(data, '$.checkpoints[0]', CAST(%s AS JSON)),
updated_at = VALUES(updated_at)
""",
(
index_key,
entry.get("checkpoint_id"),
entry_json,
now,
now,
json.dumps({}),
entry_json,
),
)
await conn.commit()

async def index_remove(self, index_key: str, checkpoint_id: str) -> None:
"""Remove a checkpoint from an index row atomically.

Uses ``SELECT ... FOR UPDATE`` so the read-modify-write of the index
row is serialized across processes (companion to :meth:`index_add`).
"""
await self._ensure_table()
pool = await self._get_pool()

now = datetime.now(UTC).replace(tzinfo=None)

async with pool.acquire() as conn:
async with await conn.cursor() as cur:
await cur.execute(
f"SELECT data FROM {self._quoted_table_name} WHERE thread_id = %s FOR UPDATE",
(index_key,),
)
row = await cur.fetchone()
if row is None:
await conn.commit()
return

existing = _decode_json(row[0]) or {"checkpoints": []}
existing["checkpoints"] = [
cp
for cp in existing.get("checkpoints", [])
if cp.get("checkpoint_id") != checkpoint_id
]
await cur.execute(
f"UPDATE {self._quoted_table_name} SET data = %s, updated_at = %s "
"WHERE thread_id = %s",
(json.dumps(existing), now, index_key),
)
await conn.commit()

async def load(self, thread_id: str) -> dict[str, Any] | None:
"""Load checkpoint from MySQL."""
await self._ensure_table()
Expand Down
28 changes: 28 additions & 0 deletions tests/integration/test_checkpointer_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,34 @@ async def save_checkpoint(i: int) -> str:
await self._drop_adapter_table(adapter)
await adapter.close()

@pytest.mark.asyncio
async def test_mysql_adapter_cross_process_index_no_loss(self, sample_state):
"""Two adapter instances over one table keep every index entry (#301).

Models two processes: each adapter holds its own in-process index lock,
so the per-instance lock provides no protection. The MySQL backend's
atomic ``index_add`` must serialize the appends in the store itself, so
no checkpoint is dropped from the index.
"""
table = f"test_adapter_xproc_{uuid4().hex[:12]}"
a = self._mysql_adapter(table)
b = self._mysql_adapter(table) # separate instance => separate lock
thread_id = "mysql-xproc-thread"

async def save_checkpoint(adapter, i: int) -> str:
state = sample_state.model_copy(update={"agent_id": f"agent-{i}"})
return await adapter.save(state, thread_id, checkpoint_id=f"cp-{i}")

try:
await asyncio.gather(*(save_checkpoint(a if i % 2 == 0 else b, i) for i in range(20)))

listed = await a.list_checkpoints(thread_id, limit=1000)
assert sorted(listed) == sorted(f"cp-{i}" for i in range(20))
finally:
await self._drop_adapter_table(a)
await a.close()
await b.close()


# =============================================================================
# OpenSearch Backend Tests (requires OpenSearch)
Expand Down
71 changes: 71 additions & 0 deletions tests/unit/test_memory_backends_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,77 @@ def test_uninspectable_callable_returns_false(self) -> None:
assert _callable_has_parameter(42, "metadata") is False


class _AtomicIndexBackend:
"""Backend that advertises atomic index ops (like MySQLBackend, #301)."""

def __init__(self) -> None:
self.store: dict[str, dict[str, Any]] = {}
self.index_added: list[tuple[str, dict[str, Any]]] = []
self.index_removed: list[tuple[str, str]] = []
# Sentinel so we can prove the in-process RMW lock path was NOT used.
self.saved_index = False

async def save(self, key: str, data: dict[str, Any]) -> None:
if key.endswith(":_checkpoints"):
self.saved_index = True
self.store[key] = data

async def load(self, key: str) -> dict[str, Any] | None:
return self.store.get(key)

async def delete(self, key: str) -> bool:
return self.store.pop(key, None) is not None

async def exists(self, key: str) -> bool:
return key in self.store

async def index_add(self, index_key: str, entry: dict[str, Any]) -> None:
self.index_added.append((index_key, entry))

async def index_remove(self, index_key: str, checkpoint_id: str) -> None:
self.index_removed.append((index_key, checkpoint_id))


class TestAtomicIndexDelegation:
@pytest.mark.asyncio
async def test_save_delegates_to_backend_index_add(self) -> None:
backend = _AtomicIndexBackend()
adapter = StorageBackendAdapter(backend)
await adapter.save(_FakeState(), thread_id="t1", checkpoint_id="cp-1")

assert backend.index_added == [
("t1:_checkpoints", backend.index_added[0][1]),
]
assert backend.index_added[0][1]["checkpoint_id"] == "cp-1"
# The blob read-modify-write index path must be bypassed entirely.
assert backend.saved_index is False

@pytest.mark.asyncio
async def test_delete_delegates_to_backend_index_remove(self) -> None:
backend = _AtomicIndexBackend()
adapter = StorageBackendAdapter(backend)
await adapter.save(_FakeState(), thread_id="t1", checkpoint_id="cp-1")

await adapter.delete("t1", checkpoint_id="cp-1")
assert backend.index_removed == [("t1:_checkpoints", "cp-1")]

@pytest.mark.asyncio
async def test_magic_mock_attr_is_not_mistaken_for_capability(self) -> None:
# getattr on a MagicMock auto-creates truthy children; the adapter must
# only treat real async methods as the atomic-index capability.
backend = MagicMock()
backend.save = AsyncMock()
backend.load = AsyncMock(return_value=None)
adapter = StorageBackendAdapter(backend)

await adapter.save(_FakeState(), thread_id="t1", checkpoint_id="cp-1")

# Fell back to the blob path: it saved the index via save(), not index_add().
assert any(
call.args and call.args[0] == "t1:_checkpoints" for call in backend.save.await_args_list
)


class TestUpdateCheckpointIndex:
@pytest.mark.asyncio
async def test_existing_index_replaces_duplicate_and_sorts_newest_first(self) -> None:
Expand Down
61 changes: 61 additions & 0 deletions tests/unit/test_memory_backends_mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,67 @@ async def test_exists_false(self, monkeypatch: pytest.MonkeyPatch) -> None:
assert await backend.exists("t") is False


# ---------------------------------------------------------------------------
# Atomic checkpoint index (#301)
# ---------------------------------------------------------------------------


class TestIndexOps:
@pytest.mark.asyncio
async def test_index_add_uses_atomic_json_array_insert(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
cur = _StubCursor(fetchone=(1,))
_stub_mysql_connector(monkeypatch, _StubConn(cur))
backend = MySQLBackend()
entry = {"checkpoint_id": "cp-1", "timestamp": "2026-01-01T00:00:00", "metadata": {}}

await backend.index_add("t:_checkpoints", entry)

sql, args = cur.execute_calls[-1]
# Single atomic statement (no separate read), newest-first append.
assert "ON DUPLICATE KEY UPDATE" in sql
assert "JSON_ARRAY_INSERT" in sql
assert "$.checkpoints[0]" in sql
assert args[0] == "t:_checkpoints"
assert json.loads(args[2]) == entry

@pytest.mark.asyncio
async def test_index_remove_uses_select_for_update(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
index = {"checkpoints": [{"checkpoint_id": "cp-1"}, {"checkpoint_id": "cp-2"}]}
cur = _StubCursor(fetchone=(json.dumps(index),))
_stub_mysql_connector(monkeypatch, _StubConn(cur))
backend = MySQLBackend()

await backend.index_remove("t:_checkpoints", "cp-1")

select_sql = cur.execute_calls[-2][0]
update_sql, update_args = cur.execute_calls[-1]
assert "FOR UPDATE" in select_sql
assert update_sql.strip().startswith("UPDATE")
# cp-1 removed, cp-2 kept.
assert json.loads(update_args[0]) == {"checkpoints": [{"checkpoint_id": "cp-2"}]}

@pytest.mark.asyncio
async def test_index_remove_is_noop_when_row_missing(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
cur = _StubCursor(fetchone=None)
conn = _StubConn(cur)
_stub_mysql_connector(monkeypatch, conn)
backend = MySQLBackend()
backend._initialized = True # skip CREATE TABLE so we observe only index_remove

await backend.index_remove("t:_checkpoints", "cp-1")

# Only the SELECT ... FOR UPDATE ran; no UPDATE issued, but still commit
# to release the row lock.
assert not any(sql.strip().startswith("UPDATE") for sql, _ in cur.execute_calls)
assert conn.commits == 1


# ---------------------------------------------------------------------------
# Query operations
# ---------------------------------------------------------------------------
Expand Down