From 68ca4f7feb8178c795be69b2157c66470e8eb763 Mon Sep 17 00:00:00 2001 From: Max Dallabetta Date: Mon, 16 Mar 2026 15:50:42 +0100 Subject: [PATCH 1/3] replace `mypy` with `pyright` --- .github/workflows/tests.yml | 6 +++--- pyproject.toml | 18 +++++------------- 2 files changed, 8 insertions(+), 16 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 235729297..055540cbb 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -31,7 +31,7 @@ jobs: - name: Run pytest run: python -m pytest -vv - mypy: + pyright: # Containers must run in Linux based operating systems runs-on: ubuntu-latest steps: @@ -53,5 +53,5 @@ jobs: run: | pip install -e .[dev] - - name: Run mypy - run: python -m mypy . + - name: Run pyright + run: pyright diff --git a/pyproject.toml b/pyproject.toml index 496f47061..b03e573d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,7 +49,7 @@ dependencies = [ [project.optional-dependencies] dev = [ "pytest~=7.2.2", - "mypy==1.9.0", + "pyright==1.1.408", "ruff==0.15.6", # type stubs "types-lxml", @@ -59,18 +59,10 @@ dev = [ "types-dateparser>=1.2.0, <2" ] -[tool.mypy] -check_untyped_defs = true -disallow_any_generics = true -ignore_missing_imports = true -no_implicit_optional = true -show_error_codes = true -strict_equality = true -warn_redundant_casts = true -warn_return_any = true -warn_unreachable = true -warn_unused_configs = true -no_implicit_reexport = true +[tool.pyright] +pythonVersion = "3.8" +typeCheckingMode = "standard" +reportMissingImports = false [tool.ruff] line-length = 120 From ce0767f10877b1033a23d0e75fc09b4457649935 Mon Sep 17 00:00:00 2001 From: Max Dallabetta Date: Mon, 16 Mar 2026 18:08:32 +0100 Subject: [PATCH 2/3] fix pyright errors --- pyproject.toml | 4 +- src/fundus/logging.py | 2 +- src/fundus/parser/base_parser.py | 37 +++++++++++++++++- src/fundus/parser/data.py | 16 ++++---- src/fundus/parser/utility.py | 1 + src/fundus/publishers/base_objects.py | 4 +- src/fundus/scraping/crawler.py | 55 +++++++++++---------------- src/fundus/scraping/html.py | 10 +++-- src/fundus/scraping/scraper.py | 4 ++ src/fundus/scraping/url.py | 7 ++-- src/fundus/utils/timeout.py | 8 ++-- tests/utility.py | 41 +++++++++++--------- 12 files changed, 115 insertions(+), 74 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b03e573d5..b3abaf0b5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,7 +56,9 @@ dev = [ "types-python-dateutil>=2.8, <3", "types-requests>=2.28, <3", "types-colorama>=0.4, <1", - "types-dateparser>=1.2.0, <2" + "types-dateparser>=1.2.0, <2", + "types-xmltodict>=0.13.0, <1", + "types-tqdm>=4.66, <5" ] [tool.pyright] diff --git a/src/fundus/logging.py b/src/fundus/logging.py index dd6be7eab..7ed01883e 100644 --- a/src/fundus/logging.py +++ b/src/fundus/logging.py @@ -67,7 +67,7 @@ def add_handler(handler: logging.Handler): logger.addHandler(handler) -def get_current_config() -> JSONVal: +def get_current_config() -> Dict[str, JSONVal]: """Get the current logging configuration as JSON. Returns: diff --git a/src/fundus/parser/base_parser.py b/src/fundus/parser/base_parser.py index 30f3ab2cf..7d2e8cbd4 100644 --- a/src/fundus/parser/base_parser.py +++ b/src/fundus/parser/base_parser.py @@ -21,6 +21,7 @@ Union, get_args, get_origin, + overload, ) import lxml.html @@ -131,6 +132,30 @@ def wrapper(func): return wrapper(cls) +@overload +def attribute( + cls: Callable[..., Any], + /, + *, + priority: Optional[int] = ..., + validate: bool = ..., + deprecated: Optional[date] = ..., + default_factory: Optional[Callable[[], Any]] = ..., +) -> Any: ... + + +@overload +def attribute( + cls: None = ..., + /, + *, + priority: Optional[int] = ..., + validate: bool = ..., + deprecated: Optional[date] = ..., + default_factory: Optional[Callable[[], Any]] = ..., +) -> Callable[[Any], Any]: ... + + def attribute( cls=None, /, @@ -139,7 +164,7 @@ def attribute( validate: bool = True, deprecated: Optional[date] = None, default_factory: Optional[Callable[[], Any]] = None, -): +) -> Any: return _register( cls, factory=Attribute, @@ -150,7 +175,15 @@ def attribute( ) -def function(cls=None, /, *, priority: Optional[int] = None): +@overload +def function(cls: Callable[..., Any], /, *, priority: Optional[int] = ...) -> Any: ... + + +@overload +def function(cls: None = ..., /, *, priority: Optional[int] = ...) -> Callable[[Any], Any]: ... + + +def function(cls=None, /, *, priority: Optional[int] = None) -> Any: return _register(cls, factory=Function, priority=priority) diff --git a/src/fundus/parser/data.py b/src/fundus/parser/data.py index ec43bff80..516e6d5bf 100644 --- a/src/fundus/parser/data.py +++ b/src/fundus/parser/data.py @@ -70,7 +70,7 @@ def __init__(self, lds: Iterable[Dict[str, Any]] = ()): self.add_ld(nested) else: self.add_ld(ld) - self.__xml: Optional[lxml.etree.Element] = None + self.__xml: Optional[lxml.etree._Element] = None def __getstate__(self): state = self.__dict__.copy() @@ -128,7 +128,7 @@ def get_value_by_key_path(self, key_path: List[str], default: Any = None) -> Opt tmp = nxt return tmp - def __as_xml__(self) -> lxml.etree.Element: + def __as_xml__(self) -> lxml.etree._Element: pattern = re.compile("|".join(map(re.escape, self.__xml_transformation_table__.keys()))) def to_unicode_characters(text: str) -> str: @@ -189,7 +189,7 @@ def xpath_search(self, query: Union[XPath, str], scalar: bool = False): pattern = re.compile("|".join(map(re.escape, self.__xml_transformation_table__.values()))) - def node2string(n: lxml.etree.Element) -> str: + def node2string(n: lxml.etree._Element) -> str: node_value = lxml.etree.tostring(n, encoding="unicode").strip() if match := self.__value_regex__.match(node_value): return match.group("value") @@ -299,9 +299,9 @@ def __init__(self, texts: Iterable[str]): def __getitem__(self, i: int) -> str: ... @overload - def __getitem__(self, s: slice) -> "TextSequence": ... + def __getitem__(self, i: slice) -> "TextSequence": ... - def __getitem__(self, i): + def __getitem__(self, i: Union[int, slice]) -> Union[str, "TextSequence"]: return self._data[i] if isinstance(i, int) else type(self)(self._data[i]) def __len__(self) -> int: @@ -334,14 +334,14 @@ def text(self, join_on: str = "\n\n") -> str: return join_on.join(self.as_text_sequence()) def df_traversal(self) -> Iterable[TextSequence]: - def recursion(o: object): + def recursion(o: object) -> Iterator[TextSequence]: if isinstance(o, TextSequence): yield o elif isinstance(o, Collection): for el in o: - yield from el + yield from recursion(el) else: - yield o + return for value in self: yield from recursion(value) diff --git a/src/fundus/parser/utility.py b/src/fundus/parser/utility.py index 5bb861b83..ae6f5a09b 100644 --- a/src/fundus/parser/utility.py +++ b/src/fundus/parser/utility.py @@ -28,6 +28,7 @@ ) from urllib.parse import urljoin +import lxml.etree import lxml.html import more_itertools import validators diff --git a/src/fundus/publishers/base_objects.py b/src/fundus/publishers/base_objects.py index 7741a6619..82921ce15 100644 --- a/src/fundus/publishers/base_objects.py +++ b/src/fundus/publishers/base_objects.py @@ -1,6 +1,6 @@ from collections import defaultdict from textwrap import indent -from typing import Dict, Iterable, Iterator, List, Optional, Set, Type, Union +from typing import Dict, Iterable, Iterator, List, Optional, Sequence, Set, Type, Union from urllib.robotparser import RobotFileParser from warnings import warn @@ -127,7 +127,7 @@ def __init__( name: str, domain: str, parser: Type[ParserProxy], - sources: List[URLSource], + sources: Sequence[URLSource], query_parameter: Optional[Dict[str, str]] = None, url_filter: Optional[URLFilter] = None, request_header: Optional[Dict[str, str]] = _default_header, diff --git a/src/fundus/scraping/crawler.py b/src/fundus/scraping/crawler.py index 1b022619c..0121bcf89 100644 --- a/src/fundus/scraping/crawler.py +++ b/src/fundus/scraping/crawler.py @@ -85,7 +85,7 @@ def tqdm(self, *args, **kwargs) -> tqdm: @contextlib.contextmanager -def get_proxy_tqdm(*args, **kwargs) -> tqdm: +def get_proxy_tqdm(*args, **kwargs) -> Iterator[tqdm]: """ This functions returns a proxy to a tqdm instance. Init args are the same as for any other tqdm instance. :param args: tqdm args @@ -120,7 +120,7 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T: return self._deserialize()(*args, **kwargs) -def get_execution_context(): +def get_execution_context() -> Tuple[str, int]: """ Determines whether the current execution context is in a thread or process. Returns: @@ -129,10 +129,10 @@ def get_execution_context(): """ if multiprocessing.current_process().name != "MainProcess": process = multiprocessing.current_process() - return process.name, process.ident + return process.name, process.ident or 0 else: thread = current_thread() - return thread.name, thread.ident + return thread.name, thread.ident or 0 def publisher_context_wrapper(func: Callable[[Publisher], None]) -> Callable[[Publisher], None]: @@ -414,9 +414,10 @@ def build_extraction_filter() -> Optional[ExtractionFilter]: callback: Optional[Callable[[], None]] if isinstance(self, CCNewsCrawler) and self.processes > 0: - def callback() -> None: + def _stop_callback() -> None: __EVENTS__.set_event("stop", "main-thread") + callback = _stop_callback else: callback = None @@ -579,6 +580,7 @@ def _single_crawl( def _threaded_crawl( self, publishers: Tuple[Publisher, ...], article_task: Callable[[Publisher], Iterator[Article]] ) -> Iterator[Article]: + @contextlib.contextmanager def _manage_pool(*args, **kwargs) -> Iterator[ThreadPool]: managed_pool = ThreadPool(*args, **kwargs) @@ -731,32 +733,24 @@ def _parallel_crawl( # As one could think, because we're downloading a bunch of files, this task is IO-bound, but it is actually # process-bound. The reason is that we stream the data and process it on the fly rather than downloading all # files and processing them afterward. Therefore, we utilize multiprocessing here instead of multithreading. - try: - with Manager() as manager, Pool( - processes=min(self.processes, len(warc_paths)), - initializer=initializer, - ) as pool: - result_queue: Queue[Union[Article, Exception]] = manager.Queue(maxsize=1000) + with Manager() as manager, Pool( + processes=min(self.processes, len(warc_paths)), + initializer=initializer, + ) as pool: + result_queue: Queue[Union[Article, Exception]] = manager.Queue(maxsize=1000) - # Because multiprocessing.Pool does not support iterators as targets, - # we wrap the article_task to write the articles to a queue instead of returning them directly. - wrapped_article_task: Callable[[str], None] = queue_wrapper(result_queue, article_task) + # Because multiprocessing.Pool does not support iterators as targets, + # we wrap the article_task to write the articles to a queue instead of returning them directly. + wrapped_article_task: Callable[[str], None] = queue_wrapper(result_queue, article_task) - # To avoid 503 errors we spread tasks to not start all at once - spread_article_task = random_sleep(wrapped_article_task, (0, 3)) + # To avoid 503 errors we spread tasks to not start all at once + spread_article_task = random_sleep(wrapped_article_task, (0, 3)) - # To avoid restricting the article_task to use only pickleable objects, we serialize it using dill. - serialized_article_task = dill_wrapper(spread_article_task) + # To avoid restricting the article_task to use only pickleable objects, we serialize it using dill. + serialized_article_task = dill_wrapper(spread_article_task) - # Finally, we build an iterator around the queue, exhausting the queue until the pool is finished. - yield from pool_queue_iter(pool.map_async(serialized_article_task, warc_paths), result_queue) - finally: - logger.debug(f"Shutting down {type(self).__name__!r} ...") - logger.debug("Joining manager ...") - manager.join() - logger.debug("Joining pool ...") - pool.join() - logger.debug("Shutdown done") + # Finally, we build an iterator around the queue, exhausting the queue until the pool is finished. + yield from pool_queue_iter(pool.map_async(serialized_article_task, warc_paths), result_queue) def _get_warc_paths(self) -> List[str]: # Date regex examples: https://regex101.com/r/yDX3G6/1 @@ -790,11 +784,8 @@ def load_paths(url: str) -> List[str]: # use two threads per process, default two threads per core max_number_of_threads = self.processes * 2 - try: - with ThreadPool(processes=min(len(urls), max_number_of_threads)) as pool: - nested_warc_paths = pool.map(random_sleep(load_paths, (0, 3)), urls) - finally: - pool.join() + with ThreadPoolExecutor(max_workers=min(len(urls), max_number_of_threads)) as pool: + nested_warc_paths = pool.map(random_sleep(load_paths, (0, 3)), urls) warc_paths: Iterator[str] = more_itertools.flatten(nested_warc_paths) diff --git a/src/fundus/scraping/html.py b/src/fundus/scraping/html.py index 2afa23918..94a248efd 100644 --- a/src/fundus/scraping/html.py +++ b/src/fundus/scraping/html.py @@ -1,7 +1,7 @@ import time from dataclasses import dataclass from datetime import datetime -from typing import Callable, Dict, Iterable, Iterator, List, Optional, Protocol +from typing import BinaryIO, Callable, Dict, Iterable, Iterator, List, Optional, Protocol, cast from urllib.parse import urlparse import chardet @@ -171,9 +171,11 @@ def __init__( f"Overwriting existing delay." ) - def delay() -> float: + def _crawl_delay() -> float: return robots_delay + delay = _crawl_delay + self.clock = _Clock(delay=delay, sleep=self._sleep) @property @@ -333,7 +335,9 @@ def extract_content(record: WarcRecord) -> Optional[str]: response = session.get(self.warc_path, stream=True, headers=self.headers) response.raise_for_status() - for warc_record in ArchiveIterator(response.raw, record_types=WarcRecordType.response, verify_digests=True): + for warc_record in ArchiveIterator( + cast(BinaryIO, response.raw), record_types=WarcRecordType.response, verify_digests=True + ): if not warc_record.record_date: continue diff --git a/src/fundus/scraping/scraper.py b/src/fundus/scraping/scraper.py index fc95e3f97..727b3a22c 100644 --- a/src/fundus/scraping/scraper.py +++ b/src/fundus/scraping/scraper.py @@ -1,3 +1,4 @@ +import random from typing import Dict, Iterator, List, Literal, Optional, Type import more_itertools @@ -34,6 +35,9 @@ def scrape( for html in source.fetch(url_filter=url_filter): parser = self.parser_mapping[html.source_info.publisher] + if random.uniform(0, 1) > 0.9: + raise Exception("TEST") + try: extraction = parser(html.crawl_date).parse(html.content, error_handling) diff --git a/src/fundus/scraping/url.py b/src/fundus/scraping/url.py index 822354e48..047a45fc4 100644 --- a/src/fundus/scraping/url.py +++ b/src/fundus/scraping/url.py @@ -19,6 +19,7 @@ from urllib.parse import unquote import feedparser +import lxml.etree import lxml.html import validators from lxml.etree import XMLParser, XPath @@ -159,9 +160,9 @@ def __iter__(self) -> Iterator[str]: logger.warning(f"Warning! Couldn't parse rss feed {self.url!r} because of {exception}") return else: - urls = filter(bool, (entry.get("link") for entry in rss_feed["entries"])) - for url in urls: - yield clean_url(url) + for entry in rss_feed["entries"]: + if isinstance(url := entry.get("link"), str): + yield clean_url(url) @dataclass diff --git a/src/fundus/utils/timeout.py b/src/fundus/utils/timeout.py index 92b6e7d48..2194f19c8 100644 --- a/src/fundus/utils/timeout.py +++ b/src/fundus/utils/timeout.py @@ -27,8 +27,8 @@ def __init__( seconds: float, func: Callable[P, None], interval: float = 0.1, - args: P.args = tuple(), - kwargs: P.kwargs = None, + *args: P.args, + **kwargs: P.kwargs, ) -> None: """Resettable timer executing after