diff --git a/src/google/adk/cli/cli_tools_click.py b/src/google/adk/cli/cli_tools_click.py index 60253454e2..70f7e4512d 100644 --- a/src/google/adk/cli/cli_tools_click.py +++ b/src/google/adk/cli/cli_tools_click.py @@ -19,6 +19,7 @@ from datetime import datetime import functools import hashlib +import importlib import json import logging import os @@ -26,6 +27,7 @@ import sys import tempfile import textwrap +from typing import Any from typing import Optional import click @@ -35,13 +37,13 @@ from . import cli_create from . import cli_deploy +from . import fast_api from .. import version from ..agents.run_config import StreamingMode from ..evaluation.constants import MISSING_EVAL_DEPENDENCIES_MESSAGE from ..features import FeatureName from ..features import override_feature_enabled from .cli import run_cli -from .fast_api import get_fast_api_app from .utils import envs from .utils import evals from .utils import logs @@ -1811,7 +1813,7 @@ async def _lifespan(app: FastAPI): fg="green", ) - app = get_fast_api_app( + app = fast_api.get_fast_api_app( agents_dir=agents_dir, session_service_uri=session_service_uri, artifact_service_uri=artifact_service_uri, @@ -1844,6 +1846,18 @@ async def _lifespan(app: FastAPI): server.run() +def _load_lifespan_handler(lifespan_path: str) -> Any: + """Dynamically import a lifespan handler from a string path.""" + try: + module_name, func_name = lifespan_path.rsplit(".", 1) + module = importlib.import_module(module_name) + return getattr(module, func_name) + except Exception as e: + raise click.ClickException( + f"Failed to load lifespan handler '{lifespan_path}': {e}" + ) from e + + @main.command("api_server") @feature_options() # The directory of agents, where each subdirectory is a single agent. @@ -1866,6 +1880,15 @@ async def _lifespan(app: FastAPI): "Automatically create a session if it doesn't exist when calling /run." ), ) +@click.option( + "--lifespan", + type=str, + default=None, + help=( + "Optional. The import path to a lifespan context manager (e.g.," + " 'path.to.module.lifespan_handler')." + ), +) def cli_api_server( agents_dir: str, eval_storage_uri: Optional[str] = None, @@ -1888,6 +1911,7 @@ def cli_api_server( extra_plugins: Optional[list[str]] = None, auto_create_session: bool = False, trigger_sources: Optional[list[str]] = None, + lifespan: str | None = None, ): """Starts a FastAPI server for agents. @@ -1902,8 +1926,13 @@ def cli_api_server( artifact_service_uri = artifact_service_uri or artifact_storage_uri logs.setup_adk_logger(getattr(logging, log_level.upper())) + if agents_dir and agents_dir not in sys.path: + sys.path.insert(0, agents_dir) + + lifespan_handler = _load_lifespan_handler(lifespan) if lifespan else None + config = uvicorn.Config( - get_fast_api_app( + fast_api.get_fast_api_app( agents_dir=agents_dir, session_service_uri=session_service_uri, artifact_service_uri=artifact_service_uri, @@ -1922,6 +1951,7 @@ def cli_api_server( extra_plugins=extra_plugins, auto_create_session=auto_create_session, trigger_sources=trigger_sources, + lifespan=lifespan_handler, ), host=host, port=port, @@ -2343,7 +2373,6 @@ def cli_migrate_session( " It can only be `root_agent` or `app`. (default: `root_agent`)" ), ) - @click.option( "--env_file", type=str, @@ -2419,7 +2448,6 @@ def cli_deploy_agent_engine( adk_app: str, adk_app_object: Optional[str], temp_folder: Optional[str], - env_file: str, requirements_file: str, absolutize_imports: bool, @@ -2460,7 +2488,6 @@ def cli_deploy_agent_engine( description=description, adk_app=adk_app, temp_folder=temp_folder, - env_file=env_file, requirements_file=requirements_file, absolutize_imports=absolutize_imports, diff --git a/tests/unittests/cli/utils/test_cli_tools_click.py b/tests/unittests/cli/utils/test_cli_tools_click.py index 0406442b80..d7af06ffdd 100644 --- a/tests/unittests/cli/utils/test_cli_tools_click.py +++ b/tests/unittests/cli/utils/test_cli_tools_click.py @@ -1138,7 +1138,7 @@ def test_cli_web_invokes_uvicorn( agents_dir = tmp_path / "agents" agents_dir.mkdir() monkeypatch.setattr( - cli_tools_click, "get_fast_api_app", lambda **_k: object() + "google.adk.cli.fast_api.get_fast_api_app", lambda **_k: object() ) runner = CliRunner() result = runner.invoke(cli_tools_click.main, ["web", str(agents_dir)]) @@ -1153,7 +1153,7 @@ def test_cli_api_server_invokes_uvicorn( agents_dir = tmp_path / "agents_api" agents_dir.mkdir() monkeypatch.setattr( - cli_tools_click, "get_fast_api_app", lambda **_k: object() + "google.adk.cli.fast_api.get_fast_api_app", lambda **_k: object() ) runner = CliRunner() result = runner.invoke(cli_tools_click.main, ["api_server", str(agents_dir)]) @@ -1161,6 +1161,38 @@ def test_cli_api_server_invokes_uvicorn( assert _patch_uvicorn.calls, "uvicorn.Server.run must be called" +def test_cli_api_server_passes_lifespan( + tmp_path: Path, _patch_uvicorn: _Recorder, monkeypatch: pytest.MonkeyPatch +) -> None: + """api_server should pass lifespan handler to get_fast_api_app.""" + agents_dir = tmp_path / "agents_api_lifespan" + agents_dir.mkdir() + lifespan_file = agents_dir / "dummy_lifespan.py" + lifespan_file.write_text(""" +from contextlib import asynccontextmanager +@asynccontextmanager +async def dummy_handler(app): + yield +""") + mock_get_app = _Recorder() + monkeypatch.setattr("google.adk.cli.fast_api.get_fast_api_app", mock_get_app) + runner = CliRunner() + result = runner.invoke( + cli_tools_click.main, + [ + "api_server", + str(agents_dir), + "--lifespan", + "dummy_lifespan.dummy_handler", + ], + ) + assert result.exit_code == 0, f"Output: {result.output}" + assert mock_get_app.calls + called_kwargs = mock_get_app.calls[0][1] + assert called_kwargs.get("lifespan") is not None + assert called_kwargs.get("lifespan").__name__ == "dummy_handler" + + def test_cli_web_passes_service_uris( tmp_path: Path, monkeypatch: pytest.MonkeyPatch, _patch_uvicorn: _Recorder ) -> None: @@ -1169,7 +1201,7 @@ def test_cli_web_passes_service_uris( agents_dir.mkdir() mock_get_app = _Recorder() - monkeypatch.setattr(cli_tools_click, "get_fast_api_app", mock_get_app) + monkeypatch.setattr("google.adk.cli.fast_api.get_fast_api_app", mock_get_app) runner = CliRunner() result = runner.invoke( @@ -1204,7 +1236,7 @@ def test_cli_web_warns_and_maps_deprecated_uris( agents_dir.mkdir() mock_get_app = _Recorder() - monkeypatch.setattr(cli_tools_click, "get_fast_api_app", mock_get_app) + monkeypatch.setattr("google.adk.cli.fast_api.get_fast_api_app", mock_get_app) runner = CliRunner() result = runner.invoke(