From cec7570dde936297fda25c4f5b472835a913267a Mon Sep 17 00:00:00 2001 From: Max Dallabetta Date: Mon, 8 Jun 2026 19:54:19 +0200 Subject: [PATCH 1/8] refactor utility helpers and add concurrency and timing modules --- src/fundus/utils/concurrency.py | 88 +++++++++++++++++++++++++++++++ src/fundus/utils/events.py | 48 +++++++++++++++++ src/fundus/utils/serialization.py | 41 ++++++++++++++ src/fundus/utils/timeout.py | 81 ++++++++++------------------ src/fundus/utils/timing.py | 20 +++++++ 5 files changed, 225 insertions(+), 53 deletions(-) create mode 100644 src/fundus/utils/concurrency.py create mode 100644 src/fundus/utils/timing.py diff --git a/src/fundus/utils/concurrency.py b/src/fundus/utils/concurrency.py new file mode 100644 index 000000000..30d46ab95 --- /dev/null +++ b/src/fundus/utils/concurrency.py @@ -0,0 +1,88 @@ +import contextlib +import multiprocessing +from functools import lru_cache +from multiprocessing.managers import BaseManager +from threading import current_thread +from typing import Callable, Generic, Iterator, Optional, Tuple, TypeVar, cast + +import dill +from tqdm import tqdm +from typing_extensions import ParamSpec + +_T = TypeVar("_T") +_P = ParamSpec("_P") + + +def get_execution_context() -> Tuple[str, Optional[int]]: + """Return the name and identifier of the current execution context. + + If running inside a non-main process, returns that process's name and PID; otherwise + returns the current thread's name and thread id. + + Returns: + Tuple[str, Optional[int]]: The context's name and its integer identifier. + """ + if multiprocessing.current_process().name != "MainProcess": + process = multiprocessing.current_process() + return process.name, process.ident + else: + thread = current_thread() + return thread.name, thread.ident + + +class TQDMManager(BaseManager): + """multiprocessing manager exposing a shared tqdm proxy so worker processes drive one progress bar.""" + + def __init__(self, *args, **kwargs): + """Initialize the manager and register tqdm so it can be created behind a proxy.""" + super().__init__(*args, **kwargs) + self.register("_tqdm", tqdm) + + def tqdm(self, *args, **kwargs) -> tqdm: + """Create and return a manager-hosted (proxied) tqdm instance from the given tqdm args.""" + return getattr(self, "_tqdm")(*args, **kwargs) + + +@contextlib.contextmanager +def get_proxy_tqdm(*args, **kwargs) -> Iterator[tqdm]: + """Yield a manager-backed tqdm proxy that can be shared across processes. + + Init args are forwarded verbatim and are the same as for any other tqdm instance. The + backing manager is started on entry and shut down on exit. + + Args: + *args: Positional tqdm arguments. + **kwargs: Keyword tqdm arguments. + + Yields: + tqdm: A self-managed, proxied tqdm instance. + """ + manager = TQDMManager() + try: + manager.start() + yield manager.tqdm(*args, **kwargs) + finally: + manager.shutdown() + + +class dill_wrapper(Generic[_P, _T]): + """Callable wrapper that dill-serializes its target so it survives multiprocessing pickling.""" + + def __init__(self, target: Callable[_P, _T]): + """Wraps function in dill serialization. + + This is in order to use unpickable functions within multiprocessing. + + Args: + target: The function to wrap. + """ + self._serialized_target: bytes = dill.dumps(target) + + @lru_cache + def _deserialize(self) -> Callable[_P, _T]: + """Deserialize and cache the wrapped target on first use (once per process).""" + return cast(Callable[_P, _T], dill.loads(self._serialized_target)) + + def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T: + """Deserialize the target (cached) and invoke it with the given arguments.""" + return self._deserialize()(*args, **kwargs) diff --git a/src/fundus/utils/events.py b/src/fundus/utils/events.py index a9b40eb27..27d86df8f 100644 --- a/src/fundus/utils/events.py +++ b/src/fundus/utils/events.py @@ -7,12 +7,60 @@ from fundus.logging import create_logger +# TODO (planned redesign): replace __EVENTS__ with explicit CancellationToken objects. +# +# Current state. __EVENTS__ is a global registry that maps a string alias (publisher +# name, "main-thread") to a dict of named threading.Event objects, plus a bidict +# linking aliases to thread ids so callers running inside a thread context can resolve +# `key=None` to "their own" events. It mashes three concerns into one mechanism: +# 1. cooperative cancellation (per-publisher stop signal) +# 2. shutdown propagation (system-wide stop via set_for_all(future=True)) +# 3. post-mortem queryability (main thread asking "did publisher X already stop?" +# after its worker exited, hence aliases-persist-after-thread-exit) +# +# Pain points: implicit thread-id resolution, string-keyed events (only "stop" exists +# in practice), the `future=True` hack for shutdown, leaky test setup (every test that +# touches WebSource/CCNewsSource needs __EVENTS__.context("test") aliasing), and an +# unclear seam for multiprocessing (threading.Event does not cross process boundaries). +# +# Planned shape: +# +# class CancellationToken: +# def __init__(self) -> None: +# self._event = threading.Event() +# self._children: list[CancellationToken] = [] +# def cancel(self) -> None: ... +# def is_cancelled(self) -> bool: ... +# def wait(self, timeout: float) -> bool: ... +# def child(self) -> "CancellationToken": ... # cancelled when parent is +# +# Mapping current usage onto tokens: +# - Source classes (WebSource, CCNewsSource): receive a CancellationToken via +# constructor instead of reading from __EVENTS__. +# - Crawler: holds a `dict[Publisher, CancellationToken]`. On per-publisher limit +# reached, calls `tokens[publisher].cancel()`. Replaces __EVENTS__.set_event( +# "stop", publisher_name) and __EVENTS__.is_event_set(...) at the same site. +# - Shutdown: a root token; each publisher's token is `root.child()`. Cancelling +# root cancels all children — replaces set_for_all(future=True) / clear_for_all. +# - queueing.enqueue_results: takes the shutdown token, replaces the +# __EVENTS__.is_event_set("stop", __MAIN_THREAD_ALIAS__) probe in _delivered. +# - Session.get_with_interrupt: takes a CancellationToken, polls +# token.is_cancelled() instead of __EVENTS__. +# +# What disappears: aliases, the thread-id bidict, main_context_lock, string event +# keys, default_events, future=True / clear_for_all, the test-context fixture, and +# every `from fundus.utils.events import __EVENTS__` in non-orchestrator code. + _T = TypeVar("_T") logger = create_logger(__name__) __DEFAULT_EVENTS__: List[str] = ["stop"] +# Alias under which the main thread registers its context in __EVENTS__; crawlers set/probe +# the "stop" event against it to drive cooperative shutdown. +__MAIN_THREAD_ALIAS__ = "main-thread" + _sentinel = object() diff --git a/src/fundus/utils/serialization.py b/src/fundus/utils/serialization.py index 0b15da0a4..9d3286976 100644 --- a/src/fundus/utils/serialization.py +++ b/src/fundus/utils/serialization.py @@ -1,10 +1,13 @@ import inspect import json from dataclasses import asdict, fields, is_dataclass +from datetime import datetime from typing import ( Any, Callable, Dict, + Optional, + Protocol, Sequence, Type, TypeVar, @@ -12,6 +15,7 @@ get_args, get_origin, get_type_hints, + runtime_checkable, ) from typing_extensions import TypeAlias @@ -21,6 +25,43 @@ JSONVal: TypeAlias = Union[None, bool, str, float, int, Sequence["JSONVal"], Dict[str, "JSONVal"]] +@runtime_checkable +class Serializable(Protocol): + """Anything that knows how to convert itself into a JSON-compatible structure. + + Implementing types opt into the export path used by Article.to_json. + """ + + def serialize(self) -> JSONVal: ... + + +def serialize_value(value: Any, field_name: Optional[str] = None) -> JSONVal: + """Recursively convert a value to JSON-compatible form. + + Args: + value: The value to serialize. + field_name: Optional originating field name, used only for error messages. + + Returns: + A JSON-serializable structure. + + Raises: + TypeError: If the value's type has no defined serialization. + """ + if value is None or isinstance(value, (str, int, float, bool)): + return value + if isinstance(value, datetime): + return value.isoformat() + if isinstance(value, (list, tuple)): + return [serialize_value(item, field_name) for item in value] + if isinstance(value, dict): + return {str(k): serialize_value(v, field_name) for k, v in value.items()} + if isinstance(value, Serializable): + return value.serialize() + location = f"field {field_name!r}" if field_name else "value" + raise TypeError(f"Cannot serialize {location} of type {type(value).__name__}") + + def is_jsonable(x): try: json.dumps(x) diff --git a/src/fundus/utils/timeout.py b/src/fundus/utils/timeout.py index 92b6e7d48..3e7d3121e 100644 --- a/src/fundus/utils/timeout.py +++ b/src/fundus/utils/timeout.py @@ -4,91 +4,67 @@ import time from typing import Callable, Iterator, Optional -from typing_extensions import ParamSpec -P = ParamSpec("P") - - -class Stopwatch: - def __init__(self): - self._start = time.time() +def _interrupt_handler() -> None: + thread.interrupt_main() - @property - def time(self) -> float: - return max(0.0, time.time() - self._start) - def reset(self): - self._start = time.time() +class ResettableTimer: + class _Stopwatch: + def __init__(self) -> None: + self._start = time.time() + @property + def elapsed(self) -> float: + return max(0.0, time.time() - self._start) -class ResettableTimer(threading.Thread): - def __init__( - self, - seconds: float, - func: Callable[P, None], - interval: float = 0.1, - args: P.args = tuple(), - kwargs: P.kwargs = None, - ) -> None: - """Resettable timer executing after