diff --git a/server/.env.example b/server/.env.example index 114db1239b..29a788681f 100644 --- a/server/.env.example +++ b/server/.env.example @@ -10,6 +10,8 @@ POSTGRES_DB= POSTGRES_USER= POSTGRES_PASSWORD= POSTGRES_COLLECTION_NAME= +# Optional: if your Postgres connection requires sslmode +# POSTGRES_SSLMODE=require ADMIN_API_KEY= JWT_SECRET= diff --git a/server/db.py b/server/db.py index a142374b08..94a6dd0a3f 100644 --- a/server/db.py +++ b/server/db.py @@ -10,7 +10,12 @@ def _build_database_url() -> str: user = os.environ.get("POSTGRES_USER", "postgres") password = os.environ.get("POSTGRES_PASSWORD", "postgres") db = os.environ.get("APP_DB_NAME", "mem0_app") - return f"postgresql+psycopg://{user}:{password}@{host}:{port}/{db}" + sslmode = os.environ.get("POSTGRES_SSLMODE") + + url = f"postgresql+psycopg://{user}:{password}@{host}:{port}/{db}" + if sslmode: + url += f"?sslmode={sslmode}" + return url engine = create_engine(_build_database_url(), pool_pre_ping=True) diff --git a/server/main.py b/server/main.py index 098712bf11..91d1eb19d7 100644 --- a/server/main.py +++ b/server/main.py @@ -109,6 +109,7 @@ def _warn_if_unconfigured() -> None: POSTGRES_USER = os.environ.get("POSTGRES_USER", "postgres") POSTGRES_PASSWORD = os.environ.get("POSTGRES_PASSWORD", "postgres") POSTGRES_COLLECTION_NAME = os.environ.get("POSTGRES_COLLECTION_NAME", "memories") +POSTGRES_SSLMODE = os.environ.get("POSTGRES_SSLMODE") OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY") HISTORY_DB_PATH = os.environ.get("HISTORY_DB_PATH", "/app/history/history.db") @@ -126,6 +127,7 @@ def _warn_if_unconfigured() -> None: "user": POSTGRES_USER, "password": POSTGRES_PASSWORD, "collection_name": POSTGRES_COLLECTION_NAME, + "sslmode": POSTGRES_SSLMODE, }, }, "llm": { diff --git a/tests/test_server_config.py b/tests/test_server_config.py new file mode 100644 index 0000000000..60b1b6d522 --- /dev/null +++ b/tests/test_server_config.py @@ -0,0 +1,68 @@ +import importlib +import os +from unittest.mock import patch + +import pytest + +pytest.importorskip("fastapi", reason="fastapi not installed") + + +@pytest.fixture(autouse=True) +def mock_db_dependencies(): + """Mock sqlalchemy create_engine globally to avoid loading psycopg dialect.""" + with patch("sqlalchemy.create_engine") as mock_create: + yield mock_create + + +def test_build_database_url_without_sslmode(): + env_mock = { + "POSTGRES_HOST": "localhost", + "POSTGRES_PORT": "5432", + "POSTGRES_USER": "testuser", + "POSTGRES_PASSWORD": "testpassword", + "APP_DB_NAME": "testdb", + } + with patch.dict(os.environ, env_mock, clear=True): + import server.db as server_db + importlib.reload(server_db) + url = server_db._build_database_url() + assert "sslmode" not in url + assert url == "postgresql+psycopg://testuser:testpassword@localhost:5432/testdb" + + +def test_build_database_url_with_sslmode(): + env_mock = { + "POSTGRES_HOST": "localhost", + "POSTGRES_PORT": "5432", + "POSTGRES_USER": "testuser", + "POSTGRES_PASSWORD": "testpassword", + "APP_DB_NAME": "testdb", + "POSTGRES_SSLMODE": "require", + } + with patch.dict(os.environ, env_mock, clear=True): + import server.db as server_db + importlib.reload(server_db) + url = server_db._build_database_url() + assert "?sslmode=require" in url + assert url == "postgresql+psycopg://testuser:testpassword@localhost:5432/testdb?sslmode=require" + + +def test_server_main_config_sslmode(): + env_mock = { + "OPENAI_API_KEY": "fake-key", + "ADMIN_API_KEY": "admin-key", + "POSTGRES_SSLMODE": "prefer", + "AUTH_DISABLED": "true", + } + # Mock Memory.from_config so server/main doesn't try to instantiate a real memory object + with patch("mem0.Memory.from_config"): + with patch.dict(os.environ, env_mock): + import server.main as server_main + importlib.reload(server_main) + + # Verify the global variable is set + assert server_main.POSTGRES_SSLMODE == "prefer" + + # Verify the DEFAULT_CONFIG contains sslmode + pg_config = server_main.DEFAULT_CONFIG["vector_store"]["config"] + assert pg_config["sslmode"] == "prefer"