Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 158 additions & 0 deletions lib/crewai-tools/src/crewai_tools/security/safe_requests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
"""SSRF-safe HTTP fetching for crewai-tools.

:func:`validate_url` checks the URL it is handed, but it cannot protect a
fetch on its own: ``requests`` re-resolves DNS at connect time and follows
redirects automatically, so a public-looking host that 302-redirects to an
internal address (or that rebinds DNS between validation and connect) reaches
the internal target without ever being re-checked.

This module closes both gaps at the connection layer:

* :class:`SSRFProtectedAdapter` re-runs :func:`validate_url` for every request
it sends. ``requests.Session.send`` invokes the adapter once per redirect
hop, so each ``Location`` target is validated before it is followed.
* The adapter's connections validate the *actual* peer IP immediately after
the socket connects. The IP that was authorised is therefore the IP the
connection uses, removing the DNS time-of-check/time-of-use gap that
:func:`validate_url`'s own ``getaddrinfo`` call leaves open.

Use :func:`safe_get` (or :func:`create_safe_session`) instead of calling
``requests.get`` directly from any tool that fetches a user- or
LLM-controlled URL.
"""

from __future__ import annotations

from typing import Any

import requests
from requests.adapters import DEFAULT_POOLBLOCK, HTTPAdapter
from urllib3.connection import HTTPConnection, HTTPSConnection
from urllib3.connectionpool import HTTPConnectionPool, HTTPSConnectionPool
from urllib3.poolmanager import PoolManager

from crewai_tools.security.safe_path import (
_is_escape_hatch_enabled,
_is_private_or_reserved,
validate_url,
)


def _assert_safe_peer(sock: Any) -> None:
"""Raise if a connected socket's peer is a private/reserved address.

Validating the real peer (rather than a separately resolved IP) is what
defeats DNS rebinding: the address we connected to is the address we check.
"""
if _is_escape_hatch_enabled():
return
try:
peer = sock.getpeername()
except OSError:
return
ip_str = str(peer[0])
if _is_private_or_reserved(ip_str):
raise ValueError(
f"Connection resolved to private/reserved IP {ip_str}. "
f"Access to internal networks is not allowed (possible SSRF via "
f"redirect or DNS rebinding)."
)


class _SafeHTTPConnection(HTTPConnection):
def connect(self) -> None:
super().connect()
_assert_safe_peer(self.sock)


class _SafeHTTPSConnection(HTTPSConnection):
def connect(self) -> None:
super().connect()
_assert_safe_peer(self.sock)


class _SafeHTTPConnectionPool(HTTPConnectionPool):
ConnectionCls = _SafeHTTPConnection


class _SafeHTTPSConnectionPool(HTTPSConnectionPool):
ConnectionCls = _SafeHTTPSConnection


_SAFE_POOL_CLASSES = {
"http": _SafeHTTPConnectionPool,
"https": _SafeHTTPSConnectionPool,
}


class _SafePoolManager(PoolManager):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.pool_classes_by_scheme = _SAFE_POOL_CLASSES


class SSRFProtectedAdapter(HTTPAdapter):
"""Transport adapter that re-validates every hop and pins the peer IP.

``validate_url`` runs on each ``send`` — including every redirect hop
``requests`` follows — and the underlying connections reject any socket
that ends up connected to a private/reserved address.
"""

def init_poolmanager(
self,
connections: int,
maxsize: int,
block: bool = DEFAULT_POOLBLOCK,
**pool_kwargs: Any,
) -> None:
self.poolmanager = _SafePoolManager(
num_pools=connections,
maxsize=maxsize,
block=block,
**pool_kwargs,
)

def send(self, request: Any, *args: Any, **kwargs: Any) -> Any:
# Re-validate the target of every request the session sends. Because
# Session.send calls this once per redirect hop, each Location is
# checked before it is followed.
validate_url(request.url)
return super().send(request, *args, **kwargs)


def create_safe_session() -> requests.Session:
"""Return a ``requests.Session`` that is hardened against SSRF.

The session validates every request (and redirect hop) and pins
connections to the validated peer IP.
"""
session = requests.Session()
adapter = SSRFProtectedAdapter()
session.mount("http://", adapter)
session.mount("https://", adapter)
return session


