`.
+ Note that this can only be found within the actual sitemap and not the index map.
#### Filter noisy sitemaps
@@ -255,12 +259,12 @@ sitemap_filter=inverse(regex_filter("sitemap-content-"))
````
will exclude all sitemap URLs not containing the substring `sitemap-content-`.
-### Finishing the Publisher Specification
+### Finishing the publisher specification
-1. If your publisher requires to use custom request headers to work properly you can alter it by using the `request_header` parameter of `PublisherSpec`.
+1. If your publisher requires custom request headers to work properly you can set them using the `request_header` parameter of `Publisher`.
The default is: `{"user-agent": "Fundus/2.0 (contact: github.com/flairnlp/fundus)"}`.
2. If you want to block URLs for the entire publisher use the `url_filter` parameter of `Publisher`.
-3. In some cases it can be necessary to append query parameters to the end of the URL, e.g. to load the article as one page. This can be achieved by adding the `query_parameter` attribute of `PublisherSpec` and assigning it a dictionary object containing the key - value pairs: e.g. `{"page": "all"}`. These key - value pairs will be appended to all crawled URLs.
+3. In some cases it can be necessary to append query parameters to the end of the URL, e.g. to load the article as one page. This can be achieved by setting the `query_parameter` parameter of `Publisher` and assigning it a dictionary containing the key-value pairs, e.g. `{"page": "all"}`. These key-value pairs will be appended to all crawled URLs.
4. If the publisher is only reachable through a browser-like TLS/HTTP fingerprint (i.e. plain `requests`/`curl` get blocked by an anti-bot layer such as Cloudflare or Akamai), you can declare a browser profile via the `impersonate` parameter, e.g. `impersonate="chrome"`. See [curl_cffi's supported targets](https://curl-cffi.readthedocs.io/en/latest/impersonate/targets.html) for the full list.
Because browser impersonation is an opt-in feature on the user side (see [Browser impersonation](5_advanced_topics.md#browser-impersonation)), the profile only takes effect when the user constructs the `Crawler` with `impersonate=True`; with the default `impersonate=False` your publisher will be requested without impersonation and will likely fail. Only set this when the publisher genuinely cannot be crawled without it.
@@ -284,7 +288,7 @@ class US(PublisherGroup):
)
```
-## 4. Validating the Current Implementation Progress
+## 3. Validating the current implementation progress
Now validate your implementation progress by crawling some example articles from your publisher.
The following script fits The Intercept and is adaptable by changing the publisher variable accordingly.
@@ -319,14 +323,14 @@ Fundus-Article:
Since we didn't add any specific implementation to the parser yet, most entries are empty.
-## 5. Implementing the Parser
+## 4. Implementing the parser
Now bring your parser to life and define the attributes you want to extract.
One important caveat to consider is the type of content on a particular page.
Some news outlets feature live tickers, displaying podcasts, or hub sites that link to other pages but are not articles themselves.
-At this stage, there's no need to concern yourself with handling non-article pages.
-our parser should concentrate on extracting desired attributes from most pages that can be classified as articles.
+At this stage, there's no need to concern yourself with handling non-article pages.
+Your parser should concentrate on extracting the desired attributes from most pages that can be classified as articles.
Pages lacking the desired attributes will be filtered out by the library during a later phase of the processing pipeline.
You can add attributes by decorating the methods of your parser with the `@attribute` decorator.
@@ -337,8 +341,8 @@ There you can locate an attribute named `title`, which precisely corresponds to
It is essential to adhere to the specified return types, as they are enforced through our unit tests.
While you're welcome to experiment locally, contributions to the repository won't be accepted if your pull request deviates from the guidelines.
-**_NOTE:_**
-Should you wish to add an attribute not covered in the guidelines, set the `validate` parameter of the attribute decorator to `False`, like this:
+> [!NOTE]
+> Should you wish to add an attribute not covered in the guidelines, set the `validate` parameter of the attribute decorator to `False`, like this:
``` python
@attribute(validate=False)
@@ -373,10 +377,10 @@ This is a title
This is a title
```
-Fundus will automatically add your decorated attributes as instance attributes to the `article` object during parsing.
-Additionally, attributes defined in the attribute guidelines are explicitly defined as `dataclasses.fields`.
+Fundus will automatically expose your decorated attributes on the `article` object during parsing.
+Attributes defined in the attribute guidelines are additionally available as typed properties of `Article`, each with a default value, so they can be accessed safely even on articles whose parser didn't extract them.
-### Extracting Attributes from Precomputed
+### Extracting attributes from Precomputed
One way to extract useful information from articles rather than placeholders is to utilize the `ld` and `meta` attributes of the `Article`.
These attributes are automatically extracted when they are present in the currently parsed HTML.
@@ -412,10 +416,11 @@ For instance, to extract the title for an article in The Intercept, we can acces
return self.precomputed.ld.get_value_by_key_path(["NewsArticle", "headline"])
```
-**_NOTE:_** In case a `class` is present in the HTML `meta` tag, it will be appended as a namespace to avoid collisions.
-I.e. the content of the following meta tag ` [!NOTE]
+> In case a `class` is present in the HTML `meta` tag, it will be appended as a namespace to avoid collisions.
+> I.e. the content of the following meta tag ` [!NOTE]
+> The nodes are returned in depth-first pre-order.
Similarly, you can select based on the `class` attribute of a tag.
For instance, selecting all `` tags with class `A` looks like this.
@@ -537,8 +543,9 @@ Output:
This is a paragraph with a weird attribute
````
-**_NOTE:_** It's also possible to select solely by the existence of an attribute by omitting the equality.
-Sticking to the above example you can simply use `CSSSelector("p[additional-attribute]")` instead.
+> [!NOTE]
+> It's also possible to select solely by the existence of an attribute by omitting the equality.
+> Sticking to the above example you can simply use `CSSSelector("p[additional-attribute]")` instead.
#### XPath
@@ -546,11 +553,13 @@ Sticking to the above example you can simply use `CSSSelector("p[additional-attr
Given the complexity of XPath compared to CSS-Select, we refrain from providing an extensive tutorial here.
Instead, we recommend referring to [this](https://devhints.io/xpath) documentation for a translation table and a concise overview of XPath functionalities beyond CSS-Select.
-**_NOTE:_** Although it's possible to select nodes using the built-in methods of `lxml.html.HtmlElement`, it's recommended to use the dedicated selectors [`CSSSelect`](https://lxml.de/cssselect.html) and [`XPath`](https://lxml.de/xpathxslt.html), as demonstrated in the above examples.
+> [!NOTE]
+> Although it's possible to select nodes using the built-in methods of `lxml.html.HtmlElement`, it's recommended to use the dedicated selectors [`CSSSelect`](https://lxml.de/cssselect.html) and [`XPath`](https://lxml.de/xpathxslt.html), as demonstrated in the above examples.
-**_NOTE:_** The `fundus/parser/utility.py` module includes several utility functions that can assist you in implementing parser attributes.
-Make sure to examine other parsers and consult the [attribute guidelines](attribute_guidelines.md) for specifics on attribute implementation.
-We strongly encourage utilizing these utility functions, especially when parsing the `ArticleBody`.
+> [!NOTE]
+> The `fundus/parser/utility.py` module includes several utility functions that can assist you in implementing parser attributes.
+> Make sure to examine other parsers and consult the [attribute guidelines](attribute_guidelines.md) for specifics on attribute implementation.
+> We strongly encourage utilizing these utility functions, especially when parsing the `ArticleBody`.
### Extracting the ArticleBody
@@ -620,7 +629,7 @@ def free_access(self) -> bool:
Usually you can identify a premium article by an indicator within the URL or by using XPath or CSSSelector and selecting
the element asking to purchase a subscription to view the article.
-### Finishing the Parser
+### Finishing the parser
Bringing all the above together, the The Intercept Parser now looks like this.
@@ -682,7 +691,7 @@ class TheInterceptParser(ParserProxy):
```
-Now, execute the example script from step 4 to validate your implementation.
+Now, execute the example script from step 3 to validate your implementation.
If the attributes are implemented correctly, they appear in the printout accordingly.
```console
@@ -700,7 +709,7 @@ Fundus-Article:
- From: The Intercept (2024-06-06 17:16)
```
-## 6. Generate unit tests and update tables
+## 5. Generate unit tests and update tables
### Add unit tests
@@ -719,7 +728,7 @@ Then in most cases it should be enough to simply run
python -m scripts.generate_parser_test_files -p
````
-with being the class name of the `Publisher` your working on.
+with being the class name of the `Publisher` you're working on.
In our case, we would run:
@@ -729,8 +738,9 @@ python -m scripts.generate_parser_test_files -p TheIntercept
to generate a unit test for our parser.
-Note: If you need to modify your parser slightly after already adding a unit test, there's no need to create a new test case and load a new HTML file.
-You can simply run the script with the `-oj` flag.
+> [!NOTE]
+> If you need to modify your parser slightly after already adding a unit test, there's no need to create a new test case and load a new HTML file.
+> You can simply run the script with the `-oj` flag.
In our scenario, the command would be:
@@ -755,14 +765,14 @@ Now to test your newly added publisher you should run pytest with the following
pytest
````
-## 7. Opening a Pull Request
+## 6. Opening a pull request
1. Make sure you tested your parser using `pytest`.
2. Run `ruff format src`, `ruff check --fix src`, and `mypy src` with no errors.
3. Push and open a new PR
-4. Congratulation and thank you very much.
+4. Congratulations and thank you very much.
-## 8. Maintaining publishers
+## 7. Maintaining publishers
Website layouts change over time, so we may occasionally need to update a publisher's parser.
If you run into an issue, feel free to correct it and submit a pull request (PR).
diff --git a/docs/how_to_contribute.md b/docs/how_to_contribute.md
index 7ebf80019..4ef37650d 100644
--- a/docs/how_to_contribute.md
+++ b/docs/how_to_contribute.md
@@ -29,7 +29,7 @@ If you haven't done this yet or are uncertain, follow these steps:
3. Navigate to the root of the repository.
4. Run `pip install -e .[dev]`
-## Known issues:
+## Known issues
1. `zsh: no matches found: .[dev]`
When using zsh, you have to wrap the optional dependencies in quotes like this: `pip install -e .'[dev]'`.
@@ -40,4 +40,5 @@ See [this issue](https://github.com/mu-editor/mu/issues/852#issue-451861103) for
1. [How to add a publisher](how_to_add_a_publisher.md)
-**_NOTE:_** If you run into any problems while contributing don't hesitate to ask questions in the [**issue**](https://github.com/flairNLP/fundus/issues) tab.
\ No newline at end of file
+> [!NOTE]
+> If you run into any problems while contributing don't hesitate to ask questions in the [**issue**](https://github.com/flairNLP/fundus/issues) tab.
\ No newline at end of file
diff --git a/pyproject.toml b/pyproject.toml
index 6d79a93da..c5423e87c 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -88,4 +88,7 @@ quote-style = "double"
filterwarnings = [
"error"
]
+markers = [
+ "integration: slow integration tests requiring mocked I/O",
+]
diff --git a/scripts/generate_parser_test_files.py b/scripts/generate_parser_test_files.py
index daf49e187..72b387f53 100644
--- a/scripts/generate_parser_test_files.py
+++ b/scripts/generate_parser_test_files.py
@@ -11,9 +11,9 @@
from fundus.publishers.base_objects import Publisher
from fundus.scraping.article import Article
from fundus.scraping.filter import RequiresAll
-from fundus.scraping.html import WebSource
-from fundus.scraping.scraper import BaseScraper
-from tests.test_parser import attributes_required_to_cover
+from fundus.scraping.pipeline import Pipeline
+from fundus.scraping.pipeline.source.web import WebSource
+from tests.publishers.test_parser_coverage import attributes_required_to_cover
from tests.utility import HTMLTestFile, get_test_case_json, load_html_test_file_mapping
logger = create_logger(__name__)
@@ -22,11 +22,11 @@
def get_test_article(publisher: Publisher, url: Optional[str] = None) -> Optional[Article]:
if url is not None:
source = WebSource([url], publisher=publisher)
- scraper = BaseScraper(source, parser_mapping={publisher.name: publisher.parser})
- return next(scraper.scrape(error_handling="suppress", extraction_filter=RequiresAll()), None)
+ pipeline = Pipeline(source, publishers=[publisher])
+ return next(pipeline.run(raise_on_error=False, extraction_filter=RequiresAll()), None)
crawler = Crawler(publisher)
- return next(crawler.crawl(max_articles=1, error_handling="suppress", only_complete=RequiresAll()), None)
+ return next(crawler.crawl(max_articles=1, only_complete=RequiresAll()), None)
def parse_arguments() -> Namespace:
diff --git a/scripts/publisher_coverage.py b/scripts/publisher_coverage.py
index 4d9acdb59..173ec1495 100644
--- a/scripts/publisher_coverage.py
+++ b/scripts/publisher_coverage.py
@@ -8,7 +8,7 @@
import sys
import traceback
from argparse import ArgumentParser
-from typing import Any, Callable, List, Optional, Union
+from typing import List, Optional
from fundus import Crawler, PublisherCollection
from fundus.publishers.base_objects import Publisher, PublisherGroup
@@ -55,58 +55,37 @@ def main() -> None:
crawler: Crawler = Crawler(publisher, delay=0.4, ignore_robots=True)
complete_article: Optional[Article] = next(
- crawler.crawl(
- max_articles=1, timeout=timeout_in_seconds, only_complete=True, error_handling="suppress"
- ),
+ crawler.crawl(max_articles=1, timeout=timeout_in_seconds, only_complete=True),
None,
)
if complete_article is None:
- incomplete_article: Optional[Article] = next(
- crawler.crawl(
- max_articles=1, timeout=timeout_in_seconds, only_complete=False, error_handling="catch"
- ),
- None,
- )
+ try:
+ incomplete_article: Optional[Article] = next(
+ crawler.crawl(
+ max_articles=1, timeout=timeout_in_seconds, only_complete=False, raise_on_error=True
+ ),
+ None,
+ )
+ except Exception as exception:
+ print(f"❌ FAILED: {publisher_name!r} - Encountered exception during crawling")
+ traceback.print_exception(type(exception), exception, exception.__traceback__, file=sys.stdout)
+ failed += 1
+ continue
if incomplete_article is None:
print(f"❌ FAILED: {publisher_name!r} - No articles received")
- elif incomplete_article.exception is not None:
- print(
- f"❌ FAILED: {publisher_name!r} - Encountered exception during crawling "
- f"(URL: {incomplete_article.html.requested_url})"
- )
- traceback.print_exception(
- etype=type(incomplete_article.exception),
- value=incomplete_article.exception,
- tb=incomplete_article.exception.__traceback__,
- file=sys.stdout,
- )
-
else:
-
- def guard(field, fnc: Callable[[Any], bool] = lambda x: x is not None) -> Union[bool, str]:
- """Makes a boolean evaluation of based on and guards exceptions
-
- Args:
- field: The article field to evaluate
- fnc: The evaluation function
-
- Returns:
- Either True, False or Exception if isinstance(field, Exception) = True
- """
- return fnc(field) if not isinstance(field, Exception) else repr(field)
-
print(
f"❌ FAILED: {publisher_name!r} - No complete articles received "
f"(URL of an incomplete article: {incomplete_article.html.requested_url}) with attributes:\n"
- f"title: {guard(incomplete_article.title)}\n"
- f"plaintext: {guard(incomplete_article.body, bool)}\n"
- f"publishing_date: {guard(incomplete_article.publishing_date)}\n"
- f"authors: {guard(incomplete_article.authors, bool)}\n"
- f"topics: {guard(incomplete_article.topics, bool)}\n"
- f"images: {guard(incomplete_article.images, bool)}\n"
+ f"title: {incomplete_article.title is not None}\n"
+ f"plaintext: {bool(incomplete_article.body)}\n"
+ f"publishing_date: {incomplete_article.publishing_date is not None}\n"
+ f"authors: {bool(incomplete_article.authors)}\n"
+ f"topics: {bool(incomplete_article.topics)}\n"
+ f"images: {bool(incomplete_article.images)}\n"
)
failed += 1
continue
diff --git a/src/fundus/parser/base_parser.py b/src/fundus/parser/base_parser.py
index 30f3ab2cf..a54b75ab6 100644
--- a/src/fundus/parser/base_parser.py
+++ b/src/fundus/parser/base_parser.py
@@ -13,7 +13,6 @@
Dict,
Iterator,
List,
- Literal,
Optional,
Tuple,
Type,
@@ -278,7 +277,7 @@ def _base_setup(self, html: str) -> None:
doc = lxml.html.document_fromstring(html)
self.precomputed = Precomputed(html, doc, get_meta_content(doc), get_ld_content(doc))
- def parse(self, html: str, error_handling: Literal["suppress", "catch", "raise"] = "raise") -> Dict[str, Any]:
+ def parse(self, html: str, raise_on_error: bool = True) -> Dict[str, Any]:
# wipe existing precomputed
self._base_setup(html)
@@ -294,18 +293,13 @@ def parse(self, html: str, error_handling: Literal["suppress", "catch", "raise"]
try:
parsed_data[attribute_name] = func()
except Exception as err:
- if error_handling == "suppress":
- parsed_data[attribute_name] = func.__default__
- logger.info(
- f"Couldn't parse attribute {attribute_name!r} for "
- f"{self.precomputed.meta.get('og:url')!r}: {err!r}"
- )
- elif error_handling == "catch":
- parsed_data[attribute_name] = err
- elif error_handling == "raise":
+ if raise_on_error:
raise err
- else:
- raise ValueError(f"Invalid value {error_handling!r} for parameter ")
+ parsed_data[attribute_name] = func.__default__
+ logger.info(
+ f"Couldn't parse attribute {attribute_name!r} for "
+ f"{self.precomputed.meta.get('og:url')!r}: {err!r}"
+ )
else:
raise TypeError(f"Invalid type for {func}. Only subclasses of 'RegisteredFunction' are allowed")
diff --git a/src/fundus/parser/data.py b/src/fundus/parser/data.py
index 819869df3..6f82e6300 100644
--- a/src/fundus/parser/data.py
+++ b/src/fundus/parser/data.py
@@ -20,7 +20,6 @@
Union,
overload,
)
-from urllib.parse import urljoin, urlparse
import lxml.etree
import lxml.html
@@ -30,7 +29,7 @@
from lxml.etree import XPath, fromstring, tostring
from typing_extensions import Self, TypeAlias, deprecated
-from fundus.scraping.url import is_valid_url
+from fundus.scraping.url import is_valid_url, strip_query_and_fragment
from fundus.utils.serialization import (
DataclassSerializationMixin,
JSONVal,
@@ -457,12 +456,6 @@ def from_ratio(
return None
-def remove_query_parameters_from_url(url: str) -> str:
- if any(parameter_indicator in url for parameter_indicator in ("?", "#")):
- return urljoin(url, urlparse(url).path)
- return url
-
-
@total_ordering
@dataclass
class ImageVersion(DataclassSerializationMixin):
@@ -475,7 +468,7 @@ class ImageVersion(DataclassSerializationMixin):
def __post_init__(self):
if not self.type:
- url_without_query = remove_query_parameters_from_url(self.url)
+ url_without_query = strip_query_and_fragment(self.url)
self.type = self._parse_type(url_without_query)
def _parse_type(self, url: str) -> Optional[str]:
diff --git a/src/fundus/parser/utility.py b/src/fundus/parser/utility.py
index fbcffd9ab..6565ad788 100644
--- a/src/fundus/parser/utility.py
+++ b/src/fundus/parser/utility.py
@@ -46,7 +46,6 @@
LinkedDataMapping,
TextSequence,
)
-from fundus.scraping.url import is_valid_url
from fundus.utils.regex import _get_match_dict
from fundus.utils.serialization import JSONVal
@@ -607,15 +606,6 @@ def parse_title_from_root(root: lxml.html.HtmlElement) -> Optional[str]:
return strip_nodes_to_text(title_node)
-def preprocess_url(url: str, domain: str) -> str:
- url = re.sub(r"\\/", "/", url)
- # Some publishers use relative URLs
- if not is_valid_url(url):
- publisher_domain = "https://" + domain
- url = urljoin(publisher_domain, url)
- return url
-
-
def image_author_parsing(authors: Union[str, List[str]]) -> List[str]:
credit_keywords = [
"Источник",
diff --git a/src/fundus/publishers/base_objects.py b/src/fundus/publishers/base_objects.py
index 83fd20fc2..10ce798c2 100644
--- a/src/fundus/publishers/base_objects.py
+++ b/src/fundus/publishers/base_objects.py
@@ -67,7 +67,10 @@ def read(self) -> None:
" Defaulting to disallow all."
)
self.disallow_all = True
- elif 400 <= err.response.status_code < 500:
+ else:
+ # Any other HTTP error — a 4xx without a robots.txt, or a 5xx server error —
+ # leaves us with no retrievable rules, so default to allow-all rather than an
+ # unset parser state. (Inside this except, raise_for_status guarantees >= 400.)
self.allow_all = True
else:
self.parse(response.text.splitlines())
@@ -231,6 +234,9 @@ def source_types(self) -> Set[Type[URLSource]]:
def __str__(self) -> str:
return f"{self.name}"
+ def serialize(self) -> str:
+ return self.name
+
def __hash__(self) -> int:
return hash(self.name)
diff --git a/src/fundus/scraping/article.py b/src/fundus/scraping/article.py
index a64502bc0..95db316a9 100644
--- a/src/fundus/scraping/article.py
+++ b/src/fundus/scraping/article.py
@@ -1,6 +1,6 @@
from datetime import datetime
from textwrap import TextWrapper, dedent
-from typing import Any, Dict, List, Mapping, Optional
+from typing import Any, ClassVar, Dict, List, Optional, Tuple, TypedDict, cast
import langdetect
import lxml.html
@@ -9,89 +9,138 @@
from fundus.logging import create_logger
from fundus.parser import ArticleBody, Image
from fundus.scraping.html import HTML
-from fundus.utils.serialization import JSONVal, is_jsonable
+from fundus.utils.serialization import JSONVal, serialize_value
logger = create_logger(__name__)
-class AttributeView:
- def __init__(self, key: str, extraction: Mapping[str, Any]):
- self.ref = extraction
- self.key = key
+class Extraction(TypedDict, total=False):
+ """Schema for the narrowly-typed subset of extraction keys.
- def __get__(self, instance: object, owner: type):
- return self.ref[self.key]
+ Parsers may pass additional keys; those live in __extraction__ alongside these
+ and are exposed via __getattr__ with type Any. Only the keys declared here are
+ type-checked at the property accessors.
+ """
- def __set__(self, obj, value):
- # For now, this is read-only
- raise AttributeError("attribute is read only")
+ # TODO: once PEP 728 (https://peps.python.org/pep-0728/) is accepted and supported
+ # by our mypy version, inherit from typing_extensions.TypedDict and add the
+ # `extra_items=Any` parameter. That lets us drop the `_narrow` cast workaround and
+ # annotate __init__ kwargs as `**extraction: Unpack[Extraction]` while still
+ # accepting parser-specific extras.
+
+ title: Optional[str]
+ body: Optional[ArticleBody]
+ authors: List[str]
+ publishing_date: Optional[datetime]
+ topics: List[str]
+ free_access: bool
+ images: List[Image]
class Article:
- __extraction__: Mapping[str, Any] = {}
+ """A parsed news article: the source HTML plus the parser's extracted attributes.
+
+ Declared attributes (title, body, authors, publishing_date, topics, free_access,
+ images) are exposed as type-checked properties; any extra keys a parser returns are
+ accessible as read-only attributes via __getattr__. Derived properties (plaintext,
+ lang, publisher) are computed on access. Use to_json() to export selected fields.
+ """
+
+ DEFAULT_EXPORT_FIELDS: ClassVar[Tuple[str, ...]] = (
+ "title",
+ "authors",
+ "publishing_date",
+ "topics",
+ "free_access",
+ "body",
+ "images",
+ "plaintext",
+ "lang",
+ "publisher",
+ )
+
+ def __init__(self, *, html: HTML, **extraction: Any) -> None:
+ """Build an article from its source HTML and the parser's extracted attributes.
+
+ Args:
+ html (HTML): The source document the article was parsed from.
+ **extraction (Any): Attributes produced by the parser (e.g. title, body,
+ authors). Declared keys are surfaced through typed properties; any
+ additional keys are exposed as read-only attributes via __getattr__.
- def __init__(self, *, html: HTML, exception: Optional[Exception] = None, **extraction: Any) -> None:
+ """
self.html = html
- self.exception = exception
- self.__extraction__ = extraction
+ self.__extraction__: Dict[str, Any] = extraction
- # create descriptors for attributes that aren't pre-defined as properties.
- for attribute in extraction.keys():
- if not hasattr(self, attribute):
- setattr(self, attribute, AttributeView(attribute, self.__extraction__))
+ @property
+ def _narrow(self) -> Extraction:
+ """View of __extraction__ restricted to the narrowly-typed schema.
+
+ Storage stays Dict[str, Any] because the dict legitimately holds parser-extras
+ outside the schema. This cast applies the schema only where it's true: at the
+ narrow accessors below.
+ """
+ return cast(Extraction, self.__extraction__)
@property
def title(self) -> Optional[str]:
- return self.__extraction__.get("title")
+ return self._narrow.get("title")
@property
def body(self) -> Optional[ArticleBody]:
- return self.__extraction__.get("body")
+ return self._narrow.get("body")
@property
def authors(self) -> List[str]:
- return self.__extraction__.get("authors", [])
+ return self._narrow.get("authors", [])
@property
def publishing_date(self) -> Optional[datetime]:
- return self.__extraction__.get("publishing_date")
+ return self._narrow.get("publishing_date")
@property
def topics(self) -> List[str]:
- return self.__extraction__.get("topics", [])
+ return self._narrow.get("topics", [])
@property
def free_access(self) -> bool:
- return self.__extraction__.get("free_access", False)
+ return self._narrow.get("free_access", False)
@property
def images(self) -> List[Image]:
- return self.__extraction__.get("images", [])
+ return self._narrow.get("images", [])
@property
def publisher(self) -> str:
return self.html.source_info.publisher
- def __getattribute__(self, item: str):
- if (attribute := object.__getattribute__(self, item)) and hasattr(attribute, "__get__"):
- return attribute.__get__(self, type(self))
- return attribute
-
- def __setattr__(self, key: str, value: object):
- if hasattr(self, key):
- # we can't use getattr here, because it would invoke __get__, so unfortunately no default value
- attribute = object.__getattribute__(self, key)
- if hasattr(attribute, "__set__"):
- attribute.__set__(key, value)
- return
- object.__setattr__(self, key, value)
+ def __getattr__(self, item: str) -> Any:
+ """Expose parser-extra extraction keys as read-only attributes; raise AttributeError otherwise.
- def __getattr__(self, item: str):
- raise AttributeError(f"{type(self).__name__!r} object has no attribute {str(item)!r}")
+ Only invoked when normal attribute lookup fails.
+ """
+ # Read from __dict__ directly to avoid infinite recursion when __extraction__ itself isn't
+ # set yet (e.g., during unpickling before __setstate__ restores instance state).
+ extraction = self.__dict__.get("__extraction__")
+ if extraction is None or item not in extraction:
+ raise AttributeError(f"{type(self).__name__!r} object has no attribute {item!r}")
+ return extraction[item]
+
+ def __setattr__(self, key: str, value: object) -> None:
+ """Block writes to extraction-backed attributes; allow all others."""
+ # During __init__, html/__extraction__ are assigned before __extraction__ exists;
+ # check via __dict__ to avoid triggering __getattr__.
+ extraction = self.__dict__.get("__extraction__")
+ if extraction is not None and key in extraction:
+ raise AttributeError(f"attribute {key!r} is read only")
+ object.__setattr__(self, key, value)
@property
def plaintext(self) -> Optional[str]:
- return str(self.body) or None if not isinstance(self.body, Exception) else None
+ body = self.body
+ if body is None or isinstance(body, Exception):
+ return None
+ return str(body) or None
@property
def lang(self) -> Optional[str]:
@@ -104,53 +153,39 @@ def lang(self) -> Optional[str]:
logger.debug(f"Unable to detect language for article {self.html.responded_url!r}")
# use @lang attribute of tag as fallback
- if not language or language == langdetect.detector_factory.Detector.UNKNOWN_LANG:
+ if (not language or language == langdetect.detector_factory.Detector.UNKNOWN_LANG) and self.html.content:
language = lxml.html.fromstring(self.html.content).get("lang")
if language and "-" in language:
language = language.split("-")[0]
return language
- def to_json(self, *attributes: str) -> Dict[str, JSONVal]:
- """Converts article object into a JSON serializable dictionary.
-
- One can specify which attributes should be included by passing attribute names as parameters.
- Default: title, plaintext, authors, publishing_date, topics, free_access + unvalidated attributes
+ def to_json(self, *fields: str) -> Dict[str, JSONVal]:
+ """Export selected article fields as a JSON-compatible dict.
Args:
- *attributes: The attributes to serialize. Default: see docstring.
+ *fields: Field names to export. Each must resolve to an attribute of this
+ article (a built-in property or an extraction key). If empty,
+ DEFAULT_EXPORT_FIELDS is used. Pass "html" to include the source
+ document with its provenance metadata.
Returns:
- A json serializable dictionary
- """
-
- # default value for attributes
- if not attributes:
- attributes = tuple(set(self.__extraction__.keys()) - {"meta", "ld"})
+ A JSON-serializable dict. Key order matches the order of .
- def serialize(v: Any) -> JSONVal:
- if hasattr(v, "serialize"):
- return v.serialize() # type: ignore[no-any-return]
- elif isinstance(v, datetime):
- return str(v)
- elif not is_jsonable(v):
- raise TypeError(f"Attribute {attribute!r} of type {type(v)!r} is not JSON serializable")
- return v # type: ignore[no-any-return]
-
- serialization: Dict[str, JSONVal] = {}
- for attribute in attributes:
- if not hasattr(self, attribute):
- continue
- value = getattr(self, attribute)
-
- if isinstance(value, list):
- serialization[attribute] = [serialize(item) for item in value]
- else:
- serialization[attribute] = serialize(value)
-
- return serialization
+ Raises:
+ KeyError: If a requested field is not present on this article.
+ TypeError: If a value's type has no defined serialization.
+ """
+ selected = fields or self.DEFAULT_EXPORT_FIELDS
+ output: Dict[str, JSONVal] = {}
+ for field in selected:
+ if not hasattr(self, field):
+ raise KeyError(field)
+ output[field] = serialize_value(getattr(self, field), field)
+ return output
def __str__(self):
+ """Render a compact, human-readable summary (title, truncated text, URL, publisher, date)."""
# the subsequent indent here is a bit wacky, but textwrapper.dedent won't work with tabs, so we have to use
# whitespaces instead.
title_wrapper = TextWrapper(width=80, max_lines=1, initial_indent="")
diff --git a/src/fundus/scraping/crawler.py b/src/fundus/scraping/crawler.py
deleted file mode 100644
index ebb9a8236..000000000
--- a/src/fundus/scraping/crawler.py
+++ /dev/null
@@ -1,872 +0,0 @@
-from __future__ import annotations
-
-import contextlib
-import gzip
-import json
-import logging.config
-import multiprocessing
-import os
-import random
-import re
-import time
-import traceback
-from abc import ABC, abstractmethod
-from collections import defaultdict
-from concurrent.futures import ThreadPoolExecutor, as_completed
-from datetime import datetime
-from functools import lru_cache, partial, wraps
-from multiprocessing import Manager
-from multiprocessing.context import TimeoutError
-from multiprocessing.managers import BaseManager
-from multiprocessing.pool import MapResult, Pool, ThreadPool
-from pathlib import Path
-from queue import Empty, Full, Queue
-from threading import current_thread
-from typing import (
- Any,
- Callable,
- Dict,
- Generic,
- Iterator,
- List,
- Literal,
- Optional,
- Pattern,
- Set,
- Tuple,
- Type,
- TypeVar,
- Union,
- cast,
-)
-
-import dill
-import fastwarc.stream_io
-import more_itertools
-import requests
-import urllib3.exceptions
-from dateutil.rrule import MONTHLY, rrule
-from more_itertools import roundrobin
-from tqdm import tqdm
-from typing_extensions import ParamSpec, TypeAlias
-
-from fundus.logging import create_logger, get_current_config
-from fundus.parser.data import remove_query_parameters_from_url
-from fundus.publishers.base_objects import FilteredPublisher, Publisher, PublisherGroup
-from fundus.scraping.article import Article
-from fundus.scraping.delay import Delay
-from fundus.scraping.filter import ExtractionFilter, Requires, RequiresAll, URLFilter
-from fundus.scraping.html import CCNewsSource
-from fundus.scraping.scraper import CCNewsScraper, WebScraper
-from fundus.scraping.session import CrashThread, session_handler
-from fundus.scraping.url import URLSource
-from fundus.utils.events import __EVENTS__
-from fundus.utils.timeout import Timeout
-
-logger = create_logger(__name__)
-
-__MAIN_THREAD_ALIAS__ = "main-thread"
-
-_T = TypeVar("_T")
-_P = ParamSpec("_P")
-
-PublisherType: TypeAlias = Union[Publisher, PublisherGroup]
-
-
-class RemoteException(Exception):
- pass
-
-
-class TQDMManager(BaseManager):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.register("_tqdm", tqdm)
-
- def tqdm(self, *args, **kwargs) -> tqdm:
- return getattr(self, "_tqdm")(*args, **kwargs)
-
-
-@contextlib.contextmanager
-def get_proxy_tqdm(*args, **kwargs) -> tqdm:
- """
- This functions returns a proxy to a tqdm instance. Init args are the same as for any other tqdm instance.
- :param args: tqdm args
- :param kwargs: tqdm kwargs
- :return: a self-managed, proxied tqdm instance
- """
- manager = TQDMManager()
- try:
- manager.start()
- yield manager.tqdm(*args, **kwargs)
- finally:
- manager.shutdown()
-
-
-# noinspection PyPep8Naming
-class dill_wrapper(Generic[_P, _T]):
- def __init__(self, target: Callable[_P, _T]):
- """Wraps function in dill serialization.
-
- This is in order to use unpickable functions within multiprocessing.
-
- Args:
- target: The function to wrap.
- """
- self._serialized_target: bytes = dill.dumps(target)
-
- @lru_cache
- def _deserialize(self) -> Callable[_P, _T]:
- return cast(Callable[_P, _T], dill.loads(self._serialized_target))
-
- def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T:
- return self._deserialize()(*args, **kwargs)
-
-
-def get_execution_context():
- """
- Determines whether the current execution context is in a thread or process.
- Returns:
- context (str): "thread" or "process"
- ident (int): Thread ID or Process ID
- """
- if multiprocessing.current_process().name != "MainProcess":
- process = multiprocessing.current_process()
- return process.name, process.ident
- else:
- thread = current_thread()
- return thread.name, thread.ident
-
-
-def publisher_context_wrapper(func: Callable[[Publisher], None]) -> Callable[[Publisher], None]:
- """Wraps a callable to register an ``__EVENTS__`` alias context for the publisher argument.
-
- The alias is entered as the very first thing the thread does and stays alive for the
- entire call — including any exception handling in the caller — so that
- ``__EVENTS__.get_alias`` always resolves while the thread is running.
-
- Args:
- func: A callable whose first positional argument is a :class:`Publisher`.
-
- Returns:
- The wrapped callable.
- """
-
- @wraps(func)
- def wrapper(publisher: Publisher) -> None:
- with __EVENTS__.context(publisher.name):
- func(publisher)
-
- return wrapper
-
-
-def queue_wrapper(
- queue: Queue[Union[_T, Exception]],
- target: Callable[_P, Iterator[_T]],
- silenced_exceptions: Tuple[Type[BaseException], ...] = (),
-) -> Callable[_P, None]:
- """Wraps the target callable to add its results to the queue instead of returning them directly.
-
- Args:
- queue: The buffer queue.
- target: A target callable.
- silenced_exceptions: Exception types that should be silenced
-
- Returns:
- (Callable[_P, None]) The wrapped target.
- """
-
- @wraps(target)
- def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
- def _guarded_put(obj: _T) -> bool:
- """Safely putting results on the queue avoiding deadlocks"""
- while True:
- try:
- # We use nowait here to avoid a deadlock on the put when the pool is already shutting down
- # and therefore the queue never will never be free.
- queue.put_nowait(obj)
- except Full:
- if __EVENTS__.is_event_set("stop", __MAIN_THREAD_ALIAS__):
- return False
- time.sleep(0.05)
- else:
- return True
-
- def _process_target():
- """Iterate over and put results into """
- for obj in target(*args, **kwargs):
- if not _guarded_put(obj):
- return
-
- try:
- _process_target()
- except silenced_exceptions:
- pass
- except Exception as err:
- tb_str = "".join(traceback.TracebackException.from_exception(err).format())
- context, ident = get_execution_context()
- alias = __EVENTS__.get_alias(ident, "")
- queue.put(
- RemoteException(
- f"There was a(n) {type(err).__name__!r} occurring in {context} "
- f"with ident {ident} ({alias})\n{tb_str}"
- )
- )
-
- logger.debug(f"Encountered remote exception in thread {ident} ({alias}): {err!r}")
-
- return wrapper
-
-
-def pool_queue_iter(handle: MapResult[Any], queue: Queue[Union[_T, Exception]]) -> Iterator[_T]:
- """Utility function to iterate exhaustively over a pool queue.
-
- The underlying iterator of this function repeatedly exhausts the given queue.
- Then, if the queue is empty only if all the pool's jobs have finished, the iterator reruns.
- Otherwise, it waits for the queue to be populated with the next result from the pool.
-
- Args:
- handle: A handle of the MappedResult of the underling multiprocessing pool.
- queue: The pool queue.
-
- Returns:
- Iterator[_T]: The iterator over the queue as it is populated.
- """
-
- def _exception_guard() -> _T:
- if isinstance(nxt := queue.get_nowait(), Exception):
- raise Exception("There was an exception occurring in a remote thread/process") from nxt
- return nxt
-
- while True:
- try:
- yield _exception_guard()
- except Empty:
- try:
- handle.get(timeout=0.01)
- except TimeoutError:
- # listen for stop-event set for main-thread
- if __EVENTS__.is_event_set("stop", __MAIN_THREAD_ALIAS__):
- __EVENTS__.clear_event("stop", __MAIN_THREAD_ALIAS__)
- break
- continue
-
- # empty queue and look for exception
- while not queue.empty():
- yield _exception_guard()
-
- return
-
-
-def random_sleep(func: Callable[_P, _T], between: Tuple[float, float]) -> Callable[_P, _T]:
- @wraps(func)
- def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T:
- time.sleep(random.uniform(*between))
- return func(*args, **kwargs)
-
- return wrapper
-
-
-class CrawlerBase(ABC):
- def __init__(self, *publishers: PublisherType):
- self.publishers: List[Union[Publisher, FilteredPublisher]] = list(set(more_itertools.collapse(publishers)))
- if not self.publishers:
- raise ValueError("param of must include at least one publisher.")
-
- @abstractmethod
- def _build_article_iterator(
- self,
- publishers: Tuple[Publisher, ...],
- error_handling: Literal["suppress", "catch", "raise"],
- extraction_filter: Optional[ExtractionFilter],
- url_filter: Optional[URLFilter],
- language_filter: Optional[List[str]],
- skip_publishers_disallowing_training: bool = False,
- ) -> Iterator[Article]:
- raise NotImplementedError
-
- def crawl(
- self,
- max_articles: Optional[int] = None,
- max_articles_per_publisher: Optional[int] = None,
- timeout: Optional[int] = None,
- error_handling: Literal["suppress", "catch", "raise"] = "suppress",
- only_complete: Union[bool, ExtractionFilter] = Requires("title", "body", "publishing_date"),
- url_filter: Optional[URLFilter] = None,
- language_filter: Optional[List[str]] = None,
- only_unique: bool = True,
- save_to_file: Union[None, str, Path] = None,
- skip_publishers_disallowing_training: bool = False,
- ) -> Iterator[Article]:
- """Yields articles from initialized scrapers
-
- Args:
- max_articles (Optional[int]): Number of articles to crawl. If there are fewer articles
- than max_articles the Iterator will stop before max_articles. If None, all retrievable
- articles are returned. Defaults to None.
- max_articles_per_publisher: Specify the number of articles to crawl per publisher.
- Disables . Defaults to None.
- timeout (Optional[int]): timeout (Optional[int]): Specifies the duration in seconds the crawler
- will wait without receiving any articles before stopping. If set <= 0, or if not provided,
- the crawler will run until all sources are exhausted. Defaults to None.
- error_handling (Literal["suppress", "catch", "raise"]): Define how to handle errors
- encountered during extraction. If set to "suppress", all errors will be skipped, either
- with None values for respective attributes in the extraction or by skipping entire articles.
- If set to "catch", errors will be caught as attribute values or, if an entire article fails,
- through Article.exception. If set to "raise" all errors encountered during extraction will
- be raised. Defaults to "suppress".
- only_complete (Union[bool, ExtractionFilter]): Set a callable satisfying the ExtractionFilter
- protocol as an extraction filter or use a boolean. If False, all articles will be yielded,
- if True, only those with all attributes extracted. Defaults to ExtractionFilter letting
- through all articles with at least title, body, and publishing_date set.
- url_filter (Optional[URLFilter]): A callable object satisfying the URLFilter protocol to skip
- URLs before download. This filter applies on both requested and responded URL. Defaults to None.
- language_filter (Optional[List[str]]): A set of language codes to filter the articles by. If set,
- articles of different languages will be skipped and not counted towards the article count. Defaults
- to None.
- only_unique (bool): If set to True, articles yielded will be unique on the responded URL.
- Always returns the first encountered article. Defaults to True.
- save_to_file (Union[None, str, Path]): If set, the crawled articles will be collected saved to the
- specified file as a JSON list.
- skip_publishers_disallowing_training (bool): If set to True, publishers that disallow training
- are skipped. Note that this is an indicator only and users with the intention of using Fundus to gather
- training data should always check the publisher's terms of use beforehand.
-
- Returns:
- Iterator[Article]: An iterator yielding objects of type Article.
- """
-
- if max_articles == 0:
- return
-
- max_articles = max_articles or -1
- timeout = timeout or -1
-
- if max_articles_per_publisher:
- if timeout < 120:
- print(
- "It is recommended to set a minimum of 120 seconds when using max_articles_per_publisher."
- )
- max_articles = -1
-
- def build_extraction_filter() -> Optional[ExtractionFilter]:
- if isinstance(only_complete, bool):
- return None if only_complete is False else RequiresAll()
- else:
- return only_complete
-
- response_cache: Set[str] = set()
-
- extraction_filter = build_extraction_filter()
- fitting_publishers: List[Union[Publisher, FilteredPublisher]] = []
-
- if isinstance(extraction_filter, Requires):
- for publisher in self.publishers:
- supported_attributes = set(
- more_itertools.flatten(
- collection.names for collection in publisher.parser.attribute_mapping.values()
- )
- )
- if missing_attributes := extraction_filter.required_attributes - supported_attributes:
- logger.warning(
- f"The required attribute(s) `{', '.join(missing_attributes)}` "
- f"is(are) not supported by {publisher.name}. Skipping publisher"
- )
- elif language_filter and not publisher.supports(languages=language_filter):
- logger.warning(
- f"None of the required language(s) `{', '.join(language_filter)}` "
- f"is(are) supported by {publisher.name}. Skipping publisher"
- )
- else:
- fitting_publishers.append(publisher)
-
- if not fitting_publishers:
- logger.error(
- f"Could not find any fitting publishers for required attributes "
- f"`{', '.join(extraction_filter.required_attributes)}`"
- )
- return
- else:
- fitting_publishers = self.publishers
-
- # check if there are filtered publishers and if so, adopt their language restrictions
- publisher_language_filter = set()
- for publisher in fitting_publishers:
- if isinstance(publisher, FilteredPublisher):
- publisher_language_filter.update(publisher.language_filter)
-
- if language_filter and publisher_language_filter:
- language_filter = list(set(language_filter).union(publisher_language_filter))
- logger.info(
- f"Publisher language filter: {publisher_language_filter} will be added to the given language filter: "
- f"{language_filter}. "
- )
- elif publisher_language_filter:
- language_filter = list(publisher_language_filter)
- logger.info(f"Publisher language filter: {publisher_language_filter} will be used as the language filter. ")
-
- article_count: Dict[str, int] = defaultdict(int)
- crawled_articles: Dict[str, List[Article]] = defaultdict(list)
-
- # Unfortunately we relly on this little workaround here to terminate the 'Pool' used within
- # the 'CCNewsCrawler'. The 'Timeout' contextmanager utilizes '_thread.interrupt_main',
- # throwing a KeyboardInterrupt in the main thread after seconds. My guess (MaxDall)
- # is, that within 'queue_wrapper's 'handle.get(timeout=0.1)', the main thread cannot be
- # interrupted via a KeyboardInterrupt. The workaround is to have a modul global event
- # that can be set within the 'Timeout' thread using a callback.
- # With Python 3.10 we can pass a signum to '_thread.interrupt_main', maybe that's the way to go.
- callback: Optional[Callable[[], None]]
- if isinstance(self, CCNewsCrawler) and self.processes > 0:
-
- def callback() -> None:
- __EVENTS__.set_event("stop", __MAIN_THREAD_ALIAS__)
-
- else:
- callback = None
-
- try:
- with __EVENTS__.main_context(__MAIN_THREAD_ALIAS__), Timeout(
- seconds=timeout, silent=True, callback=callback, disable=timeout <= 0
- ) as timer:
- for article in self._build_article_iterator(
- tuple(fitting_publishers),
- error_handling,
- build_extraction_filter(),
- url_filter,
- language_filter,
- skip_publishers_disallowing_training,
- ):
- if max_articles_per_publisher and article_count[article.publisher] == max_articles_per_publisher:
- if (isinstance(self, Crawler) and self.threading) and not __EVENTS__.is_event_set(
- "stop", article.publisher
- ):
- __EVENTS__.set_event("stop", article.publisher)
- if sum(article_count.values()) == len(self.publishers) * max_articles_per_publisher:
- break
- continue
- timer.reset()
- url_without_query_parameters = remove_query_parameters_from_url(article.html.responded_url)
- if not only_unique or url_without_query_parameters not in response_cache:
- response_cache.add(url_without_query_parameters)
- article_count[article.publisher] += 1
- if save_to_file:
- crawled_articles[article.publisher].append(article)
- yield article
- if sum(article_count.values()) == max_articles:
- break
- finally:
- session_handler.close_sessions()
- if save_to_file is not None:
- if isinstance(save_to_file, str):
- save_to_file = Path(save_to_file)
- save_to_file.parent.mkdir(parents=True, exist_ok=True)
- with open(save_to_file, "w", encoding="utf-8") as file:
- logger.info(f"Writing crawled articles to {save_to_file!r}")
- file.write(
- json.dumps(crawled_articles, default=lambda o: o.to_json(), ensure_ascii=False, indent=4)
- )
-
-
-class Crawler(CrawlerBase):
- def __init__(
- self,
- *publishers: PublisherType,
- restrict_sources_to: Optional[List[Type[URLSource]]] = None,
- ignore_deprecated: bool = False,
- delay: Optional[Union[float, Delay]] = 1.0,
- threading: bool = True,
- ignore_robots: bool = False,
- ignore_crawl_delay: bool = False,
- impersonate: bool = False,
- ):
- """Fundus base class for crawling articles from the web.
-
- Examples:
- >>> from fundus import PublisherCollection, Crawler
- >>> crawler = Crawler(*PublisherCollection)
- >>> # Crawler(PublisherCollection.us) to crawl only american news
- >>> for article in crawler.crawl():
- >>> print(article)
-
- Args:
- *publishers (Union[Publisher, PublisherGroup]): The publishers to crawl.
- restrict_sources_to (Optional[List[Type[URLSource]]]): Lets you restrict sources defined in the publisher
- specs. If set, only articles from given source types will be yielded.
- ignore_deprecated (bool): If set to True, Publishers marked as deprecated will be skipped.
- Defaults to False.
- delay (Optional[Union[float, Delay]]): Set a delay time in seconds to be used between article
- downloads. You can set a delay directly using float or any callable satisfying the Delay
- protocol. If set to None, no delay will be used between batches. See Delay for more
- information. Defaults to None.
- threading (bool): If True, the crawler will use a dedicated thread per publisher, if set to False,
- the crawler will use a single thread for all publishers and load articles successively. This will
- greatly influence performance, and it is highly recommended to use a threaded crawler.
- Defaults to True.
- ignore_robots (bool): Determines whether to bypass the consideration of the robots.txt file when
- filtering URLs from publishers. If set to True, the URLs will not be filtered based on the
- robots.txt file. Defaults to False.
- ignore_crawl_delay (bool): Determines whether to ignore a crawl delay given by a publisher.
- If set to False, this will overwrite . If ignore_robots is set to True, the crawl delay
- will also be ignored.
- impersonate (bool): If True, publishers that declare an `impersonate` browser profile will use
- curl_cffi's TLS/HTTP fingerprint impersonation. If False (default), the profile is ignored
- and requests go out with Fundus' regular fingerprint — publishers gated by anti-bot checks
- will likely return 4xx/5xx. Defaults to False.
- """
-
- def filter_publishers(publisher: Publisher) -> bool:
- if publisher.deprecated and ignore_deprecated:
- logger.warning(f"Skipping deprecated publisher: {publisher.name}")
- return False
- return True
-
- fitting_publishers = list(filter(filter_publishers, more_itertools.collapse(publishers)))
- if not fitting_publishers:
- raise ValueError(
- "All given publishers are deprecated. Either set to `False` or "
- "include at least one publisher that isn't deprecated."
- )
-
- super().__init__(*fitting_publishers)
-
- self.restrict_sources_to = restrict_sources_to
- self.delay = delay
- self.threading = threading
- self.ignore_robots = ignore_robots
- self.ignore_crawl_delay = ignore_crawl_delay
- self.impersonate = impersonate
-
- def _fetch_articles(
- self,
- publisher: Publisher,
- error_handling: Literal["suppress", "catch", "raise"],
- extraction_filter: Optional[ExtractionFilter] = None,
- url_filter: Optional[URLFilter] = None,
- language_filter: Optional[List[str]] = None,
- skip_publishers_disallowing_training: bool = False,
- ) -> Iterator[Article]:
- if skip_publishers_disallowing_training and publisher.disallows_training:
- logger.info(f"Skipping publisher {publisher.name} because it disallows training.")
- return
- elif publisher.robots.disallow_all():
- logger.info(f"Skipping publisher {publisher.name} because it disallows all URLs.")
- return
-
- def build_delay() -> Optional[Delay]:
- if isinstance(self.delay, float):
- delay = self.delay
-
- def constant_delay() -> float:
- return delay
-
- return constant_delay
-
- elif isinstance(self.delay, Delay):
- return self.delay
-
- else:
- raise TypeError("param of ")
-
- scraper = WebScraper(
- publisher,
- self.restrict_sources_to,
- build_delay(),
- ignore_robots=self.ignore_robots,
- ignore_crawl_delay=self.ignore_crawl_delay,
- impersonate=self.impersonate,
- )
- if not scraper.sources and self.restrict_sources_to:
- logger.warning(
- f"No sources of type {[source_type.__name__ for source_type in self.restrict_sources_to]} "
- f"found for publisher {publisher.name}. Skipping publisher."
- )
- return
- yield from scraper.scrape(error_handling, extraction_filter, url_filter, language_filter)
-
- @staticmethod
- def _single_crawl(
- publishers: Tuple[Publisher, ...], article_task: Callable[[Publisher], Iterator[Article]]
- ) -> Iterator[Article]:
- article_iterators = [article_task(publisher) for publisher in publishers]
- yield from roundrobin(*article_iterators)
-
- def _threaded_crawl(
- self, publishers: Tuple[Publisher, ...], article_task: Callable[[Publisher], Iterator[Article]]
- ) -> Iterator[Article]:
- @contextlib.contextmanager
- def _manage_pool(*args, **kwargs) -> Iterator[ThreadPool]:
- managed_pool = ThreadPool(*args, **kwargs)
- try:
- yield managed_pool
- finally:
- logger.debug(f"Shutting down {type(self).__name__!r} ...")
- managed_pool.close()
- __EVENTS__.set_for_all("stop", future=True, active_only=True)
- managed_pool.join()
- __EVENTS__.clear_for_all("stop")
- logger.debug("Shutdown done")
-
- result_queue: Queue[Union[Article, Exception]] = Queue(len(publishers))
- wrapped_article_task = publisher_context_wrapper(
- queue_wrapper(result_queue, article_task, silenced_exceptions=(CrashThread,))
- )
-
- with _manage_pool(processes=len(publishers) or None) as pool:
- yield from pool_queue_iter(pool.map_async(wrapped_article_task, publishers), result_queue)
-
- def _build_article_iterator(
- self,
- publishers: Tuple[Publisher, ...],
- error_handling: Literal["suppress", "catch", "raise"],
- extraction_filter: Optional[ExtractionFilter],
- url_filter: Optional[URLFilter],
- language_filter: Optional[List[str]],
- skip_publishers_disallowing_training: bool = False,
- ) -> Iterator[Article]:
- article_task = partial(
- self._fetch_articles,
- error_handling=error_handling,
- extraction_filter=extraction_filter,
- url_filter=url_filter,
- language_filter=language_filter,
- skip_publishers_disallowing_training=skip_publishers_disallowing_training,
- )
-
- if self.threading:
- yield from self._threaded_crawl(publishers, article_task)
- else:
- yield from self._single_crawl(publishers, article_task)
-
-
-class CCNewsCrawler(CrawlerBase):
- def __init__(
- self,
- *publishers: PublisherType,
- start: datetime = datetime(2016, 8, 1),
- end: datetime = datetime.now(),
- processes: int = -1,
- retries: int = 3,
- disable_tqdm: bool = False,
- server_address: str = "https://data.commoncrawl.org/",
- ):
- """Initializes a crawler for the CC-NEWS dataset.
-
- The crawler crawls the CC-NEWS dataset from to .
-
- Args:
- *publishers: The publishers to crawl.
- start: The date to start crawling from. Refers to the date the WARC record was added to CC-NEWS,
- not when it was published. Defaults to 2016/8/1.
- end: The date to end crawling. Refers to the date the WARC record was added to CC-NEWS, not when
- it was published. Defaults to datetime.now().
- processes: Number of additional process to use for crawling.
- If -1, the number of processes is set to `os.cpu_count()`.
- If `os.cpu_count()` is not available, the number of processes is set to 0.
- If 0, only the main process is used. Defaults to -1.
- retries: The number of times to retry crawling a WARC record when a connection error occurs. Between
- retries, the crawler sleeps for * 30 seconds. Defaults to 3.
- disable_tqdm: Disable the usage of tqdm within the crawler. Defaults to False.
- server_address: The CC-NEWS dataset server address. Defaults to 'https://data.commoncrawl.org/'.
- """
-
- super().__init__(*publishers)
-
- self.start = start
- self.end = end
-
- if processes < 0:
- print(
- f"{type(self).__name__} will automatically use all available cores: {os.cpu_count()}. "
- f"For optimal performance, we recommend manually setting the number of processes "
- f"using the parameter. A good rule of thumb is to allocate `one process per "
- f"200 Mbps of bandwidth`."
- )
- self.processes = os.cpu_count() or 0
- else:
- self.processes = processes
-
- self.retries = retries
- self.disable_tqdm = disable_tqdm
- self.server_address = server_address
-
- def _fetch_articles(
- self,
- warc_path: str,
- publishers: Tuple[Publisher, ...],
- error_handling: Literal["suppress", "catch", "raise"],
- extraction_filter: Optional[ExtractionFilter] = None,
- url_filter: Optional[URLFilter] = None,
- language_filter: Optional[List[str]] = None,
- bar: Optional[tqdm] = None,
- ) -> Iterator[Article]:
- retries: int = 0
- while True:
- source = CCNewsSource(*publishers, warc_path=warc_path)
- scraper = CCNewsScraper(source)
- try:
- yield from scraper.scrape(error_handling, extraction_filter, url_filter, language_filter)
- except (requests.HTTPError, fastwarc.stream_io.StreamError, urllib3.exceptions.HTTPError) as exception:
- if retries >= self.retries:
- logger.error(f"Failed to load WARC file {warc_path!r} after {retries} retries")
- break
- else:
- retries += 1
- sleep_time = (30 * retries) + random.uniform(-2, 2)
- logger.warning(
- f"Could not load WARC file {warc_path!r}. Retry after {sleep_time:.2f} seconds: {exception!r}"
- )
- time.sleep(sleep_time)
- else:
- break
-
- if bar is not None:
- bar.update()
-
- @staticmethod
- def _single_crawl(
- warc_paths: Tuple[str, ...], article_task: Callable[[str], Iterator[Article]]
- ) -> Iterator[Article]:
- for warc_path in warc_paths:
- yield from article_task(warc_path)
-
- def _parallel_crawl(
- self, warc_paths: Tuple[str, ...], article_task: Callable[[str], Iterator[Article]]
- ) -> Iterator[Article]:
- # because logging configurations are overwritten when using 'spawn' as start method,
- # we have to get current logging configurations and initialize them in the new process
- if multiprocessing.get_start_method() == "spawn":
- logging_config = get_current_config()
- initializer = partial(logging.config.dictConfig, config=logging_config)
- else:
- initializer = None
-
- # As one could think, because we're downloading a bunch of files, this task is IO-bound, but it is actually
- # process-bound. The reason is that we stream the data and process it on the fly rather than downloading all
- # files and processing them afterward. Therefore, we utilize multiprocessing here instead of multithreading.
- with Manager() as manager, Pool(
- processes=min(self.processes, len(warc_paths)),
- initializer=initializer,
- ) as pool:
- result_queue: Queue[Union[Article, Exception]] = manager.Queue(maxsize=1000)
-
- # Because multiprocessing.Pool does not support iterators as targets,
- # we wrap the article_task to write the articles to a queue instead of returning them directly.
- wrapped_article_task: Callable[[str], None] = queue_wrapper(result_queue, article_task)
-
- # To avoid 503 errors we spread tasks to not start all at once
- spread_article_task = random_sleep(wrapped_article_task, (0, 3))
-
- # To avoid restricting the article_task to use only pickleable objects, we serialize it using dill.
- serialized_article_task = dill_wrapper(spread_article_task)
-
- # Finally, we build an iterator around the queue, exhausting the queue until the pool is finished.
- yield from pool_queue_iter(pool.map_async(serialized_article_task, warc_paths), result_queue)
-
- logger.debug(f"Shutting down {type(self).__name__!r} ...")
-
- def _get_warc_paths(self) -> List[str]:
- # Date regex examples: https://regex101.com/r/yDX3G6/1
- date_pattern: Pattern[str] = re.compile(r"CC-NEWS-(?P\d{14})-")
-
- if self.start >= self.end:
- raise ValueError("Start date has to be < end date.")
-
- if self.start < datetime(2016, 8, 1):
- raise ValueError("The default, and earliest possible, start date is 2016/08/01.")
-
- if self.end > datetime.now():
- raise ValueError("The specified end date is in the future. We don't want to give spoilers, do we?")
-
- date_sequence: List[datetime] = list(rrule(MONTHLY, dtstart=self.start, until=self.end))
- urls: List[str] = [
- f"{self.server_address}crawl-data/CC-NEWS/{date.strftime('%Y/%m')}/warc.paths.gz" for date in date_sequence
- ]
-
- with tqdm(total=len(urls), desc="Loading WARC Paths", leave=False, disable=self.disable_tqdm) as bar:
-
- def load_paths(url: str) -> List[str]:
- with requests.Session() as session:
- paths = gzip.decompress(session.get(url).content).decode("utf-8").split()
- bar.update()
- return paths
-
- if self.processes == 0:
- nested_warc_paths = [load_paths(url) for url in urls]
- else:
- # use two threads per process, default two threads per core
- max_number_of_threads = self.processes * 2
-
- try:
- with ThreadPool(processes=min(len(urls), max_number_of_threads)) as pool:
- nested_warc_paths = pool.map(random_sleep(load_paths, (0, 3)), urls)
- finally:
- pool.join()
-
- warc_paths: Iterator[str] = more_itertools.flatten(nested_warc_paths)
-
- start_strf = self.start.strftime("%Y%m%d%H%M%S")
- end_strf = self.end.strftime("%Y%m%d%H%M%S")
-
- def filter_warc_path_by_date(path: str) -> bool:
- match: Optional[re.Match[str]] = date_pattern.search(path)
- if match is None:
- raise AssertionError(f"Invalid WARC path {path!r}")
- return start_strf <= match["date"] <= end_strf
-
- return sorted(
- (f"{self.server_address}{warc_path}" for warc_path in filter(filter_warc_path_by_date, warc_paths)),
- reverse=True,
- )
-
- def _build_article_iterator(
- self,
- publishers: Tuple[Publisher, ...],
- error_handling: Literal["suppress", "catch", "raise"],
- extraction_filter: Optional[ExtractionFilter],
- url_filter: Optional[URLFilter],
- language_filter: Optional[List[str]],
- skip_publishers_disallowing_training: bool = False,
- **kwargs,
- ) -> Iterator[Article]:
- if skip_publishers_disallowing_training:
- max_workers = self.processes if self.processes > 0 else min(len(publishers), 5)
- verified_publishers: List["Publisher"] = []
-
- def run_disallow_training(publisher: Publisher) -> bool:
- return publisher.disallows_training
-
- with ThreadPoolExecutor(max_workers=max_workers) as executor, session_handler.context(timeout=10):
- future_to_publisher = {
- executor.submit(run_disallow_training, publisher=publisher): publisher for publisher in publishers
- }
-
- warc_paths = tuple(self._get_warc_paths())
-
- for future in as_completed(future_to_publisher.keys()):
- publisher = future_to_publisher[future]
- try:
- if not future.result():
- verified_publishers.append(publisher)
- else:
- logger.warning(f"Skipping publisher {publisher.name!r} because it disallows training.")
- except Exception as exc:
- logger.warning(f"Could not verify training policy for {publisher.name!r}: {exc}", exc_info=True)
- publishers = tuple(verified_publishers)
-
- else:
- warc_paths = tuple(self._get_warc_paths())
-
- with get_proxy_tqdm(total=len(warc_paths), desc="Process WARC files", disable=self.disable_tqdm) as bar:
- article_task = partial(
- self._fetch_articles,
- publishers=publishers,
- error_handling=error_handling,
- extraction_filter=extraction_filter,
- url_filter=url_filter,
- language_filter=language_filter,
- bar=bar,
- )
-
- if self.processes == 0:
- yield from self._single_crawl(warc_paths, article_task)
- else:
- yield from self._parallel_crawl(warc_paths, article_task)
diff --git a/src/fundus/scraping/crawler/__init__.py b/src/fundus/scraping/crawler/__init__.py
new file mode 100644
index 000000000..c5dc5c70c
--- /dev/null
+++ b/src/fundus/scraping/crawler/__init__.py
@@ -0,0 +1,5 @@
+from fundus.scraping.crawler.base import CrawlerBase
+from fundus.scraping.crawler.ccnews import CCNewsCrawler
+from fundus.scraping.crawler.web import Crawler
+
+__all__ = ["CrawlerBase", "Crawler", "CCNewsCrawler"]
diff --git a/src/fundus/scraping/crawler/base.py b/src/fundus/scraping/crawler/base.py
new file mode 100644
index 000000000..b48b217c6
--- /dev/null
+++ b/src/fundus/scraping/crawler/base.py
@@ -0,0 +1,269 @@
+from __future__ import annotations
+
+import json
+from abc import ABC, abstractmethod
+from collections import defaultdict
+from pathlib import Path
+from typing import (
+ Dict,
+ Iterator,
+ List,
+ Optional,
+ Set,
+ Tuple,
+ Union,
+)
+
+import more_itertools
+
+from fundus.logging import create_logger
+from fundus.publishers.base_objects import FilteredPublisher, Publisher, PublisherGroup
+from fundus.scraping.article import Article
+from fundus.scraping.filter import ExtractionFilter, Requires, RequiresAll, URLFilter
+from fundus.scraping.session import session_handler
+from fundus.scraping.url import strip_query_and_fragment
+from fundus.utils.events import __EVENTS__, __MAIN_THREAD_ALIAS__
+from fundus.utils.timeout import Timeout
+
+logger = create_logger(__name__)
+
+PublisherType = Union[Publisher, PublisherGroup]
+
+
+class _CrawlState:
+ """Tracks per-publisher and total article counts, a dedup cache, and optionally the kept articles."""
+
+ def __init__(self, only_unique: bool, track_articles: bool) -> None:
+ """Initialize counters; only_unique enables URL dedup, track_articles retains accepted articles."""
+ self._only_unique = only_unique
+ self._track_articles = track_articles
+
+ self._response_cache: Set[str] = set()
+
+ self.article_count: Dict[str, int] = defaultdict(int)
+ self.total_count: int = 0
+ self.crawled_articles: Dict[str, List[Article]] = defaultdict(list)
+
+ def accept(self, article: Article) -> bool:
+ """Record the article in the running counts; return False if dropped as a duplicate."""
+ url = strip_query_and_fragment(article.html.responded_url)
+ if self._only_unique and url in self._response_cache:
+ return False
+ self._response_cache.add(url)
+ self.article_count[article.publisher] += 1
+ self.total_count += 1
+ if self._track_articles:
+ self.crawled_articles[article.publisher].append(article)
+ return True
+
+
+class CrawlerBase(ABC):
+ """Base class for crawlers: holds the publisher set and drives the shared crawl() loop.
+
+ Subclasses implement _build_article_iterator to supply articles from a concrete backend
+ (the live web in Crawler, the CC-NEWS archive in CCNewsCrawler); crawl() layers on the
+ publisher/attribute/language filtering, limits, timeout, dedup, and optional file export.
+ """
+
+ def __init__(self, *publishers: PublisherType) -> None:
+ """Collect and de-duplicate the publishers to crawl.
+
+ Args:
+ *publishers (PublisherType): Publishers or publisher groups to crawl. Groups are
+ flattened and duplicate publishers removed.
+
+ Raises:
+ ValueError: If no publishers are supplied.
+
+ """
+ self.publishers: List[Union[Publisher, FilteredPublisher]] = list(set(more_itertools.collapse(publishers)))
+ if not self.publishers:
+ raise ValueError("param of must include at least one publisher.")
+
+ @abstractmethod
+ def _build_article_iterator(
+ self,
+ publishers: Tuple[Publisher, ...],
+ raise_on_error: bool,
+ extraction_filter: Optional[ExtractionFilter],
+ url_filter: Optional[URLFilter],
+ language_filter: Optional[List[str]],
+ skip_publishers_disallowing_training: bool = False,
+ ) -> Iterator[Article]:
+ """Yield articles from the concrete backend. Implemented by each crawler subclass."""
+ raise NotImplementedError
+
+ def _on_timeout(self) -> None:
+ """Hook invoked when the crawl timeout fires; no-op by default, overridden where cleanup is needed."""
+ pass
+
+ def _on_publisher_limit_reached(self, publisher_name: str) -> None:
+ """Hook invoked when a publisher hits its per-publisher article limit; no-op by default."""
+ pass
+
+ @staticmethod
+ def _build_extraction_filter(only_complete: Union[bool, ExtractionFilter]) -> Optional[ExtractionFilter]:
+ """Resolve the only_complete argument into an ExtractionFilter, or None to keep everything."""
+ if isinstance(only_complete, bool):
+ return None if only_complete is False else RequiresAll()
+ return only_complete
+
+ def _filter_publishers(
+ self,
+ extraction_filter: Optional[ExtractionFilter],
+ language_filter: Optional[List[str]],
+ ) -> List[Union[Publisher, FilteredPublisher]]:
+ """Drop publishers that can't supply the required attributes or any requested language."""
+ if not isinstance(extraction_filter, Requires):
+ return list(self.publishers)
+
+ fitting_publishers: List[Union[Publisher, FilteredPublisher]] = []
+ for publisher in self.publishers:
+ supported_attributes = set(
+ more_itertools.flatten(collection.names for collection in publisher.parser.attribute_mapping.values())
+ )
+ if missing_attributes := extraction_filter.required_attributes - supported_attributes:
+ logger.warning(
+ f"The required attribute(s) `{', '.join(missing_attributes)}` "
+ f"is(are) not supported by {publisher.name}. Skipping publisher"
+ )
+ elif language_filter and not publisher.supports(languages=language_filter):
+ logger.warning(
+ f"None of the required language(s) `{', '.join(language_filter)}` "
+ f"is(are) supported by {publisher.name}. Skipping publisher"
+ )
+ else:
+ fitting_publishers.append(publisher)
+
+ if not fitting_publishers:
+ logger.error(
+ f"Could not find any fitting publishers for required attributes "
+ f"`{', '.join(extraction_filter.required_attributes)}`"
+ )
+
+ return fitting_publishers
+
+ @staticmethod
+ def _resolve_language_filter(
+ publishers: List[Union[Publisher, FilteredPublisher]],
+ language_filter: Optional[List[str]],
+ ) -> Optional[List[str]]:
+ """Merge the caller's language filter with each FilteredPublisher's own language filter."""
+ publisher_language_filter: Set[str] = set()
+ for publisher in publishers:
+ if isinstance(publisher, FilteredPublisher):
+ publisher_language_filter.update(publisher.language_filter)
+
+ if language_filter and publisher_language_filter:
+ return list(set(language_filter).union(publisher_language_filter))
+ if publisher_language_filter:
+ return list(publisher_language_filter)
+ return language_filter
+
+ @staticmethod
+ def _save_articles(path: Union[str, Path], articles: Dict[str, List[Article]]) -> None:
+ """Write the collected articles to as a JSON list, creating parent dirs as needed."""
+ if isinstance(path, str):
+ path = Path(path)
+ path.parent.mkdir(parents=True, exist_ok=True)
+ with open(path, "w", encoding="utf-8") as file:
+ logger.info(f"Writing crawled articles to {path!r}")
+ file.write(json.dumps(articles, default=lambda o: o.to_json(), ensure_ascii=False, indent=4))
+
+ def crawl(
+ self,
+ max_articles: Optional[int] = None,
+ max_articles_per_publisher: Optional[int] = None,
+ timeout: Optional[float] = None,
+ raise_on_error: bool = False,
+ only_complete: Union[bool, ExtractionFilter] = Requires("title", "body", "publishing_date"),
+ url_filter: Optional[URLFilter] = None,
+ language_filter: Optional[List[str]] = None,
+ only_unique: bool = True,
+ save_to_file: Union[None, str, Path] = None,
+ skip_publishers_disallowing_training: bool = False,
+ ) -> Iterator[Article]:
+ """Yields articles from the initialized crawlers.
+
+ Args:
+ max_articles (Optional[int]): Total number of articles to crawl. The iterator stops early
+ if fewer articles are available. If None, yields every retrievable article. Defaults to None.
+ max_articles_per_publisher (Optional[int]): Number of articles to crawl per publisher.
+ Overrides . Defaults to None.
+ timeout (Optional[float]): How long, in seconds, the crawler waits without receiving an
+ article before stopping. If <= 0 or None, it runs until all sources are exhausted.
+ Defaults to None.
+ raise_on_error (bool): If True, errors encountered while parsing an Article are raised
+ immediately, failing fast. If False, errors are skipped and attributes that fail to
+ extract fall back to their default values. Defaults to False.
+ only_complete (Union[bool, ExtractionFilter]): An ExtractionFilter, or a boolean shorthand.
+ False yields every article; True yields only fully extracted ones. Defaults to a filter
+ that passes articles with at least title, body, and publishing_date set.
+ url_filter (Optional[URLFilter]): A URLFilter callable used to skip articles by URL, both
+ before and after download. Applied to the requested and the responded URL. Defaults to None.
+ language_filter (Optional[List[str]]): Language codes to keep. Articles in other languages
+ are skipped and excluded from the article count. Defaults to None.
+ only_unique (bool): If True, deduplicates articles by their responded URL, yielding only
+ the first article seen per URL. Defaults to True.
+ save_to_file (Union[None, str, Path]): If set, collects the crawled articles and writes them
+ to the given file as a JSON list. Defaults to None.
+ skip_publishers_disallowing_training (bool): If True, skips publishers that disallow training.
+ This is only an indicator; anyone gathering training data with Fundus should still review
+ each publisher's terms of use. Defaults to False.
+
+ Yields:
+ Article: The extracted articles.
+ """
+ if max_articles == 0:
+ return
+
+ if max_articles_per_publisher:
+ if timeout is None or timeout < 120:
+ logger.warning(
+ "It is recommended to set a minimum of 120 seconds when using max_articles_per_publisher."
+ )
+ max_articles = None
+
+ extraction_filter = self._build_extraction_filter(only_complete)
+ fitting_publishers = self._filter_publishers(extraction_filter, language_filter)
+
+ if not fitting_publishers:
+ return
+
+ language_filter = self._resolve_language_filter(fitting_publishers, language_filter)
+
+ state = _CrawlState(only_unique=only_unique, track_articles=save_to_file is not None)
+
+ try:
+ with __EVENTS__.main_context(__MAIN_THREAD_ALIAS__), Timeout(
+ seconds=timeout,
+ silent=True,
+ callback=self._on_timeout,
+ ) as timer:
+ for article in self._build_article_iterator(
+ tuple(fitting_publishers),
+ raise_on_error,
+ extraction_filter,
+ url_filter,
+ language_filter,
+ skip_publishers_disallowing_training,
+ ):
+ if (
+ max_articles_per_publisher
+ and state.article_count[article.publisher] == max_articles_per_publisher
+ ):
+ self._on_publisher_limit_reached(article.publisher)
+ if state.total_count == len(self.publishers) * max_articles_per_publisher:
+ break
+ continue
+
+ timer.reset()
+ if state.accept(article):
+ yield article
+
+ if max_articles is not None and state.total_count == max_articles:
+ break
+ finally:
+ session_handler.close_sessions()
+ if save_to_file is not None:
+ self._save_articles(save_to_file, state.crawled_articles)
diff --git a/src/fundus/scraping/crawler/ccnews.py b/src/fundus/scraping/crawler/ccnews.py
new file mode 100644
index 000000000..017b1ce17
--- /dev/null
+++ b/src/fundus/scraping/crawler/ccnews.py
@@ -0,0 +1,255 @@
+from __future__ import annotations
+
+import gzip
+import logging.config
+import multiprocessing
+import os
+import random
+import re
+import time
+from concurrent.futures import ThreadPoolExecutor, as_completed
+from datetime import datetime
+from functools import partial
+from multiprocessing import Manager, Pool
+from multiprocessing.pool import ThreadPool
+from typing import Callable, Iterator, List, Optional, Pattern, Tuple
+
+import more_itertools
+import requests
+from dateutil.rrule import MONTHLY, rrule
+from tqdm import tqdm
+
+from fundus.logging import create_logger, get_current_config
+from fundus.publishers.base_objects import Publisher
+from fundus.scraping.article import Article
+from fundus.scraping.crawler.base import CrawlerBase, PublisherType
+from fundus.scraping.crawler.queueing import (
+ enqueue_results,
+ iter_pool_results,
+)
+from fundus.scraping.filter import ExtractionFilter, URLFilter
+from fundus.scraping.pipeline import Pipeline
+from fundus.scraping.pipeline.source.ccnews import CCNewsSource, WarcFileLoadError
+from fundus.scraping.session import session_handler
+from fundus.utils.concurrency import dill_wrapper, get_proxy_tqdm
+from fundus.utils.events import __EVENTS__, __MAIN_THREAD_ALIAS__
+from fundus.utils.timing import random_sleep
+
+logger = create_logger(__name__)
+
+
+class CCNewsCrawler(CrawlerBase):
+ """Crawler for the CC-NEWS archive: extracts articles from Common Crawl's monthly WARC files.
+
+ Resolves the WARC paths covering the requested date range, then processes each archive with a
+ CCNewsSource-backed Pipeline. Runs across multiple processes by default; WARC downloads are
+ retried on transient errors.
+ """
+
+ def __init__(
+ self,
+ *publishers: PublisherType,
+ start: datetime = datetime(2016, 8, 1),
+ end: Optional[datetime] = None,
+ processes: int = -1,
+ retries: int = 3,
+ disable_tqdm: bool = False,
+ server_address: str = "https://data.commoncrawl.org/",
+ ):
+ """Initializes a crawler for the CC-NEWS dataset.
+
+ Args:
+ *publishers: The publishers to crawl.
+ start: Start date for WARC records. Defaults to 2016/8/1.
+ end: End date for WARC records. Defaults to datetime.now().
+ processes: Number of additional processes. -1 uses all CPU cores. Defaults to -1.
+ retries: Retry count on connection errors. Defaults to 3.
+ disable_tqdm: Disable tqdm progress bars. Defaults to False.
+ server_address: CC-NEWS server address.
+ """
+ super().__init__(*publishers)
+
+ self.start = start
+ self.end = end if end is not None else datetime.now()
+
+ if start >= self.end:
+ raise ValueError("Start date has to be < end date.")
+ if start < datetime(2016, 8, 1):
+ raise ValueError("The default, and earliest possible, start date is 2016/08/01.")
+ if self.end > datetime.now():
+ raise ValueError("The specified end date is in the future.")
+
+ if processes < 0:
+ logger.warning(
+ f"{type(self).__name__} will automatically use all available cores: {os.cpu_count()}. "
+ f"For optimal performance, we recommend manually setting the number of processes "
+ f"using the parameter. A good rule of thumb is to allocate `one process per "
+ f"200 Mbps of bandwidth`."
+ )
+ self.processes = os.cpu_count() or 0
+ else:
+ self.processes = processes
+
+ self.retries = retries
+ self.disable_tqdm = disable_tqdm
+ self.server_address = server_address
+
+ def _on_timeout(self) -> None:
+ """Set the main-thread-aliased stop event on timeout."""
+ if self.processes > 0:
+ __EVENTS__.set_event("stop", __MAIN_THREAD_ALIAS__)
+
+ def _fetch_articles(
+ self,
+ warc_path: str,
+ publishers: Tuple[Publisher, ...],
+ raise_on_error: bool,
+ extraction_filter: Optional[ExtractionFilter] = None,
+ url_filter: Optional[URLFilter] = None,
+ language_filter: Optional[List[str]] = None,
+ bar: Optional[tqdm] = None,
+ ) -> Iterator[Article]:
+ """Process one WARC file through a pipeline, retrying on WarcFileLoadError up to self.retries."""
+ retries: int = 0
+ while True:
+ pipeline = Pipeline(
+ CCNewsSource(*publishers, warc_path=warc_path),
+ publishers=publishers,
+ )
+ try:
+ yield from pipeline.run(raise_on_error, extraction_filter, url_filter, language_filter)
+ except WarcFileLoadError as exception:
+ if retries >= self.retries:
+ logger.error(f"Failed to load WARC file {warc_path!r} after {retries} retries")
+ break
+ retries += 1
+ sleep_time = (30 * retries) + random.uniform(-2, 2)
+ logger.warning(
+ f"Could not load WARC file {warc_path!r}. Retry after {sleep_time:.2f} seconds: {exception!r}"
+ )
+ time.sleep(sleep_time)
+ else:
+ break
+
+ if bar is not None:
+ bar.update()
+
+ @staticmethod
+ def _single_crawl(
+ warc_paths: Tuple[str, ...], article_task: Callable[[str], Iterator[Article]]
+ ) -> Iterator[Article]:
+ """Process every WARC path sequentially in the calling process (no multiprocessing)."""
+ for warc_path in warc_paths:
+ yield from article_task(warc_path)
+
+ def _parallel_crawl(
+ self, warc_paths: Tuple[str, ...], article_task: Callable[[str], Iterator[Article]]
+ ) -> Iterator[Article]:
+ """Process WARC paths across a process pool, funnelling articles through a managed queue."""
+ if multiprocessing.get_start_method() == "spawn":
+ logging_config = get_current_config()
+ initializer = partial(logging.config.dictConfig, config=logging_config)
+ else:
+ initializer = None
+
+ with Manager() as manager, Pool(
+ processes=min(self.processes, len(warc_paths)),
+ initializer=initializer,
+ ) as pool:
+ result_queue = manager.Queue(maxsize=1000)
+ wrapped_task = enqueue_results(result_queue, article_task)
+ spread_task = random_sleep(wrapped_task, (0, 3))
+ serialized_task = dill_wrapper(spread_task)
+ yield from iter_pool_results(pool.map_async(serialized_task, warc_paths), result_queue)
+ logger.debug(f"Shutting down {type(self).__name__!r} ...")
+
+ def _get_warc_paths(self) -> List[str]:
+ """Resolve and return the WARC archive URLs covering [start, end], newest first."""
+ date_pattern: Pattern[str] = re.compile(r"CC-NEWS-(?P\d{14})-")
+
+ date_sequence: List[datetime] = list(rrule(MONTHLY, dtstart=self.start, until=self.end))
+ urls: List[str] = [
+ f"{self.server_address}crawl-data/CC-NEWS/{date.strftime('%Y/%m')}/warc.paths.gz" for date in date_sequence
+ ]
+
+ with tqdm(total=len(urls), desc="Loading WARC Paths", leave=False, disable=self.disable_tqdm) as bar:
+
+ def load_paths(url: str) -> List[str]:
+ with requests.Session() as session:
+ paths = gzip.decompress(session.get(url).content).decode("utf-8").split()
+ bar.update()
+ return paths
+
+ if self.processes == 0:
+ nested_warc_paths = [load_paths(url) for url in urls]
+ else:
+ max_threads = self.processes * 2
+ with ThreadPool(processes=min(len(urls), max_threads)) as pool:
+ nested_warc_paths = pool.map(random_sleep(load_paths, (0, 3)), urls)
+
+ warc_paths_iter = more_itertools.flatten(nested_warc_paths)
+ start_strf = self.start.strftime("%Y%m%d%H%M%S")
+ end_strf = self.end.strftime("%Y%m%d%H%M%S")
+
+ def filter_by_date(path: str) -> bool:
+ match = date_pattern.search(path)
+ if match is None:
+ raise AssertionError(f"Invalid WARC path {path!r}")
+ return start_strf <= match["date"] <= end_strf
+
+ return sorted(
+ (f"{self.server_address}{p}" for p in filter(filter_by_date, warc_paths_iter)),
+ reverse=True,
+ )
+
+ def _build_article_iterator(
+ self,
+ publishers: Tuple[Publisher, ...],
+ raise_on_error: bool,
+ extraction_filter: Optional[ExtractionFilter],
+ url_filter: Optional[URLFilter],
+ language_filter: Optional[List[str]],
+ skip_publishers_disallowing_training: bool = False,
+ ) -> Iterator[Article]:
+ """Yield articles from the CC-NEWS archive backend: optionally drop publishers that disallow training,
+ resolve the WARC paths covering the date range, then dispatch them to the sequential or multi-process crawl.
+ """
+ if skip_publishers_disallowing_training:
+ max_workers = self.processes if self.processes > 0 else min(len(publishers), 5)
+ verified_publishers: List[Publisher] = []
+
+ with ThreadPoolExecutor(max_workers=max_workers) as executor, session_handler.context(timeout=10):
+ future_to_publisher = {
+ executor.submit(lambda p: p.disallows_training, publisher): publisher for publisher in publishers
+ }
+
+ # resolve warc paths within the ThreadPoolExecutor context
+ warc_paths = tuple(self._get_warc_paths())
+
+ for future in as_completed(future_to_publisher):
+ publisher = future_to_publisher[future]
+ try:
+ if not future.result():
+ verified_publishers.append(publisher)
+ else:
+ logger.warning(f"Skipping publisher {publisher.name!r} because it disallows training.")
+ except Exception as exc:
+ logger.warning(f"Could not verify training policy for {publisher.name!r}: {exc}", exc_info=True)
+ publishers = tuple(verified_publishers)
+ else:
+ warc_paths = tuple(self._get_warc_paths())
+
+ with get_proxy_tqdm(total=len(warc_paths), desc="Process WARC files", disable=self.disable_tqdm) as bar:
+ article_task = partial(
+ self._fetch_articles,
+ publishers=publishers,
+ raise_on_error=raise_on_error,
+ extraction_filter=extraction_filter,
+ url_filter=url_filter,
+ language_filter=language_filter,
+ bar=bar,
+ )
+ if self.processes == 0:
+ yield from self._single_crawl(warc_paths, article_task)
+ else:
+ yield from self._parallel_crawl(warc_paths, article_task)
diff --git a/src/fundus/scraping/crawler/queueing.py b/src/fundus/scraping/crawler/queueing.py
new file mode 100644
index 000000000..9b20cd12c
--- /dev/null
+++ b/src/fundus/scraping/crawler/queueing.py
@@ -0,0 +1,136 @@
+from __future__ import annotations
+
+import multiprocessing
+import time
+import traceback
+from functools import wraps
+from multiprocessing.pool import MapResult
+from queue import Empty, Full, Queue
+from typing import Any, Callable, Iterator, Tuple, Type, TypeVar, Union
+
+from typing_extensions import ParamSpec
+
+from fundus.logging import create_logger
+from fundus.utils.concurrency import get_execution_context
+from fundus.utils.events import __EVENTS__, __MAIN_THREAD_ALIAS__
+
+logger = create_logger(__name__)
+
+_T = TypeVar("_T")
+_P = ParamSpec("_P")
+
+
+class RemoteException(Exception):
+ """Carries a worker thread/process exception (with formatted traceback) back to the consumer via the queue."""
+
+
+def enqueue_results(
+ queue: Queue[Union[_T, Exception]],
+ target: Callable[_P, Iterator[_T]],
+ silenced_exceptions: Tuple[Type[BaseException], ...] = (),
+) -> Callable[_P, None]:
+ """Wrap a result-yielding callable so it pushes its results onto the queue instead of returning them.
+
+ The wrapped callable drives ``target`` to exhaustion, putting each result onto the queue.
+ When the queue is full it blocks until space frees up, bailing out early if the main-thread
+ ``stop`` event is set. Exceptions in ``silenced_exceptions`` are swallowed; any other exception
+ is forwarded to the consumer as a :class:`RemoteException` put onto the queue (so it surfaces in
+ ``iter_pool_results`` rather than crashing the worker).
+
+ Args:
+ queue: The buffer queue results (and forwarded exceptions) are pushed onto.
+ target: A callable returning an iterator of results to enqueue.
+ silenced_exceptions: Exception types to swallow instead of forwarding.
+
+ Returns:
+ Callable[_P, None]: The wrapped target, which returns nothing and enqueues instead.
+ """
+
+ @wraps(target)
+ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
+ def _delivered(result: _T) -> bool:
+ """Block until result is queued; return False if the main-thread stop event aborts the wait."""
+ while True:
+ try:
+ queue.put_nowait(result)
+ except Full:
+ if __EVENTS__.is_event_set("stop", __MAIN_THREAD_ALIAS__):
+ return False
+ time.sleep(0.05)
+ else:
+ return True
+
+ def _forward_exception(exception: Exception) -> None:
+ """Push a worker exception to the consumer as a RemoteException with a formatted traceback."""
+ tb_str = "".join(traceback.TracebackException.from_exception(exception).format())
+ context, ident = get_execution_context()
+ alias = __EVENTS__.get_alias(ident, "") if ident is not None else ""
+ queue.put(
+ RemoteException(
+ f"There was a(n) {type(exception).__name__!r} occurring in {context} "
+ f"with ident {ident} ({alias})\n{tb_str}"
+ )
+ )
+ logger.debug(f"Encountered remote exception in thread {ident} ({alias}): {exception!r}")
+
+ try:
+ for obj in target(*args, **kwargs):
+ if not _delivered(obj):
+ return
+ except silenced_exceptions:
+ pass
+ except Exception as err:
+ _forward_exception(err)
+
+ return wrapper
+
+
+def iter_pool_results(handle: MapResult[Any], queue: Queue[Union[_T, Exception]]) -> Iterator[_T]:
+ """Yield results from a pool's queue as its workers produce them.
+
+ Results are drained from the queue and yielded as they arrive. When the queue runs
+ empty, the pool handle is polled: if every job has finished, any remaining buffered
+ results are flushed and iteration ends; otherwise iteration waits for more results.
+ If the main thread's ``stop`` event is set while waiting, the event is cleared and
+ iteration ends immediately without draining the queue. Any exception a worker
+ forwarded through the queue is re-raised to the consumer.
+
+ Args:
+ handle: The ``MapResult`` handle of the underlying multiprocessing pool.
+ queue: The queue workers push their results (and forwarded exceptions) onto.
+
+ Yields:
+ _T: Each result pulled from the queue.
+ """
+
+ def _next_result() -> _T:
+ """Pop the next buffered result, re-raising any exception a worker forwarded through the queue."""
+ if isinstance(nxt := queue.get_nowait(), Exception):
+ raise Exception("There was an exception occurring in a remote thread/process") from nxt
+ return nxt
+
+ def _pool_finished() -> bool:
+ """Return True once every job in the pool has completed."""
+ try:
+ handle.get(timeout=0.01)
+ except multiprocessing.TimeoutError:
+ return False
+ return True
+
+ # Phase 1: drain results as the pool produces them, until it finishes or is stopped.
+ while True:
+ try:
+ result = _next_result()
+ except Empty:
+ if _pool_finished():
+ break
+ # listen for stop-event set for main-thread
+ if __EVENTS__.is_event_set("stop", __MAIN_THREAD_ALIAS__):
+ __EVENTS__.clear_event("stop", __MAIN_THREAD_ALIAS__)
+ return
+ continue
+ yield result
+
+ # Phase 2: pool is done, so flush whatever results are still buffered.
+ while not queue.empty():
+ yield _next_result()
diff --git a/src/fundus/scraping/crawler/web.py b/src/fundus/scraping/crawler/web.py
new file mode 100644
index 000000000..9088fea57
--- /dev/null
+++ b/src/fundus/scraping/crawler/web.py
@@ -0,0 +1,246 @@
+from __future__ import annotations
+
+import contextlib
+from functools import partial, wraps
+from multiprocessing.pool import ThreadPool
+from queue import Queue
+from typing import Callable, Iterator, List, Optional, Tuple, Type, Union
+
+import more_itertools
+from more_itertools import roundrobin
+
+from fundus.logging import create_logger
+from fundus.publishers.base_objects import Publisher
+from fundus.scraping.article import Article
+from fundus.scraping.crawler.base import CrawlerBase, PublisherType
+from fundus.scraping.crawler.queueing import enqueue_results, iter_pool_results
+from fundus.scraping.delay import Delay
+from fundus.scraping.filter import ExtractionFilter, URLFilter
+from fundus.scraping.pipeline import Pipeline
+from fundus.scraping.pipeline.source.web import WebSource
+from fundus.scraping.session import CrashThread
+from fundus.scraping.url import URLSource
+from fundus.utils.events import __EVENTS__
+
+logger = create_logger(__name__)
+
+
+def publisher_context_wrapper(func: Callable[[Publisher], None]) -> Callable[[Publisher], None]:
+ """Wraps a callable to register an ``__EVENTS__`` alias context for the publisher argument.
+
+ The alias is entered as the very first thing the thread does and stays alive for the
+ entire call — including any exception handling in the caller — so that
+ ``__EVENTS__.get_alias`` always resolves while the thread is running.
+
+ Args:
+ func: A callable whose first positional argument is a :class:`Publisher`.
+
+ Returns:
+ The wrapped callable.
+ """
+
+ @wraps(func)
+ def wrapper(publisher: Publisher) -> None:
+ with __EVENTS__.context(publisher.name):
+ func(publisher)
+
+ return wrapper
+
+
+class Crawler(CrawlerBase):
+ """Crawler for the live web: fetches articles by requesting each publisher's URL sources.
+
+ Builds one WebSource-backed Pipeline per URL source and, when threading is enabled, runs one
+ thread per publisher. Honors robots.txt and crawl delays.
+ """
+
+ def __init__(
+ self,
+ *publishers: PublisherType,
+ restrict_sources_to: Optional[List[Type[URLSource]]] = None,
+ ignore_deprecated: bool = False,
+ delay: Optional[Union[int, float, Delay]] = 1.0,
+ threading: bool = True,
+ ignore_robots: bool = False,
+ ignore_crawl_delay: bool = False,
+ impersonate: bool = False,
+ ):
+ """Crawler for fetching articles from the web.
+
+ Examples:
+ >>> from fundus import PublisherCollection, Crawler
+ >>> crawler = Crawler(*PublisherCollection)
+ >>> for article in crawler.crawl():
+ >>> print(article)
+
+ Args:
+ *publishers: The publishers to crawl.
+ restrict_sources_to: Restrict sources to the given URL source types.
+ ignore_deprecated: Skip deprecated publishers. Defaults to False.
+ delay: Delay in seconds between article downloads. Defaults to 1.0.
+ threading: Use one thread per publisher. Defaults to True.
+ ignore_robots: Bypass robots.txt filtering. Defaults to False.
+ ignore_crawl_delay: Ignore crawl-delay from robots.txt. Defaults to False.
+ impersonate: If True, publishers that declare an `impersonate` browser profile
+ will use curl_cffi's TLS/HTTP fingerprint impersonation. If False (default),
+ the profile is ignored and requests go out with Fundus' regular fingerprint —
+ publishers gated by anti-bot checks will likely return 4xx/5xx. Defaults to False.
+ """
+
+ fitting_publishers = []
+ for publisher in more_itertools.collapse(publishers):
+ if publisher.deprecated and ignore_deprecated:
+ logger.warning(f"Skipping deprecated publisher: {publisher.name}")
+ else:
+ fitting_publishers.append(publisher)
+ if not fitting_publishers:
+ raise ValueError(
+ "All given publishers are deprecated. Either set to `False` or "
+ "include at least one publisher that isn't deprecated."
+ )
+
+ super().__init__(*fitting_publishers)
+
+ self.restrict_sources_to = restrict_sources_to
+ self.threading = threading
+ self.ignore_robots = ignore_robots
+ self.ignore_crawl_delay = ignore_crawl_delay
+ self.impersonate = impersonate
+
+ self._delay = self._resolve_delay(delay)
+
+ @staticmethod
+ def _resolve_delay(delay: Optional[Union[int, float, Delay]]) -> Optional[Delay]:
+ """Normalize the delay argument into a Delay callable (or None); wraps a constant in a thunk."""
+ if delay is None:
+ return None
+ if isinstance(delay, (int, float)):
+
+ def constant_delay() -> float:
+ return delay
+
+ return constant_delay
+ if callable(delay):
+ return delay
+ raise TypeError("param of must be float, Delay, or None")
+
+ def _build_pipelines(self, publisher: Publisher) -> List[Pipeline]:
+ """Build one WebSource-backed Pipeline per (optionally restricted) URL source of the publisher."""
+ if self.restrict_sources_to:
+ url_sources = list(
+ more_itertools.flatten(
+ publisher.source_mapping[source_type]
+ for source_type in self.restrict_sources_to
+ if source_type in publisher.source_mapping
+ )
+ )
+ else:
+ url_sources = list(more_itertools.flatten(publisher.source_mapping.values()))
+
+ if not url_sources and self.restrict_sources_to:
+ logger.warning(
+ f"No sources of type {[s.__name__ for s in self.restrict_sources_to]} "
+ f"found for publisher {publisher.name}. Skipping publisher."
+ )
+ return []
+
+ return [
+ Pipeline(
+ WebSource(
+ url_source=url_source,
+ publisher=publisher,
+ delay=self._delay,
+ url_filter=publisher.url_filter,
+ query_parameters=publisher.query_parameter,
+ ignore_robots=self.ignore_robots,
+ ignore_crawl_delay=self.ignore_crawl_delay,
+ impersonate=self.impersonate,
+ stop_event=__EVENTS__.get("stop"),
+ ),
+ publishers=[publisher],
+ )
+ for url_source in url_sources
+ ]
+
+ def _on_publisher_limit_reached(self, publisher_name: str) -> None:
+ """Set the publisher's stop event so its thread halts once the per-publisher limit is hit."""
+ if self.threading and not __EVENTS__.is_event_set("stop", publisher_name):
+ __EVENTS__.set_event("stop", publisher_name)
+
+ def _fetch_articles(
+ self,
+ publisher: Publisher,
+ raise_on_error: bool,
+ extraction_filter: Optional[ExtractionFilter] = None,
+ url_filter: Optional[URLFilter] = None,
+ language_filter: Optional[List[str]] = None,
+ skip_publishers_disallowing_training: bool = False,
+ ) -> Iterator[Article]:
+ """Run every pipeline for one publisher, yielding its articles; skip it if training is disallowed."""
+ if skip_publishers_disallowing_training and publisher.disallows_training:
+ logger.info(f"Skipping publisher {publisher.name} because it disallows training.")
+ return
+ elif publisher.robots.disallow_all():
+ logger.info(f"Skipping publisher {publisher.name} because it disallows all URLs.")
+ return
+
+ for pipeline in self._build_pipelines(publisher):
+ yield from pipeline.run(raise_on_error, extraction_filter, url_filter, language_filter)
+
+ @staticmethod
+ def _single_crawl(
+ publishers: Tuple[Publisher, ...], article_task: Callable[[Publisher], Iterator[Article]]
+ ) -> Iterator[Article]:
+ """Round-robin articles from all publishers in the calling thread (no threading)."""
+ yield from roundrobin(*[article_task(publisher) for publisher in publishers])
+
+ @contextlib.contextmanager
+ def _thread_pool(self, processes: int) -> Iterator[ThreadPool]:
+ """Yield a ThreadPool, signalling all publisher threads to stop and joining them on exit."""
+ pool = ThreadPool(processes or None)
+ try:
+ yield pool
+ finally:
+ logger.debug(f"Shutting down {type(self).__name__!r} ...")
+ pool.close()
+ __EVENTS__.set_for_all("stop", future=True, active_only=True)
+ pool.join()
+ __EVENTS__.clear_for_all("stop")
+ logger.debug("Shutdown done")
+
+ def _threaded_crawl(
+ self, publishers: Tuple[Publisher, ...], article_task: Callable[[Publisher], Iterator[Article]]
+ ) -> Iterator[Article]:
+ """Run one publisher per pool thread, funnelling their articles through a shared queue."""
+ result_queue: Queue[Union[Article, Exception]] = Queue(len(publishers))
+ wrapped_task = publisher_context_wrapper(
+ enqueue_results(result_queue, article_task, silenced_exceptions=(CrashThread,))
+ )
+
+ with self._thread_pool(len(publishers)) as pool:
+ yield from iter_pool_results(pool.map_async(wrapped_task, publishers), result_queue)
+
+ def _build_article_iterator(
+ self,
+ publishers: Tuple[Publisher, ...],
+ raise_on_error: bool,
+ extraction_filter: Optional[ExtractionFilter],
+ url_filter: Optional[URLFilter],
+ language_filter: Optional[List[str]],
+ skip_publishers_disallowing_training: bool = False,
+ ) -> Iterator[Article]:
+ """Yield articles from the live-web backend: bind the per-publisher article task,
+ then dispatch the publishers to the threaded or single-threaded crawl.
+ """
+ article_task = partial(
+ self._fetch_articles,
+ raise_on_error=raise_on_error,
+ extraction_filter=extraction_filter,
+ url_filter=url_filter,
+ language_filter=language_filter,
+ skip_publishers_disallowing_training=skip_publishers_disallowing_training,
+ )
+ if self.threading:
+ yield from self._threaded_crawl(publishers, article_task)
+ else:
+ yield from self._single_crawl(publishers, article_task)
diff --git a/src/fundus/scraping/filter.py b/src/fundus/scraping/filter.py
index 35c6f22e2..7a5c3a874 100644
--- a/src/fundus/scraping/filter.py
+++ b/src/fundus/scraping/filter.py
@@ -7,14 +7,7 @@
def inverse(filter_func: Callable[P, bool]) -> Callable[P, bool]:
- """Logical not operator that can be used on filters
-
- Args:
- filter_func: The filter function to inverse.
-
- Returns:
- bool: boolean value of the evaluation
- """
+ """Returns a filter that evaluates to the logical NOT of `filter_func`."""
def __call__(*args: P.args, **kwargs: P.kwargs) -> bool:
return not filter_func(*args, **kwargs)
@@ -23,14 +16,7 @@ def __call__(*args: P.args, **kwargs: P.kwargs) -> bool:
def lor(*filters: Callable[P, bool]) -> Callable[P, bool]:
- """Logical or operator that can be used on filters
-
- Args:
- *filters: The filter functions to or.
-
- Returns:
- bool: boolean value of the evaluation
- """
+ """Returns a filter that passes when any of `filters` passes (logical OR)."""
def __call__(*args: P.args, **kwargs: P.kwargs) -> bool:
return any(f(*args, **kwargs) for f in filters)
@@ -39,14 +25,7 @@ def __call__(*args: P.args, **kwargs: P.kwargs) -> bool:
def land(*filters: Callable[P, bool]) -> Callable[P, bool]:
- """Logical and operator that can be used on filters
-
- Args:
- *filters: The filter functions to and.
-
- Returns:
- bool: boolean value of the evaluation
- """
+ """Returns a filter that passes only when all of `filters` pass (logical AND)."""
def __call__(*args: P.args, **kwargs: P.kwargs) -> bool:
return all(f(*args, **kwargs) for f in filters)
@@ -55,62 +34,56 @@ def __call__(*args: P.args, **kwargs: P.kwargs) -> bool:
class URLFilter(Protocol):
- """Protocol to define filter used before article download.
+ """Filter applied before article download. True means filtered out, False means kept."""
- Filters satisfying this protocol should work inverse to build in filter(),
- so that True gets filtered and False don't.
- """
-
- def __call__(self, url: str) -> bool:
- """Filters a website, represented by a given , on the criterion if it represents an
-
- Args:
- url: The url the evaluation should be based on.
-
- Returns:
- bool: True if an should be filtered out and not
- considered for extraction, False otherwise.
-
- """
- ...
+ def __call__(self, url: str) -> bool: ...
def regex_filter(regex: str) -> URLFilter:
+ """Returns a URLFilter that filters out URLs matching `regex`."""
+ pattern = re.compile(regex)
+
def url_filter(url: str) -> bool:
- return bool(re.search(regex, url))
+ return bool(pattern.search(url))
return url_filter
class SupportsBool(Protocol):
+ """Anything convertible to bool; the return type of an ExtractionFilter call."""
+
def __bool__(self) -> bool: ...
class ExtractionFilter(Protocol):
- """Protocol to define filters used after article extraction.
+ """Callable protocol for filters applied after article extraction.
+
+ A truthy return value excludes the article; falsy keeps it — intentionally
+ inverse to Python's built-in filter().
- Filters satisfying this protocol should work inverse to build in filter(),
- so that True gets filtered and False don't.
+ Example — exclude articles whose body is shorter than 500 characters::
+
+ def min_body_length(extraction: Dict[str, Any]) -> bool:
+ body = extraction.get("body")
+ return not body or len(str(body)) < 500
"""
def __call__(self, extraction: Dict[str, Any]) -> SupportsBool:
- """This should implement a selection based on .
-
- Extracted will be a dictionary returned by a parser mapping the attribute
- names of the parser to the extracted values.
+ """Evaluate the extraction and return whether it should be filtered out.
Args:
- extraction: The extracted values the evaluation
- should be based on.
+ extraction: Maps attribute names to their extracted values, as returned
+ by a parser. Attributes absent from the article are not present in the dict.
Returns:
- bool: True if extraction should be filtered out, False otherwise.
-
+ A truthy value to exclude the article, falsy to keep it.
"""
...
class FilterResultWithMissingAttributes:
+ """Return value of Requires.__call__. Truthy when one or more attributes are missing or falsy."""
+
def __init__(self, *attributes: str) -> None:
self.missing_attributes = attributes
@@ -118,7 +91,8 @@ def __bool__(self) -> bool:
return bool(self.missing_attributes)
-def _guarded_bool(value: Any):
+def _eval_unless_bool(value: Any) -> bool:
+ """Booleans always pass; only non-boolean values are evaluated with bool()."""
if isinstance(value, bool):
return True
else:
@@ -126,50 +100,60 @@ def _guarded_bool(value: Any):
class Requires:
- def __init__(self, *required_attributes: str, eval_booleans: bool = True) -> None:
- """Class to filter extractions based on attribute values
+ """Filters extractions based on the presence and truthiness of named attributes.
- If a required_attribute is not present in the extracted data or evaluates to bool() -> False,
- this filter won't be passed. By default, required boolean attributes are evaluated with bool().
+ When called with an extraction dict, returns a FilterResultWithMissingAttributes
+ that is truthy if any required attribute is absent, falsy, or an Exception.
+ With no required attributes specified, all keys in the extraction are evaluated.
- I.e.,
+ By default, boolean attributes are evaluated with bool():
- Requires("free_access")({"free_access": False}) -> will be filtered out
+ Requires("free_access")({"free_access": False}) # filtered out
- You can alter this behaviour by setting `eval_bool=False`
+ Set eval_booleans=False to let boolean values pass unconditionally:
- I.e.,
+ Requires("free_access", eval_booleans=False)({"free_access": False}) # passes
- Requires("free_access", eval_bool=False)({"free_access": False}) -> will pass
+ Args:
+ *required_attributes: Attributes that must be present and truthy. If none are
+ given, all keys in the extraction are evaluated.
+ eval_booleans: If True, boolean values are evaluated with bool(). If False,
+ boolean values always pass. Defaults to True.
+ """
- Args:
- *required_attributes: Attributes required to evaluate to True in order to
- pass the filter. If no attributes are given, all attributes will be evaluated
- eval_bool: If True the boolean values will also be evaluated with bool().
- If False, all boolean values evaluate to True. Defaults to True.
- """
+ def __init__(self, *required_attributes: str, eval_booleans: bool = True) -> None:
self.required_attributes = set(required_attributes)
# somehow mypy does not recognize bool as callable :(
- self._eval: Callable[[Any], bool] = bool if eval_booleans else _guarded_bool # type: ignore[assignment]
+ self._eval: Callable[[Any], bool] = bool if eval_booleans else _eval_unless_bool # type: ignore[assignment]
+
+ def _is_missing(self, value: Any) -> bool:
+ """True if value is absent, falsy, or an Exception."""
+ return not self._eval(value) or isinstance(value, Exception)
def __call__(self, extraction: Dict[str, Any]) -> FilterResultWithMissingAttributes:
- missing_attributes = [
- attribute
- for attribute in self.required_attributes or extraction.keys()
- if not self._eval(value := extraction.get(attribute)) or isinstance(value, Exception)
- ]
+ """Evaluate the extraction against the required attributes.
+
+ Args:
+ extraction: A dictionary mapping attribute names to their extracted values.
+
+ Returns:
+ FilterResultWithMissingAttributes that is truthy if any required attribute
+ is absent, falsy, or an Exception.
+ """
+ attributes = self.required_attributes if self.required_attributes else extraction.keys()
+ missing_attributes = [attribute for attribute in attributes if self._is_missing(extraction.get(attribute))]
return FilterResultWithMissingAttributes(*missing_attributes)
class RequiresAll(Requires):
- def __init__(self, eval_booleans: bool = False) -> None:
- """Name wrap for Requires(eval_booleans=False)
+ """Requires all attributes in the extraction to be present and truthy.
- This is for readability only. By default, it requires all non-boolean attributes of the extraction
- to evaluate to True. Set `eval_booleans=True` to include boolean values in the evaluation as well.
- See class:Requires docstring for more information.
+ Equivalent to Requires() with no specified attributes, but with eval_booleans=False
+ by default so boolean attributes are not counted as missing regardless of their value.
- Args:
- eval_booleans: See Requires docstring for more information. Defaults to False.
- """
+ Args:
+ eval_booleans: If True, boolean values are also evaluated. Defaults to False.
+ """
+
+ def __init__(self, eval_booleans: bool = False) -> None:
super().__init__(eval_booleans=eval_booleans)
diff --git a/src/fundus/scraping/html.py b/src/fundus/scraping/html.py
index 33e0517ff..da1e70666 100644
--- a/src/fundus/scraping/html.py
+++ b/src/fundus/scraping/html.py
@@ -1,337 +1,76 @@
-import time
-from dataclasses import dataclass
-from datetime import datetime
-from typing import Callable, Dict, Iterable, Iterator, List, Optional, Protocol, Union
-from urllib.parse import urlparse
+from __future__ import annotations
-import chardet
-import requests
-from curl_cffi.requests.exceptions import ConnectionError, HTTPError, ReadTimeout
-from fastwarc import ArchiveIterator, WarcRecord, WarcRecordType
+from dataclasses import dataclass, fields
+from datetime import datetime
+from typing import Dict
-from fundus.logging import create_logger
-from fundus.publishers.base_objects import Publisher, Robots
-from fundus.scraping.delay import Delay
-from fundus.scraping.filter import URLFilter
-from fundus.scraping.session import _default_header, session_handler
-from fundus.scraping.url import URLSource, is_valid_url
-from fundus.utils.events import __EVENTS__
+from fundus.utils.serialization import JSONVal, serialize_value
__all__ = [
"HTML",
"SourceInfo",
- "WarcSourceInfo",
- "WebSourceInfo",
- "HTMLSource",
- "WebSource",
- "CCNewsSource",
]
-logger = create_logger(__name__)
-
-
-@dataclass(frozen=True)
-class HTML:
- requested_url: str
- responded_url: str
- content: str
- crawl_date: datetime
- source_info: "SourceInfo"
-
@dataclass(frozen=True)
class SourceInfo:
- publisher: str
-
-
-@dataclass(frozen=True)
-class WarcSourceInfo(SourceInfo):
- warc_path: str
- warc_headers: Dict[str, str]
- http_headers: Dict[str, str]
-
-
-@dataclass(frozen=True)
-class WebSourceInfo(SourceInfo):
- type: str
- url: str
+ """Provenance metadata for an HTML record.
+ The base form carries only the publisher's name; needs to be pickable. Per-backend
+ subclasses (WebSourceInfo, WarcSourceInfo) add their own origin fields.
-class HTMLSource(Protocol):
- def fetch(self, url_filter: Optional[URLFilter] = None) -> Iterator[HTML]: ...
+ Attributes:
+ publisher (str): The publisher's name (its identity).
+ """
+ publisher: str
-class _Clock:
- def __init__(
- self, delay: Optional[Delay], sleep: Callable[[float], None] = time.sleep, warm_start: bool = True
- ) -> None:
- """Utility class for time-aligned delay.
-
- Keeps track of the time passed since last call or init and waits at most the remaining delay.
-
- Args:
- delay: A customized delay.
- sleep: A customized sleep function. Defaults to .
- warm_start: If true, skips first delay.
- """
- self.delay = delay
- self.timestamp = time.time()
-
- if warm_start and self.delay is not None:
- self.timestamp -= self.delay()
-
- self.sleep = sleep
-
- def __call__(self, blocking: bool = True) -> bool:
- """Waits at most seconds since last called.
-
- Args:
- blocking: If True, blocks until seconds have elapsed since last call.
- If non-blocking returns False if less time has elapsed, else returns True.
+ def serialize(self) -> Dict[str, JSONVal]:
+ """Serialize all dataclass fields to a JSON-compatible dict.
- Returns: True if seconds have elapsed since last call. False otherwise.
+ Subclasses inherit this unchanged and automatically pick up their extra fields,
+ since it reflects over the dataclass fields rather than naming them explicitly.
+ Returns:
+ Dict[str, JSONVal]: Field name to JSON-serializable value for every field.
"""
- if self.delay is None:
- return True
-
- if delay := max(0.0, self.delay() - time.time() + self.timestamp):
- if not blocking:
- return False
- self.sleep(delay)
- self.reset()
- return True
-
- def reset(self):
- self.timestamp = time.time()
-
-
-class WebSource:
- def __init__(
- self,
- url_source: Union[URLSource, Iterable[str]],
- publisher: Publisher,
- url_filter: Optional[URLFilter] = None,
- query_parameters: Optional[Dict[str, str]] = None,
- delay: Optional[Delay] = None,
- ignore_robots: bool = False,
- ignore_crawl_delay: bool = False,
- impersonate: bool = False,
- ):
- self.url_source = url_source
- self.publisher = publisher
- self.url_filter = url_filter
- self.query_parameters = query_parameters or {}
- self._impersonate_profile = publisher.impersonate if impersonate else None
-
- # parse robots:
- self.robots: Optional[Robots] = None
- if not ignore_robots:
- self.robots = self.publisher.robots
-
- if not ignore_crawl_delay:
- if robots_delay := self.robots.crawl_delay(self.publisher.request_header.get("user-agent", "*")):
- logger.debug(
- f"Found crawl-delay of {robots_delay} seconds in robots.txt for {self.publisher.name}. "
- f"Overwriting existing delay."
- )
-
- def delay() -> float:
- return robots_delay
-
- self.clock = _Clock(delay=delay, sleep=self._sleep)
-
- @property
- def _is_stopped(self):
- return __EVENTS__.is_event_set("stop")
-
- @staticmethod
- def _sleep(s: float):
- __EVENTS__.get("stop").wait(s)
+ return {f.name: serialize_value(getattr(self, f.name)) for f in fields(self)}
- def _fetch_html(self, url: str, url_filter: URLFilter) -> Optional[HTML]:
- # check if URL is malformed
- if not is_valid_url(url):
- logger.debug(f"Skipped requested URL {url!r} because the URL is malformed")
- return None
- # apply URL filter to requested URL
- if url_filter(url):
- logger.debug(f"Skipped requested URL {url!r} because of URL filter")
- return None
-
- # check robots
- if not (
- self.robots is None or self.robots.can_fetch(self.publisher.request_header.get("user-agent", "*"), url)
- ):
- logger.debug(f"Skipped requested URL {url!r} because of robots.txt")
- return None
-
- session = session_handler.get_session(self._impersonate_profile)
-
- # prepare query parameters
- for key, value in self.query_parameters.items():
- if "?" in url:
- url += "&" + key + "=" + value
- else:
- url += "?" + key + "=" + value
-
- # apply crawl-delay
- self.clock()
-
- # fetch html
- try:
- response = session.get_with_interrupt(url, headers=self.publisher.request_header)
-
- except (HTTPError, ConnectionError, ReadTimeout) as error:
- logger.warning(f"Skipped requested URL {url!r} because of {error!r}")
- if isinstance(error, HTTPError) and error.response.status_code >= 500:
- logger.warning(f"Skipped {self.publisher.name!r} due to server errors: {error!r}")
- return None
-
- # apply URL filter to responded URL
- if url_filter(str(response.url)):
- logger.debug(f"Skipped responded URL {str(response.url)!r} because of URL filter")
- return None
-
- html = response.text
-
- # check for redirects
- if response.history:
- logger.info(f"Got redirected {len(response.history)} time(s) from {url!r} -> {response.url!r}")
-
- # create WebSourceInfo
- source_info = (
- WebSourceInfo(self.publisher.name, type(self.url_source).__name__, self.url_source.url)
- if isinstance(self.url_source, URLSource)
- else SourceInfo(self.publisher.name)
- )
-
- # create HTML
- return HTML(
- requested_url=url,
- responded_url=str(response.url),
- content=html,
- crawl_date=datetime.now(),
- source_info=source_info,
- )
-
- def _build_url_filter(self, url_filter: Optional[URLFilter]) -> URLFilter:
- combined_filters: List[URLFilter] = ([self.url_filter] if self.url_filter else []) + (
- [url_filter] if url_filter else []
- )
-
- def combined_url_filter(url: str) -> bool:
- return any(f(url) for f in combined_filters)
+@dataclass(frozen=True)
+class HTML:
+ """A fetched HTML document together with its URLs, crawl time, and source provenance.
- return combined_url_filter
+ The unit of exchange between the Source and Pipeline layers: a Source yields HTML,
+ the Pipeline parses it into an Article. Frozen so it can be shared/pickled safely.
- def fetch(self, url_filter: Optional[URLFilter] = None) -> Iterator[HTML]:
- if isinstance(self.url_source, URLSource):
- url_iterator = self.url_source.fetch(
- session_handler.get_session(self._impersonate_profile),
- self.publisher.request_header,
- )
- else:
- url_iterator = iter(self.url_source)
+ Attributes:
+ requested_url (str): The URL that was requested.
+ responded_url (str): The final URL after redirects (equals requested_url when none).
+ content (str): The decoded HTML body.
+ crawl_date (datetime): When the document was fetched (or its WARC record date).
+ source_info (SourceInfo): Provenance metadata describing where the record came from.
+ """
- while not self._is_stopped:
- try:
- # check iterator
- if (url := next(url_iterator, None)) is None:
- return
- except Exception as error:
- logger.error(
- f"Warning! URLSource {self.url_source!r} crashed because of an unexpected error: {error!r}"
- )
- return
+ requested_url: str
+ responded_url: str
+ content: str
+ crawl_date: datetime
+ source_info: SourceInfo
- try:
- if html := self._fetch_html(url, self._build_url_filter(url_filter)):
- yield html
- except Exception as error:
- logger.error(f"Warning! Skipped requested URL {url!r} because of an unexpected error {error!r}")
- continue
+ def serialize(self) -> Dict[str, JSONVal]:
+ """Serialize the record to a JSON-compatible dict.
+ The crawl date is ISO-formatted and the source info is serialized via its own
+ serialize(); all other fields are emitted as-is.
-class CCNewsSource:
- def __init__(self, *publishers: Publisher, warc_path: str, headers: Optional[Dict[str, str]] = None):
- self.publishers = publishers
- self.warc_path = warc_path
- self.headers = headers or _default_header
- self._publisher_mapping: Dict[str, Publisher] = {
- urlparse(publisher.domain).netloc: publisher for publisher in self.publishers
+ Returns:
+ Dict[str, JSONVal]: The record's fields as JSON-serializable values.
+ """
+ return {
+ "requested_url": self.requested_url,
+ "responded_url": self.responded_url,
+ "content": self.content,
+ "crawl_date": self.crawl_date.isoformat(),
+ "source_info": self.source_info.serialize(),
}
-
- def fetch(self, url_filter: Optional[URLFilter] = None) -> Iterator[HTML]:
- def extract_content(record: WarcRecord) -> Optional[str]:
- warc_body: bytes = record.reader.read()
-
- try:
- return str(warc_body, encoding=record.http_charset) # type: ignore[arg-type]
- except (UnicodeDecodeError, TypeError):
- encoding: Optional[str] = chardet.detect(warc_body)["encoding"]
-
- if encoding is not None:
- logger.debug(
- f"Trying to decode record {record.record_id!r} from {target_url!r} "
- f"using detected encoding {encoding}."
- )
-
- try:
- return str(warc_body, encoding=encoding)
- except UnicodeDecodeError:
- logger.warning(
- f"Couldn't decode record {record.record_id!r} from {target_url!r} with "
- f"original charset {record.http_charset!r} using detected charset {encoding!r}."
- )
- else:
- logger.warning(
- f"Couldn't detect charset for record {record.record_id!r} from {target_url!r} "
- f"with invalid original charset {record.http_charset!r}."
- )
-
- return None
-
- with requests.Session() as session:
- response = session.get(self.warc_path, stream=True, headers=self.headers)
- response.raise_for_status()
-
- for warc_record in ArchiveIterator(response.raw, record_types=WarcRecordType.response, verify_digests=True):
- if not warc_record.record_date:
- continue
-
- target_url = str(warc_record.headers["WARC-Target-URI"])
-
- if url_filter is not None and url_filter(target_url):
- logger.debug(f"Skipped WARC record with target URI {target_url!r} because of URL filter")
- continue
-
- publisher_domain: str = urlparse(target_url).netloc
-
- if publisher_domain not in self._publisher_mapping:
- continue
-
- publisher = self._publisher_mapping[publisher_domain]
-
- if publisher.url_filter is not None and publisher.url_filter(target_url):
- logger.debug(
- f"Skipped WARC record with target URI {target_url!r} because of publisher specific URL filter"
- )
- continue
-
- if (content := extract_content(warc_record)) is None:
- continue
-
- yield HTML(
- requested_url=target_url,
- responded_url=target_url,
- content=content,
- crawl_date=warc_record.record_date,
- source_info=WarcSourceInfo(
- publisher=publisher.name,
- warc_path=self.warc_path,
- warc_headers=dict(warc_record.headers),
- http_headers=dict(warc_record.http_headers or {}),
- ),
- )
diff --git a/src/fundus/scraping/pipeline/__init__.py b/src/fundus/scraping/pipeline/__init__.py
new file mode 100644
index 000000000..e9f2d2a35
--- /dev/null
+++ b/src/fundus/scraping/pipeline/__init__.py
@@ -0,0 +1,142 @@
+from __future__ import annotations
+
+from typing import Collection, Dict, Iterator, List, Optional, Protocol
+
+from fundus.logging import create_logger
+from fundus.parser import ParserProxy
+from fundus.publishers.base_objects import Publisher
+from fundus.scraping.article import Article
+from fundus.scraping.filter import ExtractionFilter, FilterResultWithMissingAttributes, URLFilter
+from fundus.scraping.html import HTML, SourceInfo
+
+logger = create_logger(__name__)
+
+__all__ = [
+ "HTML",
+ "SourceInfo",
+ "HTMLSource",
+ "Pipeline",
+ "PipelineError",
+]
+
+
+class HTMLSource(Protocol):
+ """Protocol for HTML producers: yields HTML records, optionally gated by a URL filter.
+
+ Implemented by WebSource (live web) and CCNewsSource (CC-NEWS WARC archive).
+ """
+
+ def fetch(self, url_filter: Optional[URLFilter] = None) -> Iterator[HTML]:
+ """Stream HTML records from the underlying source.
+
+ Args:
+ url_filter (Optional[URLFilter]): Per-call URL filter; a truthy result skips the URL.
+ Combined with any source- or publisher-level filter by the implementor.
+
+ Yields:
+ HTML: One record per kept/fetched URL.
+
+ """
+ ...
+
+
+class PipelineError(Exception):
+ """Raised when an error occurs during a pipeline run."""
+
+ pass
+
+
+class Pipeline:
+ """Pairs an HTMLSource with publisher parsers, turning each fetched HTML into an Article.
+
+ Pulls HTML from the source, looks up the parser for the HTML's publisher by name, parses it,
+ and applies the extraction and language filters. HTML that fails parsing or any filter is dropped.
+ """
+
+ def __init__(self, source: HTMLSource, publishers: Collection[Publisher]) -> None:
+ """Build a pipeline over the given source and the parsers of the supplied publishers.
+
+ Args:
+ source (HTMLSource): The HTML producer to pull records from.
+ publishers (Collection[Publisher]): Publishers whose parsers may be needed to process
+ the source's HTML. Each HTML is re-associated with a parser by its publisher name.
+
+ """
+ self.source = source
+ # Identity -> parser. The parser is behavior and can't ride on the (picklable) HTML, so
+ # each HTML carries only its publisher's name and we re-associate the parser here.
+ self._parsers: Dict[str, ParserProxy] = {publisher.name: publisher.parser for publisher in publishers}
+
+ def _extract(
+ self,
+ html: HTML,
+ raise_on_error: bool,
+ extraction_filter: Optional[ExtractionFilter] = None,
+ language_filter: Optional[List[str]] = None,
+ ) -> Optional[Article]:
+ """Parse one HTML into an Article, or None if parsing fails or a filter drops it."""
+ if (parser := self._parsers.get(html.source_info.publisher)) is None:
+ raise PipelineError(
+ f"No parser for publisher {html.source_info.publisher!r}; "
+ f"pipeline was built for {sorted(self._parsers)}"
+ )
+
+ try:
+ extraction = parser(html.crawl_date).parse(html.content, raise_on_error)
+
+ except Exception as error:
+ if raise_on_error:
+ error_message = f"Run into an error processing article {html.requested_url!r}"
+ logger.error(error_message)
+ error.args = (str(error) + "\n\n" + error_message,)
+ raise
+ logger.info(f"Skipped article at {html.requested_url!r} because of: {error!r}")
+ return None
+
+ else:
+ if extraction_filter and (filter_result := extraction_filter(extraction)):
+ if isinstance(filter_result, FilterResultWithMissingAttributes):
+ logger.debug(
+ f"Skipped article at {html.requested_url!r} because attribute(s) "
+ f"{', '.join(filter_result.missing_attributes)!r} is(are) missing"
+ )
+ else:
+ logger.debug(f"Skipped article at {html.requested_url!r} because of extraction filter")
+ return None
+
+ article = Article(html=html, **extraction)
+ if language_filter and article.lang not in language_filter:
+ logger.debug(
+ f"Skipped article at {html.requested_url!r} because article language "
+ f"{article.lang!r} is not in allowed languages: {language_filter!r}"
+ )
+ return None
+
+ return article
+
+ def run(
+ self,
+ raise_on_error: bool,
+ extraction_filter: Optional[ExtractionFilter] = None,
+ url_filter: Optional[URLFilter] = None,
+ language_filter: Optional[List[str]] = None,
+ ) -> Iterator[Article]:
+ """Stream Articles by fetching HTML from the source and parsing each record.
+
+ Args:
+ raise_on_error (bool): If True, parser exceptions propagate; if False they are logged
+ and the offending article is skipped.
+ extraction_filter (Optional[ExtractionFilter]): Applied after extraction; a truthy
+ result drops the article.
+ url_filter (Optional[URLFilter]): Forwarded to the source's fetch() to skip URLs before
+ they are downloaded/parsed.
+ language_filter (Optional[List[str]]): If set, articles whose detected language is not
+ in this list are skipped.
+
+ Yields:
+ Article: One per HTML record that parses and passes all filters.
+
+ """
+ for html in self.source.fetch(url_filter=url_filter):
+ if article := self._extract(html, raise_on_error, extraction_filter, language_filter):
+ yield article
diff --git a/src/fundus/scraping/pipeline/source/__init__.py b/src/fundus/scraping/pipeline/source/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/src/fundus/scraping/pipeline/source/ccnews.py b/src/fundus/scraping/pipeline/source/ccnews.py
new file mode 100644
index 000000000..04b2f9ec8
--- /dev/null
+++ b/src/fundus/scraping/pipeline/source/ccnews.py
@@ -0,0 +1,171 @@
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import Dict, Iterator, Optional
+from urllib.parse import urlparse
+
+import chardet
+import requests
+import urllib3.exceptions
+from fastwarc import ArchiveIterator, WarcRecord, WarcRecordType
+from fastwarc.stream_io import StreamError
+
+from fundus.logging import create_logger
+from fundus.publishers.base_objects import Publisher
+from fundus.scraping.filter import URLFilter
+from fundus.scraping.html import HTML, SourceInfo
+from fundus.scraping.session import _default_header
+
+logger = create_logger(__name__)
+
+
+class WarcFileLoadError(Exception):
+ """Raised when a CC-NEWS WARC archive cannot be downloaded or its stream is corrupt or truncated."""
+
+
+@dataclass(frozen=True)
+class WarcSourceInfo(SourceInfo):
+ """Origin metadata attached to an HTML record extracted from a CC-NEWS WARC archive.
+
+ Attributes:
+ warc_path (str): HTTPS URL of the WARC archive the record came from.
+ warc_headers (Dict[str, str]): WARC envelope headers (e.g. WARC-Target-URI, WARC-Date).
+ http_headers (Dict[str, str]): HTTP response headers captured by the original crawl.
+ """
+
+ warc_path: str
+ warc_headers: Dict[str, str]
+ http_headers: Dict[str, str]
+
+
+class CCNewsSource:
+ """HTML source backed by a single CC-NEWS WARC archive on Common Crawl.
+
+ Streams the archive once, walks its response records, and yields HTML for those records whose
+ target URI matches one of the configured publishers' domains. Unlike WebSource, there is no
+ per-URL network request: the archive contains pages already crawled by Common Crawl, so this
+ source is effectively a selection-and-decode pipeline over a pre-fetched corpus.
+ """
+
+ def __init__(self, *publishers: Publisher, warc_path: str, headers: Optional[Dict[str, str]] = None) -> None:
+ """Initialize a source over a single CC-NEWS WARC archive.
+
+ Args:
+ *publishers (Publisher): Publishers whose articles should be extracted. WARC records
+ whose target URI does not belong to any of these publishers' domains are dropped
+ during iteration.
+ warc_path (str): HTTPS URL of the WARC archive to read (e.g. a CC-NEWS .warc.gz path).
+ headers (Optional[Dict[str, str]]): Request headers for the WARC download. Defaults to
+ the shared fundus user-agent header.
+
+ """
+ self.publishers = publishers
+ self.warc_path = warc_path
+ self.headers = headers or _default_header
+ self._publisher_mapping: Dict[str, Publisher] = {
+ urlparse(publisher.domain).netloc: publisher for publisher in self.publishers
+ }
+
+ @staticmethod
+ def _extract_content(record: WarcRecord, target_url: str) -> Optional[str]:
+ """Decode the WARC body using the declared charset, falling back to chardet detection."""
+ warc_body: bytes = record.reader.read()
+ try:
+ return str(warc_body, encoding=record.http_charset) # type: ignore[arg-type]
+ except (UnicodeDecodeError, TypeError):
+ encoding: Optional[str] = chardet.detect(warc_body)["encoding"]
+ if encoding is not None:
+ logger.debug(
+ f"Trying to decode record {record.record_id!r} from {target_url!r} "
+ f"using detected encoding {encoding}."
+ )
+ try:
+ return str(warc_body, encoding=encoding)
+ except UnicodeDecodeError:
+ logger.warning(
+ f"Couldn't decode record {record.record_id!r} from {target_url!r} with "
+ f"original charset {record.http_charset!r} using detected charset {encoding!r}."
+ )
+ else:
+ logger.warning(
+ f"Couldn't detect charset for record {record.record_id!r} from {target_url!r} "
+ f"with invalid original charset {record.http_charset!r}."
+ )
+ return None
+
+ def _validate(self, target_url: str, url_filter: Optional[URLFilter]) -> Optional[Publisher]:
+ """Return the matching publisher, or None if the URL should be skipped."""
+ if url_filter is not None and url_filter(target_url):
+ logger.debug(f"Skipped WARC record with target URI {target_url!r} because of URL filter")
+ return None
+ publisher = self._publisher_mapping.get(urlparse(target_url).netloc)
+ if publisher is None:
+ return None
+ if publisher.url_filter is not None and publisher.url_filter(target_url):
+ logger.debug(f"Skipped WARC record with target URI {target_url!r} because of publisher specific URL filter")
+ return None
+ return publisher
+
+ def _record_to_html(self, record: WarcRecord, url_filter: Optional[URLFilter]) -> Optional[HTML]:
+ """Validate, decode, and assemble a single WARC record. Returns None if skipped."""
+ record_date = record.record_date
+ if record_date is None:
+ return None
+ target_url = str(record.headers["WARC-Target-URI"])
+ if (publisher := self._validate(target_url, url_filter)) is None:
+ return None
+ if (content := self._extract_content(record, target_url)) is None:
+ return None
+ return HTML(
+ requested_url=target_url,
+ responded_url=target_url,
+ content=content,
+ crawl_date=record_date,
+ source_info=WarcSourceInfo(
+ publisher=publisher.name,
+ warc_path=self.warc_path,
+ warc_headers=dict(record.headers),
+ http_headers=dict(record.http_headers or {}),
+ ),
+ )
+
+ def _open_stream(self) -> requests.Response:
+ """Open a streaming GET against the WARC archive. Wraps transport errors in WarcFileLoadError."""
+ try:
+ session = requests.Session()
+ response = session.get(self.warc_path, stream=True, headers=self.headers)
+ response.raise_for_status()
+ return response
+ except (requests.HTTPError, urllib3.exceptions.HTTPError) as error:
+ raise WarcFileLoadError(f"{type(error).__name__}: {error}") from error
+
+ @staticmethod
+ def _iter_warc_records(response: requests.Response) -> Iterator[WarcRecord]:
+ """Iterate WARC response records from the open stream. Wraps StreamError in WarcFileLoadError."""
+ try:
+ yield from ArchiveIterator(response.raw, record_types=WarcRecordType.response, verify_digests=True)
+ except StreamError as error:
+ raise WarcFileLoadError(f"{type(error).__name__}: {error}") from error
+
+ def fetch(self, url_filter: Optional[URLFilter] = None) -> Iterator[HTML]:
+ """Stream HTML records from the configured WARC archive.
+
+ Walks every response record in the archive, keeps those whose target URI matches a
+ configured publisher and passes the URL filters, decodes the body, and yields the
+ resulting HTML record.
+
+ Args:
+ url_filter (Optional[URLFilter]): Per-call URL filter applied in addition to each
+ publisher's own url_filter. Truthy means skip the URL.
+
+ Yields:
+ HTML: One record per kept WARC entry.
+
+ Raises:
+ WarcFileLoadError: If the archive cannot be downloaded or the WARC stream is corrupt.
+
+ """
+ response = self._open_stream()
+ for record in self._iter_warc_records(response):
+ if (html := self._record_to_html(record, url_filter)) is not None:
+ yield html
diff --git a/src/fundus/scraping/pipeline/source/web.py b/src/fundus/scraping/pipeline/source/web.py
new file mode 100644
index 000000000..a3712e8a0
--- /dev/null
+++ b/src/fundus/scraping/pipeline/source/web.py
@@ -0,0 +1,298 @@
+from __future__ import annotations
+
+import threading
+import time
+from dataclasses import dataclass
+from datetime import datetime
+from typing import Callable, Dict, Iterable, Iterator, List, Optional, Union
+from urllib.parse import parse_qsl, urlencode, urlsplit, urlunsplit
+
+from curl_cffi.requests import Response
+from curl_cffi.requests.exceptions import ConnectionError, HTTPError, ReadTimeout
+
+from fundus.logging import create_logger
+from fundus.publishers.base_objects import Publisher, Robots
+from fundus.scraping.delay import Delay
+from fundus.scraping.filter import URLFilter
+from fundus.scraping.html import HTML, SourceInfo
+from fundus.scraping.session import session_handler
+from fundus.scraping.url import URLSource, is_valid_url
+
+logger = create_logger(__name__)
+
+
+@dataclass(frozen=True)
+class WebSourceInfo(SourceInfo):
+ """Origin metadata attached to an HTML record fetched via a URLSource.
+
+ Attributes:
+ type (str): Class name of the URLSource that produced the URL (e.g. "RSSFeed", "Sitemap").
+ url (str): The feed/sitemap URL the article was discovered from.
+ """
+
+ type: str
+ url: str
+
+
+class _Pacer:
+ """Per-source rate limiter. Sleeps as needed so consecutive calls are at least `delay` apart."""
+
+ def __init__(
+ self, delay: Optional[Delay], sleep: Callable[[float], object] = time.sleep, warm_start: bool = True
+ ) -> None:
+ """Build a pacer with the given delay and sleep function. warm_start=True skips sleep on first call."""
+ self.delay = delay
+ self.timestamp = time.time()
+ if warm_start and self.delay is not None:
+ self.timestamp -= self.delay()
+ self.sleep = sleep
+
+ def __call__(self) -> None:
+ """Sleep just long enough to enforce the configured delay since the last call, then reset."""
+ if self.delay is None:
+ return
+ if delay := max(0.0, self.delay() - time.time() + self.timestamp):
+ self.sleep(delay)
+ self.reset()
+
+ def reset(self) -> None:
+ """Mark the current time as the last call; the next call will wait the full delay from now."""
+ self.timestamp = time.time()
+
+
+class WebSource:
+ """HTML source backed by live HTTP requests over a URLSource (RSSFeed/Sitemap/NewsMap) or any iterable of URLs.
+
+ Iterates URLs one at a time, applies URL filters and robots.txt, rate-limits requests via an
+ internal pacer, and fetches each URL through an InterruptableSession. Yields one HTML record per
+ successful fetch. Honors a cooperative stop_event for early cancellation.
+ """
+
+ def __init__(
+ self,
+ url_source: Union[URLSource, Iterable[str]],
+ publisher: Publisher,
+ url_filter: Optional[URLFilter] = None,
+ query_parameters: Optional[Dict[str, str]] = None,
+ delay: Optional[Delay] = None,
+ ignore_robots: bool = False,
+ ignore_crawl_delay: bool = False,
+ impersonate: bool = False,
+ stop_event: Optional[threading.Event] = None,
+ ):
+ """Initialize a source that fetches HTML from URLs produced by a URLSource or any iterable.
+
+ Args:
+ url_source (Union[URLSource, Iterable[str]]): A URLSource (RSSFeed/Sitemap/NewsMap) or
+ any iterable of URL strings. URLSource instances are passed the publisher's session
+ and request headers when iterated.
+ publisher (Publisher): Publisher the URLs belong to. Provides request headers, robots,
+ impersonate profile, and the publisher-level URL filter.
+ url_filter (Optional[URLFilter]): Source-level URL filter, OR-combined with any
+ per-call filter passed to fetch(). Truthy means skip the URL.
+ query_parameters (Optional[Dict[str, str]]): Query parameters appended to every URL
+ before it is requested. Existing query strings are preserved.
+ delay (Optional[Delay]): Per-request delay (seconds). Overridden by robots.txt
+ crawl-delay unless ignore_crawl_delay=True.
+ ignore_robots (bool): If True, skip robots.txt checks (both can_fetch and crawl-delay).
+ ignore_crawl_delay (bool): If True, keep the supplied delay even when robots.txt
+ declares its own crawl-delay.
+ stop_event (Optional[threading.Event]): Cooperative-cancellation flag. When set, any
+ in-flight sleep is interrupted and the source stops iterating URLs.
+
+ """
+ self.url_source = url_source
+ self.publisher = publisher
+ self.url_filter = url_filter
+ self.query_parameters = query_parameters or {}
+ self._impersonate_profile = publisher.impersonate if impersonate else None
+ self.robots: Optional[Robots] = None if ignore_robots else self.publisher.robots
+ self.stop_event = stop_event
+ self._delay = delay
+ self._ignore_crawl_delay = ignore_crawl_delay
+ # Built lazily on the first request (see _build_pacer): resolving the crawl-delay may
+ # read robots.txt, and construction must stay free of I/O.
+ self.pacer: Optional[_Pacer] = None
+
+ # source_info depends only on url_source's type, which is fixed at construction time.
+ self.source_info: SourceInfo = (
+ WebSourceInfo(publisher.name, type(url_source).__name__, url_source.url)
+ if isinstance(url_source, URLSource)
+ else SourceInfo(publisher.name)
+ )
+
+ @staticmethod
+ def _resolve_delay(
+ robots: Optional[Robots],
+ user_agent: str,
+ supplied_delay: Optional[Delay],
+ ignore_crawl_delay: bool,
+ publisher_name: str = "",
+ ) -> Optional[Delay]:
+ """Return the effective per-request delay.
+
+ Robots' crawl_delay (if present) overrides supplied_delay; the override is skipped
+ when robots is None or ignore_crawl_delay is True or robots has no crawl_delay set.
+ """
+ if robots is None or ignore_crawl_delay:
+ return supplied_delay
+ robots_delay = robots.crawl_delay(user_agent)
+ if robots_delay is None:
+ return supplied_delay
+ logger.debug(
+ f"Found crawl-delay of {robots_delay} seconds in robots.txt for {publisher_name}. "
+ f"Overwriting existing delay."
+ )
+ return lambda: robots_delay
+
+ def _build_pacer(self) -> _Pacer:
+ """Resolve the effective delay (may read robots.txt) and build the rate limiter.
+
+ Deferred out of __init__ so construction performs no I/O; called on the first request.
+ """
+ user_agent = self.publisher.request_header.get("user-agent", "*")
+ resolved_delay = self._resolve_delay(
+ self.robots, user_agent, self._delay, self._ignore_crawl_delay, publisher_name=self.publisher.name
+ )
+ # stop_event.wait makes the per-request delay interruptable; time.sleep does not.
+ sleep: Callable[[float], object]
+ if self.stop_event is None:
+ sleep = time.sleep
+ else:
+ sleep = self.stop_event.wait
+ return _Pacer(delay=resolved_delay, sleep=sleep)
+
+ @staticmethod
+ def _apply_query_parameters(url: str, params: Dict[str, str]) -> str:
+ """Append query parameters to a URL, preserving existing ones and URL-encoding values."""
+ if not params:
+ return url
+ parts = urlsplit(url)
+ existing = parse_qsl(parts.query, keep_blank_values=True)
+ new_query = urlencode([*existing, *params.items()])
+ return urlunsplit(parts._replace(query=new_query))
+
+ @property
+ def _is_stopped(self) -> bool:
+ """True if a stop_event was supplied and has been set."""
+ return self.stop_event is not None and self.stop_event.is_set()
+
+ def _pre_validate(self, url: str, url_filter: URLFilter) -> bool:
+ """Return True if the URL is fit to request. Logs the reason and returns False otherwise."""
+ if not is_valid_url(url):
+ logger.debug(f"Skipped requested URL {url!r} because the URL is malformed")
+ return False
+ if url_filter(url):
+ logger.debug(f"Skipped requested URL {url!r} because of URL filter")
+ return False
+ user_agent = self.publisher.request_header.get("user-agent", "*")
+ if self.robots is not None and not self.robots.can_fetch(user_agent, url):
+ logger.debug(f"Skipped requested URL {url!r} because of robots.txt")
+ return False
+ return True
+
+ def _request(self, url: str) -> Optional[Response]:
+ """Sleep on the pacer, then GET the URL. Returns None on request error."""
+ session = session_handler.get_session(self._impersonate_profile)
+ pacer = self.pacer
+ if pacer is None:
+ pacer = self.pacer = self._build_pacer()
+ pacer()
+ try:
+ return session.get_with_interrupt(url, headers=self.publisher.request_header)
+ except (HTTPError, ConnectionError, ReadTimeout) as error:
+ logger.warning(f"Skipped requested URL {url!r} because of {error!r}")
+ return None
+
+ @staticmethod
+ def _post_validate(response: Response, url_filter: URLFilter) -> bool:
+ """Return True if the response should be kept. Logs the reason and returns False otherwise."""
+ if url_filter(str(response.url)):
+ logger.debug(f"Skipped responded URL {str(response.url)!r} because of URL filter")
+ return False
+ return True
+
+ def _build_html(self, requested_url: str, response: Response) -> HTML:
+ """Assemble the HTML record from a successful response."""
+ return HTML(
+ requested_url=requested_url,
+ responded_url=str(response.url),
+ content=response.text,
+ crawl_date=datetime.now(),
+ source_info=self.source_info,
+ )
+
+ def _fetch_one(self, url: str, url_filter: URLFilter) -> Optional[HTML]:
+ """Run the full per-URL pipeline: pre-validate, request, post-validate, build. None if skipped."""
+ if not self._pre_validate(url, url_filter):
+ return None
+ url = self._apply_query_parameters(url, self.query_parameters)
+ response = self._request(url)
+ if response is None:
+ return None
+ if not self._post_validate(response, url_filter):
+ return None
+ return self._build_html(url, response)
+
+ def _iter_urls(self) -> Iterator[str]:
+ """Yield URLs from the configured source, swallowing iterator crashes with a warning."""
+ if isinstance(self.url_source, URLSource):
+ source_iter: Iterator[str] = self.url_source.fetch(
+ session_handler.get_session(self._impersonate_profile),
+ self.publisher.request_header,
+ )
+ else:
+ source_iter = iter(self.url_source)
+ while True:
+ try:
+ url = next(source_iter, None)
+ except Exception as error:
+ logger.error(
+ f"Warning! URLSource {self.url_source!r} crashed because of an unexpected error: {error!r}"
+ )
+ return
+ if url is None:
+ return
+ yield url
+
+ def _build_url_filter(self, url_filter: Optional[URLFilter]) -> URLFilter:
+ """Combine source-level and per-call URL filters with logical OR. Returns a pass-through if both are None."""
+ combined: List[URLFilter] = ([self.url_filter] if self.url_filter else []) + (
+ [url_filter] if url_filter else []
+ )
+
+ def combined_url_filter(url: str) -> bool:
+ return any(f(url) for f in combined)
+
+ return combined_url_filter
+
+ def fetch(self, url_filter: Optional[URLFilter] = None) -> Iterator[HTML]:
+ """Stream HTML records by iterating url_source and fetching each URL.
+
+ Each URL is gated by the combined URL filter (source-level OR per-call), robots.txt, and
+ rate-limited by the configured delay. Per-URL request errors (HTTP / connection / timeout)
+ are logged and skipped; the iteration continues. If stop_event is set, iteration short-
+ circuits at the next boundary.
+
+ Args:
+ url_filter (Optional[URLFilter]): Per-call URL filter, OR-combined with the source's
+ own url_filter. Truthy means skip the URL.
+
+ Yields:
+ HTML: One record per successfully fetched URL.
+
+ """
+ combined_filter = self._build_url_filter(url_filter)
+ url_iterator = self._iter_urls()
+ # Check the stop event BEFORE advancing the iterator: pulling the next URL from a
+ # URLSource triggers its feed/sitemap download, so a stopped source must return without
+ # ever touching its iterator — otherwise every remaining source is fetched after stop.
+ while not self._is_stopped:
+ url = next(url_iterator, None)
+ if url is None:
+ return
+ try:
+ if html := self._fetch_one(url, combined_filter):
+ yield html
+ except Exception as error:
+ logger.error(f"Warning! Skipped requested URL {url!r} because of an unexpected error {error!r}")
diff --git a/src/fundus/scraping/scraper.py b/src/fundus/scraping/scraper.py
deleted file mode 100644
index b11a0d2b4..000000000
--- a/src/fundus/scraping/scraper.py
+++ /dev/null
@@ -1,114 +0,0 @@
-from typing import Dict, Iterator, List, Literal, Optional, Type
-
-import more_itertools
-
-from fundus.logging import create_logger
-from fundus.parser import ParserProxy
-from fundus.publishers.base_objects import Publisher
-from fundus.scraping.article import Article
-from fundus.scraping.delay import Delay
-from fundus.scraping.filter import (
- ExtractionFilter,
- FilterResultWithMissingAttributes,
- URLFilter,
-)
-from fundus.scraping.html import CCNewsSource, HTMLSource, WebSource
-from fundus.scraping.url import URLSource
-
-logger = create_logger(__name__)
-
-
-class BaseScraper:
- def __init__(self, *sources: HTMLSource, parser_mapping: Dict[str, ParserProxy]):
- self.sources = sources
- self.parser_mapping = parser_mapping
-
- def scrape(
- self,
- error_handling: Literal["suppress", "catch", "raise"],
- extraction_filter: Optional[ExtractionFilter] = None,
- url_filter: Optional[URLFilter] = None,
- language_filter: Optional[List[str]] = None,
- ) -> Iterator[Article]:
- for source in self.sources:
- for html in source.fetch(url_filter=url_filter):
- parser = self.parser_mapping[html.source_info.publisher]
-
- try:
- extraction = parser(html.crawl_date).parse(html.content, error_handling)
-
- except Exception as error:
- if error_handling == "raise":
- error_message = f"Run into an error processing article {html.requested_url!r}"
- logger.error(error_message)
- error.args = (str(error) + "\n\n" + error_message,)
- raise error
- elif error_handling == "catch":
- yield Article(html=html, exception=error)
- elif error_handling == "suppress":
- logger.info(f"Skipped article at {html.requested_url!r} because of: {error!r}")
- else:
- raise ValueError(f"Unknown value {error_handling!r} for parameter '")
-
- else:
- if extraction_filter and (filter_result := extraction_filter(extraction)):
- if isinstance(filter_result, FilterResultWithMissingAttributes):
- logger.debug(
- f"Skipped article at {html.requested_url!r} because attribute(s) "
- f"{', '.join(filter_result.missing_attributes)!r} is(are) missing"
- )
- else:
- logger.debug(f"Skipped article at {html.requested_url!r} because of extraction filter")
- else:
- article = Article(html=html, **extraction)
- if language_filter and article.lang not in language_filter:
- logger.debug(
- f"Skipped article at {html.requested_url!r} because article language: "
- f"{article.lang!r} is not in allowed languages: {language_filter!r}"
- )
- else:
- yield article
-
-
-class WebScraper(BaseScraper):
- def __init__(
- self,
- publisher: Publisher,
- restrict_sources_to: Optional[List[Type[URLSource]]] = None,
- delay: Optional[Delay] = None,
- ignore_robots: bool = False,
- ignore_crawl_delay: bool = False,
- impersonate: bool = False,
- ):
- if restrict_sources_to:
- url_sources = tuple(
- more_itertools.flatten(
- publisher.source_mapping[source_type]
- for source_type in restrict_sources_to
- if source_type in publisher.source_mapping
- )
- )
- else:
- url_sources = tuple(more_itertools.flatten(publisher.source_mapping.values()))
-
- html_sources = [
- WebSource(
- url_source=url_source,
- publisher=publisher,
- delay=delay,
- url_filter=publisher.url_filter,
- query_parameters=publisher.query_parameter,
- ignore_robots=ignore_robots,
- ignore_crawl_delay=ignore_crawl_delay,
- impersonate=impersonate,
- )
- for url_source in url_sources
- ]
- parser_mapping: Dict[str, ParserProxy] = {publisher.name: publisher.parser}
- super().__init__(*html_sources, parser_mapping=parser_mapping)
-
-
-class CCNewsScraper(BaseScraper):
- def __init__(self, source: CCNewsSource):
- parser_mapping: Dict[str, ParserProxy] = {publisher.name: publisher.parser for publisher in source.publishers}
- super().__init__(source, parser_mapping=parser_mapping)
diff --git a/src/fundus/scraping/session.py b/src/fundus/scraping/session.py
index 444ae7071..dbe4d96b1 100644
--- a/src/fundus/scraping/session.py
+++ b/src/fundus/scraping/session.py
@@ -1,8 +1,12 @@
from __future__ import annotations
+import random
import re
import threading
+import time
from contextlib import contextmanager
+from datetime import datetime, timezone
+from email.utils import parsedate_to_datetime
from queue import Empty, Queue
from typing import Any, Dict, Iterator, List, NamedTuple, Optional, Union
from urllib.parse import urljoin
@@ -23,12 +27,14 @@
class CrashThread(BaseException):
- """Is raised to end a thread without relying on the thread ending naturally"""
+ """Raised to terminate a thread without waiting for it to exit naturally."""
pass
class _RequestTask(NamedTuple):
+ """A unit of work handed to the session worker thread: the URL, request kwargs, and a queue for the result."""
+
url: str
kwargs: Any
result_queue: Queue[Union[curl_cffi.requests.Response, Exception]]
@@ -69,14 +75,34 @@ class InterruptableSession(curl_cffi.requests.Session[curl_cffi.requests.Respons
The daemon thread owns the curl handle for the lifetime of the session, enabling
connection reuse across requests. get_with_interrupt() submits work to the daemon
thread and polls for a stop event every second, raising CrashThread if interrupted.
+ 5xx responses are retried in place with interruptable exponential backoff (see
+ get_with_interrupt); an exhausted retry surfaces as a normal HTTPError.
"""
- def __init__(self, **kwargs: Any) -> None:
- # use_thread_local_curl=True gives the worker thread its own curl handle, separate
- # from the caller thread's handle closed in close(). Prevents close() from touching
- # a handle that may still be in use by the worker.
+ def __init__(
+ self,
+ *,
+ max_retries: int = 3,
+ retry_backoff_base: float = 1.0,
+ retry_backoff_cap: float = 30.0,
+ **kwargs: Any,
+ ) -> None:
+ """Start the persistent worker thread; forwards kwargs to curl_cffi.Session.
+
+ use_thread_local_curl is forced on so the worker thread gets its own curl handle,
+ separate from the caller thread's handle that close() tears down; otherwise close()
+ could touch a handle still in use by the worker.
+
+ Args:
+ max_retries (int): Number of additional attempts for 5xx responses (0 disables retrying).
+ retry_backoff_base (float): Base for the full-jitter exponential backoff between retries (seconds).
+ retry_backoff_cap (float): Upper bound on a single backoff wait, including Retry-After (seconds).
+ """
kwargs.pop("use_thread_local_curl", None)
super().__init__(use_thread_local_curl=True, **kwargs)
+ self.max_retries = max_retries
+ self.retry_backoff_base = retry_backoff_base
+ self.retry_backoff_cap = retry_backoff_cap
self._closed = False
self._task_queue: Queue[Optional[_RequestTask]] = Queue()
self._worker_thread = threading.Thread(target=self._worker_loop, name=f"session-worker-{id(self)}", daemon=True)
@@ -84,6 +110,7 @@ def __init__(self, **kwargs: Any) -> None:
@staticmethod
def _log_response(response: curl_cffi.requests.Response) -> None:
+ """Debug-log the request method, any redirect chain, the final status, and elapsed time."""
history: List[curl_cffi.requests.Response] = object.__getattribute__(response, "_history")
method = getattr(getattr(response, "request", None), "method", "GET")
if history:
@@ -117,6 +144,7 @@ def _follow_redirects(self, url: str, **kwargs: Any) -> curl_cffi.requests.Respo
raise TooManyRedirects(f"Exceeded {self.max_redirects} maximum redirects following {url!r}")
def _worker_loop(self) -> None:
+ """Pull tasks off the queue and run each request, returning the response or the raised error; exit on the None sentinel."""
while True:
task = self._task_queue.get()
if task is None:
@@ -126,17 +154,44 @@ def _worker_loop(self) -> None:
except Exception as error:
task.result_queue.put(error)
- def get_with_interrupt(self, url: str, **kwargs: Any) -> curl_cffi.requests.Response:
- """Interruptable GET request.
+ @staticmethod
+ def _parse_retry_after(value: str) -> Optional[float]:
+ """Parse a Retry-After header (delta-seconds or HTTP-date) into seconds from now, or None if unparseable."""
+ value = value.strip()
+ if value.isdigit():
+ return float(value)
+ try:
+ # Raises TypeError (py<3.10) or ValueError (py>=3.10) on unparseable input.
+ parsed = parsedate_to_datetime(value)
+ except (TypeError, ValueError):
+ return None
+ if parsed.tzinfo is None:
+ parsed = parsed.replace(tzinfo=timezone.utc)
+ return max(0.0, (parsed - datetime.now(timezone.utc)).total_seconds())
+
+ def _retry_backoff(self, response: curl_cffi.requests.Response, attempt: int) -> float:
+ """Seconds to wait before retrying: a valid Retry-After (capped) if present, else full-jitter exponential backoff."""
+ retry_after = response.headers.get("retry-after")
+ if retry_after is not None and (parsed := self._parse_retry_after(retry_after)) is not None:
+ return min(parsed, self.retry_backoff_cap)
+ window = min(self.retry_backoff_cap, self.retry_backoff_base * 2**attempt)
+ return random.uniform(0.0, window)
- Submits the request to the persistent daemon thread and polls every second
- for a stop event. Raises CrashThread if interrupted. When impersonating a
- browser, kwargs are dropped so curl_cffi can apply the full browser
- fingerprint unmodified.
+ @staticmethod
+ def _sleep_with_interrupt(seconds: float, url: str) -> None:
+ """Sleep up to `seconds`, waking every second to honor the stop event (raises CrashThread if set)."""
+ deadline = time.monotonic() + seconds
+ while (remaining := deadline - time.monotonic()) > 0:
+ if __EVENTS__.is_event_set("stop"):
+ logger.debug(f"Interrupt backoff before retrying {url!r}")
+ raise CrashThread(f"Backoff before retrying {url} was interrupted by stop event")
+ time.sleep(min(1.0, remaining))
+
+ def _submit_and_wait(self, url: str, request_kwargs: Dict[str, Any]) -> curl_cffi.requests.Response:
+ """Submit one request to the worker thread and block until a result, polling the stop event each second.
+
+ Raises any exception the worker raised, or CrashThread if the stop event fires while waiting.
"""
- if self._closed:
- raise RuntimeError("Session is closed")
- request_kwargs: Dict[str, Any] = {} if self.impersonate else kwargs
response_queue: Queue[Union[curl_cffi.requests.Response, Exception]] = Queue()
self._task_queue.put(_RequestTask(url, request_kwargs, response_queue))
@@ -150,10 +205,44 @@ def get_with_interrupt(self, url: str, **kwargs: Any) -> curl_cffi.requests.Resp
else:
if isinstance(response, Exception):
raise response
- self._log_response(response)
- response.raise_for_status()
return response
+ def get_with_interrupt(self, url: str, **kwargs: Any) -> curl_cffi.requests.Response:
+ """Interruptable GET request with in-place 5xx retry.
+
+ Submits the request to the persistent daemon thread and polls every second
+ for a stop event. Raises CrashThread if interrupted. When impersonating a
+ browser, kwargs are dropped so curl_cffi can apply the full browser
+ fingerprint unmodified.
+
+ A 5xx response is retried up to max_retries times with interruptable
+ exponential backoff (honoring Retry-After); once retries are exhausted the
+ status surfaces as a normal HTTPError via raise_for_status.
+ """
+ if self._closed:
+ raise RuntimeError("Session is closed")
+ request_kwargs: Dict[str, Any] = {} if self.impersonate else kwargs
+
+ # Hand-rolled rather than curl_cffi's retry=/RetryStrategy: that only retries transport
+ # exceptions (not 5xx), ignores Retry-After, and sleeps with a blocking time.sleep that the
+ # stop event can't interrupt.
+ for attempt in range(self.max_retries + 1):
+ response = self._submit_and_wait(url, request_kwargs)
+ self._log_response(response)
+ if response.status_code >= 500 and attempt < self.max_retries:
+ backoff = self._retry_backoff(response, attempt)
+ logger.debug(
+ f"Server error {response.status_code} for {url!r}; "
+ f"retry {attempt + 1}/{self.max_retries} in {backoff:.2f}s"
+ )
+ self._sleep_with_interrupt(backoff, url)
+ continue
+ response.raise_for_status()
+ return response
+
+ # Unreachable: the loop either returns or raises on its final iteration.
+ raise AssertionError("retry loop exited without returning")
+
def close(self) -> None:
"""Signal the worker thread to exit and close this thread's curl handle.
@@ -174,9 +263,15 @@ class SessionHandler:
session, the old session is closed and replaced.
"""
- DEFAULT_SESSION_KWARGS: Dict[str, Any] = {"timeout": 30}
+ DEFAULT_SESSION_KWARGS: Dict[str, Any] = {
+ "timeout": 30,
+ "max_retries": 3,
+ "retry_backoff_base": 1.0,
+ "retry_backoff_cap": 30.0,
+ }
def __init__(self) -> None:
+ """Initialize the per-thread session registry with the default session kwargs."""
self._session_kwargs: Dict[str, Any] = dict(self.DEFAULT_SESSION_KWARGS)
self._context_lock = threading.RLock()
self._sessions: Dict[int, InterruptableSession] = {}
diff --git a/src/fundus/scraping/url.py b/src/fundus/scraping/url.py
index f38fdb498..cf3b79f66 100644
--- a/src/fundus/scraping/url.py
+++ b/src/fundus/scraping/url.py
@@ -4,24 +4,25 @@
import lzma
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
-from functools import cached_property, partial
from typing import (
+ Any,
Callable,
ClassVar,
Dict,
Iterable,
Iterator,
List,
+ NamedTuple,
Optional,
Pattern,
Set,
+ Tuple,
)
-from urllib.parse import unquote, urlparse
+from urllib.parse import unquote, urljoin, urlparse
import feedparser
-import lxml.html
from curl_cffi.requests.exceptions import ConnectionError, HTTPError, ReadTimeout
-from lxml.etree import XMLParser, XPath
+from lxml.etree import XMLParser, XPath, fromstring
from fundus.logging import create_logger
from fundus.scraping.filter import URLFilter, inverse
@@ -30,89 +31,43 @@
logger = create_logger(__name__)
-class CompressionFormat:
- def __init__(
- self, name: str, decompression: Optional[Callable[[bytes], bytes]] = None, *, byte_mask: Optional[bytes] = None
- ) -> None:
- self.name = name
- self.decompression = decompression
- self.byte_mask = byte_mask
-
- def match(self, compressed_content: bytes) -> bool:
- if self.byte_mask:
- return compressed_content.startswith(self.byte_mask)
- return False
-
- def __call__(self, compressed_content: bytes) -> bytes:
- if self.decompression is None:
- raise NotImplementedError(f"Decompression not implemented for {self.name!r}")
- return self.decompression(compressed_content)
-
- def __repr__(self):
- if self.decompression is None:
- return f"{self.name} -- Not implemented"
- return self.name
-
-
-class CompressionFormats:
- GZIP = CompressionFormat("gzip", gzip.decompress, byte_mask=b"\x1f\x8b")
- BZ2 = CompressionFormat("bz2", bz2.decompress, byte_mask=b"\x42\x5a")
- ZIP = CompressionFormat("zip", byte_mask=b"PK\x03\x04")
- LZMA = CompressionFormat("lzma", lzma.decompress, byte_mask=b"\x28\xb5\x2f\xfd")
-
- @classmethod
- def iter_formats(cls) -> Iterator[CompressionFormat]:
- for obj in cls.__dict__.values():
- if isinstance(obj, CompressionFormat):
- yield obj
-
- @classmethod
- def identify(cls, compressed_content: bytes) -> Optional[CompressionFormat]:
- for compression_format in cls.iter_formats():
- if compression_format.match(compressed_content):
- return compression_format
- return None
-
-
-class _ArchiveDecompressor:
- def __init__(self):
- self.archive_mapping: Dict[str, Callable[[bytes], bytes]] = {
- "application/octet-stream": self._decompress_octet_stream,
- "application/x-gzip": CompressionFormats.GZIP,
- "gzip": CompressionFormats.GZIP,
- }
-
- def _decompress_octet_stream(self, compressed_content: bytes) -> bytes:
- if (compression_format := CompressionFormats.identify(compressed_content)) is None:
- logger.debug("Could not identify compression format")
- raise NotImplementedError
-
- return compression_format(compressed_content)
-
- def decompress(self, content: bytes, file_format: "str") -> bytes:
- decompress_function = self.archive_mapping[file_format]
- return decompress_function(content)
-
- @cached_property
- def supported_file_formats(self) -> List[str]:
- return list(self.archive_mapping.keys())
-
-
def is_valid_url(url: str) -> bool:
+ """True if the URL has an http/https scheme and a non-empty network location."""
parsed = urlparse(url)
return bool(parsed.scheme in ("http", "https") and parsed.netloc)
-def clean_url(url: str) -> str:
- return unquote(url)
+def strip_query_and_fragment(url: str) -> str:
+ """Return the URL with its query string and fragment removed.
+
+ Intended for *identity* use (dedup keys, equality probes), not for fetching:
+ the result may resolve to a different resource than the input on servers that
+ rely on query parameters for routing.
+ """
+ if any(indicator in url for indicator in ("?", "#")):
+ return urljoin(url, urlparse(url).path)
+ return url
@dataclass
class URLSource(Iterable[str], ABC):
+ """Abstract source of article URLs for a single feed/sitemap endpoint.
+
+ Concrete subclasses (RSSFeed, Sitemap, NewsMap) implement fetch() to stream URLs from
+ the endpoint at . Iterating the source directly (__iter__) uses a default session
+ and headers for standalone use; production scraping calls fetch() through WebSource with
+ publisher-specific session and headers.
+
+ Attributes:
+ url (str): The feed/sitemap URL to pull article URLs from.
+ languages (Set[str]): Language codes the source is known to serve, if any.
+ """
+
url: str
languages: Set[str] = field(default_factory=set)
def __post_init__(self):
+ """Warn (but don't fail) if the configured URL is malformed."""
if not is_valid_url(self.url):
logger.error(f"{type(self).__name__} initialized with invalid URL {self.url}")
@@ -154,6 +109,8 @@ def get_urls(self, max_urls: Optional[int] = None) -> Iterator[str]:
@dataclass
class RSSFeed(URLSource):
+ """URLSource that yields article links from an RSS/Atom feed."""
+
def fetch(self, session: InterruptableSession, headers: Dict[str, str]) -> Iterator[str]:
try:
response = session.get_with_interrupt(self.url, headers=headers)
@@ -173,79 +130,158 @@ def fetch(self, session: InterruptableSession, headers: Dict[str, str]) -> Itera
else:
urls = filter(bool, (entry.get("link") for entry in rss_feed["entries"]))
for url in urls:
- yield clean_url(url)
+ # Some publishers emit URLs with percent-encoded path separators
+ # (e.g. `https://example.com%2Farticle.html`); see PR #753.
+ yield unquote(url)
+
+
+class _Codec(NamedTuple):
+ """A supported compression format: its name, leading magic bytes, and decompress function."""
+
+ name: str
+ magic: bytes
+ decompress: Callable[[bytes], bytes]
+
+
+# Identified by magic-byte sniff rather than headers: the formats we support all carry
+# unambiguous signatures, and sniffing handles misadvertised or header-less payloads alike.
+_CODECS: Tuple[_Codec, ...] = (
+ _Codec("gzip", b"\x1f\x8b", gzip.decompress),
+ _Codec("bzip2", b"BZh", bz2.decompress),
+ _Codec("xz", b"\xfd7zXZ\x00", lzma.decompress),
+)
+
+
+def decompress(content: bytes) -> bytes:
+ """Decompress content if its leading bytes match a known codec, else return unchanged."""
+ for codec in _CODECS:
+ if content.startswith(codec.magic):
+ return codec.decompress(content)
+ return content
+
+
+def _default_sitemap_filter(url: str) -> bool:
+ """Default sitemap_filter: drop empty/falsy entries, keep everything else."""
+ return not bool(url)
@dataclass
class Sitemap(URLSource):
+ """URLSource that yields article links from an XML sitemap, descending into sitemap indexes.
+
+ Attributes:
+ recursive (bool): If True, follow nested references in a sitemap index.
+ Defaults to True.
+ reverse (bool): If True, yield URLs (and descend into sub-sitemaps) in reverse order.
+ Defaults to False.
+ sitemap_filter (URLFilter): Filter applied to sub-sitemap values; a truthy result
+ drops the entry. Defaults to dropping only empty values.
+ sort_predicate (Optional[Pattern[str]]): If set, sub-sitemap URLs are sorted (descending)
+ by the matched substring of this pattern; the pattern must match every URL.
+ """
+
recursive: bool = True
reverse: bool = False
- sitemap_filter: URLFilter = lambda url: not bool(url)
+ sitemap_filter: URLFilter = _default_sitemap_filter
sort_predicate: Optional[Pattern[str]] = None
- _decompressor: ClassVar[_ArchiveDecompressor] = _ArchiveDecompressor()
_sitemap_selector: ClassVar[XPath] = XPath("//*[local-name()='sitemap']/*[local-name()='loc']")
_url_selector: ClassVar[XPath] = XPath("//*[local-name()='url']/*[local-name()='loc']")
- _parser = XMLParser(strip_cdata=False, recover=True)
+
+ @staticmethod
+ def _fetch_bytes(
+ sitemap_url: str,
+ session: InterruptableSession,
+ headers: Dict[str, str],
+ ) -> Optional[bytes]:
+ """Fetch sitemap bytes, decompressing if needed. Returns None on any failure.
+
+ Handles HTTP errors, decompression failures, and empty bodies. Each failure
+ mode is logged at its point of occurrence; callers just check for None.
+ """
+ if not is_valid_url(sitemap_url):
+ logger.info(f"Skipped sitemap {sitemap_url!r} because the URL is malformed")
+ return None
+ try:
+ response = session.get_with_interrupt(url=sitemap_url, headers=headers)
+ except (HTTPError, ConnectionError, ReadTimeout) as error:
+ logger.warning(f"Warning! Couldn't reach sitemap {sitemap_url!r} because of {error!r}")
+ return None
+ except Exception as error:
+ logger.error(f"Warning! Couldn't reach sitemap {sitemap_url!r} because of an unexpected error {error!r}")
+ return None
+
+ content = response.content.strip()
+ try:
+ content = decompress(content)
+ except Exception as error:
+ logger.warning(f"Decompression failed for {sitemap_url!r}: {error!r}")
+ return None
+ if not content:
+ logger.warning(f"Warning! Empty sitemap at {sitemap_url!r}")
+ return None
+ return content
+
+ def _ordered_sub_locs(self, tree: Any) -> List[str]:
+ """Extract sub-sitemap values, sorted by sort_predicate and filtered."""
+ locs = [node.text for node in self._sitemap_selector(tree)]
+
+ if self.sort_predicate is not None:
+ pattern = self.sort_predicate
+
+ def key(text: str) -> str:
+ if match := pattern.search(text):
+ return match.group()
+ raise NotImplementedError(" must match in all sitemap URLs")
+
+ locs = sorted(locs, key=key, reverse=True)
+
+ return list(filter(inverse(self.sitemap_filter), locs))
+
+ def _yield_from_sitemap(
+ self,
+ sitemap_url: str,
+ session: InterruptableSession,
+ headers: Dict[str, str],
+ parser: XMLParser,
+ ) -> Iterator[str]:
+ # Download (and decompress) the sitemap bytes.
+ content = self._fetch_bytes(sitemap_url, session, headers)
+ if content is None:
+ return
+
+ # Parse the bytes into an XML tree.
+ tree = fromstring(content, parser=parser)
+ if tree is None:
+ logger.warning(f"Warning! Couldn't parse sitemap {sitemap_url!r}") # type: ignore[unreachable]
+ return
+
+ # Yield the article URLs contained in this sitemap, if any.
+ urls = [node.text for node in self._url_selector(tree)]
+ if urls:
+ for new_url in reversed(urls) if self.reverse else urls:
+ yield unquote(new_url)
+ return
+
+ # Otherwise descend into nested sitemap-index references.
+ if not self.recursive:
+ return
+ locs = self._ordered_sub_locs(tree)
+ for loc in reversed(locs) if self.reverse else locs:
+ yield from self._yield_from_sitemap(loc, session, headers, parser)
def fetch(self, session: InterruptableSession, headers: Dict[str, str]) -> Iterator[str]:
- def yield_recursive(sitemap_url: str) -> Iterator[str]:
- if not is_valid_url(sitemap_url):
- logger.info(f"Skipped sitemap {sitemap_url!r} because the URL is malformed")
- try:
- response = session.get_with_interrupt(url=sitemap_url, headers=headers)
-
- except (HTTPError, ConnectionError, ReadTimeout) as error:
- logger.warning(f"Warning! Couldn't reach sitemap {sitemap_url!r} because of {error!r}")
- return
- except Exception as error:
- logger.error(
- f"Warning! Couldn't reach sitemap {sitemap_url!r} because of an unexpected error {error!r}"
- )
- return
-
- content = response.content.strip()
- if (content_type := response.headers.get("content-type")) in self._decompressor.supported_file_formats:
- try:
- content = self._decompressor.decompress(content, content_type)
- except NotImplementedError:
- logger.warning(f"No matching decompression found for {sitemap_url!r}")
- return
- if not content:
- logger.warning(f"Warning! Empty sitemap at {sitemap_url!r}")
- return
- tree = lxml.etree.fromstring(content, parser=self._parser)
- if tree is None:
- # in case we somehow end up with non xml content
- logger.warning(f"Warning! Couldn't parse sitemap {sitemap_url!r}") # type: ignore[unreachable]
- return
- urls = [node.text for node in self._url_selector(tree)]
- if urls:
- for new_url in reversed(urls) if self.reverse else urls:
- yield clean_url(new_url)
- elif self.recursive:
- sitemap_locs = [node.text for node in self._sitemap_selector(tree)]
-
- if self.sort_predicate is not None:
-
- def _extract_predicate(text: str, pattern: Pattern[str]) -> str:
- if match := pattern.search(text):
- return match.group()
- raise NotImplementedError(" must match in all sitemap URLs")
-
- sitemap_locs = sorted(
- sitemap_locs,
- key=partial(_extract_predicate, pattern=self.sort_predicate),
- reverse=True,
- )
-
- filtered_locs = list(filter(inverse(self.sitemap_filter), sitemap_locs))
- for loc in reversed(filtered_locs) if self.reverse else filtered_locs:
- yield from yield_recursive(loc)
-
- yield from yield_recursive(self.url)
+ # lxml parsers serialize access across threads; construct one per fetch() so
+ # concurrent sitemap fetches don't contend. Each fetch() generator is consumed
+ # by a single thread, so the parser stays single-threaded for its lifetime.
+ parser = XMLParser(strip_cdata=False, recover=True)
+ yield from self._yield_from_sitemap(self.url, session, headers, parser)
@dataclass
class NewsMap(Sitemap):
- pass
+ """Marker subclass for Google-News-style sitemaps (recent articles only).
+
+ Parsing is identical to Sitemap; the distinct type lets the scraper prioritize
+ news sitemaps over full archive sitemaps via __SOURCE_ORDER__ in base_objects.py.
+ """
diff --git a/src/fundus/utils/concurrency.py b/src/fundus/utils/concurrency.py
new file mode 100644
index 000000000..30d46ab95
--- /dev/null
+++ b/src/fundus/utils/concurrency.py
@@ -0,0 +1,88 @@
+import contextlib
+import multiprocessing
+from functools import lru_cache
+from multiprocessing.managers import BaseManager
+from threading import current_thread
+from typing import Callable, Generic, Iterator, Optional, Tuple, TypeVar, cast
+
+import dill
+from tqdm import tqdm
+from typing_extensions import ParamSpec
+
+_T = TypeVar("_T")
+_P = ParamSpec("_P")
+
+
+def get_execution_context() -> Tuple[str, Optional[int]]:
+ """Return the name and identifier of the current execution context.
+
+ If running inside a non-main process, returns that process's name and PID; otherwise
+ returns the current thread's name and thread id.
+
+ Returns:
+ Tuple[str, Optional[int]]: The context's name and its integer identifier.
+ """
+ if multiprocessing.current_process().name != "MainProcess":
+ process = multiprocessing.current_process()
+ return process.name, process.ident
+ else:
+ thread = current_thread()
+ return thread.name, thread.ident
+
+
+class TQDMManager(BaseManager):
+ """multiprocessing manager exposing a shared tqdm proxy so worker processes drive one progress bar."""
+
+ def __init__(self, *args, **kwargs):
+ """Initialize the manager and register tqdm so it can be created behind a proxy."""
+ super().__init__(*args, **kwargs)
+ self.register("_tqdm", tqdm)
+
+ def tqdm(self, *args, **kwargs) -> tqdm:
+ """Create and return a manager-hosted (proxied) tqdm instance from the given tqdm args."""
+ return getattr(self, "_tqdm")(*args, **kwargs)
+
+
+@contextlib.contextmanager
+def get_proxy_tqdm(*args, **kwargs) -> Iterator[tqdm]:
+ """Yield a manager-backed tqdm proxy that can be shared across processes.
+
+ Init args are forwarded verbatim and are the same as for any other tqdm instance. The
+ backing manager is started on entry and shut down on exit.
+
+ Args:
+ *args: Positional tqdm arguments.
+ **kwargs: Keyword tqdm arguments.
+
+ Yields:
+ tqdm: A self-managed, proxied tqdm instance.
+ """
+ manager = TQDMManager()
+ try:
+ manager.start()
+ yield manager.tqdm(*args, **kwargs)
+ finally:
+ manager.shutdown()
+
+
+class dill_wrapper(Generic[_P, _T]):
+ """Callable wrapper that dill-serializes its target so it survives multiprocessing pickling."""
+
+ def __init__(self, target: Callable[_P, _T]):
+ """Wraps function in dill serialization.
+
+ This is in order to use unpickable functions within multiprocessing.
+
+ Args:
+ target: The function to wrap.
+ """
+ self._serialized_target: bytes = dill.dumps(target)
+
+ @lru_cache
+ def _deserialize(self) -> Callable[_P, _T]:
+ """Deserialize and cache the wrapped target on first use (once per process)."""
+ return cast(Callable[_P, _T], dill.loads(self._serialized_target))
+
+ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T:
+ """Deserialize the target (cached) and invoke it with the given arguments."""
+ return self._deserialize()(*args, **kwargs)
diff --git a/src/fundus/utils/events.py b/src/fundus/utils/events.py
index a9b40eb27..27d86df8f 100644
--- a/src/fundus/utils/events.py
+++ b/src/fundus/utils/events.py
@@ -7,12 +7,60 @@
from fundus.logging import create_logger
+# TODO (planned redesign): replace __EVENTS__ with explicit CancellationToken objects.
+#
+# Current state. __EVENTS__ is a global registry that maps a string alias (publisher
+# name, "main-thread") to a dict of named threading.Event objects, plus a bidict
+# linking aliases to thread ids so callers running inside a thread context can resolve
+# `key=None` to "their own" events. It mashes three concerns into one mechanism:
+# 1. cooperative cancellation (per-publisher stop signal)
+# 2. shutdown propagation (system-wide stop via set_for_all(future=True))
+# 3. post-mortem queryability (main thread asking "did publisher X already stop?"
+# after its worker exited, hence aliases-persist-after-thread-exit)
+#
+# Pain points: implicit thread-id resolution, string-keyed events (only "stop" exists
+# in practice), the `future=True` hack for shutdown, leaky test setup (every test that
+# touches WebSource/CCNewsSource needs __EVENTS__.context("test") aliasing), and an
+# unclear seam for multiprocessing (threading.Event does not cross process boundaries).
+#
+# Planned shape:
+#
+# class CancellationToken:
+# def __init__(self) -> None:
+# self._event = threading.Event()
+# self._children: list[CancellationToken] = []
+# def cancel(self) -> None: ...
+# def is_cancelled(self) -> bool: ...
+# def wait(self, timeout: float) -> bool: ...
+# def child(self) -> "CancellationToken": ... # cancelled when parent is
+#
+# Mapping current usage onto tokens:
+# - Source classes (WebSource, CCNewsSource): receive a CancellationToken via
+# constructor instead of reading from __EVENTS__.
+# - Crawler: holds a `dict[Publisher, CancellationToken]`. On per-publisher limit
+# reached, calls `tokens[publisher].cancel()`. Replaces __EVENTS__.set_event(
+# "stop", publisher_name) and __EVENTS__.is_event_set(...) at the same site.
+# - Shutdown: a root token; each publisher's token is `root.child()`. Cancelling
+# root cancels all children — replaces set_for_all(future=True) / clear_for_all.
+# - queueing.enqueue_results: takes the shutdown token, replaces the
+# __EVENTS__.is_event_set("stop", __MAIN_THREAD_ALIAS__) probe in _delivered.
+# - Session.get_with_interrupt: takes a CancellationToken, polls
+# token.is_cancelled() instead of __EVENTS__.
+#
+# What disappears: aliases, the thread-id bidict, main_context_lock, string event
+# keys, default_events, future=True / clear_for_all, the test-context fixture, and
+# every `from fundus.utils.events import __EVENTS__` in non-orchestrator code.
+
_T = TypeVar("_T")
logger = create_logger(__name__)
__DEFAULT_EVENTS__: List[str] = ["stop"]
+# Alias under which the main thread registers its context in __EVENTS__; crawlers set/probe
+# the "stop" event against it to drive cooperative shutdown.
+__MAIN_THREAD_ALIAS__ = "main-thread"
+
_sentinel = object()
diff --git a/src/fundus/utils/serialization.py b/src/fundus/utils/serialization.py
index 0b15da0a4..9d3286976 100644
--- a/src/fundus/utils/serialization.py
+++ b/src/fundus/utils/serialization.py
@@ -1,10 +1,13 @@
import inspect
import json
from dataclasses import asdict, fields, is_dataclass
+from datetime import datetime
from typing import (
Any,
Callable,
Dict,
+ Optional,
+ Protocol,
Sequence,
Type,
TypeVar,
@@ -12,6 +15,7 @@
get_args,
get_origin,
get_type_hints,
+ runtime_checkable,
)
from typing_extensions import TypeAlias
@@ -21,6 +25,43 @@
JSONVal: TypeAlias = Union[None, bool, str, float, int, Sequence["JSONVal"], Dict[str, "JSONVal"]]
+@runtime_checkable
+class Serializable(Protocol):
+ """Anything that knows how to convert itself into a JSON-compatible structure.
+
+ Implementing types opt into the export path used by Article.to_json.
+ """
+
+ def serialize(self) -> JSONVal: ...
+
+
+def serialize_value(value: Any, field_name: Optional[str] = None) -> JSONVal:
+ """Recursively convert a value to JSON-compatible form.
+
+ Args:
+ value: The value to serialize.
+ field_name: Optional originating field name, used only for error messages.
+
+ Returns:
+ A JSON-serializable structure.
+
+ Raises:
+ TypeError: If the value's type has no defined serialization.
+ """
+ if value is None or isinstance(value, (str, int, float, bool)):
+ return value
+ if isinstance(value, datetime):
+ return value.isoformat()
+ if isinstance(value, (list, tuple)):
+ return [serialize_value(item, field_name) for item in value]
+ if isinstance(value, dict):
+ return {str(k): serialize_value(v, field_name) for k, v in value.items()}
+ if isinstance(value, Serializable):
+ return value.serialize()
+ location = f"field {field_name!r}" if field_name else "value"
+ raise TypeError(f"Cannot serialize {location} of type {type(value).__name__}")
+
+
def is_jsonable(x):
try:
json.dumps(x)
diff --git a/src/fundus/utils/timeout.py b/src/fundus/utils/timeout.py
index 92b6e7d48..3e7d3121e 100644
--- a/src/fundus/utils/timeout.py
+++ b/src/fundus/utils/timeout.py
@@ -4,91 +4,67 @@
import time
from typing import Callable, Iterator, Optional
-from typing_extensions import ParamSpec
-P = ParamSpec("P")
-
-
-class Stopwatch:
- def __init__(self):
- self._start = time.time()
+def _interrupt_handler() -> None:
+ thread.interrupt_main()
- @property
- def time(self) -> float:
- return max(0.0, time.time() - self._start)
- def reset(self):
- self._start = time.time()
+class ResettableTimer:
+ class _Stopwatch:
+ def __init__(self) -> None:
+ self._start = time.time()
+ @property
+ def elapsed(self) -> float:
+ return max(0.0, time.time() - self._start)
-class ResettableTimer(threading.Thread):
- def __init__(
- self,
- seconds: float,
- func: Callable[P, None],
- interval: float = 0.1,
- args: P.args = tuple(),
- kwargs: P.kwargs = None,
- ) -> None:
- """Resettable timer executing after seconds, checking every .
+ def reset(self) -> None:
+ self._start = time.time()
- Args:
- seconds: Time to pass in seconds.
- func: Callable to execute when has passed.
- interval: Check every seconds if condition is met (reduce workload on CPU).
- *args: Arguments to .
- **kwargs: Keyword arguments to .
- """
- super().__init__(target=func, args=args, kwargs=kwargs)
+ def __init__(self, seconds: float, func: Callable[[], None], interval: float = 0.1) -> None:
self.seconds = seconds
self.interval = interval
- self.watch = Stopwatch()
+ self._func = func
+ self._watch = self._Stopwatch()
self._canceled = threading.Event()
+ self._thread = threading.Thread(target=self._run, daemon=True)
- def run(self) -> None:
- self.watch.reset()
- while True and self.watch.time < self.seconds:
+ def _run(self) -> None:
+ self._watch.reset()
+ while self._watch.elapsed < self.seconds:
time.sleep(self.interval)
if self._canceled.is_set():
return
- # noinspection PyUnresolvedReferences
- self._target(*self._args, **self._kwargs) # type: ignore[attr-defined]
+ self._func()
+
+ def start(self) -> None:
+ self._thread.start()
def reset(self) -> None:
- self.watch.reset()
+ self._watch.reset()
def cancel(self) -> None:
self._canceled.set()
-def _interrupt_handler() -> None:
- thread.interrupt_main()
-
-
# noinspection PyPep8Naming
@contextlib.contextmanager
def Timeout(
- seconds: float, silent: bool = False, callback: Optional[Callable[[], None]] = None, disable: bool = False
+ seconds: Optional[float], silent: bool = False, callback: Optional[Callable[[], None]] = None
) -> Iterator[ResettableTimer]:
"""Context manager applying a resettable timeout.
- Contextmanager implementation of timeout which does not relly on a function.
- If enter the context manager will time out after seconds.
- See docstring of 'timeout' for more information
-
Args:
- seconds: The time after which to timout in seconds. If set to <= 0, set timer never start.
- silent: If True, the KeyboardInterrupt will be silently ignored and None returned instead.
- Defaults to False.
+ seconds: The time after which to timeout in seconds. If None, the timeout is disabled.
+ silent: If True, the KeyboardInterrupt will be silently ignored. Defaults to False.
callback: If given, will be called instead of raising KeyboardInterrupt. Defaults to None.
- disable: If True, the timer will never start effectively disable the timeout.
Returns:
ResettableTimer: A timer to reset or cancel the timeout.
"""
- timer = ResettableTimer(seconds, callback or _interrupt_handler)
+ timer = ResettableTimer(seconds or 0, callback or _interrupt_handler)
try:
- if not disable:
+ if seconds is not None:
timer.start()
yield timer
except KeyboardInterrupt as err:
@@ -96,4 +72,3 @@ def Timeout(
raise TimeoutError from err
finally:
timer.cancel()
- del timer
diff --git a/src/fundus/utils/timing.py b/src/fundus/utils/timing.py
new file mode 100644
index 000000000..9fc6d99a9
--- /dev/null
+++ b/src/fundus/utils/timing.py
@@ -0,0 +1,20 @@
+import random
+import time
+from functools import wraps
+from typing import Callable, Tuple, TypeVar
+
+from typing_extensions import ParamSpec
+
+_T = TypeVar("_T")
+_P = ParamSpec("_P")
+
+
+def random_sleep(func: Callable[_P, _T], between: Tuple[float, float]) -> Callable[_P, _T]:
+ """Wrap func so each invocation first sleeps a random duration within the (low, high) interval (seconds)."""
+
+ @wraps(func)
+ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T:
+ time.sleep(random.uniform(*between))
+ return func(*args, **kwargs)
+
+ return wrapper
diff --git a/tests/README.md b/tests/README.md
new file mode 100644
index 000000000..2fb99cb8e
--- /dev/null
+++ b/tests/README.md
@@ -0,0 +1,114 @@
+# Test suite
+
+A quick orientation for contributors. The suite is plain `pytest`; everything below is
+convention, not magic.
+
+## Running
+
+```bash
+python -m pytest # full suite
+python -m pytest -m "not integration" # fast inner loop — skip the slow tests
+python -m pytest tests/scraping/pipeline/source/test_web.py # one module
+```
+
+Note `pyproject.toml` sets `filterwarnings = ["error"]`: any warning emitted during a test
+fails it.
+
+## Layout — mirror the source tree
+
+A test's **location** mirrors the package it covers. The test for
+`src/fundus//.py` lives at `tests//test_.py`, so "where's the
+test for `web.py`?" has exactly one answer.
+
+```
+tests/
+├── conftest.py # auto-loads fixture_*.py as plugins; autouse __EVENTS__ reset
+├── utility.py # shared helpers, imported as `tests.utility`
+├── exceptions.py # and `tests.exceptions`
+├── fixtures/ # builders, fakes, and @pytest.fixture wrappers (see below)
+├── resources/ # recorded HTML, parser test data, frozen snapshots
+├── parser/ # mirrors src/fundus/parser/
+├── publishers/ # mirrors src/fundus/publishers/
+├── scraping/ # mirrors src/fundus/scraping/ (pipeline/, crawler/, ...)
+└── utils/ # mirrors src/fundus/utils/
+```
+
+Every directory under `tests/` except `fixtures/` has an `__init__.py`, so same-named modules
+in different packages — `scraping/crawler/test_web.py` vs.
+`scraping/pipeline/source/test_web.py` — get distinct dotted names and don't collide.
+(`fixtures/` is reached through the `fixture_*` plugin glob, not as a package, so it needs
+none.)
+
+**Location is by subject, never by speed.** A file's speed class drifts as it grows; the
+module it tests does not. Cost/scope rides on **markers**, not directories (see below). We
+deliberately rejected a `unit/component/integration/` directory split — it would tear
+cohesive single-module files in half and add a permanent "which bucket?" tax for a benefit
+the marker already delivers.
+
+## Markers
+
+Only one marker today, defined in `pyproject.toml`:
+
+- **`integration`** — slow, multi-component tests with mocked I/O (may spawn threads or
+ processes), e.g. `scraping/crawler/test_integration.py`. Select with `-m integration` or
+ skip with `-m "not integration"`.
+
+## Test data helpers: builders vs. fixtures vs. fakes vs. doubles
+
+All shared test-data machinery lives in `tests/fixtures/builders.py`, where **the prefix tells
+you what you get back**:
+
+- **`make_*`** → a **real** domain object (`make_publisher`, `make_html`, `make_article`, ...).
+- **`stub_*`** → a hand-rolled **stub** standing in for a real type (`stub_publisher`).
+- **`mock_*`** → a **`MagicMock`** for a collaborator that isn't worth (or can't be) built real
+ (`mock_response`, `mock_robots`).
+
+Reach for them in this order:
+
+**Builders (`make_*`)** — the default way to construct a domain object. One keyword-only
+builder per type, each with sensible defaults. Builders nest, so a test about one layer
+needn't know how to assemble the layers beneath it (`make_article` → `make_html` →
+`make_source_info`). For a non-default value, compose at the call site so the object graph
+stays visible:
+
+```python
+make_article(html=make_html(requested_url="...", publisher="pub_a"))
+```
+
+Don't add caller-specific shortcut kwargs to the global builders; if one file repeats the
+same composition many times *and* it hurts readability, add a small local helper in that file.
+
+**Fixtures (`fixture_*.py`)** — thin `@pytest.fixture` wrappers, mostly default no-arg
+builder calls, for inject-by-name convenience (`publisher`, `parser_proxy_with_version`, the
+publisher-group fixtures, `patched_web_session_handler`, ...). Use a fixture when the
+constructed object is ceremony the test never inspects; construct inline (via a builder)
+when the test asserts on the *specific values you put in*.
+
+**Fakes (`fakes.py`)** — behavior-correct simplified subclasses of production classes
+(e.g. `FakeCrawler`) for tests that need a real-ish object, not a stand-in.
+
+**Doubles (`stub_*` / `mock_*`)** — the fallback when you can't (or needn't) use a real or
+fake object. Default to a hand-rolled **`stub_*`**; reach for **`mock_*`** (`MagicMock`) only
+when the stub can't do the job:
+
+- **`stub_*`** — a small hand-rolled class. Clearer about its interface, picklable, and honest
+ under `isinstance`. Use it for data-shaped collaborators a test just threads through
+ (`stub_publisher`: scraping tests that carry a publisher without exercising the real one).
+- **`mock_*`** — a `MagicMock`. Use it only when the test needs call-recording
+ (`mock.foo.assert_called_with(...)`), the real surface is wide and unpredictable
+ (`mock_response` → `curl_cffi.Response`), or the collaborator is behavioral rather than
+ data-shaped (`mock_robots` → `Robots`, with its `can_fetch` / `crawl_delay`).
+
+> **Plugin rule:** `conftest.py` auto-registers every `tests/fixtures/fixture_*.py` as a
+> pytest plugin, and plugin modules must contain **only** `@pytest.fixture` callables. Bare
+> helper functions and classes belong in non-plugin modules (`builders.py`, `fakes.py`).
+> Mixing the two trips pytest's assertion-rewrite ordering.
+
+## A couple of conventions worth knowing
+
+- **Assert behavior, not structure.** Avoid `isinstance`/type/shape assertions — a test that
+ still passes with the feature ripped out isn't testing the feature.
+- **`xfail` pins live bugs.** When the bug is fixed, remove the `xfail` rather than the test.
+- **`__EVENTS__` is process-global** cancellation state. `conftest.py` has an autouse fixture
+ (`_reset_events_registry`) that calls `__EVENTS__.reset()` after every test, so tests may
+ exercise the registry freely without leaking into the next one.
\ No newline at end of file
diff --git a/tests/conftest.py b/tests/conftest.py
index f831f6b5f..6ca90c03d 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,8 +1,9 @@
from pathlib import Path
-from typing import List
+from typing import Iterator, List
-# noinspection PyUnresolvedReferences
-import pytest # noqa: F401
+import pytest
+
+from fundus.utils.events import __EVENTS__
def path_to_plugin(path: Path) -> str:
@@ -15,3 +16,17 @@ def path_to_plugin(path: Path) -> str:
# Documentation on the `pytest_plugins` variable:
# https://docs.pytest.org/en/latest/reference/reference.html#globalvar-pytest_plugins
pytest_plugins: List[str] = [path_to_plugin(fixture) for fixture in Path("tests/fixtures").glob("fixture_*.py")]
+
+
+@pytest.fixture(autouse=True)
+def _reset_events_registry() -> Iterator[None]:
+ """Clear the process-global ``__EVENTS__`` registry after every test.
+
+ ``__EVENTS__`` holds alias→event mappings that persist across tests by design.
+ Without an explicit reset, a test that sets the ``"stop"``event or registers
+ an alias leaks that state into the next test. ``CrawlerBase.crawl``already
+ resets on exit via ``main_context``, but tests that touch ``WebSource`` /
+ ``CCNewsSource`` / the registry directly bypass that.
+ """
+ yield
+ __EVENTS__.reset()
diff --git a/tests/fixtures/builders.py b/tests/fixtures/builders.py
new file mode 100644
index 000000000..3d1c609f0
--- /dev/null
+++ b/tests/fixtures/builders.py
@@ -0,0 +1,171 @@
+"""Canonical test-data builders. See tests/README.md for conventions."""
+
+from __future__ import annotations
+
+from dataclasses import dataclass, field
+from datetime import datetime
+from typing import Any, Callable, Dict, List, Optional, Type, Union, cast
+from unittest.mock import MagicMock
+
+from curl_cffi.requests import BrowserTypeLiteral
+from curl_cffi.requests.exceptions import HTTPError
+
+from fundus.parser import BaseParser, ParserProxy
+from fundus.publishers.base_objects import Publisher, PublisherGroup, Robots
+from fundus.scraping.article import Article
+from fundus.scraping.filter import URLFilter
+from fundus.scraping.html import HTML, SourceInfo
+from fundus.scraping.url import NewsMap, URLSource
+
+_DEFAULT_REQUEST_HEADER: Dict[str, str] = {"user-agent": "test-agent"}
+
+
+class _DefaultParserProxy(ParserProxy):
+ # Module-level so a Publisher carrying this parser pickles by qualified name.
+ class Default(BaseParser):
+ pass
+
+
+def make_publisher(
+ *,
+ name: str = "test_pub",
+ domain: str = "https://test.com/",
+ sources: Optional[List[URLSource]] = None,
+ parser: Type[ParserProxy] = _DefaultParserProxy,
+ request_header: Optional[Dict[str, str]] = None,
+ url_filter: Optional[URLFilter] = None,
+ impersonate: Optional[BrowserTypeLiteral] = None,
+ suppress_robots: bool = False,
+ disallows_training: bool = False,
+) -> Publisher:
+ return Publisher(
+ name=name,
+ domain=domain,
+ sources=sources if sources is not None else [NewsMap("https://test.com/test_news_map")],
+ parser=parser,
+ request_header=request_header if request_header is not None else dict(_DEFAULT_REQUEST_HEADER),
+ url_filter=url_filter,
+ impersonate=impersonate,
+ suppress_robots=suppress_robots,
+ disallows_training=disallows_training,
+ )
+
+
+def make_publisher_group(
+ *,
+ name: str = "TestGroup",
+ default_language: Optional[str] = None,
+ **members: Union[Publisher, PublisherGroup],
+) -> PublisherGroup:
+ """Build a PublisherGroup from named Publisher/PublisherGroup members.
+
+ Inline equivalent of ``class (metaclass=PublisherGroup): ...``, so a test can construct
+ exactly the group it asserts against right next to the assertion instead of reaching for a
+ distant fixture. Member kwargs become the group's attributes (``eng=...`` -> ``group.eng``),
+ and ``default_language`` propagates to member sources that declare no languages of their own —
+ exactly as the metaclass does for real publisher groups.
+ """
+ namespace: Dict[str, object] = {}
+ if default_language is not None:
+ namespace["default_language"] = default_language
+ namespace.update(members)
+ return PublisherGroup(name, (), namespace)
+
+
+@dataclass
+class _PublisherStub:
+ """Picklable stand-in for Publisher exposing only the attributes consumers read.
+
+ Lives behind ``stub_publisher`` which casts it to ``Publisher`` for the type checker.
+ Tests should not reference this class directly.
+ """
+
+ name: str = "test_pub"
+ domain: str = "https://example.com/"
+ impersonate: Optional[str] = None
+ request_header: Dict[str, str] = field(default_factory=lambda: dict(_DEFAULT_REQUEST_HEADER))
+ robots: Optional[Any] = None
+ url_filter: Optional[Callable[[str], bool]] = None
+
+ def serialize(self) -> str:
+ return self.name
+
+
+def stub_publisher(
+ *,
+ name: str = "test_pub",
+ domain: str = "https://example.com/",
+ impersonate: Optional[str] = None,
+ request_header: Optional[Dict[str, str]] = None,
+ robots: Optional[Any] = None,
+ url_filter: Optional[Callable[[str], bool]] = None,
+) -> Publisher:
+ stub = _PublisherStub(
+ name=name,
+ domain=domain,
+ impersonate=impersonate,
+ request_header=request_header if request_header is not None else dict(_DEFAULT_REQUEST_HEADER),
+ robots=robots,
+ url_filter=url_filter,
+ )
+ return cast(Publisher, stub)
+
+
+def make_source_info(*, publisher: str = "test_pub") -> SourceInfo:
+ return SourceInfo(publisher=publisher)
+
+
+def make_html(
+ *,
+ requested_url: str = "https://example.com/article",
+ responded_url: Optional[str] = None,
+ content: str = " ",
+ crawl_date: Optional[datetime] = None,
+ publisher: str = "test_pub",
+) -> HTML:
+ return HTML(
+ requested_url=requested_url,
+ responded_url=responded_url if responded_url is not None else requested_url,
+ content=content,
+ crawl_date=crawl_date if crawl_date is not None else datetime(2024, 1, 1),
+ source_info=make_source_info(publisher=publisher),
+ )
+
+
+def make_article(*, html: Optional[HTML] = None, **extraction: Any) -> Article:
+ return Article(html=html if html is not None else make_html(), **extraction)
+
+
+def make_http_error(*, status_code: int) -> HTTPError:
+ """Real curl_cffi HTTPError carrying a MagicMock response with the given status_code."""
+ return HTTPError("boom", response=MagicMock(status_code=status_code))
+
+
+# --- test doubles ---
+# Everything below fabricates a stand-in, not a real domain object. The prefix says which:
+# ``mock_*`` returns a MagicMock; ``stub_publisher`` (above) returns a hand-rolled stub.
+
+
+def mock_response(
+ *,
+ text: str = " ",
+ url: str = "https://example.com/article",
+ history: Optional[List[Any]] = None,
+) -> MagicMock:
+ """MagicMock for curl_cffi Response — wide surface, only a few fields tests touch."""
+ response = MagicMock()
+ response.text = text
+ response.url = url
+ response.history = history if history is not None else []
+ return response
+
+
+def mock_robots(*, can_fetch: bool = True, crawl_delay: Optional[float] = None) -> MagicMock:
+ """MagicMock for Robots — a behavioral collaborator (can_fetch / crawl_delay).
+
+ Defaults are permissive: fetching allowed, no crawl-delay.
+ """
+ robots = MagicMock(spec=Robots)
+ robots.can_fetch.return_value = can_fetch
+ robots.crawl_delay.return_value = crawl_delay
+ return robots
diff --git a/tests/fixtures/fakes.py b/tests/fixtures/fakes.py
new file mode 100644
index 000000000..6fca5f27a
--- /dev/null
+++ b/tests/fixtures/fakes.py
@@ -0,0 +1,36 @@
+"""Simplified real-subclass implementations of internal interfaces.
+
+A *fake* is a working, behavior-correct simplified implementation — distinct from a stub
+(dumb data holder) and a mock (call-recorder). Use a fake when the real method dispatch
+matters but the production implementation is expensive or pulls in external dependencies.
+
+This module deliberately does NOT match the ``fixture_*.py`` glob picked up by conftest.
+"""
+
+from __future__ import annotations
+
+from typing import Iterator, List, Optional, Sequence, Tuple
+
+from fundus.publishers.base_objects import Publisher
+from fundus.scraping.article import Article
+from fundus.scraping.crawler.base import CrawlerBase, PublisherType
+from fundus.scraping.filter import ExtractionFilter, URLFilter
+
+
+class FakeCrawler(CrawlerBase):
+ """CrawlerBase subclass that yields a fixed sequence of pre-built articles."""
+
+ def __init__(self, *publishers: PublisherType, articles: Sequence[Article] = ()) -> None:
+ super().__init__(*publishers)
+ self._articles = list(articles)
+
+ def _build_article_iterator(
+ self,
+ publishers: Tuple[Publisher, ...],
+ raise_on_error: bool,
+ extraction_filter: Optional[ExtractionFilter],
+ url_filter: Optional[URLFilter],
+ language_filter: Optional[List[str]],
+ skip_publishers_disallowing_training: bool = False,
+ ) -> Iterator[Article]:
+ yield from self._articles
diff --git a/tests/fixtures/fixture_collection.py b/tests/fixtures/fixture_collection.py
index 12c4d9c51..fe14925fe 100644
--- a/tests/fixtures/fixture_collection.py
+++ b/tests/fixtures/fixture_collection.py
@@ -1,7 +1,8 @@
import pytest
from fundus import NewsMap, RSSFeed, Sitemap
-from fundus.publishers.base_objects import Publisher, PublisherGroup
+from fundus.publishers.base_objects import PublisherGroup
+from tests.fixtures.builders import make_publisher, make_publisher_group
@pytest.fixture
@@ -23,9 +24,7 @@ class GroupWithEmptyPublisherSubgroup(metaclass=PublisherGroup):
@pytest.fixture
def publisher_group_with_news_map(parser_proxy_with_version):
class PubGroup(metaclass=PublisherGroup):
- value = Publisher(
- name="test_pub",
- domain="https://test.com/",
+ value = make_publisher(
sources=[NewsMap("https://test.com/test_news_map")],
parser=parser_proxy_with_version,
)
@@ -36,9 +35,7 @@ class PubGroup(metaclass=PublisherGroup):
@pytest.fixture
def publisher_group_with_rss_feeds(parser_proxy_with_version):
class PubGroup(metaclass=PublisherGroup):
- value = Publisher(
- name="test_pub",
- domain="https://test.com/",
+ value = make_publisher(
sources=[RSSFeed("https://test.com/test_feed")],
parser=parser_proxy_with_version,
)
@@ -49,9 +46,7 @@ class PubGroup(metaclass=PublisherGroup):
@pytest.fixture
def publisher_group_with_sitemaps(parser_proxy_with_version):
class PubGroup(metaclass=PublisherGroup):
- value = Publisher(
- name="test_pub",
- domain="https://test.com/",
+ value = make_publisher(
sources=[Sitemap("https://test.com/test_sitemap")],
parser=parser_proxy_with_version,
)
@@ -70,17 +65,13 @@ class CollectionWithValidatePublisherEnum(metaclass=PublisherGroup):
@pytest.fixture
def group_with_two_valid_publisher_subgroups(parser_proxy_with_version):
class PubGroupNews(metaclass=PublisherGroup):
- news = Publisher(
- name="test_pub",
- domain="https://test.com/",
+ news = make_publisher(
sources=[NewsMap("https://test.com/test_newsmap")],
parser=parser_proxy_with_version,
)
class PubGroupSitemap(metaclass=PublisherGroup):
- sitemap = Publisher(
- name="test_pub",
- domain="https://test.com/",
+ sitemap = make_publisher(
sources=[Sitemap("https://test.com/test_sitemap")],
parser=parser_proxy_with_version,
)
@@ -97,23 +88,20 @@ def publisher_group_with_languages(parser_proxy_with_version):
class LangPubGroup(metaclass=PublisherGroup):
default_language = "en"
- eng = Publisher(
+ eng = make_publisher(
name="test_pub_eng",
- domain="https://test.com/",
sources=[NewsMap("https://test.com/test_sitemap")],
parser=parser_proxy_with_version,
)
- ger = Publisher(
+ ger = make_publisher(
name="test_pub_ger",
- domain="https://test.com/",
sources=[Sitemap("https://test.com/test_sitemap", languages={"de"})],
parser=parser_proxy_with_version,
)
- mixed = Publisher(
+ mixed = make_publisher(
name="test_pub_mixed",
- domain="https://test.com/",
sources=[
RSSFeed("https://test.com/test_feed", languages={"es", "pl"}),
NewsMap("https://test.com/test_newsmap", languages={"es"}),
@@ -123,3 +111,21 @@ class LangPubGroup(metaclass=PublisherGroup):
)
return LangPubGroup
+
+
+@pytest.fixture
+def publisher_group_with_versioned_attrs(proxy_with_two_versions_and_different_attrs):
+ return make_publisher_group(
+ value=make_publisher(
+ parser=proxy_with_two_versions_and_different_attrs, sources=[NewsMap("https://test.com/test_news_map")]
+ )
+ )
+
+
+@pytest.fixture
+def publisher_group_with_deprecated_attrs(proxy_with_two_deprecated_attributes):
+ return make_publisher_group(
+ value=make_publisher(
+ parser=proxy_with_two_deprecated_attributes, sources=[NewsMap("https://test.com/test_news_map")]
+ )
+ )
diff --git a/tests/fixtures/fixture_server.py b/tests/fixtures/fixture_server.py
new file mode 100644
index 000000000..3aa996e71
--- /dev/null
+++ b/tests/fixtures/fixture_server.py
@@ -0,0 +1,46 @@
+"""Pytest fixtures providing real loopback servers for timeout integration tests.
+
+These let a test drive a genuine curl_cffi timeout through the stack instead of feeding
+a hand-picked exception class to a mocked session — the only way to confirm the code
+catches the exception curl_cffi actually raises.
+"""
+
+import socket
+import threading
+from typing import Iterator, List
+
+import pytest
+
+
+@pytest.fixture
+def hanging_url() -> Iterator[str]:
+ """Yield a URL whose server accepts the TCP connection but never sends a response.
+
+ A request to it connects fine but then times out reading, so the caller sees the
+ real curl_cffi timeout exception. Bound on an ephemeral loopback port; torn down
+ after the test.
+ """
+ server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ server.bind(("127.0.0.1", 0))
+ server.listen(1)
+ accepted: List[socket.socket] = [] # hold connections open so the peer isn't reset
+
+ def accept_and_hang() -> None:
+ while True:
+ try:
+ connection, _ = server.accept()
+ except OSError:
+ return
+ accepted.append(connection)
+
+ worker = threading.Thread(target=accept_and_hang, daemon=True)
+ worker.start()
+ host, port = server.getsockname()
+ try:
+ yield f"http://{host}:{port}/"
+ finally:
+ server.close()
+ for connection in accepted:
+ connection.close()
+ worker.join(timeout=1)
diff --git a/tests/fixtures/fixture_source.py b/tests/fixtures/fixture_source.py
new file mode 100644
index 000000000..1f7273989
--- /dev/null
+++ b/tests/fixtures/fixture_source.py
@@ -0,0 +1,28 @@
+"""Pytest fixtures for tests of pipeline source classes (WebSource, CCNewsSource, ...).
+
+Parameterized builders live in ``tests.fixtures.builders`` — this module only wraps the
+default no-arg builder calls in ``@pytest.fixture`` decorators for the common-case
+injection-by-name pattern.
+"""
+
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from fundus.publishers.base_objects import Publisher
+from tests.fixtures.builders import mock_response, stub_publisher
+
+
+@pytest.fixture
+def publisher() -> Publisher:
+ return stub_publisher()
+
+
+@pytest.fixture
+def patched_web_session_handler():
+ """Patch the session_handler used by WebSource; yield the session mock."""
+ with patch("fundus.scraping.pipeline.source.web.session_handler") as sh:
+ session = MagicMock()
+ session.get_with_interrupt.return_value = mock_response()
+ sh.get_session.return_value = session
+ yield session
diff --git a/tests/parser/__init__.py b/tests/parser/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/test_data.py b/tests/parser/test_data.py
similarity index 100%
rename from tests/test_data.py
rename to tests/parser/test_data.py
diff --git a/tests/test_parser.py b/tests/parser/test_parser.py
similarity index 61%
rename from tests/test_parser.py
rename to tests/parser/test_parser.py
index 174aec728..b60473e75 100644
--- a/tests/test_parser.py
+++ b/tests/parser/test_parser.py
@@ -1,12 +1,9 @@
import datetime
-import pickle
from typing import Any, Dict, List, Optional, Tuple, Union
-import lxml.html
import pytest
from fundus.parser.base_parser import (
- Attribute,
AttributeCollection,
BaseParser,
ParserProxy,
@@ -14,25 +11,6 @@
attribute,
)
from fundus.parser.utility import generic_author_parsing
-from fundus.publishers import PublisherCollection
-from fundus.publishers.base_objects import Publisher
-from tests.resources import attribute_annotations_mapping
-from tests.utility import (
- get_meta_info_file,
- load_html_test_file_mapping,
- load_supported_publishers_markdown,
- load_test_case_data,
-)
-
-
-def test_supported_publishers_table():
- root = lxml.html.fromstring(load_supported_publishers_markdown())
- parsed_names: List[str] = root.xpath("//table[contains(@class,'publishers')]//td[1]/code/text()")
- for publisher in PublisherCollection:
- assert publisher.__name__ in parsed_names, (
- f"Publisher {publisher.name} is not included in docs/supported_news.md. "
- f"Run 'python -m scripts.generate_tables'"
- )
class TestBaseParser:
@@ -196,84 +174,6 @@ def get_initialized_attrs(parser: BaseParser) -> List[RegisteredFunction]:
assert parser3 != parser2 != parser1
-# enforce test coverage for test parsing
-# because this is also used for the generate_parser_test_files script we export it here
-attributes_required_to_cover = {"title", "authors", "topics", "publishing_date", "body", "images"}
-
-attributes_parsers_are_required_to_cover = {"body"}
-
-
-@pytest.mark.parametrize(
- "publisher", list(PublisherCollection), ids=[publisher.__name__ for publisher in PublisherCollection]
-)
-class TestParser:
- def test_annotations(self, publisher: Publisher) -> None:
- parser_proxy = publisher.parser
- for versioned_parser in parser_proxy:
- assert attributes_parsers_are_required_to_cover.issubset(
- set(versioned_parser.attributes().validated.names)
- ), f"{versioned_parser.__name__!r} should implement at least {attributes_parsers_are_required_to_cover!r}"
- for attr in versioned_parser.attributes().validated:
- if annotation := attribute_annotations_mapping[attr.__name__]:
- assert attr.__annotations__.get("return") == annotation, (
- f"Attribute {attr.__name__!r} for {versioned_parser.__name__!r} is of wrong type. "
- f"{attr.__annotations__.get('return')} != {annotation}"
- )
- else:
- raise KeyError(f"Unsupported attribute {attr.__name__!r}")
-
- def test_parsing(self, publisher: Publisher) -> None:
- comparative_data = load_test_case_data(publisher)
- html_mapping = load_html_test_file_mapping(publisher)
-
- for versioned_parser in publisher.parser:
- # validate json
- version_name = versioned_parser.__name__
- assert (version_data := comparative_data.get(version_name)), (
- f"Missing test data for parser version {version_name!r}"
- )
-
- # validate test HTML
- assert (html := html_mapping.get(versioned_parser)), (
- f"Missing test HTML for parser version {version_name} of publisher {publisher.name}"
- )
-
- # re-instantiate parser to address deprecated attributes
- timestamp_instantiated_parser = publisher.parser(html.crawl_date)
-
- for key, value in version_data.items():
- if not value:
- raise ValueError(
- f"There is no value set for key {key!r} in the test JSON. "
- f"Only complete articles should be used as test cases"
- )
-
- # test coverage
- supported_attrs = set(timestamp_instantiated_parser.registered_attributes.names)
- missing_attrs = attributes_required_to_cover & supported_attrs - set(version_data.keys())
- assert not missing_attrs, (
- f"Test JSON for {version_name} of publisher {publisher.name} does not cover the following attribute(s): {missing_attrs}"
- )
-
- assert list(version_data.keys()) == sorted(attributes_required_to_cover & supported_attrs), (
- f"Test JSON for {version_name} is not in alphabetical order"
- )
-
- # compare data
- extraction = timestamp_instantiated_parser.parse(html.content, "raise")
- for key, value in version_data.items():
- assert value == extraction[key], f"{key!r} is not equal"
-
- # check if extraction is pickable
- pickle.dumps(extraction)
-
- def test_reserved_attribute_names(self, publisher: Publisher):
- parser = publisher.parser
- for attr in attribute_annotations_mapping.keys():
- if value := getattr(parser, attr, None):
- assert isinstance(value, Attribute), f"The name {attr!r} is reserved for attributes only."
-
-
class TestUtility:
def test_generic_author_parsing(self):
# type None
@@ -308,14 +208,3 @@ def test_generic_author_parsing(self):
[{"name": "Peter Funny"}, {"name": "Funny Peter"}, {"this": "is not a pipe"}, {}] # type: ignore
) == ["Peter Funny", "Funny Peter"]
assert generic_author_parsing([{}]) == generic_author_parsing([{}, {"wrong": "key"}]) == [] # type: ignore
-
-
-class TestMetaInfo:
- def test_order(self):
- for cc in PublisherCollection.get_subgroup_mapping().values():
- meta_file = get_meta_info_file(cc)
- meta_info = meta_file.load()
- assert meta_info, f"Meta info file {meta_file.path} is missing"
- assert sorted(meta_info.keys()) == list(meta_info.keys()), (
- f"Meta info file {meta_file.path} isn't ordered properly."
- )
diff --git a/tests/publishers/__init__.py b/tests/publishers/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/publishers/conftest.py b/tests/publishers/conftest.py
new file mode 100644
index 000000000..5dbff297c
--- /dev/null
+++ b/tests/publishers/conftest.py
@@ -0,0 +1,32 @@
+from typing import cast
+
+import pytest
+
+from fundus import PublisherCollection
+from fundus.publishers import Publisher, PublisherGroup
+
+
+@pytest.fixture(params=list(PublisherCollection), ids=lambda publisher: publisher.__name__)
+def publisher(request) -> Publisher:
+ """Fan a test out over every publisher in the live ``PublisherCollection``.
+
+ Any test under ``tests/publishers/`` that declares a ``publisher`` argument is
+ parametrized across the whole collection, with the publisher's ``__name__`` as the
+ test id. Scoped to this directory (rather than a global ``fixture_*`` plugin) so it
+ overrides the stub ``publisher`` fixture from ``tests/fixtures/fixture_source.py``
+ only here, where every test wants the real collection.
+ """
+ return cast(Publisher, request.param)
+
+
+@pytest.fixture(
+ params=list(PublisherCollection.get_subgroup_mapping().values()),
+ ids=lambda region: region.__name__,
+)
+def region(request) -> PublisherGroup:
+ """Fan a test out over every region (country subgroup) in the live ``PublisherCollection``.
+
+ Mirror of the ``publisher`` fixture for tests that operate per subgroup rather than
+ per publisher, with the region's ``__name__`` as the test id.
+ """
+ return cast(PublisherGroup, request.param)
diff --git a/tests/publishers/test_base_objects.py b/tests/publishers/test_base_objects.py
new file mode 100644
index 000000000..c5cc5fef6
--- /dev/null
+++ b/tests/publishers/test_base_objects.py
@@ -0,0 +1,372 @@
+from unittest.mock import MagicMock, patch
+
+import pytest
+from curl_cffi.requests.exceptions import ConnectionError, ReadTimeout
+
+from fundus import NewsMap, RSSFeed, Sitemap
+from fundus.publishers.base_objects import CustomRobotFileParser, FilteredPublisher, Robots
+from fundus.scraping.session import session_handler
+from tests.fixtures.builders import make_http_error, make_publisher, make_publisher_group, mock_response
+
+
+class TestPublisherSupports:
+ """`supports` answers: is there a *single* source matching the given (type AND language)?"""
+
+ def test_unconstrained_matches_publisher_with_sources(self):
+ publisher = make_publisher(sources=[NewsMap("https://x/news", languages={"en"})])
+ assert publisher.supports() is True
+
+ @pytest.mark.parametrize("source_type, expected", [(NewsMap, True), (Sitemap, False)])
+ def test_filters_by_source_type(self, source_type, expected):
+ publisher = make_publisher(sources=[NewsMap("https://x/news", languages={"en"})])
+ assert publisher.supports(source_types=[source_type]) is expected
+
+ @pytest.mark.parametrize("language, expected", [("es", True), ("de", True), ("fr", False)])
+ def test_filters_by_language(self, language, expected):
+ publisher = make_publisher(
+ sources=[
+ RSSFeed("https://x/feed", languages={"es"}),
+ Sitemap("https://x/map", languages={"de"}),
+ ]
+ )
+ assert publisher.supports(languages=[language]) is expected
+
+ @pytest.mark.parametrize(
+ "source_type, language, expected",
+ [
+ (NewsMap, "es", True), # the NewsMap source carries 'es'
+ (RSSFeed, "pl", True), # the RSSFeed source carries 'pl'
+ (NewsMap, "pl", False), # 'pl' exists, but only on the RSSFeed — no single source is both
+ (RSSFeed, "es", False), # 'es' exists, but only on the NewsMap
+ ],
+ )
+ def test_requires_one_source_matching_both_type_and_language(self, source_type, language, expected):
+ publisher = make_publisher(
+ sources=[
+ RSSFeed("https://x/feed", languages={"pl"}),
+ NewsMap("https://x/news", languages={"es"}),
+ ]
+ )
+ assert publisher.supports(source_types=[source_type], languages=[language]) is expected
+
+
+class TestPublisherGroup:
+ def test_len(
+ self,
+ empty_publisher_group,
+ group_with_empty_publisher_subgroup,
+ group_with_valid_publisher_subgroup,
+ group_with_two_valid_publisher_subgroups,
+ ):
+ assert len(empty_publisher_group) == 0
+ assert len(group_with_empty_publisher_subgroup) == 0
+ assert len(group_with_valid_publisher_subgroup) == 1
+ assert len(group_with_two_valid_publisher_subgroups) == 2
+
+ def test_iter_empty_group(self, empty_publisher_group):
+ assert list(empty_publisher_group) == []
+
+ def test_iter_group_with_empty_publisher_subgroup(self, group_with_empty_publisher_subgroup):
+ assert list(group_with_empty_publisher_subgroup) == []
+
+ def test_iter_group_with_publisher_subgroup(self, group_with_valid_publisher_subgroup):
+ assert list(group_with_valid_publisher_subgroup) == [group_with_valid_publisher_subgroup.pub.value]
+
+ def test_string_representation_nests_subgroups_and_publishers(self):
+ group = make_publisher_group(
+ name="Root",
+ news=make_publisher_group(name="News", a=make_publisher(name="pub_a")),
+ sitemap=make_publisher_group(name="Sitemap", b=make_publisher(name="pub_b")),
+ )
+ assert str(group) == "\n\t\n\t\tpub_a\n\t\n\t\tpub_b"
+
+ def test_string_representation_lists_direct_publishers(self):
+ group = make_publisher_group(name="PubGroup", value=make_publisher(name="test_pub"))
+ assert str(group) == "\n\ttest_pub"
+
+
+@pytest.mark.filterwarnings("ignore::UserWarning") # searches that match nothing call warn()
+class TestPublisherGroupSearch:
+ @pytest.mark.parametrize("args", [(), ([],), ([], [])])
+ def test_requires_at_least_one_criterion(self, args):
+ with pytest.raises(ValueError):
+ make_publisher_group().search(*args)
+
+ def test_matches_only_attributes_of_the_active_parser_version(
+ self, publisher_group_with_versioned_attrs, proxy_with_two_versions_and_different_attrs
+ ):
+ # the publisher's active parser is the latest version, so only its attributes are searchable
+ current, superseded = proxy_with_two_versions_and_different_attrs().attribute_mapping.values()
+ assert len(publisher_group_with_versioned_attrs.search(current.names)) == 1
+ assert publisher_group_with_versioned_attrs.search(superseded.names) == []
+
+ @pytest.mark.parametrize("source_types, expected", [([NewsMap], 1), ([Sitemap], 0)])
+ def test_combines_attribute_match_with_source_constraint(
+ self, publisher_group_with_versioned_attrs, proxy_with_two_versions_and_different_attrs, source_types, expected
+ ):
+ # attributes match, but the publisher only has a NewsMap source
+ current, _ = proxy_with_two_versions_and_different_attrs().attribute_mapping.values()
+ assert len(publisher_group_with_versioned_attrs.search(current.names, source_types=source_types)) == expected
+
+ def test_excludes_deprecated_attributes_by_default(
+ self, publisher_group_with_deprecated_attrs, proxy_with_two_deprecated_attributes
+ ):
+ (attributes,) = proxy_with_two_deprecated_attributes().attribute_mapping.values()
+ assert publisher_group_with_deprecated_attrs.search([attributes.deprecated.names[0]]) == []
+
+ def test_includes_deprecated_attributes_when_requested(
+ self, publisher_group_with_deprecated_attrs, proxy_with_two_deprecated_attributes
+ ):
+ (attributes,) = proxy_with_two_deprecated_attributes().attribute_mapping.values()
+ result = publisher_group_with_deprecated_attrs.search(
+ [attributes.deprecated.names[0]], include_deprecated_attributes=True
+ )
+ assert len(result) == 1
+
+ def test_returns_every_matching_publisher(self):
+ group = make_publisher_group(
+ default_language="en", # eng's source declares no language, so it inherits "en"
+ eng=make_publisher(name="eng", sources=[NewsMap("https://x/news")]),
+ ger=make_publisher(name="ger", sources=[Sitemap("https://x/map", languages={"de"})]),
+ )
+ # search collects every publisher in the group that matches
+ assert len(group.search(languages=["en"])) == 1
+ assert len(group.search(languages=["de"])) == 1
+ assert len(group.search(languages=["en", "de"])) == 2
+
+ def test_returns_results_narrowed_to_search_criteria(self):
+ group = make_publisher_group(value=make_publisher(sources=[NewsMap("https://x/n"), Sitemap("https://x/s")]))
+ (result,) = group.search(source_types=[NewsMap])
+ assert set(result.source_mapping) == {NewsMap} # Sitemap dropped; result is a narrowed FilteredPublisher
+
+
+@pytest.fixture
+def robots_session():
+ """Patch the session_handler that CustomRobotFileParser.read() reaches for; yield the session mock."""
+ with patch("fundus.publishers.base_objects.session_handler") as handler:
+ session = MagicMock()
+ handler.get_session.return_value = session
+ yield session
+
+
+class TestCustomRobotFileParser:
+ @pytest.mark.parametrize(
+ "lines, expected",
+ [
+ (["# we allow machine learning training"], True), # keyword inside a comment
+ (["# just an ordinary comment"], False), # comment, no keyword
+ (["User-agent: *", "Disallow: /machine"], False), # keyword token, but not a comment line
+ ],
+ )
+ def test_parse_detects_disallow_training_in_comments(self, lines, expected):
+ parser = CustomRobotFileParser("https://x/robots.txt")
+ parser.parse(lines)
+ assert parser.disallows_training is expected
+
+ @pytest.mark.parametrize("status", [401, 403])
+ def test_read_auth_error_disallows_all(self, robots_session, status):
+ robots_session.get_with_interrupt.side_effect = make_http_error(status_code=status)
+ parser = CustomRobotFileParser("https://x/robots.txt")
+ parser.read()
+ assert parser.disallow_all is True
+
+ @pytest.mark.parametrize("status", [400, 404, 429, 500, 503])
+ def test_read_error_defaults_to_allow_all(self, robots_session, status):
+ # any non-401/403 HTTP error (missing robots.txt or a server error) → treat as no restrictions
+ robots_session.get_with_interrupt.side_effect = make_http_error(status_code=status)
+ parser = CustomRobotFileParser("https://x/robots.txt")
+ parser.read()
+ assert parser.allow_all is True
+
+ def test_read_success_scans_body_for_disallow_training(self, robots_session):
+ robots_session.get_with_interrupt.return_value = mock_response(text="# trained for machine learning")
+ parser = CustomRobotFileParser("https://x/robots.txt")
+ parser.read()
+ assert parser.disallows_training is True
+
+ def test_read_success_enforces_parsed_rules(self, robots_session):
+ robots_session.get_with_interrupt.return_value = mock_response(text="User-agent: *\nDisallow: /private")
+ parser = CustomRobotFileParser("https://x/robots.txt")
+ parser.read()
+ assert parser.can_fetch("*", "https://x/private") is False
+ assert parser.can_fetch("*", "https://x/public") is True
+
+
+class TestRobots:
+ @pytest.mark.parametrize("raw, expected", [(5, 5.0), (2.5, 2.5), (None, None)])
+ def test_crawl_delay_is_coerced_to_float(self, raw, expected):
+ robots = Robots("https://x/robots.txt")
+ robots.robots_file_parser = MagicMock()
+ robots.robots_file_parser.crawl_delay.return_value = raw
+ result = robots.crawl_delay("*")
+ assert result == expected
+ if expected is not None:
+ assert isinstance(result, float)
+
+ def test_robots_is_read_once_across_calls(self):
+ robots = Robots("https://x/robots.txt")
+ robots.robots_file_parser = MagicMock()
+ robots.can_fetch("*", "https://x/a")
+ robots.crawl_delay("*")
+ robots.disallow_all()
+ robots.robots_file_parser.read.assert_called_once()
+
+ @pytest.mark.parametrize("error", [ConnectionError("boom"), ReadTimeout("boom")])
+ def test_read_failure_is_swallowed_and_allows_all(self, error):
+ robots = Robots("https://x/robots.txt")
+ robots.robots_file_parser = MagicMock()
+ robots.robots_file_parser.read.side_effect = error
+ robots.ensure_ready() # must not raise
+ assert robots.robots_file_parser.allow_all is True
+ assert robots.ready is True
+
+ @pytest.mark.integration
+ @pytest.mark.xfail(
+ reason="_read catches ReadTimeout, not the base Timeout curl_cffi raises on a real timeout, so the "
+ "timeout propagates out of ensure_ready instead of defaulting to allow-all. Fixed by flairNLP/fundus#939.",
+ strict=True,
+ )
+ def test_real_timeout_is_swallowed_and_allows_all(self, hanging_url):
+ robots = Robots(hanging_url)
+ with session_handler.context(timeout=0.3):
+ robots.ensure_ready() # must not raise
+ assert robots.robots_file_parser.allow_all is True
+ assert robots.ready is True
+
+ def test_can_fetch_delegates_to_parser(self):
+ robots = Robots("https://x/robots.txt")
+ robots.robots_file_parser = MagicMock()
+ robots.robots_file_parser.can_fetch.return_value = False
+ assert robots.can_fetch("bot", "https://x/p") is False
+ robots.robots_file_parser.can_fetch.assert_called_once_with("bot", "https://x/p")
+
+
+class TestPublisherConstruction:
+ @pytest.mark.parametrize("missing", [{"name": ""}, {"domain": ""}, {"sources": []}])
+ def test_requires_mandatory_fields(self, missing):
+ with pytest.raises(ValueError):
+ make_publisher(**missing)
+
+ def test_rejects_non_urlsource_sources(self):
+ with pytest.raises(TypeError):
+ make_publisher(sources=["https://x/not-a-source"]) # type: ignore[list-item]
+
+ def test_rejects_unknown_impersonate(self):
+ with pytest.raises(ValueError):
+ make_publisher(impersonate="definitely-not-a-browser") # type: ignore[arg-type]
+
+ def test_accepts_valid_impersonate(self):
+ assert make_publisher(impersonate="chrome").impersonate
+
+ @pytest.mark.parametrize("domain", ["https://x.com/", "https://x.com"])
+ def test_robots_url_appends_path(self, domain):
+ assert make_publisher(domain=domain).robots.url == "https://x.com/robots.txt"
+
+ def test_orders_sources_rss_newsmap_sitemap(self):
+ publisher = make_publisher(
+ sources=[
+ Sitemap("https://x/sitemap"),
+ NewsMap("https://x/news"),
+ RSSFeed("https://x/feed"),
+ ]
+ )
+ assert list(publisher.source_mapping) == [RSSFeed, NewsMap, Sitemap]
+
+ def test_suppress_robots_sets_allow_all(self):
+ publisher = make_publisher(suppress_robots=True)
+ assert publisher.robots.robots_file_parser.allow_all is True
+
+
+class TestPublisherProperties:
+ def test_languages_unions_all_sources(self):
+ publisher = make_publisher(
+ sources=[
+ RSSFeed("https://x/feed", languages={"en", "de"}),
+ Sitemap("https://x/map", languages={"fr"}),
+ ]
+ )
+ assert publisher.languages == {"en", "de", "fr"}
+
+ def test_source_types_reflects_present_types(self):
+ publisher = make_publisher(sources=[RSSFeed("https://x/feed"), NewsMap("https://x/news")])
+ assert publisher.source_types == {RSSFeed, NewsMap}
+
+ def test_disallows_training_short_circuits_on_flag(self):
+ publisher = make_publisher(disallows_training=True)
+ publisher.robots = MagicMock()
+ assert publisher.disallows_training is True
+ publisher.robots.disallows_training.assert_not_called()
+
+ def test_disallows_training_falls_back_to_robots(self):
+ publisher = make_publisher(disallows_training=False)
+ publisher.robots = MagicMock()
+ publisher.robots.disallows_training.return_value = True
+ assert publisher.disallows_training is True
+
+ def test_disallows_training_false_when_neither(self):
+ publisher = make_publisher(disallows_training=False)
+ publisher.robots = MagicMock()
+ publisher.robots.disallows_training.return_value = False
+ assert publisher.disallows_training is False
+
+
+class TestPublisherEquality:
+ def test_hash_is_name_based(self):
+ assert hash(make_publisher(name="x")) == hash("x")
+ assert hash(make_publisher(name="x")) == hash(make_publisher(name="x"))
+
+ def test_differs_by_name(self):
+ assert make_publisher(name="a") != make_publisher(name="b")
+
+ def test_not_equal_to_non_publisher(self):
+ assert make_publisher() != "not a publisher"
+
+ def test_equal_to_itself(self):
+ publisher = make_publisher()
+ assert publisher == publisher
+
+ @pytest.mark.xfail(
+ reason="Publisher.__eq__ compares self.parser by identity (ParserProxy defines no __eq__), "
+ "so two value-equal publishers never compare equal.",
+ strict=True,
+ )
+ def test_value_equal_publishers_compare_equal(self):
+ assert make_publisher(name="x") == make_publisher(name="x")
+
+
+class TestFilteredPublisher:
+ def test_no_filter_exposes_all_sources(self):
+ publisher = make_publisher(sources=[NewsMap("https://x/n"), Sitemap("https://x/s")])
+ filtered = FilteredPublisher.from_publisher(publisher)
+ assert filtered.source_mapping == publisher.source_mapping
+
+ def test_narrows_by_source_type(self):
+ publisher = make_publisher(sources=[NewsMap("https://x/n"), Sitemap("https://x/s")])
+ filtered = FilteredPublisher.from_publisher(publisher, source_types={NewsMap})
+ assert set(filtered.source_mapping) == {NewsMap}
+
+ def test_narrows_by_language(self):
+ publisher = make_publisher(
+ sources=[RSSFeed("https://x/r", languages={"es"}), Sitemap("https://x/s", languages={"de"})]
+ )
+ filtered = FilteredPublisher.from_publisher(publisher, languages={"es"})
+ assert set(filtered.source_mapping) == {RSSFeed} # German Sitemap dropped
+
+ def test_combines_source_type_and_language_filters(self):
+ publisher = make_publisher(
+ sources=[
+ RSSFeed("https://x/r", languages={"es"}),
+ NewsMap("https://x/n", languages={"es"}),
+ Sitemap("https://x/s", languages={"es"}),
+ ]
+ )
+ filtered = FilteredPublisher.from_publisher(publisher, source_types={NewsMap, Sitemap}, languages={"es"})
+ assert set(filtered.source_mapping) == {NewsMap, Sitemap} # RSSFeed excluded by source-type filter
+
+ def test_language_filter_is_exposed(self):
+ filtered = FilteredPublisher.from_publisher(make_publisher(), languages={"es"})
+ assert filtered.language_filter == {"es"}
+
+ def test_carries_over_publisher_identity(self):
+ filtered = FilteredPublisher.from_publisher(make_publisher(name="orig"))
+ assert filtered.name == "orig"
diff --git a/tests/publishers/test_init.py b/tests/publishers/test_init.py
new file mode 100644
index 000000000..808b22748
--- /dev/null
+++ b/tests/publishers/test_init.py
@@ -0,0 +1,57 @@
+from typing import Set
+
+import more_itertools
+import pytest
+
+from fundus.publishers import Publisher, PublisherCollectionMeta, PublisherGroup
+from tests.fixtures.builders import make_publisher, make_publisher_group
+from tests.resources import __module_path__ as resources_path
+from tests.utility import get_meta_info_file
+
+# ISO 639-1 codes, frozen from the "List of ISO 639 language codes" Wikipedia table
+# (en.wikipedia.org/wiki/List_of_ISO_639_language_codes, table id="Table", td/@id) so the
+# data-hygiene checks below run offline. If a new publisher needs a code not listed here,
+# re-snapshot that table into iso_639_codes.txt.
+language_codes: Set[str] = set((resources_path / "iso_639_codes.txt").read_text(encoding="utf-8").split())
+
+
+class TestPublisherCollection:
+ def test_default_language(self, region: PublisherGroup):
+ assert hasattr(region, "default_language"), f"Region {region.__name__!r} has no default language set"
+
+ default_language = getattr(region, "default_language")
+
+ assert default_language in language_codes, (
+ f"Default language {default_language!r} isn't a ISO 639 language code"
+ )
+
+ def test_source_languages(self, publisher: Publisher):
+ for source in more_itertools.flatten(publisher.source_mapping.values()):
+ assert source.languages.issubset(language_codes)
+
+
+class TestPublisherCollectionMeta:
+ def test_rejects_duplicate_publisher_across_subgroups(self):
+ with pytest.raises(ValueError):
+ PublisherCollectionMeta(
+ "C",
+ (),
+ {"a": make_publisher_group(Foo=make_publisher()), "b": make_publisher_group(Foo=make_publisher())},
+ )
+
+ def test_rejects_non_publisher_attribute(self):
+ with pytest.raises(TypeError):
+ PublisherCollectionMeta("C", (), {"x": "not a publisher"})
+
+ def test_accepts_valid_collection(self):
+ PublisherCollectionMeta("C", (), {"a": make_publisher_group(Foo=make_publisher())}) # must not raise
+
+
+class TestMetaInfo:
+ def test_order(self, region: PublisherGroup):
+ meta_file = get_meta_info_file(region)
+ meta_info = meta_file.load()
+ assert meta_info, f"Meta info file {meta_file.path} is missing"
+ assert sorted(meta_info.keys()) == list(meta_info.keys()), (
+ f"Meta info file {meta_file.path} isn't ordered properly."
+ )
diff --git a/tests/publishers/test_parser_coverage.py b/tests/publishers/test_parser_coverage.py
new file mode 100644
index 000000000..e0a295e27
--- /dev/null
+++ b/tests/publishers/test_parser_coverage.py
@@ -0,0 +1,114 @@
+import pickle
+from typing import List
+
+import lxml.html
+
+from fundus import PublisherCollection
+from fundus.parser.base_parser import Attribute
+from fundus.publishers import Publisher
+from tests.resources import attribute_annotations_mapping
+from tests.utility import (
+ load_html_test_file_mapping,
+ load_supported_publishers_markdown,
+ load_test_case_data,
+)
+
+
+def test_supported_publishers_table():
+ root = lxml.html.fromstring(load_supported_publishers_markdown())
+ parsed_names: List[str] = root.xpath("//table[contains(@class,'publishers')]//td[1]/code/text()")
+ for publisher in PublisherCollection:
+ assert publisher.__name__ in parsed_names, (
+ f"Publisher {publisher.name} is not included in docs/supported_news.md. "
+ f"Run 'python -m scripts.generate_tables'"
+ )
+
+
+# enforce test coverage for test parsing
+# because this is also used for the generate_parser_test_files script we export it here
+attributes_required_to_cover = {"title", "authors", "topics", "publishing_date", "body", "images"}
+
+attributes_parsers_are_required_to_cover = {"body"}
+
+
+class TestPublisherParsers:
+ def test_annotations(self, publisher: Publisher) -> None:
+ parser_proxy = publisher.parser
+ for versioned_parser in parser_proxy:
+ assert attributes_parsers_are_required_to_cover.issubset(
+ set(versioned_parser.attributes().validated.names)
+ ), f"{versioned_parser.__name__!r} should implement at least {attributes_parsers_are_required_to_cover!r}"
+ for attr in versioned_parser.attributes().validated:
+ annotation = attribute_annotations_mapping.get(attr.__name__)
+ assert annotation, (
+ f"Attribute {attr.__name__!r} has no registered annotation in attribute_annotations_mapping"
+ )
+ assert attr.__annotations__.get("return") == annotation, (
+ f"Attribute {attr.__name__!r} for {versioned_parser.__name__!r} is of wrong type. "
+ f"{attr.__annotations__.get('return')} != {annotation}"
+ )
+
+ def test_test_data_wellformed(self, publisher: Publisher) -> None:
+ """Validate the test fixture: it exists, is complete, and matches the parser's required attributes."""
+ comparative_data = load_test_case_data(publisher)
+ html_mapping = load_html_test_file_mapping(publisher)
+
+ for versioned_parser in publisher.parser:
+ version_name = versioned_parser.__name__
+
+ assert (version_data := comparative_data.get(version_name)), (
+ f"Missing test data for parser version {version_name!r}"
+ )
+ assert (html := html_mapping.get(versioned_parser)), (
+ f"Missing test HTML for parser version {version_name} of publisher {publisher.name}"
+ )
+
+ # only complete articles should be used as test cases
+ for key, value in version_data.items():
+ assert value, (
+ f"There is no value set for key {key!r} in the test JSON. "
+ f"Only complete articles should be used as test cases"
+ )
+
+ # the fixture must cover the parser's required attributes, in alphabetical order;
+ # re-instantiate parser to address deprecated attributes
+ timestamp_instantiated_parser = publisher.parser(html.crawl_date)
+ supported_attrs = set(timestamp_instantiated_parser.registered_attributes.names)
+ missing_attrs = attributes_required_to_cover & supported_attrs - set(version_data.keys())
+ assert not missing_attrs, (
+ f"Test JSON for {version_name} of publisher {publisher.name} does not cover the following attribute(s): {missing_attrs}"
+ )
+ assert list(version_data.keys()) == sorted(attributes_required_to_cover & supported_attrs), (
+ f"Test JSON for {version_name} is not in alphabetical order"
+ )
+
+ def test_extraction_matches(self, publisher: Publisher) -> None:
+ """Validate the parser: its extraction matches the expected fixture data and is picklable."""
+ comparative_data = load_test_case_data(publisher)
+ html_mapping = load_html_test_file_mapping(publisher)
+
+ for versioned_parser in publisher.parser:
+ version_name = versioned_parser.__name__
+
+ assert (version_data := comparative_data.get(version_name)), (
+ f"Missing test data for parser version {version_name!r}"
+ )
+ assert (html := html_mapping.get(versioned_parser)), (
+ f"Missing test HTML for parser version {version_name} of publisher {publisher.name}"
+ )
+
+ # re-instantiate parser to address deprecated attributes
+ timestamp_instantiated_parser = publisher.parser(html.crawl_date)
+
+ extraction = timestamp_instantiated_parser.parse(html.content)
+ for key, value in version_data.items():
+ assert value == extraction[key], f"{key!r} is not equal"
+
+ # check if extraction is pickable
+ pickle.dumps(extraction)
+
+ def test_reserved_attribute_names(self, publisher: Publisher):
+ parser = publisher.parser
+ for attr in attribute_annotations_mapping.keys():
+ if value := getattr(parser, attr, None):
+ assert isinstance(value, Attribute), f"The name {attr!r} is reserved for attributes only."
diff --git a/tests/resources/__init__.py b/tests/resources/__init__.py
index 002944b8c..c7dab3acd 100644
--- a/tests/resources/__init__.py
+++ b/tests/resources/__init__.py
@@ -1,3 +1,7 @@
+import pathlib
+
from tests.resources.parser.attribute_annotations import attribute_annotations_mapping
-__all__ = ["attribute_annotations_mapping"]
+__module_path__ = pathlib.Path(__file__).parent
+
+__all__ = ["attribute_annotations_mapping", "__module_path__"]
diff --git a/tests/resources/iso_639_codes.txt b/tests/resources/iso_639_codes.txt
new file mode 100644
index 000000000..fa25ed473
--- /dev/null
+++ b/tests/resources/iso_639_codes.txt
@@ -0,0 +1,182 @@
+aa
+ab
+ae
+af
+ak
+am
+an
+ar
+as
+av
+ay
+az
+ba
+be
+bg
+bi
+bm
+bn
+bo
+br
+bs
+ca
+ce
+ch
+co
+cr
+cs
+cu
+cv
+cy
+da
+de
+dv
+dz
+ee
+el
+en
+eo
+es
+et
+eu
+fa
+ff
+fi
+fj
+fo
+fr
+fy
+ga
+gd
+gl
+gn
+gu
+gv
+ha
+he
+hi
+ho
+hr
+ht
+hu
+hy
+hz
+ia
+id
+ie
+ig
+ik
+io
+is
+it
+iu
+ja
+jv
+ka
+kg
+ki
+kj
+kk
+kl
+km
+kn
+ko
+kr
+ks
+ku
+kv
+kw
+ky
+la
+lb
+lg
+li
+ln
+lo
+lt
+lu
+lv
+mg
+mh
+mi
+mk
+ml
+mn
+mr
+ms
+mt
+my
+na
+nb
+nd
+ne
+ng
+nl
+nn
+no
+nr
+nv
+ny
+oc
+oj
+om
+or
+os
+pa
+pi
+pl
+ps
+pt
+qu
+rm
+rn
+ro
+ru
+rw
+sa
+sc
+sd
+se
+sg
+si
+sk
+sl
+sm
+sn
+so
+sq
+sr
+ss
+st
+su
+sv
+sw
+ta
+te
+tg
+th
+ti
+tk
+tl
+tn
+to
+tr
+ts
+tt
+tw
+ty
+ug
+uk
+ur
+uz
+ve
+vi
+vo
+wa
+wo
+xh
+yi
+yo
+za
+zh
+zu
diff --git a/tests/scraping/__init__.py b/tests/scraping/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/scraping/crawler/__init__.py b/tests/scraping/crawler/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/scraping/crawler/conftest.py b/tests/scraping/crawler/conftest.py
new file mode 100644
index 000000000..fd3103e08
--- /dev/null
+++ b/tests/scraping/crawler/conftest.py
@@ -0,0 +1,16 @@
+from typing import Iterator
+
+import pytest
+
+from fundus.utils.events import __EVENTS__, __MAIN_THREAD_ALIAS__
+
+
+@pytest.fixture
+def main_thread_context() -> Iterator[None]:
+ """Register the main-thread alias, mirroring the context ``CrawlerBase.crawl`` sets up.
+
+ Without this context, tests that probe events hit an unregistered alias and
+ raise ``KeyError``.
+ """
+ with __EVENTS__.main_context(__MAIN_THREAD_ALIAS__):
+ yield
diff --git a/tests/scraping/crawler/test_base.py b/tests/scraping/crawler/test_base.py
new file mode 100644
index 000000000..fd47e0f5a
--- /dev/null
+++ b/tests/scraping/crawler/test_base.py
@@ -0,0 +1,109 @@
+from __future__ import annotations
+
+import json
+
+from fundus.scraping.crawler.base import CrawlerBase, _CrawlState
+from fundus.scraping.filter import Requires, RequiresAll
+from tests.fixtures.builders import make_article, make_html
+from tests.fixtures.fakes import FakeCrawler
+
+
+class TestCrawlState:
+ def test_accept_new_url_returns_true(self):
+ state = _CrawlState(only_unique=True, track_articles=False)
+ article = make_article(html=make_html(requested_url="https://example.com/1", publisher="pub"))
+ assert state.accept(article) is True
+
+ def test_accept_duplicate_returns_false(self):
+ state = _CrawlState(only_unique=True, track_articles=False)
+ article = make_article(html=make_html(requested_url="https://example.com/1", publisher="pub"))
+ state.accept(article)
+ assert state.accept(article) is False
+
+ def test_allows_duplicate_when_not_unique(self):
+ state = _CrawlState(only_unique=False, track_articles=False)
+ article = make_article(html=make_html(requested_url="https://example.com/1", publisher="pub"))
+ assert state.accept(article) is True
+ assert state.accept(article) is True
+
+ def test_counts_per_publisher_and_total(self):
+ state = _CrawlState(only_unique=False, track_articles=False)
+ state.accept(make_article(html=make_html(requested_url="https://example.com/1", publisher="pub_a")))
+ state.accept(make_article(html=make_html(requested_url="https://example.com/2", publisher="pub_a")))
+ state.accept(make_article(html=make_html(requested_url="https://example.com/3", publisher="pub_b")))
+ assert state.total_count == 3
+ assert state.article_count["pub_a"] == 2
+ assert state.article_count["pub_b"] == 1
+
+ def test_tracks_articles_when_enabled(self):
+ state = _CrawlState(only_unique=False, track_articles=True)
+ article = make_article(html=make_html(requested_url="https://example.com/1", publisher="pub"))
+ state.accept(article)
+ assert article in state.crawled_articles["pub"]
+
+ def test_does_not_track_when_disabled(self):
+ state = _CrawlState(only_unique=False, track_articles=False)
+ state.accept(make_article(html=make_html(requested_url="https://example.com/1", publisher="pub")))
+ assert len(state.crawled_articles) == 0
+
+
+class TestBuildExtractionFilter:
+ def test_false_returns_none(self):
+ assert CrawlerBase._build_extraction_filter(False) is None
+
+ def test_true_returns_requires_all(self):
+ assert isinstance(CrawlerBase._build_extraction_filter(True), RequiresAll)
+
+ def test_filter_passed_through(self):
+ f = Requires("title")
+ assert CrawlerBase._build_extraction_filter(f) is f
+
+
+class TestFilterPublishers:
+ def test_no_extraction_filter_returns_all_publishers(self, publisher_group_with_news_map):
+ crawler = FakeCrawler(publisher_group_with_news_map)
+ result = crawler._filter_publishers(extraction_filter=None, language_filter=None)
+ assert len(result) == len(crawler.publishers)
+
+ def test_unsupported_attribute_filters_out_publisher(self, publisher_group_with_news_map):
+ crawler = FakeCrawler(publisher_group_with_news_map)
+ # parser_proxy_with_version has no @attribute methods, so any Requires removes it
+ result = crawler._filter_publishers(extraction_filter=Requires("nonexistent_attr_xyz"), language_filter=None)
+ assert result == []
+
+
+class TestCrawl:
+ def test_max_articles_zero_yields_nothing(self, publisher_group_with_news_map):
+ articles = [make_article(html=make_html(requested_url=f"https://example.com/{i}")) for i in range(5)]
+ crawler = FakeCrawler(publisher_group_with_news_map, articles=articles)
+ assert list(crawler.crawl(max_articles=0, only_complete=False)) == []
+
+ def test_max_articles_limits_output(self, publisher_group_with_news_map):
+ articles = [make_article(html=make_html(requested_url=f"https://example.com/{i}")) for i in range(10)]
+ crawler = FakeCrawler(publisher_group_with_news_map, articles=articles)
+ assert len(list(crawler.crawl(max_articles=3, only_complete=False))) == 3
+
+ def test_only_unique_deduplicates_by_url(self, publisher_group_with_news_map):
+ articles = [make_article(html=make_html(requested_url="https://example.com/same"))] * 3
+ crawler = FakeCrawler(publisher_group_with_news_map, articles=articles)
+ assert len(list(crawler.crawl(only_complete=False, only_unique=True))) == 1
+
+ def test_not_unique_passes_duplicates(self, publisher_group_with_news_map):
+ articles = [make_article(html=make_html(requested_url="https://example.com/same"))] * 3
+ crawler = FakeCrawler(publisher_group_with_news_map, articles=articles)
+ assert len(list(crawler.crawl(only_complete=False, only_unique=False))) == 3
+
+ def test_max_articles_per_publisher(self, publisher_group_with_news_map):
+ articles = [make_article(html=make_html(requested_url=f"https://example.com/{i}")) for i in range(5)]
+ crawler = FakeCrawler(publisher_group_with_news_map, articles=articles)
+ result = list(crawler.crawl(max_articles_per_publisher=2, only_complete=False))
+ assert len(result) == 2
+
+ def test_save_to_file_writes_json(self, publisher_group_with_news_map, tmp_path):
+ articles = [make_article(html=make_html(requested_url="https://example.com/1"))]
+ crawler = FakeCrawler(publisher_group_with_news_map, articles=articles)
+ path = tmp_path / "out.json"
+ list(crawler.crawl(only_complete=False, save_to_file=str(path)))
+ assert path.exists()
+ data = json.loads(path.read_text())
+ assert "test_pub" in data
diff --git a/tests/scraping/crawler/test_ccnews.py b/tests/scraping/crawler/test_ccnews.py
new file mode 100644
index 000000000..56206bbd4
--- /dev/null
+++ b/tests/scraping/crawler/test_ccnews.py
@@ -0,0 +1,149 @@
+from __future__ import annotations
+
+import gzip
+from datetime import datetime
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from fundus.scraping.crawler import CCNewsCrawler
+from fundus.scraping.pipeline.source.ccnews import WarcFileLoadError
+from tests.fixtures.builders import make_article, make_html
+
+
+class TestCCNewsCrawlerInit:
+ def test_raises_when_start_equals_end(self, publisher_group_with_news_map):
+ date = datetime(2020, 1, 1)
+ with pytest.raises(ValueError, match="Start date has to be < end date"):
+ CCNewsCrawler(publisher_group_with_news_map, start=date, end=date)
+
+ def test_raises_when_start_after_end(self, publisher_group_with_news_map):
+ with pytest.raises(ValueError, match="Start date has to be < end date"):
+ CCNewsCrawler(
+ publisher_group_with_news_map,
+ start=datetime(2021, 1, 1),
+ end=datetime(2020, 1, 1),
+ )
+
+ def test_raises_when_start_before_minimum(self, publisher_group_with_news_map):
+ with pytest.raises(ValueError, match="2016/08/01"):
+ CCNewsCrawler(publisher_group_with_news_map, start=datetime(2016, 7, 31))
+
+ def test_raises_when_end_in_future(self, publisher_group_with_news_map):
+ with pytest.raises(ValueError, match="future"):
+ CCNewsCrawler(publisher_group_with_news_map, end=datetime(2099, 1, 1))
+
+ def test_default_end_is_evaluated_at_construction_time(self, publisher_group_with_news_map):
+ crawler = CCNewsCrawler(publisher_group_with_news_map, start=datetime(2020, 1, 1))
+ assert crawler.end <= datetime.now()
+
+
+class TestFetchArticles:
+ def test_retries_on_warc_file_load_error(self, publisher_group_with_news_map):
+ crawler = CCNewsCrawler(
+ publisher_group_with_news_map,
+ start=datetime(2020, 1, 1),
+ end=datetime(2021, 1, 1),
+ retries=2,
+ )
+ publishers = tuple(crawler.publishers)
+
+ with patch("fundus.scraping.crawler.ccnews.Pipeline") as MockPipeline, patch("time.sleep"):
+ MockPipeline.return_value.run.side_effect = WarcFileLoadError("test")
+ list(crawler._fetch_articles("fake/path.warc.gz", publishers, False))
+
+ assert MockPipeline.call_count == 3 # initial attempt + 2 retries
+
+ def test_stops_immediately_on_success(self, publisher_group_with_news_map):
+ crawler = CCNewsCrawler(
+ publisher_group_with_news_map,
+ start=datetime(2020, 1, 1),
+ end=datetime(2021, 1, 1),
+ retries=3,
+ )
+ publishers = tuple(crawler.publishers)
+ fake = make_article(html=make_html(requested_url="https://example.com/1"))
+
+ with patch("fundus.scraping.crawler.ccnews.Pipeline") as MockPipeline:
+ MockPipeline.return_value.run.return_value = iter([fake])
+ result = list(crawler._fetch_articles("fake/path.warc.gz", publishers, False))
+
+ assert result == [fake]
+ assert MockPipeline.call_count == 1
+
+
+class TestGetWarcPaths:
+ def test_filters_paths_by_date_range(self, publisher_group_with_news_map):
+ # single month so requests.Session.get is called exactly once
+ crawler = CCNewsCrawler(
+ publisher_group_with_news_map,
+ start=datetime(2020, 6, 1),
+ end=datetime(2020, 6, 30),
+ processes=0,
+ )
+ paths = [
+ "crawl-data/CC-NEWS/2020/06/CC-NEWS-20200615000000-00001.warc.gz", # in range
+ "crawl-data/CC-NEWS/2020/05/CC-NEWS-20200531000000-00001.warc.gz", # before start
+ "crawl-data/CC-NEWS/2020/07/CC-NEWS-20200701000000-00001.warc.gz", # after end
+ ]
+ mock_response = MagicMock()
+ mock_response.content = gzip.compress("\n".join(paths).encode())
+
+ with patch("requests.Session.get", return_value=mock_response):
+ result = crawler._get_warc_paths()
+
+ assert len(result) == 1
+ assert result[0] == f"{crawler.server_address}{paths[0]}"
+
+ def test_results_sorted_newest_first(self, publisher_group_with_news_map):
+ crawler = CCNewsCrawler(
+ publisher_group_with_news_map,
+ start=datetime(2020, 6, 1),
+ end=datetime(2020, 6, 30),
+ processes=0,
+ )
+ paths = [
+ "crawl-data/CC-NEWS/2020/06/CC-NEWS-20200610000000-00001.warc.gz",
+ "crawl-data/CC-NEWS/2020/06/CC-NEWS-20200620000000-00001.warc.gz",
+ "crawl-data/CC-NEWS/2020/06/CC-NEWS-20200601000000-00001.warc.gz",
+ ]
+ mock_response = MagicMock()
+ mock_response.content = gzip.compress("\n".join(paths).encode())
+
+ with patch("requests.Session.get", return_value=mock_response):
+ result = crawler._get_warc_paths()
+
+ assert "20200620" in result[0]
+ assert "20200601" in result[-1]
+
+
+class TestDispatch:
+ def test_single_crawl_when_processes_zero(self, publisher_group_with_news_map):
+ crawler = CCNewsCrawler(
+ publisher_group_with_news_map,
+ start=datetime(2020, 1, 1),
+ end=datetime(2021, 1, 1),
+ processes=0,
+ )
+ with patch.object(crawler, "_get_warc_paths", return_value=["path1"]), patch.object(
+ crawler, "_single_crawl"
+ ) as mock_single, patch("fundus.scraping.crawler.ccnews.get_proxy_tqdm"):
+ mock_single.return_value = iter([])
+ list(crawler._build_article_iterator(tuple(crawler.publishers), False, None, None, None))
+
+ mock_single.assert_called_once()
+
+ def test_parallel_crawl_when_processes_nonzero(self, publisher_group_with_news_map):
+ crawler = CCNewsCrawler(
+ publisher_group_with_news_map,
+ start=datetime(2020, 1, 1),
+ end=datetime(2021, 1, 1),
+ processes=1,
+ )
+ with patch.object(crawler, "_get_warc_paths", return_value=["path1"]), patch.object(
+ crawler, "_parallel_crawl"
+ ) as mock_parallel, patch("fundus.scraping.crawler.ccnews.get_proxy_tqdm"):
+ mock_parallel.return_value = iter([])
+ list(crawler._build_article_iterator(tuple(crawler.publishers), False, None, None, None))
+
+ mock_parallel.assert_called_once()
diff --git a/tests/scraping/crawler/test_integration.py b/tests/scraping/crawler/test_integration.py
new file mode 100644
index 000000000..7a249fbee
--- /dev/null
+++ b/tests/scraping/crawler/test_integration.py
@@ -0,0 +1,100 @@
+from __future__ import annotations
+
+import time
+from datetime import datetime
+from typing import Iterator
+from unittest.mock import patch
+
+import pytest
+
+from fundus import Crawler
+from fundus.scraping.article import Article
+from fundus.scraping.crawler import CCNewsCrawler
+from fundus.scraping.html import HTML, SourceInfo
+from fundus.scraping.pipeline import Pipeline
+from fundus.utils.timeout import _interrupt_handler
+from tests.fixtures.builders import make_article, make_html
+from tests.fixtures.fakes import FakeCrawler
+
+
+def _parallel_task(warc_path: str) -> Iterator[Article]:
+ yield Article(
+ html=HTML(
+ requested_url=f"https://example.com/{warc_path}",
+ responded_url=f"https://example.com/{warc_path}",
+ content="",
+ crawl_date=datetime(2020, 1, 1),
+ source_info=SourceInfo(publisher="test_pub"),
+ )
+ )
+
+
+@pytest.mark.integration
+class TestCrawlerThreadedIntegration:
+ def test_articles_flow_through_thread_pool(self, publisher_group_with_news_map):
+ fake_articles = [make_article(html=make_html(requested_url=f"https://example.com/{i}")) for i in range(3)]
+ crawler = Crawler(publisher_group_with_news_map, threading=True, ignore_robots=True)
+
+ def mock_run(self, *args, **kwargs):
+ yield from fake_articles
+
+ with patch.object(Pipeline, "run", mock_run):
+ result = list(crawler.crawl(max_articles=3, only_complete=False))
+
+ assert len(result) == 3
+
+
+@pytest.mark.integration
+class TestCCNewsCrawlerIntegration:
+ def test_single_process_full_pipeline(self, publisher_group_with_news_map):
+ crawler = CCNewsCrawler(
+ publisher_group_with_news_map,
+ start=datetime(2020, 1, 1),
+ end=datetime(2021, 1, 1),
+ processes=0,
+ )
+ fake = make_article(html=make_html(requested_url="https://example.com/1"))
+
+ with patch.object(crawler, "_get_warc_paths", return_value=["fake.warc.gz"]), patch(
+ "fundus.scraping.crawler.ccnews.Pipeline"
+ ) as MockPipeline:
+ MockPipeline.return_value.run.return_value = iter([fake])
+ result = list(crawler.crawl(max_articles=1, only_complete=False))
+
+ assert len(result) == 1
+
+ def test_parallel_process_articles_flow_through_queue(self, publisher_group_with_news_map, main_thread_context):
+ crawler = CCNewsCrawler(
+ publisher_group_with_news_map,
+ start=datetime(2020, 1, 1),
+ end=datetime(2021, 1, 1),
+ processes=1,
+ )
+ # patch random_sleep in the main process so no sleep is added to the serialized task
+ with patch("fundus.scraping.crawler.ccnews.random_sleep", side_effect=lambda f, _: f):
+ result = list(crawler._parallel_crawl(("path1", "path2"), _parallel_task))
+
+ assert len(result) == 2
+
+
+@pytest.mark.integration
+class TestTimeoutIntegration:
+ def test_crawl_terminates_on_timeout(self, publisher_group_with_news_map):
+ timeout = 0.3
+
+ class TimeoutCrawler(FakeCrawler):
+ def _on_timeout(self) -> None:
+ _interrupt_handler()
+
+ def _build_article_iterator(self, *args, **kwargs) -> Iterator[Article]:
+ yield make_article(html=make_html(requested_url="https://example.com/1"))
+ while True:
+ time.sleep(0.001) # short sleeps so Windows processes the pending KeyboardInterrupt
+
+ crawler = TimeoutCrawler(publisher_group_with_news_map)
+ start = time.time()
+ result = list(crawler.crawl(timeout=timeout, only_complete=False))
+ elapsed = time.time() - start
+
+ assert len(result) == 1
+ assert elapsed < timeout + 0.3
diff --git a/tests/scraping/crawler/test_queueing.py b/tests/scraping/crawler/test_queueing.py
new file mode 100644
index 000000000..03933a3d2
--- /dev/null
+++ b/tests/scraping/crawler/test_queueing.py
@@ -0,0 +1,59 @@
+import multiprocessing
+from queue import Queue
+from typing import Union
+from unittest.mock import MagicMock
+
+import pytest
+
+from fundus.scraping.crawler.queueing import (
+ RemoteException,
+ iter_pool_results,
+)
+from fundus.utils.events import __EVENTS__, __MAIN_THREAD_ALIAS__
+
+
+def _never_ready_handle() -> MagicMock:
+ """Stand-in for a pool MapResult whose jobs never finish: get() always times out."""
+ handle = MagicMock()
+ handle.get.side_effect = multiprocessing.TimeoutError
+ return handle
+
+
+def _ready_handle() -> MagicMock:
+ """Stand-in for a finished pool MapResult: get() returns immediately."""
+ handle = MagicMock()
+ handle.get.return_value = None
+ return handle
+
+
+class TestPoolQueueIter:
+ def test_main_thread_stop_event_ends_iteration(self):
+ # Pool still running (handle never ready) and queue empty: setting the main-thread stop
+ # event must terminate the iterator instead of spinning forever waiting for results.
+ queue: Queue[Union[str, Exception]] = Queue()
+ with __EVENTS__.main_context(__MAIN_THREAD_ALIAS__):
+ __EVENTS__.set_event("stop", __MAIN_THREAD_ALIAS__)
+ assert list(iter_pool_results(_never_ready_handle(), queue)) == []
+
+ def test_stop_event_is_cleared_on_exit(self):
+ # Breaking on the stop event clears it, so a later crawl reusing the alias is not
+ # short-circuited by a stale flag.
+ queue: Queue[Union[str, Exception]] = Queue()
+ with __EVENTS__.main_context(__MAIN_THREAD_ALIAS__):
+ __EVENTS__.set_event("stop", __MAIN_THREAD_ALIAS__)
+ list(iter_pool_results(_never_ready_handle(), queue))
+ assert __EVENTS__.is_event_set("stop", __MAIN_THREAD_ALIAS__) is False
+
+ def test_drains_remaining_queue_when_pool_finished(self):
+ # Once the pool is done (handle ready), everything still buffered must be yielded.
+ queue: Queue[Union[str, Exception]] = Queue()
+ queue.put("a")
+ queue.put("b")
+ assert list(iter_pool_results(_ready_handle(), queue)) == ["a", "b"]
+
+ def test_raises_when_queue_yields_remote_exception(self):
+ # A RemoteException put on the queue by a worker is re-raised to the consumer.
+ queue: Queue[Exception] = Queue()
+ queue.put(RemoteException("boom"))
+ with pytest.raises(Exception, match="remote thread/process"):
+ list(iter_pool_results(_ready_handle(), queue))
diff --git a/tests/test_crawler.py b/tests/scraping/crawler/test_web.py
similarity index 51%
rename from tests/test_crawler.py
rename to tests/scraping/crawler/test_web.py
index 71fbb74bd..69c8befe4 100644
--- a/tests/test_crawler.py
+++ b/tests/scraping/crawler/test_web.py
@@ -1,8 +1,9 @@
import pytest
from fundus import Crawler, NewsMap, RSSFeed
-from fundus.publishers.base_objects import Publisher
-from fundus.scraping.html import WebSource
+from fundus.scraping.pipeline.source.web import WebSource
+from fundus.utils.events import __EVENTS__
+from tests.fixtures.builders import make_publisher
class TestPipeline:
@@ -60,32 +61,62 @@ def test_crawler_stores_impersonate_flag(self, group_with_valid_publisher_subgro
crawler = Crawler(group_with_valid_publisher_subgroup, impersonate=True)
assert crawler.impersonate is True
- def test_websource_disabled_drops_publisher_profile(self, parser_proxy_with_version):
- publisher = Publisher(
- name="impersonating",
- domain="https://test.com/",
- sources=[RSSFeed("https://test.com/feed")],
- parser=parser_proxy_with_version,
- impersonate="chrome",
- )
- source = WebSource(
- url_source=publisher.source_mapping[RSSFeed][0],
- publisher=publisher,
- impersonate=False,
- )
+ def test_websource_disabled_drops_publisher_profile(self):
+ publisher = make_publisher(impersonate="chrome")
+ source = WebSource(url_source=[], publisher=publisher, impersonate=False)
assert source._impersonate_profile is None
- def test_websource_enabled_uses_publisher_profile(self, parser_proxy_with_version):
- publisher = Publisher(
- name="impersonating",
- domain="https://test.com/",
- sources=[RSSFeed("https://test.com/feed")],
- parser=parser_proxy_with_version,
- impersonate="chrome",
- )
- source = WebSource(
- url_source=publisher.source_mapping[RSSFeed][0],
- publisher=publisher,
- impersonate=True,
- )
+ def test_websource_enabled_uses_publisher_profile(self):
+ publisher = make_publisher(impersonate="chrome")
+ source = WebSource(url_source=[], publisher=publisher, impersonate=True)
assert source._impersonate_profile == publisher.impersonate
+
+
+class TestCrawlerResolveDelay:
+ def test_none_returns_none(self):
+ assert Crawler._resolve_delay(None) is None
+
+ def test_float_returns_constant_callable(self):
+ delay = Crawler._resolve_delay(1.5)
+ assert callable(delay)
+ assert delay() == 1.5
+
+ def test_int_returns_constant_callable(self):
+ delay = Crawler._resolve_delay(2)
+ assert callable(delay)
+ assert delay() == 2
+
+ def test_callable_returned_as_is(self):
+ def fn() -> float:
+ return 0.5
+
+ assert Crawler._resolve_delay(fn) is fn
+
+ def test_invalid_type_raises(self):
+ with pytest.raises(TypeError):
+ Crawler._resolve_delay("1.0") # type: ignore[arg-type]
+
+
+class TestCrawlerBuildPipelines:
+ @pytest.fixture(autouse=True)
+ def _main_context(self):
+ # _build_pipelines constructs WebSource, which looks up __EVENTS__.get("stop")
+ # at construction time. Production always calls this inside main_context
+ # (threading mode adds a publisher context on top); mirror that here.
+ with __EVENTS__.main_context("test"):
+ yield
+
+ def test_returns_one_pipeline_per_source(self, publisher_group_with_news_map):
+ crawler = Crawler(publisher_group_with_news_map, ignore_robots=True)
+ pipelines = crawler._build_pipelines(crawler.publishers[0])
+ assert len(pipelines) == 1
+
+ def test_restrict_sources_excludes_non_matching_type(self, publisher_group_with_rss_feeds):
+ crawler = Crawler(publisher_group_with_rss_feeds, ignore_robots=True, restrict_sources_to=[NewsMap])
+ pipelines = crawler._build_pipelines(crawler.publishers[0])
+ assert pipelines == []
+
+ def test_no_restriction_includes_all_sources(self, publisher_group_with_rss_feeds):
+ crawler = Crawler(publisher_group_with_rss_feeds, ignore_robots=True)
+ pipelines = crawler._build_pipelines(crawler.publishers[0])
+ assert len(pipelines) == 1
diff --git a/tests/scraping/pipeline/__init__.py b/tests/scraping/pipeline/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/scraping/pipeline/source/__init__.py b/tests/scraping/pipeline/source/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/scraping/pipeline/source/test_ccnews.py b/tests/scraping/pipeline/source/test_ccnews.py
new file mode 100644
index 000000000..bb2ece727
--- /dev/null
+++ b/tests/scraping/pipeline/source/test_ccnews.py
@@ -0,0 +1,305 @@
+from __future__ import annotations
+
+from datetime import datetime
+from typing import Any, Dict, Iterable, List, Optional
+from unittest.mock import MagicMock, patch
+
+import pytest
+import requests
+import urllib3.exceptions
+from fastwarc.stream_io import StreamError
+
+from fundus.scraping.html import HTML
+from fundus.scraping.pipeline.source.ccnews import CCNewsSource, WarcFileLoadError, WarcSourceInfo
+from tests.fixtures.builders import stub_publisher
+
+# ---- helpers ---------------------------------------------------------------
+
+
+def make_warc_record(
+ target_url: str = "https://example.com/article",
+ body: bytes = b"hi",
+ http_charset: Optional[str] = "utf-8",
+ record_date: Optional[datetime] = datetime(2024, 1, 1),
+ record_id: str = "",
+ http_headers: Optional[Dict[str, str]] = None,
+ warc_headers: Optional[Dict[str, str]] = None,
+) -> MagicMock:
+ """Mock the fastwarc WarcRecord surface CCNewsSource reads."""
+ record = MagicMock()
+ record.reader.read.return_value = body
+ record.http_charset = http_charset
+ record.record_date = record_date
+ record.record_id = record_id
+ record.http_headers = http_headers or {"Content-Type": "text/html"}
+ headers = {"WARC-Target-URI": target_url}
+ if warc_headers:
+ headers.update(warc_headers)
+ record.headers = headers
+ return record
+
+
+def patch_archive_iterator(records: Iterable[Any]):
+ """Replace ArchiveIterator with a callable that returns the given records."""
+ return patch(
+ "fundus.scraping.pipeline.source.ccnews.ArchiveIterator",
+ return_value=iter(records),
+ )
+
+
+# ---- CCNewsSource.__init__ -------------------------------------------------
+
+
+class TestCCNewsSourceConstruction:
+ def test_stores_warc_path_and_publishers(self):
+ publisher = stub_publisher()
+ source = CCNewsSource(publisher, warc_path="https://commoncrawl.org/a.warc.gz")
+ assert source.warc_path == "https://commoncrawl.org/a.warc.gz"
+ assert source.publishers == (publisher,)
+
+ def test_default_headers_when_none(self):
+ from fundus.scraping.session import _default_header
+
+ source = CCNewsSource(stub_publisher(), warc_path="x")
+ assert source.headers == _default_header
+
+ def test_custom_headers_override_default(self):
+ headers = {"user-agent": "custom-agent"}
+ source = CCNewsSource(stub_publisher(), warc_path="x", headers=headers)
+ assert source.headers == headers
+
+ def test_publisher_mapping_keyed_by_netloc(self):
+ pub_a = stub_publisher(name="a", domain="https://a.example.com/")
+ pub_b = stub_publisher(name="b", domain="https://b.example.com/path")
+ source = CCNewsSource(pub_a, pub_b, warc_path="x")
+ assert source._publisher_mapping == {"a.example.com": pub_a, "b.example.com": pub_b}
+
+ def test_empty_publishers_yields_empty_mapping(self):
+ source = CCNewsSource(warc_path="x")
+ assert source._publisher_mapping == {}
+
+
+# ---- CCNewsSource._open_stream --------------------------------------------
+
+
+class TestOpenStream:
+ def test_returns_response_on_success(self):
+ source = CCNewsSource(stub_publisher(), warc_path="https://host/a.warc.gz")
+ response = MagicMock()
+ response.raise_for_status.return_value = None
+ with patch("requests.Session.get", return_value=response) as mock_get:
+ assert source._open_stream() is response
+ mock_get.assert_called_once_with("https://host/a.warc.gz", stream=True, headers=source.headers)
+
+ def test_wraps_http_error_as_warc_file_load_error(self):
+ source = CCNewsSource(stub_publisher(), warc_path="x")
+ response = MagicMock()
+ response.raise_for_status.side_effect = requests.HTTPError("404")
+ with patch("requests.Session.get", return_value=response):
+ with pytest.raises(WarcFileLoadError, match="404"):
+ source._open_stream()
+
+ def test_wraps_urllib3_error_as_warc_file_load_error(self):
+ source = CCNewsSource(stub_publisher(), warc_path="x")
+ with patch(
+ "requests.Session.get",
+ side_effect=urllib3.exceptions.ProtocolError("conn reset"),
+ ):
+ with pytest.raises(WarcFileLoadError, match="conn reset"):
+ source._open_stream()
+
+
+# ---- CCNewsSource._extract_content ----------------------------------------
+
+
+class TestExtractContent:
+ def test_decodes_with_declared_charset(self):
+ record = make_warc_record(body="héllo".encode("utf-8"), http_charset="utf-8")
+ assert CCNewsSource._extract_content(record, "https://x/a") == "héllo"
+
+ def test_falls_back_to_chardet_when_declared_charset_fails(self):
+ body = "héllo".encode("latin-1")
+ record = make_warc_record(body=body, http_charset="utf-8")
+ with patch(
+ "fundus.scraping.pipeline.source.ccnews.chardet.detect",
+ return_value={"encoding": "latin-1"},
+ ):
+ assert CCNewsSource._extract_content(record, "https://x/a") == "héllo"
+
+ def test_falls_back_to_chardet_when_charset_missing(self):
+ body = "ok".encode("utf-8")
+ record = make_warc_record(body=body, http_charset=None)
+ with patch(
+ "fundus.scraping.pipeline.source.ccnews.chardet.detect",
+ return_value={"encoding": "utf-8"},
+ ):
+ assert CCNewsSource._extract_content(record, "https://x/a") == "ok"
+
+ def test_returns_none_when_chardet_detects_nothing(self):
+ record = make_warc_record(body=b"\xff\xfe", http_charset=None)
+ with patch(
+ "fundus.scraping.pipeline.source.ccnews.chardet.detect",
+ return_value={"encoding": None},
+ ):
+ assert CCNewsSource._extract_content(record, "https://x/a") is None
+
+ def test_returns_none_when_chardet_encoding_still_fails(self):
+ body = b"\xff\xfe\xfa"
+ record = make_warc_record(body=body, http_charset="utf-8")
+ with patch(
+ "fundus.scraping.pipeline.source.ccnews.chardet.detect",
+ return_value={"encoding": "ascii"},
+ ):
+ assert CCNewsSource._extract_content(record, "https://x/a") is None
+
+
+# ---- CCNewsSource._validate -----------------------------------------------
+
+
+class TestValidate:
+ def test_returns_publisher_on_happy_path(self):
+ publisher = stub_publisher(domain="https://example.com/")
+ source = CCNewsSource(publisher, warc_path="x")
+ assert source._validate("https://example.com/article", None) is publisher
+
+ def test_returns_none_when_url_filter_blocks(self):
+ publisher = stub_publisher(domain="https://example.com/")
+ source = CCNewsSource(publisher, warc_path="x")
+ assert source._validate("https://example.com/skip", lambda u: "skip" in u) is None
+
+ def test_returns_none_for_unknown_publisher_netloc(self):
+ publisher = stub_publisher(domain="https://known.example.com/")
+ source = CCNewsSource(publisher, warc_path="x")
+ assert source._validate("https://unknown.example.com/article", None) is None
+
+ def test_returns_none_when_publisher_url_filter_blocks(self):
+ publisher = stub_publisher(
+ domain="https://example.com/",
+ url_filter=lambda u: "drop" in u,
+ )
+ source = CCNewsSource(publisher, warc_path="x")
+ assert source._validate("https://example.com/drop-this", None) is None
+
+ def test_publisher_url_filter_does_not_block_when_publisher_filter_is_none(self):
+ publisher = stub_publisher(domain="https://example.com/", url_filter=None)
+ source = CCNewsSource(publisher, warc_path="x")
+ assert source._validate("https://example.com/any", None) is publisher
+
+ def test_url_filter_none_passes_through(self):
+ publisher = stub_publisher(domain="https://example.com/")
+ source = CCNewsSource(publisher, warc_path="x")
+ assert source._validate("https://example.com/article", None) is publisher
+
+
+# ---- CCNewsSource._record_to_html -----------------------------------------
+
+
+class TestRecordToHtml:
+ def test_returns_html_on_happy_path(self):
+ publisher = stub_publisher(name="p", domain="https://example.com/")
+ source = CCNewsSource(publisher, warc_path="https://host/a.warc.gz")
+ record = make_warc_record(target_url="https://example.com/article", body=b"hi")
+ html = source._record_to_html(record, url_filter=None)
+ assert isinstance(html, HTML)
+ assert html.requested_url == "https://example.com/article"
+ assert html.responded_url == "https://example.com/article"
+ assert html.content == "hi"
+ assert html.crawl_date == datetime(2024, 1, 1)
+
+ def test_html_carries_warc_source_info(self):
+ publisher = stub_publisher(name="p", domain="https://example.com/")
+ source = CCNewsSource(publisher, warc_path="https://host/a.warc.gz")
+ record = make_warc_record(
+ target_url="https://example.com/article",
+ http_headers={"Content-Type": "text/html"},
+ )
+ html = source._record_to_html(record, url_filter=None)
+ assert html is not None
+ info = html.source_info
+ assert isinstance(info, WarcSourceInfo)
+ assert info.publisher == publisher.name
+ assert info.warc_path == "https://host/a.warc.gz"
+ assert info.warc_headers == {"WARC-Target-URI": "https://example.com/article"}
+ assert info.http_headers == {"Content-Type": "text/html"}
+
+ def test_returns_none_when_record_date_is_none(self):
+ publisher = stub_publisher(domain="https://example.com/")
+ source = CCNewsSource(publisher, warc_path="x")
+ record = make_warc_record(target_url="https://example.com/article", record_date=None)
+ assert source._record_to_html(record, url_filter=None) is None
+
+ def test_returns_none_when_validate_rejects(self):
+ # Validate paths are covered exhaustively in TestValidate; here we only check that
+ # _record_to_html threads the rejection through to a None return.
+ publisher = stub_publisher(domain="https://known.example.com/")
+ source = CCNewsSource(publisher, warc_path="x")
+ record = make_warc_record(target_url="https://unknown.example.com/article")
+ assert source._record_to_html(record, url_filter=None) is None
+
+ def test_returns_none_when_content_cannot_be_decoded(self):
+ publisher = stub_publisher(domain="https://example.com/")
+ source = CCNewsSource(publisher, warc_path="x")
+ record = make_warc_record(target_url="https://example.com/a")
+ with patch.object(CCNewsSource, "_extract_content", return_value=None):
+ assert source._record_to_html(record, url_filter=None) is None
+
+
+# ---- CCNewsSource._iter_warc_records --------------------------------------
+
+
+class TestIterWarcRecords:
+ def test_yields_records_from_archive_iterator(self):
+ records = [make_warc_record(target_url=f"https://example.com/{i}") for i in range(3)]
+ with patch_archive_iterator(records):
+ result = list(CCNewsSource._iter_warc_records(MagicMock()))
+ assert result == records
+
+ def test_wraps_stream_error_as_warc_file_load_error(self):
+ with patch(
+ "fundus.scraping.pipeline.source.ccnews.ArchiveIterator",
+ side_effect=StreamError("corrupt"),
+ ):
+ with pytest.raises(WarcFileLoadError, match="corrupt"):
+ list(CCNewsSource._iter_warc_records(MagicMock()))
+
+
+# ---- CCNewsSource.fetch ---------------------------------------------------
+
+
+class TestFetch:
+ def test_pipes_open_stream_into_iter_warc_records(self):
+ publisher = stub_publisher(domain="https://example.com/")
+ source = CCNewsSource(publisher, warc_path="https://host/a.warc.gz")
+ record = make_warc_record(target_url="https://example.com/article")
+ sentinel_response = MagicMock()
+ with patch.object(source, "_open_stream", return_value=sentinel_response) as mock_open, patch_archive_iterator(
+ [record]
+ ) as mock_iter:
+ results = list(source.fetch())
+ assert len(results) == 1
+ mock_open.assert_called_once_with()
+ mock_iter.assert_called_once()
+ # ArchiveIterator should be called with the raw stream from _open_stream's response
+ assert mock_iter.call_args[0][0] is sentinel_response.raw
+
+ def test_passes_url_filter_through(self):
+ publisher = stub_publisher(domain="https://example.com/")
+ source = CCNewsSource(publisher, warc_path="x")
+ record = make_warc_record(target_url="https://example.com/skip-me")
+ with patch.object(source, "_open_stream", return_value=MagicMock()), patch_archive_iterator([record]):
+ assert list(source.fetch(url_filter=lambda u: "skip" in u)) == []
+
+ def test_processes_multiple_records_independently(self):
+ publisher = stub_publisher(domain="https://example.com/")
+ source = CCNewsSource(publisher, warc_path="x")
+ records: List[Any] = [
+ make_warc_record(target_url="https://example.com/a"),
+ make_warc_record(target_url="https://other.com/b"), # filtered out by netloc
+ make_warc_record(target_url="https://example.com/c"),
+ ]
+ with patch.object(source, "_open_stream", return_value=MagicMock()), patch_archive_iterator(records):
+ results = list(source.fetch())
+ assert [html.requested_url for html in results] == [
+ "https://example.com/a",
+ "https://example.com/c",
+ ]
diff --git a/tests/scraping/pipeline/source/test_web.py b/tests/scraping/pipeline/source/test_web.py
new file mode 100644
index 000000000..ad8045665
--- /dev/null
+++ b/tests/scraping/pipeline/source/test_web.py
@@ -0,0 +1,386 @@
+import threading
+from typing import List
+from unittest.mock import MagicMock, patch
+
+import pytest
+from curl_cffi.requests.exceptions import ConnectionError, HTTPError, ReadTimeout
+
+from fundus.scraping.html import HTML, SourceInfo
+from fundus.scraping.pipeline.source.web import WebSource, WebSourceInfo, _Pacer
+from fundus.scraping.session import session_handler
+from fundus.scraping.url import RSSFeed
+from tests.fixtures.builders import make_http_error, mock_response, mock_robots, stub_publisher
+
+
+class _RecordingIterable:
+ """Iterable of URLs that counts how many times it is advanced (i.e. how many URLs were pulled)."""
+
+ def __init__(self, urls: List[str]) -> None:
+ self._it = iter(urls)
+ self.pulled = 0
+
+ def __iter__(self) -> "_RecordingIterable":
+ return self
+
+ def __next__(self) -> str:
+ self.pulled += 1
+ return next(self._it)
+
+
+@pytest.fixture
+def source(publisher, patched_web_session_handler):
+ """A WebSource with robots disabled and pacer stubbed to allow calls through immediately."""
+ s = WebSource(url_source=[], publisher=publisher, ignore_robots=True)
+ s.pacer = MagicMock(return_value=True)
+ return s
+
+
+# ---- _Pacer ----------------------------------------------------------------
+
+
+class TestPacer:
+ def test_no_delay_never_sleeps(self):
+ sleeps: List[float] = []
+ pacer = _Pacer(delay=None, sleep=sleeps.append)
+ pacer()
+ pacer()
+ assert sleeps == []
+
+ def test_warm_start_first_call_does_not_sleep(self):
+ sleeps: List[float] = []
+ pacer = _Pacer(delay=lambda: 5.0, sleep=sleeps.append, warm_start=True)
+ pacer()
+ assert sleeps == []
+
+ def test_without_warm_start_first_call_sleeps_for_delay(self):
+ sleeps: List[float] = []
+ pacer = _Pacer(delay=lambda: 0.5, sleep=sleeps.append, warm_start=False)
+ pacer()
+ assert len(sleeps) == 1
+ assert 0.4 <= sleeps[0] <= 0.5
+
+ def test_reset_makes_next_call_sleep(self):
+ sleeps: List[float] = []
+ pacer = _Pacer(delay=lambda: 1.0, sleep=sleeps.append, warm_start=True)
+ pacer()
+ pacer.reset()
+ pacer()
+ assert len(sleeps) == 1
+
+
+# ---- WebSource.__init__ ----------------------------------------------------
+
+
+class TestWebSourceConstruction:
+ def test_ignore_robots_leaves_robots_none(self, publisher):
+ source = WebSource(url_source=[], publisher=publisher, ignore_robots=True)
+ assert source.robots is None
+
+ def test_uses_publisher_robots_by_default(self):
+ robots = mock_robots(crawl_delay=None)
+ publisher = stub_publisher(robots=robots)
+ source = WebSource(url_source=[], publisher=publisher)
+ assert source.robots is robots
+
+ def test_crawl_delay_resolution_deferred_to_first_request(self, patched_web_session_handler):
+ # Resolving the crawl-delay may read robots.txt, so the pacer is deferred out of
+ # construction: robots is untouched at build time and consulted on the first request.
+ patched_web_session_handler.get_with_interrupt.return_value = mock_response()
+ robots = mock_robots(can_fetch=True, crawl_delay=None)
+ source = WebSource(url_source=[], publisher=stub_publisher(robots=robots))
+ robots.crawl_delay.assert_not_called()
+ source._fetch_one("https://example.com/article", lambda u: False)
+ robots.crawl_delay.assert_called_once()
+
+
+class TestApplyQueryParameters:
+ def test_no_params_returns_url_unchanged(self):
+ assert WebSource._apply_query_parameters("https://example.com/a", {}) == "https://example.com/a"
+
+ def test_appends_to_url_without_query(self):
+ assert (
+ WebSource._apply_query_parameters("https://example.com/a", {"foo": "bar"})
+ == "https://example.com/a?foo=bar"
+ )
+
+ def test_preserves_existing_query(self):
+ assert (
+ WebSource._apply_query_parameters("https://example.com/a?x=1", {"foo": "bar"})
+ == "https://example.com/a?x=1&foo=bar"
+ )
+
+ def test_url_encodes_special_characters(self):
+ result = WebSource._apply_query_parameters("https://example.com/a", {"q": "hello world&fish"})
+ assert result == "https://example.com/a?q=hello+world%26fish"
+
+ def test_url_encodes_unicode(self):
+ result = WebSource._apply_query_parameters("https://example.com/a", {"q": "café"})
+ assert result == "https://example.com/a?q=caf%C3%A9"
+
+
+def _supplied_delay() -> float:
+ return 1.0
+
+
+class TestResolveDelay:
+ def test_no_robots_returns_supplied(self):
+ assert WebSource._resolve_delay(None, "*", _supplied_delay, ignore_crawl_delay=False) is _supplied_delay
+
+ def test_ignore_crawl_delay_returns_supplied(self):
+ robots = mock_robots()
+ assert WebSource._resolve_delay(robots, "*", _supplied_delay, ignore_crawl_delay=True) is _supplied_delay
+
+ def test_no_robots_delay_returns_supplied(self):
+ robots = mock_robots(crawl_delay=None)
+ assert WebSource._resolve_delay(robots, "*", _supplied_delay, ignore_crawl_delay=False) is _supplied_delay
+
+ def test_robots_delay_overrides_supplied(self):
+ robots = mock_robots(crawl_delay=5.0)
+ resolved = WebSource._resolve_delay(robots, "*", lambda: 1.0, ignore_crawl_delay=False)
+ assert resolved is not None
+ assert resolved() == 5.0
+
+ def test_robots_delay_returned_when_no_supplied(self):
+ robots = mock_robots(crawl_delay=3.0)
+ resolved = WebSource._resolve_delay(robots, "*", None, ignore_crawl_delay=False)
+ assert resolved is not None
+ assert resolved() == 3.0
+
+
+# ---- WebSource._fetch_one -------------------------------------------------
+
+
+class TestFetchOne:
+ def test_returns_none_for_invalid_url(self, source):
+ assert source._fetch_one("not-a-url", lambda u: False) is None
+
+ def test_returns_none_when_url_filter_matches(self, source):
+ assert source._fetch_one("https://example.com/", lambda u: True) is None
+
+ def test_returns_none_when_robots_disallows(self, patched_web_session_handler):
+ robots = mock_robots(can_fetch=False, crawl_delay=None)
+ publisher = stub_publisher(robots=robots)
+ source = WebSource(url_source=[], publisher=publisher)
+ source.pacer = MagicMock(return_value=True)
+ assert source._fetch_one("https://example.com/", lambda u: False) is None
+
+ def test_returns_none_on_http_error(self, source, patched_web_session_handler):
+ patched_web_session_handler.get_with_interrupt.side_effect = HTTPError("boom")
+ assert source._fetch_one("https://example.com/", lambda u: False) is None
+
+ def test_returns_none_on_connection_error(self, source, patched_web_session_handler):
+ patched_web_session_handler.get_with_interrupt.side_effect = ConnectionError("boom")
+ assert source._fetch_one("https://example.com/", lambda u: False) is None
+
+ def test_returns_none_on_read_timeout(self, source, patched_web_session_handler):
+ patched_web_session_handler.get_with_interrupt.side_effect = ReadTimeout("boom")
+ assert source._fetch_one("https://example.com/", lambda u: False) is None
+
+ @pytest.mark.integration
+ @pytest.mark.xfail(
+ reason="_request catches ReadTimeout, not the base Timeout curl_cffi raises on a real timeout, "
+ "so the timeout propagates out of _fetch_one instead of being swallowed. Fixed by flairNLP/fundus#939.",
+ strict=True,
+ )
+ def test_returns_none_on_real_timeout(self, hanging_url):
+ source = WebSource(url_source=[], publisher=stub_publisher(), ignore_robots=True)
+ source.pacer = MagicMock(return_value=True)
+ with session_handler.context(timeout=0.3):
+ assert source._fetch_one(hanging_url, lambda u: False) is None
+
+ def test_returns_none_on_5xx_status_code(self, source, patched_web_session_handler):
+ patched_web_session_handler.get_with_interrupt.side_effect = make_http_error(status_code=503)
+ assert source._fetch_one("https://example.com/", lambda u: False) is None
+
+ def test_successful_fetch_returns_html(self, source, patched_web_session_handler):
+ patched_web_session_handler.get_with_interrupt.return_value = mock_response(
+ text="body", url="https://example.com/article"
+ )
+ result = source._fetch_one("https://example.com/article", lambda u: False)
+ assert isinstance(result, HTML)
+ assert result.requested_url == "https://example.com/article"
+ assert result.responded_url == "https://example.com/article"
+ assert result.content == "body"
+
+ def test_appends_query_parameters_without_existing_query(self, publisher, patched_web_session_handler):
+ source = WebSource(
+ url_source=[],
+ publisher=publisher,
+ ignore_robots=True,
+ query_parameters={"foo": "bar"},
+ )
+ source.pacer = MagicMock(return_value=True)
+ source._fetch_one("https://example.com/article", lambda u: False)
+ called_url = patched_web_session_handler.get_with_interrupt.call_args[0][0]
+ assert called_url == "https://example.com/article?foo=bar"
+
+ def test_appends_query_parameters_with_existing_query(self, publisher, patched_web_session_handler):
+ source = WebSource(
+ url_source=[],
+ publisher=publisher,
+ ignore_robots=True,
+ query_parameters={"foo": "bar"},
+ )
+ source.pacer = MagicMock(return_value=True)
+ source._fetch_one("https://example.com/article?x=1", lambda u: False)
+ called_url = patched_web_session_handler.get_with_interrupt.call_args[0][0]
+ assert called_url == "https://example.com/article?x=1&foo=bar"
+
+ def test_url_filter_applied_to_responded_url(self, source, patched_web_session_handler):
+ patched_web_session_handler.get_with_interrupt.return_value = mock_response(
+ url="https://redirected.example.com/"
+ )
+ assert source._fetch_one("https://example.com/article", lambda u: "redirected" in u) is None
+
+ def test_web_source_info_when_url_source_is_url_source(self, publisher, patched_web_session_handler):
+ feed = RSSFeed(url="https://example.com/feed.xml")
+ source = WebSource(url_source=feed, publisher=publisher, ignore_robots=True)
+ source.pacer = MagicMock(return_value=True)
+ result = source._fetch_one("https://example.com/article", lambda u: False)
+ assert isinstance(result, HTML)
+ assert isinstance(result.source_info, WebSourceInfo)
+ assert result.source_info.type == "RSSFeed"
+ assert result.source_info.url == "https://example.com/feed.xml"
+
+ def test_plain_source_info_when_url_source_is_iterable(self, source, patched_web_session_handler):
+ result = source._fetch_one("https://example.com/article", lambda u: False)
+ assert isinstance(result, HTML)
+ assert type(result.source_info) is SourceInfo # not the Web subclass
+
+
+# ---- WebSource._build_url_filter -------------------------------------------
+
+
+class TestBuildUrlFilter:
+ def test_no_filters_returns_pass_through(self, publisher):
+ source = WebSource(url_source=[], publisher=publisher, ignore_robots=True)
+ combined = source._build_url_filter(None)
+ assert combined("https://example.com/") is False
+
+ def test_instance_filter_only(self, publisher):
+ source = WebSource(
+ url_source=[],
+ publisher=publisher,
+ url_filter=lambda u: "blocked" in u,
+ ignore_robots=True,
+ )
+ combined = source._build_url_filter(None)
+ assert combined("https://example.com/blocked") is True
+ assert combined("https://example.com/ok") is False
+
+ def test_per_call_filter_only(self, publisher):
+ source = WebSource(url_source=[], publisher=publisher, ignore_robots=True)
+ combined = source._build_url_filter(lambda u: "blocked" in u)
+ assert combined("https://example.com/blocked") is True
+ assert combined("https://example.com/ok") is False
+
+ def test_combined_filters_via_any(self, publisher):
+ source = WebSource(
+ url_source=[],
+ publisher=publisher,
+ url_filter=lambda u: "first" in u,
+ ignore_robots=True,
+ )
+ combined = source._build_url_filter(lambda u: "second" in u)
+ assert combined("https://host/first") is True
+ assert combined("https://host/second") is True
+ assert combined("https://host/third") is False
+
+
+# ---- WebSource.fetch -------------------------------------------------------
+
+
+class TestFetch:
+ def test_iterates_plain_iterable_and_yields_html(self, publisher, patched_web_session_handler):
+ source = WebSource(
+ url_source=["https://example.com/a", "https://example.com/b"],
+ publisher=publisher,
+ ignore_robots=True,
+ )
+ source.pacer = MagicMock(return_value=True)
+ results = list(source.fetch())
+ assert len(results) == 2
+ assert all(isinstance(r, HTML) for r in results)
+
+ def test_iterates_url_source_passing_session_and_headers(self, publisher, patched_web_session_handler):
+ url_source = MagicMock(spec=RSSFeed)
+ url_source.fetch.return_value = iter(["https://example.com/a"])
+ url_source.url = "https://example.com/feed.xml"
+ source = WebSource(url_source=url_source, publisher=publisher, ignore_robots=True)
+ source.pacer = MagicMock(return_value=True)
+ list(source.fetch())
+ url_source.fetch.assert_called_once()
+ assert url_source.fetch.call_args[0][1] == publisher.request_header
+
+ def test_stop_event_set_mid_stream_halts_before_pulling_next_url(self, publisher, patched_web_session_handler):
+ # Setting the stop event after the first article must halt fetch at the next loop check,
+ # before the second URL is pulled — so no further feed/page is requested. The recording
+ # iterable proves the next URL was never pulled (pulled stays 1).
+ stop_event = threading.Event()
+ urls = _RecordingIterable(["https://example.com/a", "https://example.com/b"])
+ source = WebSource(
+ url_source=urls,
+ publisher=publisher,
+ ignore_robots=True,
+ stop_event=stop_event,
+ )
+ source.pacer = MagicMock(return_value=True)
+
+ gen = source.fetch()
+ next(gen) # pull and fetch the first URL's article
+ stop_event.set() # stop arrives between articles
+ list(gen) # fully drive the generator after the stop
+
+ assert urls.pulled == 1
+
+ def test_set_stop_event_does_not_pull_url_iterator(self, publisher, patched_web_session_handler):
+ # A stop event set at a source boundary must short-circuit BEFORE the URL iterator is
+ # advanced, so the feed/sitemap is never downloaded. "No HTML yielded" is not enough to
+ # prove this (a plain list advances for free), so we assert the URLSource is never
+ # fetched. Regression: fetch pulled the first URL before checking the stop event.
+ stop_event = threading.Event()
+ stop_event.set()
+ url_source = MagicMock(spec=RSSFeed)
+ url_source.fetch.return_value = iter(["https://example.com/a"])
+ url_source.url = "https://example.com/feed.xml"
+ source = WebSource(
+ url_source=url_source,
+ publisher=publisher,
+ ignore_robots=True,
+ stop_event=stop_event,
+ )
+ source.pacer = MagicMock(return_value=True)
+ assert list(source.fetch()) == []
+ url_source.fetch.assert_not_called()
+
+ def test_url_iterator_crash_terminates_fetch(self, publisher, patched_web_session_handler, caplog):
+ def crashing_iter():
+ yield "https://example.com/a"
+ raise RuntimeError("boom")
+
+ source = WebSource(url_source=crashing_iter(), publisher=publisher, ignore_robots=True)
+ source.pacer = MagicMock(return_value=True)
+ with caplog.at_level("ERROR"):
+ results = list(source.fetch())
+ assert len(results) == 1
+ assert any("crashed" in record.message for record in caplog.records)
+
+ def test_fetch_one_crash_continues_to_next_url(self, publisher, patched_web_session_handler, caplog):
+ source = WebSource(
+ url_source=["https://example.com/a", "https://example.com/b"],
+ publisher=publisher,
+ ignore_robots=True,
+ )
+ source.pacer = MagicMock(return_value=True)
+ sentinel_html = HTML(
+ requested_url="https://example.com/b",
+ responded_url="https://example.com/b",
+ content="",
+ crawl_date=__import__("datetime").datetime(2024, 1, 1),
+ source_info=SourceInfo(publisher="x"),
+ )
+ with patch.object(source, "_fetch_one", side_effect=[RuntimeError("boom"), sentinel_html]):
+ with caplog.at_level("ERROR"):
+ results = list(source.fetch())
+ assert results == [sentinel_html]
+ assert any("unexpected error" in record.message for record in caplog.records)
diff --git a/tests/scraping/test_article.py b/tests/scraping/test_article.py
new file mode 100644
index 000000000..08fb24cf1
--- /dev/null
+++ b/tests/scraping/test_article.py
@@ -0,0 +1,206 @@
+import pickle
+from unittest.mock import patch
+
+import langdetect
+import pytest
+
+from fundus import Article
+from tests.fixtures.builders import make_html
+
+html = make_html()
+
+
+class _StubBody:
+ """Minimal stand-in for ArticleBody so we can control str(body) without importing the parser."""
+
+ def __init__(self, text: str) -> None:
+ self._text = text
+
+ def __str__(self) -> str:
+ return self._text
+
+
+class TestConstructor:
+ def test_rejects_positional_extraction(self):
+ with pytest.raises(TypeError):
+ Article({"title": "t"}, html=html) # type: ignore[arg-type, misc]
+
+ def test_requires_html_keyword(self):
+ with pytest.raises(TypeError):
+ Article(title="t") # type: ignore[call-arg]
+
+ def test_accepts_empty_extraction(self):
+ Article(html=html)
+
+ def test_accepts_extraction_kwargs(self):
+ Article(html=html, title="t", authors=["A"])
+
+
+class TestProperties:
+ def test_defaults_when_extraction_is_empty(self):
+ article = Article(html=html)
+ assert article.title is None
+ assert article.body is None
+ assert article.authors == []
+ assert article.publishing_date is None
+ assert article.topics == []
+ assert article.free_access is False
+ assert article.images == []
+
+ def test_returns_values_from_extraction(self):
+ article = Article(html=html, title="", authors=["A", "B", "C"])
+ assert article.title == ""
+ assert article.authors == ["A", "B", "C"]
+
+ def test_publisher_comes_from_html_source_info(self):
+ article = Article(html=make_html(publisher="example.com"))
+ assert article.publisher == "example.com"
+
+
+class TestExtractionView:
+ """Arbitrary extraction kwargs are exposed as read-only attributes via AttributeView."""
+
+ def test_read_returns_extraction_value(self):
+ article = Article(html=html, custom="value")
+ assert article.custom == "value"
+
+ def test_read_reflects_extraction_mutation(self):
+ article = Article(html=html, custom="value")
+ article.__extraction__["custom"] = "mutated" # type: ignore[index]
+ assert article.custom == "mutated"
+
+ def test_write_raises_attribute_error(self):
+ article = Article(html=html, custom="value")
+ with pytest.raises(AttributeError):
+ article.custom = "new"
+
+
+class TestPickleProtocol:
+ """Article must survive the pickle protocol used by multiprocessing.Queue.
+
+ Production CCNewsCrawler workers return articles to the main process through a
+ multiprocessing.Queue, which serializes payloads with pickle. Pickling probes
+ ``hasattr(obj, "__setstate__")`` during unpickling before ``__extraction__`` is
+ restored — ``Article.__getattr__`` must short-circuit on the missing dict instead of
+ recursing infinitely.
+ """
+
+ def test_getattr_does_not_recurse_during_unpickle(self):
+ article = Article(html=html, title="t", custom="value")
+ restored = pickle.loads(pickle.dumps(article))
+ assert restored.title == "t"
+ assert restored.custom == "value"
+
+ def test_article_survives_pickle(self):
+ """An Article must round-trip through pickle — it is what CCNews returns across the
+ multiprocessing.Queue. Guards SourceInfo carrying lightweight publisher identity (the name)
+ rather than a live Publisher, so the Article never drags unpicklable parser/filter state.
+ """
+ article = Article(html=make_html(publisher="DerStandard"), title="t")
+ restored = pickle.loads(pickle.dumps(article))
+ assert restored.publisher == "DerStandard"
+ assert restored.title == "t"
+
+
+class TestPlaintext:
+ def test_returns_str_of_body(self):
+ article = Article(html=html, body=_StubBody("Article text."))
+ assert article.plaintext == "Article text."
+
+ def test_returns_none_when_str_body_is_empty(self):
+ article = Article(html=html, body=_StubBody(""))
+ assert article.plaintext is None
+
+ def test_returns_none_when_body_is_exception(self):
+ article = Article(html=html, body=ValueError("parse failed"))
+ assert article.plaintext is None
+
+ def test_returns_none_when_body_is_none(self):
+ article = Article(html=html)
+ assert article.plaintext is None
+
+
+class TestLang:
+ def test_detects_language_from_plaintext(self):
+ article = Article(html=html, body=_StubBody("Some article text"))
+ with patch("fundus.scraping.article.langdetect.detect", return_value="en"):
+ assert article.lang == "en"
+
+ def test_falls_back_to_html_lang_on_detect_exception(self):
+ article = Article(
+ html=make_html(content=' x'),
+ body=_StubBody("text"),
+ )
+ with patch(
+ "fundus.scraping.article.langdetect.detect",
+ side_effect=langdetect.LangDetectException(0, "fail"),
+ ):
+ assert article.lang == "de"
+
+ def test_falls_back_to_html_lang_when_detector_returns_unknown(self):
+ unknown = langdetect.detector_factory.Detector.UNKNOWN_LANG
+ article = Article(
+ html=make_html(content='x'),
+ body=_StubBody("text"),
+ )
+ with patch("fundus.scraping.article.langdetect.detect", return_value=unknown):
+ assert article.lang == "fr"
+
+ def test_strips_region_suffix_from_html_lang(self):
+ # no body → plaintext is None → detection is skipped, falling to html lang
+ article = Article(html=make_html(content='x'))
+ assert article.lang == "en"
+
+ def test_returns_none_when_no_plaintext_and_no_html_lang(self):
+ article = Article(html=make_html(content="x"))
+ assert article.lang is None
+
+
+class TestToJson:
+ def test_default_uses_default_export_fields(self):
+ article = Article(html=html, title="t")
+ result = article.to_json()
+ assert set(result.keys()) == set(Article.DEFAULT_EXPORT_FIELDS)
+ assert result["title"] == "t"
+
+ def test_default_does_not_include_arbitrary_extras(self):
+ article = Article(html=html, title="t", meta={"k": "v"}, ld={"k": "v"})
+ result = article.to_json()
+ assert "meta" not in result
+ assert "ld" not in result
+
+ def test_explicit_fields_filter_output(self):
+ article = Article(html=html, title="t", topics=["a", "b"])
+ result = article.to_json("title")
+ assert result == {"title": "t"}
+
+ def test_preserves_field_order(self):
+ article = Article(html=html, title="t", topics=["a"])
+ result = article.to_json("topics", "title")
+ assert list(result.keys()) == ["topics", "title"]
+
+ def test_can_export_arbitrary_extraction_key_by_request(self):
+ article = Article(html=html, meta={"k": "v"})
+ result = article.to_json("meta")
+ assert result == {"meta": {"k": "v"}}
+
+ def test_raises_key_error_on_unknown_field(self):
+ article = Article(html=html, title="t")
+ with pytest.raises(KeyError):
+ article.to_json("title", "nonexistent")
+
+
+class TestStr:
+ def test_renders_title_and_plaintext(self):
+ article = Article(html=html, title="The Title", body=_StubBody("The text."))
+ rendered = str(article)
+ assert "The Title" in rendered
+ assert "The text." in rendered
+
+ def test_marks_missing_title(self):
+ article = Article(html=html, body=_StubBody("text"))
+ assert "--missing title--" in str(article)
+
+ def test_marks_missing_plaintext(self):
+ article = Article(html=html, title="t")
+ assert "--missing plaintext--" in str(article)
diff --git a/tests/scraping/test_filter.py b/tests/scraping/test_filter.py
new file mode 100644
index 000000000..94127b46e
--- /dev/null
+++ b/tests/scraping/test_filter.py
@@ -0,0 +1,153 @@
+from fundus import Requires
+from fundus.scraping.filter import (
+ FilterResultWithMissingAttributes,
+ RequiresAll,
+ inverse,
+ land,
+ lor,
+ regex_filter,
+)
+
+
+class TestFilterResultWithMissingAttributes:
+ def test_false_when_no_missing_attributes(self):
+ assert not FilterResultWithMissingAttributes()
+
+ def test_true_when_has_missing_attributes(self):
+ assert FilterResultWithMissingAttributes("title")
+
+ def test_stores_all_missing_attributes(self):
+ result = FilterResultWithMissingAttributes("title", "body")
+ assert sorted(result.missing_attributes) == ["body", "title"]
+
+
+class TestRequires:
+ def test_passes_when_attribute_is_truthy(self):
+ assert not Requires("a")({"a": "text"})
+
+ def test_filtered_when_attribute_is_falsy(self):
+ assert Requires("a")({"a": []})
+
+ def test_filtered_when_attribute_is_missing(self):
+ assert Requires("a")({"b": "text"})
+
+ def test_filtered_when_boolean_false(self):
+ assert Requires("a")({"a": False})
+
+ def test_passes_when_boolean_false_and_eval_disabled(self):
+ assert not Requires("a", eval_booleans=False)({"a": False})
+
+ def test_reports_all_failing_attributes(self):
+ result = Requires("a", "b")({"a": [], "b": []})
+ assert sorted(result.missing_attributes) == ["a", "b"]
+
+ def test_without_arguments_evaluates_all_keys(self):
+ result = Requires()({"a": "text", "b": []})
+ assert result.missing_attributes == ("b",)
+
+
+class TestRequiresAll:
+ def test_reports_all_falsy_attributes_across_all_keys(self):
+ result = RequiresAll()({"a": [], "b": []})
+ assert sorted(result.missing_attributes) == ["a", "b"]
+
+ def test_skips_boolean_attributes_by_default(self):
+ assert not RequiresAll()({"a": "text", "b": False})
+
+ def test_evaluates_boolean_attributes_when_enabled(self):
+ assert RequiresAll(eval_booleans=True)({"a": "text", "b": False})
+
+
+class TestInverse:
+ def test_true_becomes_false(self):
+ assert not inverse(lambda url: True)("https://example.com")
+
+ def test_false_becomes_true(self):
+ assert inverse(lambda url: False)("https://example.com")
+
+ def test_double_inverse_preserves_result(self):
+ def starts_with_https(url: str) -> bool:
+ return url.startswith("https")
+
+ assert inverse(inverse(starts_with_https))("https://example.com") is True
+ assert inverse(inverse(starts_with_https))("http://example.com") is False
+
+
+class TestLor:
+ def test_true_if_any_filter_matches(self):
+ assert lor(lambda url: False, lambda url: True)("https://example.com")
+
+ def test_false_if_no_filter_matches(self):
+ assert not lor(lambda url: False, lambda url: False)("https://example.com")
+
+ def test_true_if_all_filters_match(self):
+ assert lor(lambda url: True, lambda url: True)("https://example.com")
+
+
+class TestLand:
+ def test_true_if_all_filters_match(self):
+ assert land(lambda url: True, lambda url: True)("https://example.com")
+
+ def test_false_if_any_filter_misses(self):
+ assert not land(lambda url: True, lambda url: False)("https://example.com")
+
+ def test_false_if_no_filter_matches(self):
+ assert not land(lambda url: False, lambda url: False)("https://example.com")
+
+
+class TestCombinatorNesting:
+ def test_inverse_of_lor(self):
+ # NOT (A OR B) — true only when both false
+ f = inverse(lor(lambda url: False, lambda url: False))
+ assert f("https://example.com")
+ f = inverse(lor(lambda url: True, lambda url: False))
+ assert not f("https://example.com")
+
+ def test_inverse_of_land(self):
+ # NOT (A AND B) — true when at least one is false
+ f = inverse(land(lambda url: True, lambda url: False))
+ assert f("https://example.com")
+ f = inverse(land(lambda url: True, lambda url: True))
+ assert not f("https://example.com")
+
+ def test_land_of_lors(self):
+ # (A OR B) AND (C OR D)
+ a_or_b = lor(lambda url: True, lambda url: False)
+ c_or_d = lor(lambda url: False, lambda url: False)
+ assert not land(a_or_b, c_or_d)("https://example.com")
+
+ a_or_b = lor(lambda url: True, lambda url: False)
+ c_or_d = lor(lambda url: False, lambda url: True)
+ assert land(a_or_b, c_or_d)("https://example.com")
+
+ def test_lor_of_lands(self):
+ # (A AND B) OR (C AND D)
+ a_and_b = land(lambda url: True, lambda url: False)
+ c_and_d = land(lambda url: True, lambda url: True)
+ assert lor(a_and_b, c_and_d)("https://example.com")
+
+ a_and_b = land(lambda url: False, lambda url: False)
+ c_and_d = land(lambda url: True, lambda url: False)
+ assert not lor(a_and_b, c_and_d)("https://example.com")
+
+
+class TestRegexFilter:
+ def test_matches_pattern(self):
+ assert regex_filter(r"/article/\d+")("https://example.com/article/123")
+
+ def test_no_match_returns_false(self):
+ assert not regex_filter(r"/article/\d+")("https://example.com/news/latest")
+
+ def test_partial_match_is_sufficient(self):
+ assert regex_filter(r"example")("https://example.com/some/deep/path")
+
+ def test_anchored_pattern_matches(self):
+ assert regex_filter(r"^https://")("https://example.com")
+
+ def test_anchored_pattern_rejects(self):
+ assert not regex_filter(r"^https://")("http://example.com")
+
+ def test_composable_with_inverse(self):
+ not_article = inverse(regex_filter(r"/article/"))
+ assert not_article("https://example.com/news/1")
+ assert not not_article("https://example.com/article/1")
diff --git a/tests/scraping/test_html.py b/tests/scraping/test_html.py
new file mode 100644
index 000000000..d87bbcd2f
--- /dev/null
+++ b/tests/scraping/test_html.py
@@ -0,0 +1,65 @@
+import datetime
+
+from fundus.scraping.html import HTML, SourceInfo
+from fundus.scraping.pipeline.source.ccnews import WarcSourceInfo
+from fundus.scraping.pipeline.source.web import WebSourceInfo
+
+
+class TestSourceInfoSerialize:
+ def test_base_class_serializes_publisher(self):
+ info = SourceInfo(publisher="example.com")
+ assert info.serialize() == {"publisher": "example.com"}
+
+ def test_web_subclass_includes_inherited_and_own_fields(self):
+ info = WebSourceInfo(publisher="example.com", type="rss", url="https://example.com/feed.xml")
+ assert info.serialize() == {
+ "publisher": "example.com",
+ "type": "rss",
+ "url": "https://example.com/feed.xml",
+ }
+
+ def test_warc_subclass_includes_inherited_and_own_fields(self):
+ info = WarcSourceInfo(
+ publisher="example.com",
+ warc_path="cc-news/2024/path.warc.gz",
+ warc_headers={"WARC-Type": "response"},
+ http_headers={"Content-Type": "text/html"},
+ )
+ assert info.serialize() == {
+ "publisher": "example.com",
+ "warc_path": "cc-news/2024/path.warc.gz",
+ "warc_headers": {"WARC-Type": "response"},
+ "http_headers": {"Content-Type": "text/html"},
+ }
+
+
+class TestHTMLSerialize:
+ def test_serializes_all_fields_with_isoformat_and_nested_source_info(self):
+ html = HTML(
+ requested_url="https://example.com/article",
+ responded_url="https://example.com/article",
+ content=" ",
+ crawl_date=datetime.datetime(2024, 1, 2, 3, 4, 5),
+ source_info=SourceInfo(publisher="example.com"),
+ )
+ assert html.serialize() == {
+ "requested_url": "https://example.com/article",
+ "responded_url": "https://example.com/article",
+ "content": " ",
+ "crawl_date": "2024-01-02T03:04:05",
+ "source_info": {"publisher": "example.com"},
+ }
+
+ def test_uses_subclass_source_info_serialize(self):
+ html = HTML(
+ requested_url="https://example.com/article",
+ responded_url="https://example.com/article",
+ content=" ",
+ crawl_date=datetime.datetime(2024, 1, 2, 3, 4, 5),
+ source_info=WebSourceInfo(publisher="example.com", type="rss", url="https://example.com/feed.xml"),
+ )
+ assert html.serialize()["source_info"] == {
+ "publisher": "example.com",
+ "type": "rss",
+ "url": "https://example.com/feed.xml",
+ }
diff --git a/tests/test_session_handler.py b/tests/scraping/test_session.py
similarity index 75%
rename from tests/test_session_handler.py
rename to tests/scraping/test_session.py
index 2f2b3513c..8e8107358 100644
--- a/tests/test_session_handler.py
+++ b/tests/scraping/test_session.py
@@ -2,12 +2,12 @@
import time
from queue import Queue
from threading import Thread
-from typing import Union
+from typing import Dict, Optional, Union
from unittest.mock import MagicMock, patch
import curl_cffi
import pytest
-from curl_cffi.requests.exceptions import HTTPError, TooManyRedirects
+from curl_cffi.requests.exceptions import HTTPError, Timeout, TooManyRedirects
from fundus.scraping.session import (
CrashThread,
@@ -19,7 +19,7 @@
from tests.exceptions import Success
-def _mock_response(status_code: int = 200) -> MagicMock:
+def _mock_response(status_code: int = 200, headers: Optional[Dict[str, str]] = None) -> MagicMock:
"""Build a mock curl_cffi response that satisfies InterruptableSession._log_response."""
response = MagicMock()
response._history = [] # object.__getattribute__ reads directly from __dict__
@@ -27,6 +27,7 @@ def _mock_response(status_code: int = 200) -> MagicMock:
response.url = "https://example.com"
response.elapsed = 0.01
response.request = None
+ response.headers = headers if headers is not None else {}
response.raise_for_status = MagicMock()
return response
@@ -335,6 +336,148 @@ def test_raise_for_status_propagates(self):
session.close()
+ @pytest.mark.integration
+ def test_session_timeout_raises_timeout(self, hanging_url):
+ """get_with_interrupt raises curl_cffi's Timeout when a request times out."""
+ session = InterruptableSession(timeout=0.3, max_retries=0)
+ try:
+ with pytest.raises(Timeout):
+ session.get_with_interrupt(hanging_url)
+ finally:
+ session.close()
+
+
+class TestRetry:
+ def test_5xx_retried_then_succeeds(self):
+ session = InterruptableSession(max_retries=3)
+ responses = [_mock_response(status_code=503), _mock_response(status_code=200)]
+
+ with patch.object(session, "_sleep_with_interrupt") as sleep:
+ with patch.object(session, "_follow_redirects", side_effect=responses):
+ result = session.get_with_interrupt("http://example.com")
+
+ assert result is responses[1]
+ assert sleep.call_count == 1
+ session.close()
+
+ def test_5xx_exhausted_raises_http_error(self):
+ session = InterruptableSession(max_retries=2)
+ error = _mock_response(status_code=503)
+ error.raise_for_status.side_effect = HTTPError("503")
+ follow = MagicMock(return_value=error)
+
+ with patch.object(session, "_sleep_with_interrupt"):
+ with patch.object(session, "_follow_redirects", side_effect=follow):
+ with pytest.raises(HTTPError):
+ session.get_with_interrupt("http://example.com")
+
+ # initial attempt + max_retries
+ assert follow.call_count == 3
+ session.close()
+
+ def test_no_retry_when_disabled(self):
+ session = InterruptableSession(max_retries=0)
+ error = _mock_response(status_code=503)
+ error.raise_for_status.side_effect = HTTPError("503")
+ follow = MagicMock(return_value=error)
+
+ with patch.object(session, "_sleep_with_interrupt") as sleep:
+ with patch.object(session, "_follow_redirects", side_effect=follow):
+ with pytest.raises(HTTPError):
+ session.get_with_interrupt("http://example.com")
+
+ assert follow.call_count == 1
+ sleep.assert_not_called()
+ session.close()
+
+ def test_4xx_not_retried(self):
+ session = InterruptableSession(max_retries=3)
+ error = _mock_response(status_code=404)
+ error.raise_for_status.side_effect = HTTPError("404")
+ follow = MagicMock(return_value=error)
+
+ with patch.object(session, "_sleep_with_interrupt") as sleep:
+ with patch.object(session, "_follow_redirects", side_effect=follow):
+ with pytest.raises(HTTPError):
+ session.get_with_interrupt("http://example.com")
+
+ assert follow.call_count == 1
+ sleep.assert_not_called()
+ session.close()
+
+ def test_backoff_passed_to_sleep_between_retries(self):
+ session = InterruptableSession(max_retries=1)
+ responses = [_mock_response(status_code=500), _mock_response(status_code=200)]
+
+ with patch.object(session, "_retry_backoff", return_value=2.5) as backoff:
+ with patch.object(session, "_sleep_with_interrupt") as sleep:
+ with patch.object(session, "_follow_redirects", side_effect=responses):
+ session.get_with_interrupt("http://example.com")
+
+ backoff.assert_called_once()
+ sleep.assert_called_once()
+ assert sleep.call_args.args[0] == 2.5
+ session.close()
+
+
+class TestRetryAfter:
+ def test_parse_delta_seconds(self):
+ assert InterruptableSession._parse_retry_after("120") == 120.0
+
+ def test_parse_invalid_returns_none(self):
+ assert InterruptableSession._parse_retry_after("not-a-date") is None
+
+ def test_parse_http_date_future(self):
+ from email.utils import formatdate
+
+ seconds = InterruptableSession._parse_retry_after(formatdate(time.time() + 100, usegmt=True))
+ assert seconds is not None
+ assert 90 <= seconds <= 100
+
+ def test_parse_http_date_past_clamped_to_zero(self):
+ from email.utils import formatdate
+
+ assert InterruptableSession._parse_retry_after(formatdate(time.time() - 100, usegmt=True)) == 0.0
+
+ def test_backoff_honors_retry_after(self):
+ session = InterruptableSession()
+ response = _mock_response(status_code=503, headers={"retry-after": "5"})
+ assert session._retry_backoff(response, attempt=0) == 5.0
+ session.close()
+
+ def test_backoff_caps_retry_after(self):
+ session = InterruptableSession()
+ response = _mock_response(status_code=503, headers={"retry-after": "9999"})
+ assert session._retry_backoff(response, attempt=0) == session.retry_backoff_cap
+ session.close()
+
+ def test_backoff_exponential_jitter_within_window(self):
+ session = InterruptableSession()
+ response = _mock_response(status_code=503, headers={})
+ for attempt in range(6):
+ backoff = session._retry_backoff(response, attempt)
+ window = min(session.retry_backoff_cap, session.retry_backoff_base * 2**attempt)
+ assert 0.0 <= backoff <= window
+ session.close()
+
+ def test_backoff_params_configurable(self):
+ session = InterruptableSession(retry_backoff_base=0.5, retry_backoff_cap=4.0)
+ response = _mock_response(status_code=503, headers={"retry-after": "9999"})
+ assert session._retry_backoff(response, attempt=0) == 4.0
+ session.close()
+
+
+class TestSleepWithInterrupt:
+ def test_completes_when_not_interrupted(self):
+ # zero duration returns immediately without consulting the stop event
+ InterruptableSession._sleep_with_interrupt(0.0, "http://example.com")
+
+ def test_raises_crash_thread_on_stop_event(self):
+ with __EVENTS__.context("test-stop-event"):
+ __EVENTS__.set_event("stop", "test-stop-event")
+ with pytest.raises(CrashThread):
+ InterruptableSession._sleep_with_interrupt(5.0, "http://example.com")
+
def _redirect_response(status_code: int, location: str) -> MagicMock:
response = MagicMock()
diff --git a/tests/scraping/test_url.py b/tests/scraping/test_url.py
new file mode 100644
index 000000000..4ecda391b
--- /dev/null
+++ b/tests/scraping/test_url.py
@@ -0,0 +1,270 @@
+from __future__ import annotations
+
+import bz2
+import gzip
+import lzma
+from typing import Dict, List, Optional
+from unittest.mock import MagicMock, patch
+
+import pytest
+from curl_cffi.requests.exceptions import ConnectionError, HTTPError
+
+from fundus.scraping.url import (
+ RSSFeed,
+ Sitemap,
+ decompress,
+ is_valid_url,
+)
+
+# ---- helpers ----------------------------------------------------------------
+
+_RSS_FEED = """\
+
+
+
+ Test
+ https://example.com
+ A1 https://example.com/article/1
+ A2 https://example.com/article/2
+
+ """
+
+_SITEMAP_URLS = b"""\
+
+
+ https://example.com/article/1
+ https://example.com/article/2
+ """
+
+_SITEMAP_INDEX = b"""\
+
+
+ https://example.com/sitemap-a.xml
+ https://example.com/sitemap-b.xml
+ """
+
+_SITEMAP_SINGLE_URL = b"""\
+
+
+ https://example.com/article/sub
+ """
+
+
+def _make_response(text: str = "", content: bytes = b"", headers: Optional[Dict[str, str]] = None) -> MagicMock:
+ response = MagicMock()
+ response.text = text
+ response.content = content or text.encode()
+ response.headers = headers or {}
+ return response
+
+
+def _make_session(response=None, side_effect=None) -> MagicMock:
+ session = MagicMock()
+ if side_effect is not None:
+ session.get_with_interrupt.side_effect = side_effect
+ else:
+ session.get_with_interrupt.return_value = response
+ return session
+
+
+class TestIsValidUrl:
+ def test_valid_http_url(self):
+ assert is_valid_url("http://example.com")
+
+ def test_valid_https_url(self):
+ assert is_valid_url("https://example.com")
+
+ def test_valid_url_with_path(self):
+ assert is_valid_url("https://example.com/some/path")
+
+ def test_valid_url_with_query(self):
+ assert is_valid_url("https://example.com/page?id=1")
+
+ def test_rejects_unsupported_scheme(self):
+ assert not is_valid_url("ftp://example.com")
+
+ def test_rejects_missing_netloc(self):
+ assert not is_valid_url("https://")
+
+ def test_rejects_empty_string(self):
+ assert not is_valid_url("")
+
+
+class TestDecompress:
+ def test_decompresses_gzip_by_magic_bytes(self):
+ data = b"hello world"
+ assert decompress(gzip.compress(data)) == data
+
+ def test_decompresses_bzip2_by_magic_bytes(self):
+ data = b"hello world"
+ assert decompress(bz2.compress(data)) == data
+
+ def test_decompresses_xz_by_magic_bytes(self):
+ data = b"hello world"
+ assert decompress(lzma.compress(data)) == data
+
+ def test_returns_uncompressed_content_unchanged(self):
+ data = b" "
+ assert decompress(data) == data
+
+ def test_returns_empty_unchanged(self):
+ assert decompress(b"") == b""
+
+ def test_raises_on_corrupt_gzip(self):
+ # leading magic bytes match gzip but the rest is garbage
+ with pytest.raises(Exception):
+ decompress(b"\x1f\x8b" + b"\x00" * 10)
+
+
+class TestURLSourceGetUrls:
+ def _make_source(self, urls: List[str]) -> RSSFeed:
+ class FixedSource(RSSFeed):
+ def __iter__(self_inner):
+ return iter(urls)
+
+ return FixedSource(url="https://example.com")
+
+ def test_limits_output_to_max_urls(self):
+ source = self._make_source([f"https://example.com/{i}" for i in range(10)])
+ assert len(list(source.get_urls(max_urls=3))) == 3
+
+ def test_returns_all_when_max_urls_is_none(self):
+ source = self._make_source([f"https://example.com/{i}" for i in range(5)])
+ assert len(list(source.get_urls())) == 5
+
+ def test_max_urls_zero_returns_nothing(self):
+ source = self._make_source([f"https://example.com/{i}" for i in range(5)])
+ assert list(source.get_urls(max_urls=0)) == []
+
+
+class TestRSSFeedFetch:
+ def test_yields_urls_from_valid_feed(self):
+ session = _make_session(_make_response(text=_RSS_FEED))
+ result = list(RSSFeed(url="https://example.com/feed.xml").fetch(session, {}))
+ assert result == ["https://example.com/article/1", "https://example.com/article/2"]
+
+ def test_decodes_percent_encoded_urls(self):
+ rss = _RSS_FEED.replace("https://example.com/article/1", "https://example.com/caf%C3%A9")
+ session = _make_session(_make_response(text=rss))
+ result = list(RSSFeed(url="https://example.com/feed.xml").fetch(session, {}))
+ assert "https://example.com/café" in result
+
+ def test_yields_nothing_on_http_error(self):
+ session = _make_session(side_effect=HTTPError("error"))
+ result = list(RSSFeed(url="https://example.com/feed.xml").fetch(session, {}))
+ assert result == []
+
+ def test_yields_nothing_on_connection_error(self):
+ session = _make_session(side_effect=ConnectionError("error"))
+ result = list(RSSFeed(url="https://example.com/feed.xml").fetch(session, {}))
+ assert result == []
+
+ def test_yields_nothing_on_bozo_exception(self):
+ session = _make_session(_make_response(text=_RSS_FEED))
+ with patch("fundus.scraping.url.feedparser.parse", return_value={"bozo_exception": Exception("bad")}):
+ result = list(RSSFeed(url="https://example.com/feed.xml").fetch(session, {}))
+ assert result == []
+
+ def test_skips_entries_without_link(self):
+ rss = """\
+
+
+
+ No link
+ Has link https://example.com/article/1
+
+ """
+ session = _make_session(_make_response(text=rss))
+ result = list(RSSFeed(url="https://example.com/feed.xml").fetch(session, {}))
+ assert result == ["https://example.com/article/1"]
+
+
+class TestSitemapFetch:
+ def test_yields_urls_from_urlset(self):
+ session = _make_session(_make_response(content=_SITEMAP_URLS))
+ result = list(Sitemap(url="https://example.com/sitemap.xml").fetch(session, {}))
+ assert result == ["https://example.com/article/1", "https://example.com/article/2"]
+
+ def test_follows_sub_sitemaps_recursively(self):
+ def side_effect(*args, **kwargs):
+ url = args[0] if args else kwargs["url"]
+ return _make_response(content=_SITEMAP_INDEX if "sitemap-index" in url else _SITEMAP_SINGLE_URL)
+
+ session = _make_session(side_effect=side_effect)
+ result = list(Sitemap(url="https://example.com/sitemap-index.xml").fetch(session, {}))
+ # _SITEMAP_INDEX has two sub-sitemaps, each yielding one URL
+ assert result == ["https://example.com/article/sub", "https://example.com/article/sub"]
+
+ def test_reverse_reverses_url_order(self):
+ session = _make_session(_make_response(content=_SITEMAP_URLS))
+ result = list(Sitemap(url="https://example.com/sitemap.xml", reverse=True).fetch(session, {}))
+ assert result == ["https://example.com/article/2", "https://example.com/article/1"]
+
+ def test_sitemap_filter_excludes_matching_sub_sitemaps(self):
+ def side_effect(*args, **kwargs):
+ url = args[0] if args else kwargs["url"]
+ return _make_response(content=_SITEMAP_SINGLE_URL if "sitemap-" in url else _SITEMAP_INDEX)
+
+ session = _make_session(side_effect=side_effect)
+ # filter out sitemap-b, keep sitemap-a
+ sitemap = Sitemap(
+ url="https://example.com/sitemap-index.xml",
+ sitemap_filter=lambda url: "sitemap-b" in url,
+ )
+ result = list(sitemap.fetch(session, {}))
+ assert result == ["https://example.com/article/sub"]
+
+ def test_yields_nothing_on_http_error(self):
+ session = _make_session(side_effect=HTTPError("error"))
+ result = list(Sitemap(url="https://example.com/sitemap.xml").fetch(session, {}))
+ assert result == []
+
+ def test_decompresses_gzip_content(self):
+ session = _make_session(
+ _make_response(content=gzip.compress(_SITEMAP_URLS), headers={"content-type": "application/x-gzip"})
+ )
+ result = list(Sitemap(url="https://example.com/sitemap.xml.gz").fetch(session, {}))
+ assert result == ["https://example.com/article/1", "https://example.com/article/2"]
+
+ def test_yields_nothing_on_empty_sitemap(self):
+ session = _make_session(_make_response(content=b""))
+ result = list(Sitemap(url="https://example.com/sitemap.xml").fetch(session, {}))
+ assert result == []
+
+ def test_sort_predicate_orders_sub_sitemaps_descending(self):
+ import re
+
+ sitemap_index = b"""\
+
+
+ https://example.com/sitemap-2019.xml
+ https://example.com/sitemap-2021.xml
+ https://example.com/sitemap-2020.xml
+ """
+
+ def _sub_sitemap(year: str) -> bytes:
+ return f"""\
+
+
+ https://example.com/{year}/article
+ """.encode()
+
+ def side_effect(*args, **kwargs):
+ url = args[0] if args else kwargs["url"]
+ if "sitemap-index" in url:
+ return _make_response(content=sitemap_index)
+ match = re.search(r"\d{4}", url)
+ assert match is not None
+ return _make_response(content=_sub_sitemap(match.group()))
+
+ session = _make_session(side_effect=side_effect)
+ sitemap = Sitemap(
+ url="https://example.com/sitemap-index.xml",
+ sort_predicate=re.compile(r"\d{4}"),
+ )
+ result = list(sitemap.fetch(session, {}))
+ assert result == [
+ "https://example.com/2021/article",
+ "https://example.com/2020/article",
+ "https://example.com/2019/article",
+ ]
diff --git a/tests/test_article.py b/tests/test_article.py
deleted file mode 100644
index 49d7d1f33..000000000
--- a/tests/test_article.py
+++ /dev/null
@@ -1,68 +0,0 @@
-import datetime
-from typing import Any, Dict
-
-import pytest
-
-from fundus import Article
-from fundus.scraping.html import HTML, SourceInfo
-
-info = SourceInfo(publisher="")
-html = HTML(content="", responded_url="", requested_url="", crawl_date=datetime.datetime.now(), source_info=info)
-
-
-class TestArticle:
- def test_constructor(self):
- extraction = {"authors": ["Author"], "title": "title"}
-
- with pytest.raises(TypeError):
- Article(extraction, html=html) # type: ignore[arg-type, misc]
-
- with pytest.raises(TypeError):
- Article(**extraction) # type: ignore[arg-type]
-
- Article(**{}, html=html)
- Article(**extraction, html=html, exception=None)
- Article(html=html, **extraction, exception=None)
- Article(**extraction, html=html, exception=TypeError())
-
- def test_default_values(self):
- extraction: Dict[str, Any] = {}
-
- article = Article(**extraction, html=html)
-
- assert article.title is None
- assert article.body is None
- assert article.authors == []
- assert article.publishing_date is None
- assert article.topics == []
- assert article.free_access is False
-
- def test_view(self):
- extraction = {
- "authors": ["Author1", "Author2", "Author3"],
- "title": "",
- }
-
- article = Article(**extraction, html=html, exception=None)
-
- assert article.title == ""
- assert article.authors == ["Author1", "Author2", "Author3"]
-
- def test_extraction_view_getter(self):
- extraction = {"test_attribute": "test_value"}
-
- article = Article(**extraction, html=html, exception=None)
-
- assert article.test_attribute
- assert article.test_attribute == "test_value"
-
- article.__extraction__["test_attribute"] = "very_secret_stuff" # type: ignore[index]
-
- assert article.test_attribute == "very_secret_stuff"
-
- def test_extraction_view_setter(self):
- extraction = {"test_attribute": "test_value"}
-
- article = Article(**extraction, html=html, exception=None)
- with pytest.raises(AttributeError):
- article.test_attribute = "another_value"
diff --git a/tests/test_collection.py b/tests/test_collection.py
deleted file mode 100644
index 51840bd89..000000000
--- a/tests/test_collection.py
+++ /dev/null
@@ -1,114 +0,0 @@
-import pytest
-
-from fundus import NewsMap, RSSFeed, Sitemap
-
-
-class TestCollection:
- def test_len(
- self,
- empty_publisher_group,
- group_with_empty_publisher_subgroup,
- group_with_valid_publisher_subgroup,
- group_with_two_valid_publisher_subgroups,
- ):
- assert len(empty_publisher_group) == 0
- assert len(group_with_empty_publisher_subgroup) == 0
- assert len(group_with_valid_publisher_subgroup) == 1
- assert len(group_with_two_valid_publisher_subgroups) == 2
-
- def test_iter_empty_group(self, empty_publisher_group):
- assert list(empty_publisher_group) == []
-
- def test_iter_group_with_empty_publisher_subgroup(self, group_with_empty_publisher_subgroup):
- assert list(group_with_empty_publisher_subgroup) == []
-
- def test_iter_group_with_publisher_subgroup(self, group_with_valid_publisher_subgroup):
- assert list(group_with_valid_publisher_subgroup) == [group_with_valid_publisher_subgroup.pub.value]
-
- @pytest.mark.filterwarnings("ignore::UserWarning")
- def test_supports(self, publisher_group_with_news_map, publisher_group_with_languages):
- assert publisher_group_with_news_map.value.supports(source_types=[NewsMap])
- assert not publisher_group_with_news_map.value.supports(source_types=[Sitemap])
- assert not publisher_group_with_news_map.value.supports(source_types=[RSSFeed])
- assert publisher_group_with_languages.eng.supports(languages=["en"])
- assert not publisher_group_with_languages.eng.supports(languages=["es"])
-
- @pytest.mark.filterwarnings("ignore::UserWarning")
- def test_search(
- self, publisher_group_with_news_map, proxy_with_two_versions_and_different_attrs, publisher_group_with_languages
- ):
- parser_proxy = proxy_with_two_versions_and_different_attrs()
-
- # monkey pathing publisher enums parser
- publisher_group_with_news_map.value.parser = parser_proxy
-
- later, earlier = parser_proxy.attribute_mapping.values()
-
- assert len(publisher_group_with_news_map.search(later.names, source_types=[NewsMap])) == 1
- assert len(publisher_group_with_news_map.search(later.names, source_types=[RSSFeed, Sitemap])) == 0
- assert len(publisher_group_with_news_map.search(later.names, source_types=[NewsMap, RSSFeed])) == 1
-
- # check that only latest version is supported with search
- assert len(publisher_group_with_news_map.search(later.names)) == 1
- assert len(publisher_group_with_news_map.search(earlier.names)) == 0
-
- assert len(publisher_group_with_languages.search(languages=["en", "en"])) == 1
- assert len(publisher_group_with_languages.search(languages=["de"])) == 1
- assert len(publisher_group_with_languages.search(languages=["en", "de"])) == 2
-
- assert len(publisher_group_with_languages.search(languages=["en"], source_types=[NewsMap])) == 1
- assert len(publisher_group_with_languages.search(languages=["en"], source_types=[Sitemap])) == 0
-
- assert len(publishers := publisher_group_with_languages.search(languages=["es"])) == 1
- assert len(publishers[0].source_mapping) == 2
-
- assert len(publishers := publisher_group_with_languages.search(languages=["ind"])) == 1
- assert len(publishers[0].source_mapping) == 1
-
- assert len(publishers := publisher_group_with_languages.search(languages=["pl"])) == 1
- assert len(publishers[0].source_mapping) == 1
-
- assert len(publishers := publisher_group_with_languages.search(languages=["es", "ind"])) == 1
- assert len(publishers[0].source_mapping) == 3
-
- assert len(publishers := publisher_group_with_languages.search(languages=["pl", "ind"])) == 1
- assert len(publishers[0].source_mapping) == 2
-
- assert len(publishers := publisher_group_with_languages.search(languages=["es"], source_types=[RSSFeed])) == 1
- assert len(publishers[0].source_mapping) == 1
-
- assert len(publishers := publisher_group_with_languages.search(languages=["es"], source_types=[NewsMap])) == 1
- assert len(publishers[0].source_mapping) == 1
-
- assert len(publisher_group_with_languages.search(languages=["es"], source_types=[Sitemap])) == 0
-
- assert (
- len(publishers := publisher_group_with_languages.search(languages=["es", "ind"], source_types=[Sitemap]))
- == 1
- )
- assert len(publishers[0].source_mapping) == 1
-
- with pytest.raises(ValueError):
- publisher_group_with_news_map.search()
-
- with pytest.raises(ValueError):
- publisher_group_with_news_map.search([])
-
- with pytest.raises(ValueError):
- publisher_group_with_news_map.search([], [])
-
- def test_publisher_group_with_subgroup_string_representation(self, group_with_two_valid_publisher_subgroups):
- representation = str(group_with_two_valid_publisher_subgroups)
- assertion = (
- ""
- "\n\t"
- "\n\t\ttest_pub"
- "\n\t"
- "\n\t\ttest_pub"
- )
- assert representation == assertion
-
- def test_publisher_group_with_publisher_string_representation(self, publisher_group_with_news_map):
- representation = str(publisher_group_with_news_map)
- assertion = "\n\ttest_pub"
- assert representation == assertion
diff --git a/tests/test_filter.py b/tests/test_filter.py
deleted file mode 100644
index a9e727993..000000000
--- a/tests/test_filter.py
+++ /dev/null
@@ -1,41 +0,0 @@
-from fundus import Requires
-from fundus.scraping.filter import RequiresAll
-
-
-class TestExtractionFilter:
- def test_requires(self):
- extraction = {"a": "Some Stuff", "b": [], "c": True}
-
- assert not Requires("a")(extraction)
-
- assert (result := Requires("a", "b")(extraction))
-
- assert result.missing_attributes == ("b",)
-
- assert not Requires("c")(extraction)
-
- extraction = {"a": "Some Stuff", "b": [], "c": False}
-
- assert (result := Requires("a", "b", "c")(extraction))
-
- assert sorted(result.missing_attributes) == sorted(("b", "c"))
-
- assert not Requires("c", eval_booleans=False)(extraction)
-
- def test_requires_all(self):
- extraction = {"a": "Some Stuff", "b": [], "c": False}
-
- assert (result := RequiresAll()(extraction))
- assert result.missing_attributes == ("b",)
-
- extraction = {"a": "Some Stuff", "c": False}
- assert not RequiresAll()(extraction)
-
- # test skip_boolean=False
- extraction = {"a": "Some Stuff", "b": [], "c": False}
-
- assert (result := RequiresAll(eval_booleans=True)(extraction))
- assert sorted(result.missing_attributes) == sorted(("b", "c"))
-
- extraction = {"a": "Some Stuff", "c": True}
- assert not RequiresAll(eval_booleans=True)(extraction)
diff --git a/tests/test_publisher_collection.py b/tests/test_publisher_collection.py
deleted file mode 100644
index fadcf4262..000000000
--- a/tests/test_publisher_collection.py
+++ /dev/null
@@ -1,44 +0,0 @@
-from typing import List
-
-import lxml.html
-import more_itertools
-import pytest
-import requests
-from lxml.etree import XPath
-
-from fundus import PublisherCollection
-from fundus.publishers import Publisher, PublisherGroup
-from fundus.scraping.session import _default_header
-
-_language_code_selector = XPath("//table[contains(@class, 'wikitable') and @id='Table'] //td[@id] / @id")
-
-
-def get_two_letter_code() -> List[str]:
- wiki_page = requests.get("https://en.wikipedia.org/wiki/List_of_ISO_639_language_codes", headers=_default_header)
- two_letter_codes: List[str] = _language_code_selector(lxml.html.document_fromstring(wiki_page.content))
- return two_letter_codes
-
-
-language_codes = get_two_letter_code()
-
-
-class TestPublisherCollection:
- @pytest.mark.parametrize(
- "region",
- [pytest.param(group, id=group.__name__) for group in PublisherCollection.get_subgroup_mapping().values()],
- )
- def test_default_language(self, region: PublisherGroup):
- assert hasattr(region, "default_language"), f"Region {region.__name__!r} has no default language set"
-
- default_language = getattr(region, "default_language")
-
- assert default_language in language_codes, (
- f"Default language {default_language!r} isn't a ISO 639 language code"
- )
-
- @pytest.mark.parametrize(
- "publisher", [pytest.param(publisher, id=publisher.__name__) for publisher in PublisherCollection]
- )
- def test_source_languages(self, publisher: Publisher):
- for source in more_itertools.flatten(publisher.source_mapping.values()):
- assert source.languages.issubset(language_codes)
diff --git a/tests/utility.py b/tests/utility.py
index d30128021..15e178bad 100644
--- a/tests/utility.py
+++ b/tests/utility.py
@@ -11,10 +11,11 @@
from fundus import PublisherCollection
from fundus.parser import ArticleBody, BaseParser
-from fundus.parser.data import Image, TextSequenceTree
+from fundus.parser.data import Image
from fundus.publishers.base_objects import Publisher, PublisherGroup
from fundus.scraping.article import Article
from fundus.scraping.html import HTML, SourceInfo
+from fundus.utils.serialization import Serializable
from scripts.generate_tables import supported_publishers_markdown_path
from tests.resources.parser.test_data import __module_path__ as test_resource_path
@@ -111,13 +112,10 @@ def write(self, content: _T, **kwargs) -> None:
class ExtractionEncoder(json.JSONEncoder):
def default(self, obj: object):
if isinstance(obj, datetime.datetime):
- return str(obj)
- elif isinstance(obj, TextSequenceTree):
+ return obj.isoformat()
+ if isinstance(obj, Serializable):
return obj.serialize()
- elif isinstance(obj, Image):
- return obj.serialize()
- else:
- return json.JSONEncoder.default(self, obj)
+ return json.JSONEncoder.default(self, obj)
class ExtractionDecoder(json.JSONDecoder):
diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/test_events.py b/tests/utils/test_events.py
similarity index 100%
rename from tests/test_events.py
rename to tests/utils/test_events.py
diff --git a/tests/utils/test_serialization.py b/tests/utils/test_serialization.py
new file mode 100644
index 000000000..1652194c2
--- /dev/null
+++ b/tests/utils/test_serialization.py
@@ -0,0 +1,55 @@
+import datetime
+from typing import Any
+
+import pytest
+
+from fundus.utils.serialization import Serializable, serialize_value
+
+
+class _SerializableStub:
+ def __init__(self, value: Any) -> None:
+ self._value = value
+
+ def serialize(self) -> Any:
+ return self._value
+
+
+class TestSerializableProtocol:
+ def test_detects_objects_with_serialize_method(self):
+ assert isinstance(_SerializableStub("x"), Serializable)
+
+ def test_rejects_objects_without_serialize_method(self):
+ assert not isinstance(object(), Serializable)
+
+
+class TestSerializeValue:
+ def test_passes_primitives_through(self):
+ assert serialize_value("s") == "s"
+ assert serialize_value(42) == 42
+ assert serialize_value(3.14) == 3.14
+ assert serialize_value(True) is True
+ assert serialize_value(None) is None
+
+ def test_serializes_datetime_as_isoformat(self):
+ when = datetime.datetime(2024, 1, 2, 3, 4, 5)
+ assert serialize_value(when) == "2024-01-02T03:04:05"
+
+ def test_serializes_object_with_serialize_method(self):
+ assert serialize_value(_SerializableStub({"k": "v"})) == {"k": "v"}
+
+ def test_walks_lists_recursively(self):
+ assert serialize_value([_SerializableStub("a"), _SerializableStub("b")]) == ["a", "b"]
+
+ def test_walks_dicts_recursively(self):
+ assert serialize_value({"k": _SerializableStub("v")}) == {"k": "v"}
+
+ def test_normalizes_tuples_to_lists(self):
+ assert serialize_value((1, 2, 3)) == [1, 2, 3]
+
+ def test_raises_type_error_on_unserializable(self):
+ with pytest.raises(TypeError):
+ serialize_value(object())
+
+ def test_error_message_includes_field_name_when_given(self):
+ with pytest.raises(TypeError, match="field 'foo'"):
+ serialize_value(object(), field_name="foo")
diff --git a/tests/utils/test_timeout.py b/tests/utils/test_timeout.py
new file mode 100644
index 000000000..917cc6813
--- /dev/null
+++ b/tests/utils/test_timeout.py
@@ -0,0 +1,103 @@
+import threading
+import time
+
+import pytest
+
+from fundus.utils.timeout import ResettableTimer, Timeout
+
+
+class TestStopwatch:
+ def test_elapsed_starts_near_zero(self):
+ sw = ResettableTimer._Stopwatch()
+ assert sw.elapsed < 0.05
+
+ def test_elapsed_increases_over_time(self):
+ sw = ResettableTimer._Stopwatch()
+ time.sleep(0.1)
+ assert sw.elapsed >= 0.09
+
+ def test_elapsed_is_never_negative(self):
+ sw = ResettableTimer._Stopwatch()
+ assert sw.elapsed >= 0.0
+
+ def test_reset_restarts_elapsed(self):
+ sw = ResettableTimer._Stopwatch()
+ time.sleep(0.1)
+ sw.reset()
+ assert sw.elapsed < 0.05
+
+
+class TestResettableTimer:
+ def test_fires_callback_after_timeout(self):
+ fired = threading.Event()
+ timer = ResettableTimer(0.2, fired.set)
+ timer.start()
+ assert fired.wait(timeout=1.0)
+
+ def test_does_not_fire_before_timeout(self):
+ fired = threading.Event()
+ timer = ResettableTimer(0.5, fired.set)
+ timer.start()
+ assert not fired.wait(timeout=0.2)
+ timer.cancel()
+
+ def test_does_not_fire_when_canceled(self):
+ fired = threading.Event()
+ timer = ResettableTimer(0.3, fired.set)
+ timer.start()
+ timer.cancel()
+ assert not fired.wait(timeout=0.6)
+
+ def test_reset_postpones_firing(self):
+ fired = threading.Event()
+ timer = ResettableTimer(0.3, fired.set)
+ timer.start()
+ time.sleep(0.15)
+ timer.reset()
+ assert not fired.wait(timeout=0.15) # 0.15s since reset, not yet
+ assert fired.wait(timeout=0.5) # fires ~0.3s after reset
+ timer.cancel()
+
+ def test_thread_is_daemon(self):
+ timer = ResettableTimer(10, lambda: None)
+ assert timer._thread.daemon
+
+
+class TestTimeout:
+ def test_disabled_when_seconds_is_none(self):
+ fired = threading.Event()
+ with Timeout(seconds=None, callback=fired.set):
+ time.sleep(0.1)
+ assert not fired.is_set()
+
+ def test_fires_callback_on_timeout(self):
+ fired = threading.Event()
+ with Timeout(seconds=0.1, callback=fired.set):
+ fired.wait(timeout=1.0)
+ assert fired.is_set()
+
+ def test_silent_suppresses_timeout_error(self):
+ with Timeout(seconds=0.1, silent=True):
+ time.sleep(0.5)
+
+ def test_not_silent_raises_timeout_error(self):
+ with pytest.raises(TimeoutError):
+ with Timeout(seconds=0.1, silent=False):
+ time.sleep(0.5)
+
+ def test_yields_resettable_timer(self):
+ with Timeout(seconds=None) as timer:
+ assert isinstance(timer, ResettableTimer)
+
+ def test_reset_delays_timeout(self):
+ fired = threading.Event()
+ with Timeout(seconds=0.3, callback=fired.set) as timer:
+ time.sleep(0.15)
+ timer.reset()
+ assert not fired.wait(timeout=0.15) # 0.15s since reset, not yet
+
+ def test_timer_canceled_on_context_exit(self):
+ fired = threading.Event()
+ with Timeout(seconds=0.3, callback=fired.set):
+ pass
+ assert not fired.wait(timeout=0.5)