Skip to content
25 changes: 0 additions & 25 deletions src/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import logging
from functools import wraps

from airflow.utils import db
from airflow.utils.state import State

SECTION_NAME = "ergo"
Expand All @@ -21,26 +19,3 @@ def task_state(code):
return State.QUEUED
else:
return State.FAILED


def ergo_initdb(func):
from ergo.migrations.utils import initdb

prev_wrappers = getattr(func, '_wrappers', list())
if SECTION_NAME in prev_wrappers:
return func

@wraps(func)
def wrapper(*args, **kwargs):
try:
func(*args, **kwargs)
except Exception as e:
logger.warning('Ignoring error', exc_info=e)
initdb()

wrapper._wrappers = list(prev_wrappers) + list(SECTION_NAME)

return wrapper


db.upgradedb = ergo_initdb(db.upgradedb)
2 changes: 1 addition & 1 deletion src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from airflow.configuration import conf

from ergo import SECTION_NAME
SECTION_NAME = "ergo"


class Config(object):
Expand Down
7 changes: 3 additions & 4 deletions src/dags/aries_task_queuer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from datetime import timedelta
from datetime import datetime, timedelta

from airflow import DAG
from airflow.utils import timezone
from airflow.utils.dates import days_ago

from ergo.config import Config
from ergo.operators.sqs.sqs_task_pusher import SqsTaskPusherOperator
Expand All @@ -17,7 +16,7 @@
'depends_on_past': False,
'retries': 2,
'retry_delay': timedelta(minutes=1),
'start_date': days_ago(1),
'start_date': datetime(2024, 1, 1),

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🏢 Use timezone-aware start_date

A naive datetime can trigger scheduling warnings and run offsets; use Airflow’s timezone-aware datetime helper.

Suggested change
'start_date': datetime(2024, 1, 1),
'start_date': timezone.datetime(2024, 1, 1),

'priority_weight': 900,
}

Expand All @@ -30,7 +29,7 @@
'aries_ergo_task_queuer',
default_args=default_args,
is_paused_upon_creation=False,
schedule_interval=timedelta(seconds=10),
schedule=timedelta(seconds=10),
catchup=False,
max_active_runs=max_concurrent_runs,
dagrun_timeout=timedelta(minutes=5)
Expand Down
7 changes: 3 additions & 4 deletions src/dags/calipso_task_queuer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from datetime import timedelta
from datetime import datetime, timedelta

from airflow import DAG
from airflow.utils import timezone
from airflow.utils.dates import days_ago

from ergo.config import Config
from ergo.operators.sqs.sqs_task_pusher import SqsTaskPusherOperator
Expand All @@ -17,7 +16,7 @@
'depends_on_past': False,
'retries': 2,
'retry_delay': timedelta(minutes=1),
'start_date': days_ago(1),
'start_date': datetime(2024, 1, 1),

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🏢 Use timezone-aware start_date

A naive datetime can trigger scheduling warnings and run offsets; use Airflow’s timezone-aware datetime helper.

Suggested change
'start_date': datetime(2024, 1, 1),
'start_date': timezone.datetime(2024, 1, 1),

'priority_weight': 900,
}

Expand All @@ -30,7 +29,7 @@
'calipso_ergo_task_queuer',
default_args=default_args,
is_paused_upon_creation=False,
schedule_interval=timedelta(seconds=10),
schedule=timedelta(seconds=10),
catchup=False,
max_active_runs=max_concurrent_runs,
dagrun_timeout=timedelta(minutes=5)
Expand Down
11 changes: 5 additions & 6 deletions src/dags/dag_job_collector.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from datetime import timedelta
from datetime import datetime, timedelta

from airflow import DAG
from airflow.contrib.sensors.aws_sqs_sensor import SQSSensor
from airflow.providers.amazon.aws.sensors.sqs import SqsSensor
from airflow.utils import timezone
from airflow.utils.dates import days_ago

from ergo.config import Config
from ergo.operators.sqs.result_from_messages import \
Expand All @@ -16,7 +15,7 @@
'depends_on_past': False,
'retries': 10,
'retry_delay': timedelta(seconds=30),

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🏢 Use timezone-aware start_date

