From 65b2a4ed4816998ef328482652f212faa0a78f24 Mon Sep 17 00:00:00 2001 From: Rat0323 Date: Wed, 3 Jun 2026 12:42:37 +0800 Subject: [PATCH 1/3] fix(openai-embedding): support both .cn and .com domains case-insensitively for SiliconFlow embedding models --- astrbot/core/provider/sources/openai_embedding_source.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/astrbot/core/provider/sources/openai_embedding_source.py b/astrbot/core/provider/sources/openai_embedding_source.py index 2be0165bb3..4006da3b82 100644 --- a/astrbot/core/provider/sources/openai_embedding_source.py +++ b/astrbot/core/provider/sources/openai_embedding_source.py @@ -79,7 +79,7 @@ def _embedding_kwargs(self) -> dict: if ( provider_api_base # Hard-code SiliconFlow API Base Prefix and Model Name, as it's just a temporary workaround. - and provider_api_base.strip().startswith("https://api.siliconflow.cn") + and "siliconflow" in provider_api_base.lower() and not self.model.lower().startswith("qwen") ): # For SiliconFlow and Non-Qwen models, dimensions parameter is not supported. so remove it. From 668c868a6d180fd544ffed37fd3a71931cc74c64 Mon Sep 17 00:00:00 2001 From: Rat0323 Date: Wed, 3 Jun 2026 13:13:36 +0800 Subject: [PATCH 2/3] refactor(openai-embedding): use urlparse to match SiliconFlow official domains strictly and preserve strip() --- .../sources/openai_embedding_source.py | 36 +++++++++++-------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/astrbot/core/provider/sources/openai_embedding_source.py b/astrbot/core/provider/sources/openai_embedding_source.py index 4006da3b82..ddd94ac356 100644 --- a/astrbot/core/provider/sources/openai_embedding_source.py +++ b/astrbot/core/provider/sources/openai_embedding_source.py @@ -1,3 +1,5 @@ +from urllib.parse import urlparse + import httpx from openai import AsyncOpenAI @@ -74,21 +76,27 @@ def _embedding_kwargs(self) -> dict: ) # Fix: SiliconFlow provider does not support dimensions parameter, except for Qwen models. - provider_api_base = self.provider_config.get("embedding_api_base") + provider_api_base = self.provider_config.get("embedding_api_base", "") provider_id = self.provider_config.get("id", "unknown_id") - if ( - provider_api_base - # Hard-code SiliconFlow API Base Prefix and Model Name, as it's just a temporary workaround. - and "siliconflow" in provider_api_base.lower() - and not self.model.lower().startswith("qwen") - ): - # For SiliconFlow and Non-Qwen models, dimensions parameter is not supported. so remove it. - removed_dimensions = kwargs.pop("dimensions", None) - if removed_dimensions is not None: - # Log a warning message if dimensions parameter is removed. - logger.warning( - f"dimensions not supported for model '{self.model}' of provider '{provider_id}' as SiliconFlow does not support this parameter for non-Qwen models: '{removed_dimensions}'." - ) + if provider_api_base: + api_base = provider_api_base.strip().lower() + # 兼容不带 http:// 或 https:// 头的 api_base,确保 urlparse 能正常解析 hostname + if not api_base.startswith(("http://", "https://")): + api_base = "https://" + api_base + hostname = urlparse(api_base).hostname or "" + + if hostname in { + "api.siliconflow.cn", + "api.siliconflow.com", + } and not self.model.lower().startswith("qwen"): + # For SiliconFlow and Non-Qwen models, dimensions parameter is not supported, so remove it. + removed_dimensions = kwargs.pop("dimensions", None) + if removed_dimensions is not None: + # Log a warning message if dimensions parameter is removed. + logger.warning( + f"dimensions not supported for model '{self.model}' of provider '{provider_id}' " + f"as SiliconFlow does not support this parameter for non-Qwen models: '{removed_dimensions}'." + ) return kwargs def get_dim(self) -> int: From 67b1e093ac93edadf15d63cf508067a6816a2c28 Mon Sep 17 00:00:00 2001 From: Rat0323 Date: Wed, 3 Jun 2026 13:20:58 +0800 Subject: [PATCH 3/3] refactor(openai-embedding): apply code quality improvements, exception handling and empty validation from 1.md --- .../sources/openai_embedding_source.py | 73 +++++++++++++++---- 1 file changed, 57 insertions(+), 16 deletions(-) diff --git a/astrbot/core/provider/sources/openai_embedding_source.py b/astrbot/core/provider/sources/openai_embedding_source.py index ddd94ac356..db0dfefe99 100644 --- a/astrbot/core/provider/sources/openai_embedding_source.py +++ b/astrbot/core/provider/sources/openai_embedding_source.py @@ -1,6 +1,7 @@ from urllib.parse import urlparse import httpx +import openai from openai import AsyncOpenAI from astrbot import logger @@ -20,12 +21,29 @@ def __init__(self, provider_config: dict, provider_settings: dict) -> None: super().__init__(provider_config, provider_settings) self.provider_config = provider_config self.provider_settings = provider_settings - proxy = provider_config.get("proxy", "") + provider_id = provider_config.get("id", "unknown_id") + + # 1. 强制校验 API Key (Fail-Fast) + api_key: str = provider_config.get("embedding_api_key", "") + if not api_key: + raise ValueError( + f"OpenAI embedding provider [{provider_id}] 配置错误: 缺少必需成 'embedding_api_key'" + ) + + # 2. 安全获取并转换 timeout 避免空字符串导致 int() 崩溃 + raw_timeout = provider_config.get("timeout", 20) + try: + timeout_val = int(raw_timeout) if raw_timeout else 20 + except (ValueError, TypeError): + timeout_val = 20 + + proxy = provider_config.get("proxy", "") http_client = None if proxy: logger.info(f"[OpenAI Embedding] {provider_id} Using proxy: {proxy}") http_client = httpx.AsyncClient(proxy=proxy) + api_base = ( provider_config.get("embedding_api_base", "https://api.openai.com/v1") .strip() @@ -35,34 +53,56 @@ def __init__(self, provider_config: dict, provider_settings: dict) -> None: if api_base and not api_base.endswith("/v1") and not api_base.endswith("/v4"): # /v4 see #5699 api_base = api_base + "/v1" + logger.info(f"[OpenAI Embedding] {provider_id} Using API Base: {api_base}") + self.client = AsyncOpenAI( - api_key=provider_config.get("embedding_api_key"), + api_key=api_key, base_url=api_base, - timeout=int(provider_config.get("timeout", 20)), + timeout=timeout_val, http_client=http_client, ) self.model = provider_config.get("embedding_model", "text-embedding-3-small") async def get_embedding(self, text: str) -> list[float]: """获取文本的嵌入""" + # 3. 拦截空文本防 400 报错 + if not text or not text.strip(): + raise ValueError("输入文本不能为空") + kwargs = self._embedding_kwargs() - embedding = await self.client.embeddings.create( - input=text, - model=self.model, - **kwargs, - ) - return embedding.data[0].embedding + + try: + embedding = await self.client.embeddings.create( + input=text, + model=self.model, + **kwargs, + ) + return embedding.data[0].embedding + except openai.OpenAIError as e: + # 4. 包装规范异常,使用 from e 保留原始调用栈 + raise Exception(f"OpenAI Embedding API 请求失败: {e}") from e async def get_embeddings(self, text: list[str]) -> list[list[float]]: """批量获取文本的嵌入""" + # 5. 拦截空列表和内部脏数据,与 get_embedding 逻辑保持一致 + if not text: + raise ValueError("批量输入列表不能为空") + + if any(not s or not s.strip() for s in text): + raise ValueError("批量输入文本列表中不能包含空文本") + kwargs = self._embedding_kwargs() - embeddings = await self.client.embeddings.create( - input=text, - model=self.model, - **kwargs, - ) - return [item.embedding for item in embeddings.data] + + try: + embeddings = await self.client.embeddings.create( + input=text, + model=self.model, + **kwargs, + ) + return [item.embedding for item in embeddings.data] + except openai.OpenAIError as e: + raise Exception(f"OpenAI Embedding API 批量请求失败: {e}") from e def _embedding_kwargs(self) -> dict: """构建嵌入请求的可选参数""" @@ -111,5 +151,6 @@ def get_dim(self) -> int: return 0 async def terminate(self): - if self.client: + """释放资源""" + if getattr(self, "client", None): await self.client.close()