diff --git a/aworld/models/llm_http_handler.py b/aworld/models/llm_http_handler.py index 9e410fbe6..429418b74 100644 --- a/aworld/models/llm_http_handler.py +++ b/aworld/models/llm_http_handler.py @@ -56,6 +56,39 @@ def __init__( if headers: self.headers.update(headers) + # Shared aiohttp session to prevent memory leaks + # One session per handler instance, not per request + self._session: Optional[Any] = None + + async def _get_session(self): + """Get or create the shared aiohttp session. + + This method ensures we reuse a single session across all requests, + preventing memory leaks from creating/destroying sessions repeatedly. + + Returns: + aiohttp.ClientSession: The shared session instance. + """ + import aiohttp + if self._session is None or self._session.closed: + # Create session with connection pooling + connector = aiohttp.TCPConnector( + limit=100, # Max connections + limit_per_host=30, # Max connections per host + ) + self._session = aiohttp.ClientSession(connector=connector) + return self._session + + async def close(self): + """Close the shared aiohttp session. + + Call this method when the handler is no longer needed to properly + clean up resources. + """ + if self._session and not self._session.closed: + await self._session.close() + self._session = None + def _parse_sse_line(self, line: bytes) -> Optional[Dict[str, Any]]: """Parse a Server-Sent Events (SSE) line. @@ -178,8 +211,8 @@ async def _make_async_request_stream( if headers: request_headers.update(headers) - # Create an independent session and keep it open - session = aiohttp.ClientSession() + # Use the shared session instead of creating a new one + session = await self._get_session() try: response = await session.post( url, @@ -215,9 +248,7 @@ async def _make_async_request_stream( except Exception as e: logger.error(f"Error in stream: {str(e)}") raise - finally: - # Ensure the session is eventually closed - await session.close() + # Note: We don't close the session here as it's shared and reused async def _make_async_request( self, @@ -237,21 +268,21 @@ async def _make_async_request( Raises: aiohttp.ClientError: If the request fails. """ - import aiohttp url = f"{self.base_url}/{endpoint.lstrip('/')}" request_headers = self.headers.copy() if headers: request_headers.update(headers) - async with aiohttp.ClientSession() as session: - async with session.post( - url, - headers=request_headers, - json=data, - timeout=self.timeout, - ) as response: - response.raise_for_status() - return await response.json() + # Use the shared session instead of creating a new one + session = await self._get_session() + async with session.post( + url, + headers=request_headers, + json=data, + timeout=self.timeout, + ) as response: + response.raise_for_status() + return await response.json() def sync_call( self, diff --git a/aworld/models/openai_tokenizer.py b/aworld/models/openai_tokenizer.py index b1319de48..0f60947ba 100644 --- a/aworld/models/openai_tokenizer.py +++ b/aworld/models/openai_tokenizer.py @@ -34,15 +34,29 @@ ENDOFTEXT: 100256, } +# Global cache to prevent memory leaks from repeatedly loading BPE files +_BPE_CACHE = {} + def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]: - """Load tiktoken BPE file similar to qwen_tokenizer.""" + """Load tiktoken BPE file with caching to prevent memory leaks.""" + # Check cache first + if tiktoken_bpe_file in _BPE_CACHE: + return _BPE_CACHE[tiktoken_bpe_file] + + # Load and decode file with open(tiktoken_bpe_file, 'rb') as f: contents = f.read() - return { - base64.b64decode(token): int(rank) for token, rank in (line.split() for line in contents.splitlines() if line) + + result = { + base64.b64decode(token): int(rank) + for token, rank in (line.split() for line in contents.splitlines() if line) } + # Cache the result + _BPE_CACHE[tiktoken_bpe_file] = result + return result + class OpenAITokenizer: """OpenAI tokenizer using local tiktoken file.""" diff --git a/aworld/models/qwen_tokenizer.py b/aworld/models/qwen_tokenizer.py index 73c03afca..1d3775451 100644 --- a/aworld/models/qwen_tokenizer.py +++ b/aworld/models/qwen_tokenizer.py @@ -45,14 +45,29 @@ )) SPECIAL_TOKENS_SET = set(t for i, t in SPECIAL_TOKENS) +# Global cache to prevent memory leaks from repeatedly loading BPE files +_BPE_CACHE = {} + def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]: + """Load tiktoken BPE file with caching to prevent memory leaks.""" + # Check cache first + if tiktoken_bpe_file in _BPE_CACHE: + return _BPE_CACHE[tiktoken_bpe_file] + + # Load and decode file with open(tiktoken_bpe_file, 'rb') as f: contents = f.read() - return { - base64.b64decode(token): int(rank) for token, rank in (line.split() for line in contents.splitlines() if line) + + result = { + base64.b64decode(token): int(rank) + for token, rank in (line.split() for line in contents.splitlines() if line) } + # Cache the result + _BPE_CACHE[tiktoken_bpe_file] = result + return result + class QWenTokenizer: """QWen tokenizer.""" diff --git a/aworld/models/utils.py b/aworld/models/utils.py index aa1678b93..ae98eed22 100644 --- a/aworld/models/utils.py +++ b/aworld/models/utils.py @@ -11,6 +11,34 @@ from aworld.models.openai_tokenizer import openai_tokenizer from aworld.utils import import_package +# Global cache for tiktoken encodings to prevent memory leaks +_TIKTOKEN_ENCODING_CACHE = {} + + +def _get_cached_tiktoken_encoding(model: str): + """ + Get cached tiktoken encoding to prevent memory leaks. + + Args: + model: Model name (e.g., 'gpt-4o', 'claude-3-opus') + + Returns: + Cached tiktoken encoding object + """ + if model not in _TIKTOKEN_ENCODING_CACHE: + import tiktoken + try: + _TIKTOKEN_ENCODING_CACHE[model] = tiktoken.encoding_for_model(model) + logger.debug(f"Created and cached tiktoken encoding for model: {model}") + except KeyError: + logger.debug(f"{model} model not found. Using cl100k_base encoding.") + # Cache cl100k_base if not already cached + if "cl100k_base" not in _TIKTOKEN_ENCODING_CACHE: + _TIKTOKEN_ENCODING_CACHE["cl100k_base"] = tiktoken.get_encoding("cl100k_base") + # Reuse cl100k_base for this model + _TIKTOKEN_ENCODING_CACHE[model] = _TIKTOKEN_ENCODING_CACHE["cl100k_base"] + return _TIKTOKEN_ENCODING_CACHE[model] + class ModelUtils: """Utility class for model-related operations""" @@ -265,37 +293,26 @@ def usage_process(usage: Dict[str, Union[int, Dict[str, int]]] = {}, context: Co def num_tokens_from_string(string: str, model: str = "openai"): """Return the number of tokens used by a list of messages.""" - import tiktoken - if model.lower() == "qwen": encoding = qwen_tokenizer elif model.lower() == "openai": encoding = openai_tokenizer else: - try: - encoding = tiktoken.encoding_for_model(model) - except KeyError: - logger.debug( - f"{model} model not found. Using cl100k_base encoding.") - encoding = tiktoken.get_encoding("cl100k_base") + # Use cached encoding to prevent memory leaks + encoding = _get_cached_tiktoken_encoding(model) return len(encoding.encode(string)) def num_tokens_from_messages(messages, model="openai"): """Return the number of tokens used by a list of messages.""" import_package("tiktoken") - import tiktoken if model.lower() == "qwen": encoding = qwen_tokenizer elif model.lower() == "openai": encoding = openai_tokenizer else: - try: - encoding = tiktoken.encoding_for_model(model) - except KeyError: - logger.warning( - f"{model} model not found. Using cl100k_base encoding.") - encoding = tiktoken.get_encoding("cl100k_base") + # Use cached encoding to prevent memory leaks + encoding = _get_cached_tiktoken_encoding(model) tokens_per_message = 3 tokens_per_name = 1 @@ -316,19 +333,14 @@ def num_tokens_from_messages(messages, model="openai"): def truncate_tokens_from_messages(messages: List[Dict[str, Any]], max_tokens: int, keep_both_sides: bool = False, model: str = "gpt-4o"): import_package("tiktoken") - import tiktoken if model.lower() == "qwen": return qwen_tokenizer.truncate(messages, max_tokens, keep_both_sides) elif model.lower() == "openai": return openai_tokenizer.truncate(messages, max_tokens, keep_both_sides) - try: - encoding = tiktoken.encoding_for_model(model) - except KeyError: - logger.warning(f"{model} model not found. Using cl100k_base encoding.") - encoding = tiktoken.get_encoding("cl100k_base") - + # Use cached encoding to prevent memory leaks + encoding = _get_cached_tiktoken_encoding(model) return encoding.truncate(messages, max_tokens, keep_both_sides)