Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions astrbot/core/config/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -1818,6 +1818,7 @@
"embedding_api_base": "",
"embedding_model": "",
"embedding_dimensions": 1024,
"embedding_dimensions_as_request_param": False,
"timeout": 20,
"proxy": "",
},
Expand Down Expand Up @@ -2190,6 +2191,12 @@
"hint": "嵌入向量的维度。根据模型不同,可能需要调整,请参考具体模型的文档。此配置项请务必填写正确,否则将导致向量数据库无法正常工作。",
"_special": "get_embedding_dim",
},
"embedding_dimensions_as_request_param": {
"description": "将嵌入维度作为请求参数发送",
"type": "bool",
"hint": "开启后会把嵌入维度作为 dimensions 参数发送给 OpenAI-compatible Embedding API。仅在模型明确支持该参数时开启。",
"invisible": True,
},
"embedding_model": {
"description": "嵌入模型",
"type": "string",
Expand Down
26 changes: 17 additions & 9 deletions astrbot/core/provider/sources/openai_embedding_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,23 @@ async def get_embeddings(self, text: list[str]) -> list[list[float]]:

def _embedding_kwargs(self) -> dict:
"""构建嵌入请求的可选参数"""
kwargs = {}
if "embedding_dimensions" in self.provider_config:
try:
kwargs["dimensions"] = int(self.provider_config["embedding_dimensions"])
except (ValueError, TypeError):
logger.warning(
f"embedding_dimensions in embedding configs is not a valid integer: '{self.provider_config['embedding_dimensions']}', ignored."
)
return kwargs
if not self.provider_config.get("embedding_dimensions_as_request_param", False):
return {}

if "embedding_dimensions" not in self.provider_config:
return {}

try:
dimensions = int(self.provider_config["embedding_dimensions"])
except (ValueError, TypeError):
logger.warning(
f"embedding_dimensions in embedding configs is not a valid integer: '{self.provider_config['embedding_dimensions']}', ignored."
)
return {}

if dimensions <= 0:
return {}
return {"dimensions": dimensions}

def get_dim(self) -> int:
"""获取向量的维度"""
Expand Down
2 changes: 2 additions & 0 deletions dashboard/src/components/shared/AstrBotConfig.vue
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ function hasVisibleItemsAfter(items, currentIndex) {
:plugin-name="pluginName"
:plugin-i18n="pluginI18n"
:config-key="getItemPath(key)"
:config-root="iterable"
:loading="loadingEmbeddingDim"
:show-fullscreen-btn="!!metadata[metadataKey].items[key]?.editor_mode"
@get-embedding-dim="getEmbeddingDimensions(iterable)"
Expand Down Expand Up @@ -322,6 +323,7 @@ function hasVisibleItemsAfter(items, currentIndex) {
:plugin-name="pluginName"
:plugin-i18n="pluginI18n"
:config-key="getItemPath(metadataKey)"
:config-root="iterable"
/>
</v-col>
</v-row>
Expand Down
32 changes: 32 additions & 0 deletions dashboard/src/components/shared/ConfigItemRenderer.vue
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,25 @@
>
{{ t('core.common.autoDetect') }}
</v-btn>
<v-tooltip
v-if="hasEmbeddingDimensionsRequestParam"
:text="t('core.common.embeddingDimensionsRequestParamHint')"
location="top"
>
<template #activator="{ props: tooltipProps }">
<v-switch
v-bind="tooltipProps"
:model-value="Boolean(configRoot.embedding_dimensions_as_request_param)"
@update:model-value="setEmbeddingDimensionsRequestParam"
:label="t('core.common.embeddingDimensionsRequestParam')"
color="primary"
inset
density="compact"
hide-details
class="dimensions-param-switch ml-1"
></v-switch>
</template>
</v-tooltip>
</div>
</template>

Expand Down Expand Up @@ -307,11 +326,24 @@ const listSelectItems = computed(() =>
: []
)

const hasEmbeddingDimensionsRequestParam = computed(() =>
props.configRoot
&& Object.prototype.hasOwnProperty.call(
props.configRoot,
'embedding_dimensions_as_request_param'
)
)

function toNumber(val) {
const n = parseFloat(val)
return isNaN(n) ? 0 : n
}

function setEmbeddingDimensionsRequestParam(val) {
if (!props.configRoot) return
props.configRoot.embedding_dimensions_as_request_param = Boolean(val)
}

function getLabel(itemMeta, index, option) {
const labels = getTranslatedLabels(itemMeta)
return labels ? labels[index] : option
Expand Down
2 changes: 2 additions & 0 deletions dashboard/src/i18n/locales/en-US/core/common.json
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
"no": "No",
"imagePreview": "Image Preview",
"autoDetect": "Auto Detect",
"embeddingDimensionsRequestParam": "Request Param",
"embeddingDimensionsRequestParamHint": "When enabled, the embedding dimension is sent as the API dimensions parameter. Disabled by default; enable only for models that explicitly support it.",
"dialog": {
"confirmTitle": "Confirm Action",
"confirmMessage": "Are you sure you want to perform this action?",
Expand Down
4 changes: 3 additions & 1 deletion dashboard/src/i18n/locales/ru-RU/core/common.json
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
"no": "Нет",
"imagePreview": "Предпросмотр изображения",
"autoDetect": "Автоопределение",
"embeddingDimensionsRequestParam": "Параметр API",
"embeddingDimensionsRequestParamHint": "Если включено, размерность эмбеддинга отправляется как параметр dimensions. По умолчанию выключено; включайте только для моделей, которые явно поддерживают этот параметр.",
"dialog": {
"confirmTitle": "Подтверждение",
"confirmMessage": "Вы уверены, что хотите выполнить это действие?",
Expand Down Expand Up @@ -130,4 +132,4 @@
"subtitle": "Файл FIRST_NOTICE.md не найден или пуст."
}
}
}
}
2 changes: 2 additions & 0 deletions dashboard/src/i18n/locales/zh-CN/core/common.json
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
"no": "否",
"imagePreview": "图片预览",
"autoDetect": "自动检测",
"embeddingDimensionsRequestParam": "请求参数",
"embeddingDimensionsRequestParamHint": "开启后会把嵌入维度作为 dimensions 参数发送给 API。默认关闭;仅在模型明确支持时开启。",
"dialog": {
"confirmTitle": "确认操作",
"confirmMessage": "你确定要执行此操作吗?",
Expand Down
104 changes: 104 additions & 0 deletions tests/test_openai_embedding_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from types import SimpleNamespace

import pytest

from astrbot.core.provider.sources.openai_embedding_source import (
OpenAIEmbeddingProvider,
)


class FakeEmbeddingsClient:
def __init__(self):
self.calls = []

async def create(self, **kwargs):
self.calls.append(kwargs)
input_value = kwargs["input"]
if isinstance(input_value, list):
data = [
SimpleNamespace(embedding=[float(index), 0.0, 1.0])
for index, _ in enumerate(input_value)
]
else:
data = [SimpleNamespace(embedding=[1.0, 2.0, 3.0])]
return SimpleNamespace(data=data)


@pytest.mark.asyncio
async def test_openai_embedding_does_not_send_dimensions_by_default():
provider = OpenAIEmbeddingProvider(
{
"id": "openai-compatible-embedding",
"embedding_api_key": "test-key",
"embedding_api_base": "https://example.com/v1",
"embedding_model": "BAAI/bge-m3",
"embedding_dimensions": 1024,
},
{},
)
fake_embeddings = FakeEmbeddingsClient()
provider.client = SimpleNamespace(embeddings=fake_embeddings)

embedding = await provider.get_embedding("hello")
embeddings = await provider.get_embeddings(["hello", "world"])

assert embedding == [1.0, 2.0, 3.0]
assert embeddings == [[0.0, 0.0, 1.0], [1.0, 0.0, 1.0]]
assert provider.get_dim() == 1024
assert fake_embeddings.calls == [
{"input": "hello", "model": "BAAI/bge-m3"},
{"input": ["hello", "world"], "model": "BAAI/bge-m3"},
]
Comment on lines +27 to +51
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Let's update the test to cover both cases:

  1. Non-OpenAI models or default dimensions (where dimensions should NOT be sent).
  2. OpenAI text-embedding-3 models with shortened dimensions (where dimensions SHOULD be sent).
@pytest.mark.asyncio
async def test_openai_embedding_dimensions_handling():
    # Case 1: Non-OpenAI model or default dimension -> should NOT send dimensions parameter
    provider_non_openai = OpenAIEmbeddingProvider(
        {
            "id": "openai-compatible-embedding",
            "embedding_api_key": "test-key",
            "embedding_api_base": "https://example.com/v1",
            "embedding_model": "BAAI/bge-m3",
            "embedding_dimensions": 1024,
        },
        {},
    )
    fake_embeddings_1 = FakeEmbeddingsClient()
    provider_non_openai.client = SimpleNamespace(embeddings=fake_embeddings_1)

    embedding = await provider_non_openai.get_embedding("hello")
    embeddings = await provider_non_openai.get_embeddings(["hello", "world"])

    assert embedding == [1.0, 2.0, 3.0]
    assert embeddings == [[0.0, 0.0, 1.0], [1.0, 0.0, 1.0]]
    assert provider_non_openai.get_dim() == 1024
    assert fake_embeddings_1.calls == [
        {"input": "hello", "model": "BAAI/bge-m3"},
        {"input": ["hello", "world"], "model": "BAAI/bge-m3"},
    ]

    # Case 2: OpenAI text-embedding-3 model with shortened dimension -> SHOULD send dimensions parameter
    provider_openai = OpenAIEmbeddingProvider(
        {
            "id": "openai-embedding",
            "embedding_api_key": "test-key",
            "embedding_api_base": "https://api.openai.com/v1",
            "embedding_model": "text-embedding-3-small",
            "embedding_dimensions": 512,
        },
        {},
    )
    fake_embeddings_2 = FakeEmbeddingsClient()
    provider_openai.client = SimpleNamespace(embeddings=fake_embeddings_2)

    await provider_openai.get_embedding("hello")
    assert fake_embeddings_2.calls == [
        {"input": "hello", "model": "text-embedding-3-small", "dimensions": 512},
    ]



@pytest.mark.asyncio
async def test_openai_embedding_sends_dimensions_when_explicitly_enabled():
provider = OpenAIEmbeddingProvider(
{
"id": "openai-compatible-embedding",
"embedding_api_key": "test-key",
"embedding_api_base": "https://api.openai.com/v1",
"embedding_model": "text-embedding-3-small",
"embedding_dimensions": 512,
"embedding_dimensions_as_request_param": True,
},
{},
)
fake_embeddings = FakeEmbeddingsClient()
provider.client = SimpleNamespace(embeddings=fake_embeddings)

embedding = await provider.get_embedding("hello")

assert embedding == [1.0, 2.0, 3.0]
assert provider.get_dim() == 512
assert fake_embeddings.calls == [
{
"input": "hello",
"model": "text-embedding-3-small",
"dimensions": 512,
},
]


@pytest.mark.asyncio
async def test_openai_embedding_omits_dimensions_when_dimension_not_configured():
provider = OpenAIEmbeddingProvider(
{
"id": "openai-compatible-embedding",
"embedding_api_key": "test-key",
"embedding_api_base": "https://example.com/v1",
"embedding_model": "BAAI/bge-m3",
"embedding_dimensions_as_request_param": True,
},
{},
)
fake_embeddings = FakeEmbeddingsClient()
provider.client = SimpleNamespace(embeddings=fake_embeddings)

embeddings = await provider.get_embeddings(["hello", "world"])

assert embeddings == [[0.0, 0.0, 1.0], [1.0, 0.0, 1.0]]
assert provider.get_dim() == 0
assert fake_embeddings.calls == [
{"input": ["hello", "world"], "model": "BAAI/bge-m3"},
]
Loading