def safe_get(url: str, **kwargs: Any) -> requests.Response:
"""Perform an SSRF-safe ``GET``.

Drop-in replacement for ``requests.get`` for tools that fetch a
user- or LLM-controlled URL. Validates the initial URL and every redirect
hop, and rejects connections that land on private/reserved addresses.

Args:
url: The URL to fetch.
**kwargs: Forwarded to ``Session.get`` (``headers``, ``cookies``,
``timeout``, ...).

Returns:
The ``requests.Response``.

Raises:
ValueError: If the URL, a redirect target, or the connected peer is
not allowed.
"""
validate_url(url)
with create_safe_session() as session:
return session.get(url, **kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@

from crewai.tools import BaseTool
from pydantic import BaseModel, Field
import requests

from crewai_tools.security.safe_path import validate_url
from crewai_tools.security.safe_requests import safe_get


try:
Expand Down Expand Up @@ -83,8 +82,7 @@ def _run(
if website_url is None or css_element is None:
raise ValueError("Both website_url and css_element must be provided.")

website_url = validate_url(website_url)
page = requests.get(
page = safe_get(
website_url,
headers=self.headers,
cookies=self.cookies if self.cookies else {},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
from typing import Any

from pydantic import Field
import requests

from crewai_tools.security.safe_path import validate_url
from crewai_tools.security.safe_requests import safe_get


try:
Expand Down Expand Up @@ -75,8 +74,7 @@ def _run(
if website_url is None:
raise ValueError("Website URL must be provided.")

website_url = validate_url(website_url)
page = requests.get(
page = safe_get(
website_url,
timeout=15,
headers=self.headers,
Expand Down
124 changes: 124 additions & 0 deletions lib/crewai-tools/tests/utilities/test_safe_requests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
"""Tests for SSRF-safe HTTP fetching (redirect + DNS-rebinding protection)."""

from __future__ import annotations

import http.server
import socketserver
import threading

import pytest
import requests

from crewai_tools.security import safe_requests
from crewai_tools.security.safe_requests import (
SSRFProtectedAdapter,
create_safe_session,
safe_get,
)


INTERNAL_BODY = b"INTERNAL-ONLY-SECRET"


class _InternalHandler(http.server.BaseHTTPRequestHandler):
def do_GET(self):
self.send_response(200)
self.send_header("Content-Type", "text/plain")
self.end_headers()
self.wfile.write(INTERNAL_BODY)

def log_message(self, *args): # silence
pass


def _serve(handler):
"""Start a localhost server on an ephemeral port; return (server, port)."""
server = socketserver.TCPServer(("127.0.0.1", 0), handler)
port = server.server_address[1]
threading.Thread(target=server.serve_forever, daemon=True).start()
return server, port


class TestRedirectRevalidation:
"""Layer 1: validate_url runs on every send, including each redirect hop.

``requests.Session.send`` calls ``adapter.send`` once per redirect hop, so
re-validating in ``send`` is what blocks a 302 to an internal target.
"""

def test_adapter_revalidates_before_any_network_call(self, monkeypatch):
calls: list[str] = []

def spy(url: str) -> str:
calls.append(url)
if "internal.target" in url:
raise ValueError("URL resolves to private/reserved IP")
return url

monkeypatch.setattr(safe_requests, "validate_url", spy)

adapter = SSRFProtectedAdapter()
# Internal redirect target: send() must reject it before ever calling
# the real transport (super().send is never reached).
req = requests.Request("GET", "http://internal.target/").prepare()
with pytest.raises(ValueError, match="private/reserved"):
adapter.send(req)
assert calls == ["http://internal.target/"]

def test_session_mounts_protected_adapter(self):
session = create_safe_session()
assert isinstance(session.get_adapter("http://x"), SSRFProtectedAdapter)
assert isinstance(session.get_adapter("https://x"), SSRFProtectedAdapter)


class _FakeSock:
def __init__(self, peer):
self._peer = peer

def getpeername(self):
return self._peer


class TestConnectionPeerGuard:
"""Layer 2: the connection rejects an internal peer IP at connect time.

This is what closes the validate-then-connect DNS-rebinding gap — the IP
the socket actually connected to is the IP that gets checked, so a host
that resolved public at validation time but connects internal is blocked.
"""

def test_safe_get_blocks_direct_internal(self):
# No network: validate_url rejects 127.0.0.1 at the URL layer first.
with pytest.raises(ValueError, match="private/reserved"):
safe_get("http://127.0.0.1:9/", timeout=10)

def test_assert_safe_peer_blocks_private(self):
with pytest.raises(ValueError, match="private/reserved"):
safe_requests._assert_safe_peer(_FakeSock(("127.0.0.1", 80)))

def test_assert_safe_peer_blocks_metadata(self):
with pytest.raises(ValueError, match="private/reserved"):
safe_requests._assert_safe_peer(_FakeSock(("169.254.169.254", 80)))

def test_assert_safe_peer_allows_public(self):
# A public IP must not raise.
safe_requests._assert_safe_peer(_FakeSock(("93.184.216.34", 80)))

def test_assert_safe_peer_respects_escape_hatch(self, monkeypatch):
monkeypatch.setenv("CREWAI_TOOLS_ALLOW_UNSAFE_PATHS", "true")
# No raise even for a private peer when the escape hatch is on.
safe_requests._assert_safe_peer(_FakeSock(("127.0.0.1", 80)))

def test_connection_validates_peer_after_connect(self, monkeypatch):
"""_SafeHTTPConnection.connect runs the peer guard after connecting."""
conn = safe_requests._SafeHTTPConnection("example.com")

def fake_super_connect(self):
# Simulate a rebind: we connected to an internal address.
self.sock = _FakeSock(("127.0.0.1", 80))

monkeypatch.setattr(
safe_requests.HTTPConnection, "connect", fake_super_connect
)
with pytest.raises(ValueError, match="private/reserved"):
conn.connect()
Loading