Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Literal, Optional, Union

import pyathena
from pydantic import Field, SecretStr
from pydantic import Field, SecretStr, model_validator
from soda_core.common.aws_credentials import AwsCredentials
from soda_core.common.data_source_connection import DataSourceConnection
from soda_core.common.logging_constants import soda_logger
Expand All @@ -21,10 +21,12 @@
DEFAULT_CATALOG = "awsdatacatalog"


# Currently the only supported authentication method.
class AthenaConnectionProperties(DataSourceConnectionProperties, ABC):
access_key_id: str = Field(..., description="AWS access key ID")
secret_access_key: SecretStr = Field(..., description="AWS secret access key")
# Static AWS credentials are optional. When omitted, the AWS SDK resolves
# credentials from its default provider chain (e.g. IAM Roles for Service
# Accounts / web identity federation, instance profiles, profile_name).
access_key_id: Optional[str] = Field(None, description="AWS access key ID")
secret_access_key: Optional[SecretStr] = Field(None, description="AWS secret access key")
region_name: str = Field(..., description="AWS region name")
staging_dir: str = Field(..., description="S3 staging directory")

Expand All @@ -36,6 +38,19 @@ class AthenaConnectionProperties(DataSourceConnectionProperties, ABC):
session_token: Optional[str] = Field(None, description="AWS session token")
profile_name: Optional[str] = Field(None, description="AWS profile name")

@model_validator(mode="after")
def _validate_credentials(self) -> "AthenaConnectionProperties":
# access_key_id and secret_access_key must be provided together. Providing
# only one is almost always a misconfiguration rather than an intentional
# fall-through to the AWS default credential provider chain.
if bool(self.access_key_id) != bool(self.secret_access_key):
raise ValueError(
"access_key_id and secret_access_key must be provided together. "
"Omit both to use the AWS default credential provider chain "
"(e.g. IAM Roles for Service Accounts / web identity federation)."
)
return self


class AthenaDataSource(DataSourceBase, ABC):
type: Literal["athena"] = Field("athena")
Expand All @@ -58,7 +73,7 @@ def _create_connection(
self.work_group = config.work_group
self.aws_credentials = AwsCredentials(
access_key_id=config.access_key_id,
secret_access_key=config.secret_access_key.get_secret_value(),
secret_access_key=(config.secret_access_key.get_secret_value() if config.secret_access_key else None),
role_arn=config.role_arn,
session_token=config.session_token,
region_name=config.region_name,
Expand Down
35 changes: 33 additions & 2 deletions soda-athena/tests/data_sources/test_athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pytest
from helpers.test_connection import TestConnection
from soda_core.common.data_source_impl import DataSourceImpl

# define environment variables used in test cases
ATHENA_ACCESS_KEY_ID = os.getenv("ATHENA_ACCESS_KEY_ID", None)
Expand Down Expand Up @@ -30,7 +31,9 @@
""",
),
TestConnection(
test_name="missing_credentials",
# access_key_id without secret_access_key is a misconfiguration: the two
# static credentials must be provided together (or both omitted).
test_name="partial_credentials",
connection_yaml_str=f"""
type: athena
name: ATHENA_TEST_DS
Expand All @@ -43,7 +46,7 @@
{f"work_group: {ATHENA_WORKGROUP}" if ATHENA_WORKGROUP else ""}
""",
valid_yaml=False,
expected_yaml_error="Field required [type=missing,",
expected_yaml_error="access_key_id and secret_access_key must be provided together",
),
TestConnection(
test_name="incorrect_secret_access_key",
Expand Down Expand Up @@ -83,3 +86,31 @@
@pytest.mark.parametrize("test_connection", test_connections, ids=[tc.test_name for tc in test_connections])
def test_athena_connections(test_connection: TestConnection, monkeypatch: pytest.MonkeyPatch):
test_connection.test(monkeypatch=monkeypatch)


def test_athena_connection_without_static_credentials_parses():
"""
Omitting access_key_id/secret_access_key must be valid: the AWS SDK then resolves
credentials via its default provider chain (e.g. IAM Roles for Service Accounts /
web identity federation). This is parsing-only and does not open a live connection,
so it runs without AWS access. Regression test for V4 onboarding requiring keys.
"""
from soda_core.common.logs import Logs
from soda_core.common.yaml import DataSourceYamlSource

yaml_str = f"""
type: athena
name: ATHENA_TEST_DS
connection:
staging_dir: {ATHENA_S3_TEST_DIR}
region_name: {ATHENA_REGION_NAME}
catalog: {ATHENA_CATALOG}
"""

logs = Logs()
data_source_impl = DataSourceImpl.from_yaml_source(DataSourceYamlSource.from_str(yaml_str=yaml_str))

assert data_source_impl is not None
assert not logs.has_errors
assert data_source_impl.data_source_model.connection_properties.access_key_id is None
assert data_source_impl.data_source_model.connection_properties.secret_access_key is None
Loading