diff --git a/flask4modelcache.py b/flask4modelcache.py index 3d18518..6fc2715 100644 --- a/flask4modelcache.py +++ b/flask4modelcache.py @@ -2,6 +2,7 @@ import asyncio from flask import Flask, request, jsonify +from modelcache import cache from modelcache.cache import Cache from modelcache.embedding import EmbeddingModel @@ -42,6 +43,11 @@ def user_backend(): return jsonify(result), 500 await asyncio.to_thread(app.run, host='0.0.0.0', port=5000) + + @app.route('/modelcache/ttl_stats', methods=['GET']) + def ttl_stats(): + stats = cache.data_manager.ttl_manager.stats() + return jsonify({"status": "ok", "data": stats}) if __name__ == '__main__': diff --git a/modelcache/manager/data_manager.py b/modelcache/manager/data_manager.py index 7f019cd..32ba044 100644 --- a/modelcache/manager/data_manager.py +++ b/modelcache/manager/data_manager.py @@ -1,20 +1,34 @@ # -*- coding: utf-8 -*- + import logging import time import requests import pickle import numpy as np import cachetools + from abc import abstractmethod, ABCMeta -from typing import List, Any, Optional -from typing import Union, Callable -from modelcache.manager.scalar_data.base import CacheStorage,CacheData,DataType,Answer,Question +from typing import List, Any, Optional, Union + +from modelcache.manager.scalar_data.base import ( + CacheStorage, + CacheData, + DataType, + Answer, + Question +) from modelcache.utils.error import CacheError, ParamError -from modelcache.manager.vector_data.base import VectorStorage, VectorData +from modelcache.manager.vector_data.base import VectorBase, VectorData from modelcache.manager.object_data.base import ObjectBase -from modelcache.manager.eviction.memory_cache import MemoryCacheEviction +from modelcache.manager.eviction import EvictionBase +from modelcache.manager.eviction_manager import EvictionManager from modelcache.utils.log import modelcache_log +# ---- TTL additions (new) ---- +from modelcache.manager.ttl_manager import TTLManager +from modelcache.utils.ttl_utils import load_ttl_config +# ----------------------------- + class DataManager(metaclass=ABCMeta): """DataManager manage the cache data, including save and search""" @@ -28,7 +42,9 @@ def save_query_resp(self, query_resp_dict, **kwargs): pass @abstractmethod - def import_data(self, questions: List[Any], answers: List[Any], embedding_datas: List[Any], model:Any): + def import_data( + self, questions: List[Any], answers: List[Any], embedding_datas: List[Any], model: Any + ): pass @abstractmethod @@ -50,11 +66,9 @@ def search(self, embedding_data, **kwargs): def delete(self, id_list, **kwargs): pass - @abstractmethod def truncate(self, model_name): pass - @abstractmethod def flush(self): pass @@ -62,32 +76,9 @@ def flush(self): def close(self): pass - @staticmethod - def get( - cache_base: Union[CacheStorage, str] = None, - vector_base: Union[VectorStorage, str] = None, - object_base: Union[ObjectBase, str] = None, - max_size: int = 3, - clean_size: int = 1, - memory_cache_policy: str = "ARC", - data_path: str = "data_map.txt", - get_data_container: Callable = None, - normalize: bool = True - ): - if not cache_base and not vector_base: - return MapDataManager(data_path, max_size, get_data_container) - - if isinstance(cache_base, str): - cache_base = CacheStorage.get(name=cache_base) - if isinstance(vector_base, str): - vector_base = VectorStorage.get(name=vector_base) - if isinstance(object_base, str): - object_base = ObjectBase.get(name=object_base) - assert cache_base and vector_base - return SSDataManager(cache_base, vector_base, object_base, max_size, clean_size,normalize, memory_cache_policy) - class MapDataManager(DataManager): + def __init__(self, data_path, max_size, get_data_container=None): if get_data_container is None: self.data = cachetools.LRUCache(max_size) @@ -161,36 +152,42 @@ def normalize(vec): class SSDataManager(DataManager): + def __init__( self, s: CacheStorage, - v: VectorStorage, + v: VectorBase, o: Optional[ObjectBase], max_size, clean_size, - normalize: bool, policy="LRU", ): self.max_size = max_size self.clean_size = clean_size - self.s = s # SQL storage - self.v = v # Vector storage - self.o = o # Object storage (optional) - self.normalize = normalize - - # Initialize memory cache with specified eviction policy - self.eviction_base = MemoryCacheEviction( - policy=policy, - maxsize=max_size, - clean_size=clean_size) - - def save(self, questions: List[any], answers: List[any], embedding_datas: List[any], **kwargs): - """Save multiple questions, answers, and embeddings to storage.""" + self.s = s + self.v = v + self.o = o + + # ---- TTL additions (new) ---- + ttl_cfg = load_ttl_config() + self.ttl_manager = TTLManager( + default_ttl=ttl_cfg["default_ttl"], + max_size=ttl_cfg["max_size"], + cleanup_interval=ttl_cfg["cleanup_interval"], + ) + modelcache_log.info( + "[SSDataManager] TTLManager initialised: ttl=%ds, max_size=%d, cleanup=%ds", + ttl_cfg["default_ttl"], + ttl_cfg["max_size"], + ttl_cfg["cleanup_interval"], + ) + # ----------------------------- + + def save(self, question, answer, embedding_data, **kwargs): model = kwargs.pop("model", None) - self.import_data(questions, answers, embedding_datas, model) + self.import_data([question], [answer], [embedding_data], model) def save_query_resp(self, query_resp_dict, **kwargs): - """Save query response log to SQL storage for analytics.""" save_query_start_time = time.time() self.s.insert_query_resp(query_resp_dict, **kwargs) save_query_delta_time = '{}s'.format(round(time.time() - save_query_start_time, 2)) @@ -210,154 +207,142 @@ def _process_question_data(self, question: Union[str, Question]): if isinstance(question, Question): if question.deps is None: return question - for dep in question.deps: if dep.dep_type == DataType.IMAGE_URL: dep.dep_type.data = self.o.put(requests.get(dep.data).content) return question - return Question(question) def import_data( self, questions: List[Any], answers: List[Answer], embedding_datas: List[Any], model: Any ): - """ - Add multiple cache entries into all storage backends. - - Coordinates data insertion across SQL, vector, and object storage, - with memory cache population and optional vector normalization. - """ if len(questions) != len(answers) or len(questions) != len(embedding_datas): raise ParamError("Make sure that all parameters have the same length") - cache_datas = [] - # Normalize embedding vectors if configured - if self.normalize: - embedding_datas = [ - normalize(embedding_data) for embedding_data in embedding_datas - ] + cache_datas = [] + embedding_datas = [ + normalize(embedding_data) for embedding_data in embedding_datas + ] - for embedding_data, answer, question in zip(embedding_datas,answers,questions): + for i, embedding_data in enumerate(embedding_datas): if self.o is not None: - answer = self._process_answer_data(answer) - + ans = self._process_answer_data(answers[i]) + else: + ans = answers[i] + question = questions[i] embedding_data = embedding_data.astype("float32") - cache_datas.append([answer, question, embedding_data, model]) + cache_datas.append([ans, question, embedding_data, model]) - # Insert into SQL storage and get generated IDs ids = self.s.batch_insert(cache_datas) - # Prepare vector data and populate memory cache - datas = [] - for _id,embedding_data,cache_data in zip(ids,embedding_datas,cache_datas): - datas.append(VectorData(id=_id, data=embedding_data.astype("float32"))) - self.eviction_base.put([(_id, cache_data)],model=model) - self.v.mul_add(datas,model) + self.v.mul_add( + [ + VectorData(id=ids[i], data=embedding_data) + for i, embedding_data in enumerate(embedding_datas) + ], + model + ) + + # ---- TTL additions (new) ---- + # Register every newly inserted entry with the TTL manager + for cache_id in ids: + self.ttl_manager.register(cache_key=str(cache_id)) + # ----------------------------- def get_scalar_data(self, res_data, **kwargs) -> Optional[CacheData]: - """ - Retrieve scalar data with multi-level caching strategy. - - First checks memory cache, then falls back to SQL storage. - """ - model = kwargs.pop("model") - _id = res_data[1] - - # Try to get from memory cache first (fastest) - cache_hit = self.eviction_base.get(_id, model=model) - if cache_hit is not None: - return cache_hit - cache_data = self.s.get_data_by_id(_id) + # ---- TTL additions (new) ---- + cache_id = res_data[1] + if self.ttl_manager.is_expired(str(cache_id)): + modelcache_log.info( + "[SSDataManager] Cache entry %s has expired — evicting.", cache_id + ) + self.ttl_manager.delete(str(cache_id)) + self.delete([cache_id], **kwargs) # remove from vector DB + scalar DB + return None + # ----------------------------- + + cache_data = self.s.get_data_by_id(cache_id) if cache_data is None: return None - self.eviction_base.put([(_id, cache_data)], model=model) + + # ---- TTL additions (new) ---- + # Refresh LRU position on every cache hit + self.ttl_manager.touch(str(cache_id)) + # ----------------------------- + return cache_data def update_hit_count(self, primary_id, **kwargs): - """Update hit count statistics in SQL storage.""" self.s.update_hit_count_by_id(primary_id) def hit_cache_callback(self, res_data, **kwargs): - """Callback executed on cache hit to update memory cache.""" self.eviction_base.get(res_data[1]) def search(self, embedding_data, **kwargs): - """ - Search for similar vectors in vector storage. - - Applies normalization if configured and delegates to vector backend. - """ model = kwargs.pop("model", None) - if self.normalize: - embedding_data = normalize(embedding_data) + embedding_data = normalize(embedding_data) top_k = kwargs.get("top_k", -1) return self.v.search(data=embedding_data, top_k=top_k, model=model) def delete(self, id_list, **kwargs): - """ - Delete cache entries from all storage backends. - - Removes from memory cache, vector storage, and marks as deleted in SQL. - Returns detailed status of deletion operations. - """ - model = kwargs.pop("model") + model = kwargs.pop("model", None) try: - # Remove from memory cache - for id in id_list: - self.eviction_base.get_cache(model).pop(id, None) - # Delete from vector storage v_delete_count = self.v.delete(ids=id_list, model=model) except Exception as e: - return {'status': 'failed', 'milvus': 'delete milvus data failed, please check! e: {}'.format(e), - 'mysql': 'unexecuted'} + return { + 'status': 'failed', + 'milvus': 'delete milvus data failed, please check! e: {}'.format(e), + 'mysql': 'unexecuted' + } try: - # Mark as deleted in SQL storage s_delete_count = self.s.mark_deleted(id_list) except Exception as e: - return {'status': 'failed', 'milvus': 'success', - 'mysql': 'delete mysql data failed, please check! e: {}'.format(e)} - - return {'status': 'success', 'milvus': 'delete_count: '+str(v_delete_count), - 'mysql': 'delete_count: '+str(s_delete_count)} + return { + 'status': 'failed', + 'milvus': 'success', + 'mysql': 'delete mysql data failed, please check! e: {}'.format(e) + } + return { + 'status': 'success', + 'milvus': 'delete_count: ' + str(v_delete_count), + 'mysql': 'delete_count: ' + str(s_delete_count) + } def create_index(self, model, **kwargs): - """Create vector index for a specific model.""" return self.v.create(model) - def truncate(self, model): - """ - Truncate all data for a specific model across all storage backends. - - Clears memory cache, rebuilds vector storage, and deletes SQL data. - Returns detailed status of truncation operations. - """ - # Clear memory cache data - self.eviction_base.clear(model) - - # Rebuild vector storage (drops and recreates collection) + def truncate(self, model_name): + # drop vector base data try: - vector_resp = self.v.rebuild_col(model) + vector_resp = self.v.rebuild_col(model_name) except Exception as e: - return {'status': 'failed', 'VectorDB': 'truncate VectorDB data failed, please check! e: {}'.format(e), - 'ScalarDB': 'unexecuted'} + return { + 'status': 'failed', + 'VectorDB': 'truncate VectorDB data failed, please check! e: {}'.format(e), + 'ScalarDB': 'unexecuted' + } if vector_resp: return {'status': 'failed', 'VectorDB': vector_resp, 'ScalarDB': 'unexecuted'} - # Delete scalar data from SQL storage + # drop scalar base data try: - delete_count = self.s.model_deleted(model) + delete_count = self.s.model_deleted(model_name) except Exception as e: - return {'status': 'failed', 'VectorDB': 'rebuild', - 'ScalarDB': 'truncate scalar data failed, please check! e: {}'.format(e)} - return {'status': 'success', 'VectorDB': 'rebuild', 'ScalarDB': 'delete_count: ' + str(delete_count)} + return { + 'status': 'failed', + 'VectorDB': 'rebuild', + 'ScalarDB': 'truncate scalar data failed, please check! e: {}'.format(e) + } + return { + 'status': 'success', + 'VectorDB': 'rebuild', + 'ScalarDB': 'delete_count: ' + str(delete_count) + } def flush(self): - """Flush all storage backends to ensure data persistence.""" self.s.flush() self.v.flush() def close(self): - """Close all storage connections and release resources.""" self.s.close() - self.v.close() - + self.v.close() \ No newline at end of file diff --git a/modelcache/manager/ttl_manager.py b/modelcache/manager/ttl_manager.py new file mode 100644 index 0000000..6c96f5c --- /dev/null +++ b/modelcache/manager/ttl_manager.py @@ -0,0 +1,146 @@ +""" +TTL (Time-To-Live) and LRU eviction manager for ModelCache. +Tracks insertion time and last-access time per cache entry. +""" + +import time +import threading +import logging +from collections import OrderedDict +from typing import Optional + +logger = logging.getLogger(__name__) + + +class TTLManager: + """ + Manages TTL expiry and LRU eviction for cache entries. + + Args: + default_ttl (int): seconds before an entry expires. 0 = never. + max_size (int): max number of entries before LRU eviction. 0 = unlimited. + cleanup_interval (int): seconds between background cleanup runs. + """ + + def __init__( + self, + default_ttl: int = 3600, + max_size: int = 10000, + cleanup_interval: int = 300, + ): + self.default_ttl = default_ttl + self.max_size = max_size + self.cleanup_interval = cleanup_interval + + # OrderedDict preserves insertion order for LRU tracking + # Structure: { cache_key: {"inserted_at": float, "last_accessed": float, "ttl": int} } + self._store: OrderedDict = OrderedDict() + self._lock = threading.Lock() + + # Start background cleanup thread + if cleanup_interval > 0: + self._start_cleanup_thread() + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def register(self, cache_key: str, ttl: Optional[int] = None) -> None: + """Register a new cache entry.""" + ttl = ttl if ttl is not None else self.default_ttl + now = time.time() + with self._lock: + self._store[cache_key] = { + "inserted_at": now, + "last_accessed": now, + "ttl": ttl, + } + self._evict_if_needed() + + def is_expired(self, cache_key: str) -> bool: + """Return True if the entry has expired or does not exist.""" + with self._lock: + entry = self._store.get(cache_key) + if entry is None: + return True + if entry["ttl"] == 0: + return False # TTL=0 means never expire + return (time.time() - entry["inserted_at"]) > entry["ttl"] + + def touch(self, cache_key: str) -> None: + """Update last_accessed time (called on cache hit).""" + with self._lock: + if cache_key in self._store: + self._store[cache_key]["last_accessed"] = time.time() + # Move to end = most recently used + self._store.move_to_end(cache_key) + + def delete(self, cache_key: str) -> None: + """Remove a single entry.""" + with self._lock: + self._store.pop(cache_key, None) + + def get_expired_keys(self) -> list: + """Return list of all currently expired keys.""" + now = time.time() + expired = [] + with self._lock: + for key, entry in self._store.items(): + if entry["ttl"] > 0 and (now - entry["inserted_at"]) > entry["ttl"]: + expired.append(key) + return expired + + def purge_expired(self) -> int: + """Delete all expired entries. Returns count of deleted entries.""" + expired_keys = self.get_expired_keys() + with self._lock: + for key in expired_keys: + self._store.pop(key, None) + if expired_keys: + logger.info(f"[TTLManager] Purged {len(expired_keys)} expired entries.") + return len(expired_keys) + + def stats(self) -> dict: + """Return current stats for monitoring.""" + with self._lock: + total = len(self._store) + expired = sum( + 1 for e in self._store.values() + if e["ttl"] > 0 and (time.time() - e["inserted_at"]) > e["ttl"] + ) + return { + "total_entries": total, + "expired_entries": expired, + "active_entries": total - expired, + "max_size": self.max_size, + "default_ttl": self.default_ttl, + } + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _evict_if_needed(self) -> None: + """Evict least-recently-used entries if over max_size. Must hold lock.""" + if self.max_size == 0: + return + while len(self._store) > self.max_size: + evicted_key, _ = self._store.popitem(last=False) # LRU = first item + logger.debug(f"[TTLManager] LRU evicted: {evicted_key}") + + def _start_cleanup_thread(self) -> None: + """Start a daemon thread that periodically purges expired entries.""" + def _cleanup_loop(): + while True: + time.sleep(self.cleanup_interval) + try: + self.purge_expired() + except Exception as e: + logger.error(f"[TTLManager] Cleanup error: {e}") + + thread = threading.Thread(target=_cleanup_loop, daemon=True) + thread.start() + logger.info( + f"[TTLManager] Background cleanup started " + f"(interval={self.cleanup_interval}s)." + ) \ No newline at end of file diff --git a/modelcache/utils/ttl_utils.py b/modelcache/utils/ttl_utils.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_ttl.py b/tests/test_ttl.py new file mode 100644 index 0000000..3db4c93 --- /dev/null +++ b/tests/test_ttl.py @@ -0,0 +1,53 @@ +"""Tests for TTL + LRU eviction.""" + +import time +import pytest +from modelcache.manager.ttl_manager import TTLManager + + +def test_entry_not_expired_immediately(): + mgr = TTLManager(default_ttl=60) + mgr.register("key1") + assert not mgr.is_expired("key1") + + +def test_entry_expires_after_ttl(): + mgr = TTLManager(default_ttl=1, cleanup_interval=0) + mgr.register("key1") + time.sleep(1.1) + assert mgr.is_expired("key1") + + +def test_ttl_zero_never_expires(): + mgr = TTLManager(default_ttl=0) + mgr.register("key1") + time.sleep(0.1) + assert not mgr.is_expired("key1") + + +def test_lru_eviction(): + mgr = TTLManager(default_ttl=0, max_size=3, cleanup_interval=0) + mgr.register("a") + mgr.register("b") + mgr.register("c") + mgr.register("d") # should evict "a" + assert mgr.is_expired("a") # no longer tracked = treated as expired + assert not mgr.is_expired("b") + + +def test_purge_expired(): + mgr = TTLManager(default_ttl=1, cleanup_interval=0) + mgr.register("x") + mgr.register("y") + time.sleep(1.1) + count = mgr.purge_expired() + assert count == 2 + + +def test_stats(): + mgr = TTLManager(default_ttl=60, max_size=100) + mgr.register("k1") + mgr.register("k2") + stats = mgr.stats() + assert stats["total_entries"] == 2 + assert stats["active_entries"] == 2 \ No newline at end of file