Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions mem0/configs/embeddings/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from abc import ABC
from typing import Dict, Optional, Union
from typing import Any, Dict, Optional, Union

import httpx

Expand All @@ -25,7 +25,7 @@ def __init__(
model_kwargs: Optional[dict] = None,
huggingface_base_url: Optional[str] = None,
# AzureOpenAI specific
azure_kwargs: Optional[AzureConfig] = {},
azure_kwargs: Optional[Union[AzureConfig, Dict[str, Any]]] = None,
http_client_proxies: Optional[Union[Dict, str]] = None,
# VertexAI specific
vertex_credentials_json: Optional[str] = None,
Expand Down Expand Up @@ -89,7 +89,7 @@ def __init__(
self.model_kwargs = model_kwargs or {}
self.huggingface_base_url = huggingface_base_url
# AzureOpenAI specific
self.azure_kwargs = AzureConfig(**azure_kwargs) or {}
self.azure_kwargs = azure_kwargs if isinstance(azure_kwargs, AzureConfig) else AzureConfig(**(azure_kwargs or {}))

# VertexAI specific
self.vertex_credentials_json = vertex_credentials_json
Expand All @@ -107,4 +107,3 @@ def __init__(
self.aws_access_key_id = aws_access_key_id
self.aws_secret_access_key = aws_secret_access_key
self.aws_region = aws_region or os.environ.get("AWS_REGION") or "us-west-2"

3 changes: 2 additions & 1 deletion mem0/proxy/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(self, mem0_client):
def create(
self,
model: str,
messages: List = [],
messages: Optional[List] = None,
# Mem0 arguments
user_id: Optional[str] = None,
agent_id: Optional[str] = None,
Expand Down Expand Up @@ -100,6 +100,7 @@ def create(
f"Model '{model}' does not support function calling. Please use a model that supports function calling."
)

messages = messages or []
prepared_messages = self._prepare_messages(messages)
if prepared_messages[-1]["role"] == "user":
self._async_add_to_memory(messages, user_id, agent_id, run_id, metadata, filters)
Expand Down
7 changes: 7 additions & 0 deletions tests/embeddings/test_azure_openai_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ def test_embed_text(mock_openai_client):
assert embedding == [0.1, 0.2, 0.3]


def test_base_embedder_config_azure_kwargs_default_is_isolated():
first = BaseEmbedderConfig()
second = BaseEmbedderConfig()

assert first.azure_kwargs is not second.azure_kwargs


@pytest.mark.parametrize(
"default_headers, expected_header",
[(None, None), ({"Test": "test_value"}, "test_value"), ({}, None)],
Expand Down
7 changes: 7 additions & 0 deletions tests/test_proxy.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
from unittest.mock import Mock, patch

import pytest
Expand Down Expand Up @@ -60,6 +61,12 @@ def test_chat_initialization(mock_memory_client):
assert isinstance(chat.completions, Completions)


def test_completions_create_messages_default_is_not_mutable():
signature = inspect.signature(Completions.create)

assert signature.parameters["messages"].default is None


def test_completions_create(mock_memory_client, mock_litellm):
completions = Completions(mock_memory_client)

Expand Down