A naive datetime can trigger scheduling warnings and run offsets; use Airflow’s timezone-aware datetime helper.

Suggested change
'retry_delay': timedelta(seconds=30),
'start_date': timezone.datetime(2024, 1, 1),

'start_date': days_ago(1),
'start_date': datetime(2024, 1, 1),
'priority_weight': 900,
}

Expand All @@ -27,12 +26,12 @@
'ergo_job_collector',
default_args=default_args,
is_paused_upon_creation=False,
schedule_interval=timedelta(seconds=10),
schedule=timedelta(seconds=10),
catchup=False,
dagrun_timeout=timedelta(minutes=15),
max_active_runs=Config.max_runs_dag_job_collector
) as dag:
sqs_collector = SQSSensor(
sqs_collector = SqsSensor(
task_id=TASK_ID_SQS_COLLECTOR,
sqs_queue=sqs_queue_url,
max_messages=10,
Expand Down
7 changes: 3 additions & 4 deletions src/dags/selenium_task_queuer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from datetime import timedelta
from datetime import datetime, timedelta

from airflow import DAG
from airflow.utils import timezone
from airflow.utils.dates import days_ago

from ergo.config import Config
from ergo.operators.sqs.sqs_task_pusher import SqsTaskPusherOperator
Expand All @@ -17,7 +16,7 @@
'depends_on_past': False,
'retries': 2,
'retry_delay': timedelta(minutes=1),
'start_date': days_ago(1),
'start_date': datetime(2024, 1, 1),

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🏢 Use timezone-aware start_date

A naive datetime can trigger scheduling warnings and run offsets; use Airflow’s timezone-aware datetime helper.

Suggested change
'start_date': datetime(2024, 1, 1),
'start_date': timezone.datetime(2024, 1, 1),

'priority_weight': 900,
}

Expand All @@ -30,7 +29,7 @@
'selenium_ergo_task_queuer',
default_args=default_args,
is_paused_upon_creation=False,
schedule_interval=timedelta(seconds=10),
schedule=timedelta(seconds=10),
catchup=False,
max_active_runs=max_concurrent_runs,
dagrun_timeout=timedelta(minutes=5)
Expand Down
11 changes: 3 additions & 8 deletions src/links/ergo_task_detail.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from airflow.models.baseoperator import BaseOperatorLink
from flask import url_for
from airflow.sdk import BaseOperatorLink

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import urlencode for safe query construction

The link now builds a query string manually; import urlencode so parameters are encoded safely.

Suggested change
from airflow.sdk import BaseOperatorLink
from airflow.sdk import BaseOperatorLink
from urllib.parse import urlencode



class ErgoTaskDetailLink(BaseOperatorLink):
Expand All @@ -9,9 +8,5 @@ class ErgoTaskDetailLink(BaseOperatorLink):
name = 'Ergo'

def get_link(self, operator, dttm):
return url_for(
'ErgoView.task_detail',
ti_task_id=operator.task_id,
ti_dag_id=operator.dag_id,
ti_execution_date=dttm
)
# In Airflow 3, the www module is gone; return a simple URL path
return f'/ergo/task_detail?ti_task_id={operator.task_id}&ti_dag_id={operator.dag_id}&ti_execution_date={dttm}'

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🏢 URL-encode task detail parameters

dttm renders with spaces/timezone and can create malformed URLs. Build the query using urlencode and use an ISO string; also handle None to avoid parse failures in the view.

Suggested change
# In Airflow 3, the www module is gone; return a simple URL path
return f'/ergo/task_detail?ti_task_id={operator.task_id}&ti_dag_id={operator.dag_id}&ti_execution_date={dttm}'
query = urlencode({
"ti_task_id": operator.task_id,
"ti_dag_id": operator.dag_id,
"ti_execution_date": dttm.isoformat() if dttm else ""
})
return f"/ergo/task_detail?{query}"

8 changes: 6 additions & 2 deletions src/migrations/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def run_migrations_offline():
script output.

"""
url = config.get_main_option("sqlalchemy.url")
url = os.getenv('AIRFLOW__CORE__SQL_ALCHEMY_CONN', config.get_main_option("sqlalchemy.url"))
context.configure(
url=url,
target_metadata=target_metadata,
Expand All @@ -75,8 +75,12 @@ def run_migrations_online():
and associate a connection with the context.

"""
cfg = config.get_section(config.config_ini_section)
db_url = os.getenv('AIRFLOW__CORE__SQL_ALCHEMY_CONN')
if db_url:
cfg['sqlalchemy.url'] = db_url
connectable = engine_from_config(
config.get_section(config.config_ini_section),
cfg,
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
Comment on lines +83 to 86

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Guard against missing alembic config section

config.get_section(...) can return None; assigning to cfg['sqlalchemy.url'] would raise a TypeError and prevent migrations. Use a default dict when the section is missing.

Suggested change
cfg,
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
cfg = config.get_section(config.config_ini_section) or {}
db_url = os.getenv('AIRFLOW__CORE__SQL_ALCHEMY_CONN')
if db_url:
cfg['sqlalchemy.url'] = db_url

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def upgrade():
sa.Column('ti_task_id', sa.String(length=250), nullable=False),
sa.Column('ti_dag_id', sa.String(length=250), nullable=False),
sa.Column('ti_run_id', sa.String(length=250), nullable=False),
sa.ForeignKeyConstraint(['ti_task_id', 'ti_dag_id', 'ti_run_id'], ['task_instance.task_id', 'task_instance.dag_id', 'task_instance.run_id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('ti_task_id', 'ti_dag_id', 'ti_run_id', name='ix_unique_task_instance')
)
Expand Down
19 changes: 6 additions & 13 deletions src/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,12 @@
from functools import cached_property

from airflow.models.base import ID_LEN
from airflow.models.taskinstance import TaskInstance
from airflow.utils import timezone
from airflow.sdk import timezone
from airflow.utils.sqlalchemy import UtcDateTime
from airflow.utils.state import State
from ergo import JobResultStatus
from sqlalchemy import (Column, ForeignKey, ForeignKeyConstraint, Integer,
from sqlalchemy import (Column, ForeignKey, Integer,
String, Text, UniqueConstraint)
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import declarative_base
from sqlalchemy.orm import relationship

Base = declarative_base()
Expand Down Expand Up @@ -43,14 +41,9 @@ class ErgoTask(Base):
# task_instance = relationship(TaskInstance, back_populates='ergo_task')

__table_args__ = (
ForeignKeyConstraint(
(ti_task_id, ti_dag_id, ti_run_id),
(TaskInstance.task_id, TaskInstance.dag_id, TaskInstance.run_id),
ondelete='CASCADE'
),
UniqueConstraint(
ti_task_id, ti_dag_id, ti_run_id, name='ix_unique_task_instance'
)
),
)

def __str__(self):
Expand All @@ -76,8 +69,8 @@ class ErgoJob(Base):
unique=True
)
result_data = Column(Text, nullable=True)
result_code = Column(Integer, default=JobResultStatus.NONE,
nullable=False) # enum{JobResultStatus}
result_code = Column(Integer, default=0,
nullable=False) # enum{JobResultStatus} 0=NONE
_error_msg = Column('error_msg', Text, nullable=True)

created_at = Column(
Expand Down
9 changes: 3 additions & 6 deletions src/operators/deferred_job_result.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
from datetime import datetime, timedelta
from airflow.utils.db import provide_session
from airflow.utils.decorators import apply_defaults
from airflow.utils.session import provide_session
from airflow.utils.state import State
from airflow.models import BaseOperator
from airflow.sensors.base import BaseSensorOperator
from airflow.triggers.temporal import TimeDeltaTrigger
from airflow.sdk.bases.sensor import BaseSensorOperator
from airflow.providers.standard.triggers.temporal import TimeDeltaTrigger
from ergo.exceptions import ErgoFailedResultException
from ergo.models import ErgoJob, ErgoTask
from ergo.triggers.task_poll import TaskPollTrigger
from sqlalchemy.orm import joinedload
from airflow.triggers.temporal import TimeDeltaTrigger


class ErgoDeferredJobResult(BaseOperator):

@apply_defaults
def __init__(
self,
pusher_task_id: str,
Expand Down
8 changes: 3 additions & 5 deletions src/operators/ergo_task_producer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import json
from typing import Union, List, Tuple
from airflow.contrib.hooks.aws_sqs_hook import SQSHook
from airflow.providers.amazon.aws.hooks.sqs import SqsHook
from airflow.models import BaseOperator
from ergo.links.ergo_task_detail import ErgoTaskDetailLink
from airflow.utils.db import provide_session
from airflow.utils.decorators import apply_defaults
from airflow.utils.session import provide_session
from airflow.utils.state import State
from ergo.models import ErgoJob, ErgoTask
from ergo.config import Config
Expand All @@ -17,7 +16,6 @@ class ErgoTaskQueuerOperator(BaseOperator):

operator_extra_links = (ErgoTaskDetailLink(),)

@apply_defaults
def __init__(
self,
ergo_task_callable: callable = None,
Expand Down Expand Up @@ -126,7 +124,7 @@ def execute(self, context, session=None):
session.commit()

def _send_to_sqs(self, queue_url, task) -> Tuple[List, List]:
sqs_client = SQSHook(aws_conn_id=self.aws_conn_id).get_conn()
sqs_client = SqsHook(aws_conn_id=self.aws_conn_id).get_conn()
self.log.info('Trying to push a message on queue: %s\n', queue_url)
self.log.info('Request task: %s', task.task_id)
entries = [
Expand Down
4 changes: 1 addition & 3 deletions src/operators/sqs/result_from_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@

from airflow.models import BaseOperator
from airflow.utils import timezone
from airflow.utils.db import provide_session
from airflow.utils.decorators import apply_defaults
from airflow.utils.session import provide_session
from airflow.utils.state import State
from sqlalchemy.orm import joinedload

Expand All @@ -12,7 +11,6 @@


class JobResultFromMessagesOperator(BaseOperator):
@apply_defaults
def __init__(
self,
sqs_sensor_task_id: str,
Expand Down
8 changes: 3 additions & 5 deletions src/operators/sqs/sqs_task_pusher.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
from typing import List, Tuple

from airflow.contrib.hooks.aws_sqs_hook import SQSHook
from airflow.providers.amazon.aws.hooks.sqs import SqsHook
from airflow.models import BaseOperator
from airflow.utils.db import provide_session
from airflow.utils.decorators import apply_defaults
from airflow.utils.session import provide_session
from airflow.utils.state import State
from ergo.models import ErgoJob, ErgoTask

class SqsTaskPusherOperator(BaseOperator):
filter_ergo_task = ErgoTask.state.in_([State.SCHEDULED, State.UP_FOR_RESCHEDULE])

@apply_defaults
def __init__(
self,
task_id_collector: str,
Expand Down Expand Up @@ -66,7 +64,7 @@ def execute(self, context, session=None):

def _send_to_sqs(self, queue_url, query) -> Tuple[List, List]:
tasks = list(query)
sqs_client = SQSHook(aws_conn_id=self.aws_conn_id).get_conn()
sqs_client = SqsHook(aws_conn_id=self.aws_conn_id).get_conn()
self.log.info('Trying to push %d messages on queue: %s\n',
len(tasks), queue_url)
self.log.info('Request tasks: ' + '\n'.join([str(task) for task in tasks]))
Expand Down
6 changes: 2 additions & 4 deletions src/operators/task_producer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import json
from typing import Union

from airflow.contrib.hooks.aws_sqs_hook import SQSHook
from airflow.providers.amazon.aws.hooks.sqs import SqsHook
from airflow.models import BaseOperator
from airflow.utils.db import provide_session
from airflow.utils.decorators import apply_defaults
from airflow.utils.session import provide_session

from ergo.config import Config
from ergo.links.ergo_task_detail import ErgoTaskDetailLink
Expand All @@ -16,7 +15,6 @@ class ErgoTaskProducerOperator(BaseOperator):

operator_extra_links = (ErgoTaskDetailLink(),)

@apply_defaults
def __init__(
self,
ergo_task_callable: callable = None,
Expand Down
Loading