diff --git a/CHANGELOG.md b/CHANGELOG.md index 5bff0035244..82f10f5cd50 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -34,6 +34,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added a way to remix datasets flexibly - Added `from_pretrained_transformer_and_instances` constructor to `Vocabulary` - `TransformerTextField` now supports `__len__`. +- Tango steps for keeping only the strings or text fields from a dataset ### Fixed diff --git a/allennlp/tango/__init__.py b/allennlp/tango/__init__.py index 46f795729c2..1932e7f7012 100644 --- a/allennlp/tango/__init__.py +++ b/allennlp/tango/__init__.py @@ -11,6 +11,8 @@ from allennlp.tango.training import TrainingStep from allennlp.tango.evaluation import EvaluationStep +from allennlp.tango.text_fields_only import TextFieldsOnlyDataset +from allennlp.tango.strings_only import StringsOnlyDataset import warnings diff --git a/allennlp/tango/text_only.py b/allennlp/tango/strings_only.py similarity index 64% rename from allennlp/tango/text_only.py rename to allennlp/tango/strings_only.py index 4bd3aed4806..53586578e38 100644 --- a/allennlp/tango/text_only.py +++ b/allennlp/tango/strings_only.py @@ -3,23 +3,27 @@ every time we release a new version.* """ -import dataclasses from typing import Set, Optional, Iterable, Any from allennlp.tango.dataset import DatasetDict from allennlp.tango.step import Step +from allennlp.common.sqlite_sparse_sequence import SqliteSparseSequence +from allennlp.tango.sqlite_format import SqliteDictFormat +from tqdm import tqdm -@Step.register("text_only") -class TextOnlyDataset(Step): +@Step.register("strings_only") +class StringsOnlyDataset(Step): """ - This step converts a dataset into another dataset that contains only the strings from the original dataset. + This step converts a dataset into another dataset that contains only strings from the original dataset. You can specify exactly which fields to keep from the original dataset (default is all of them). You can specify a minimum length of string to keep, to filter out strings that are too short. """ DETERMINISTIC = True + VERSION = "001" + FORMAT = SqliteDictFormat() def run( # type: ignore self, @@ -39,7 +43,11 @@ def run( # type: ignore """ def find_nested_strings(o: Any, prefix: str = "") -> Iterable[str]: - if isinstance(o, list) or isinstance(o, tuple): + if isinstance(o, str): + if fields_to_keep is None or prefix in fields_to_keep: + if min_length is None or len(o) >= min_length: + yield o + elif isinstance(o, list) or isinstance(o, tuple): for i, item in enumerate(o): new_prefix = f"{prefix}.{i}" yield from find_nested_strings(item, new_prefix) @@ -47,17 +55,17 @@ def find_nested_strings(o: Any, prefix: str = "") -> Iterable[str]: for name, item in o.items(): new_prefix = f"{prefix}.{name}" yield from find_nested_strings(item, new_prefix) - elif isinstance(o, str): - if fields_to_keep is None or prefix in fields_to_keep: - if min_length is None or len(o) >= min_length: - yield o - return dataclasses.replace( - input, - splits={ - split_name: [ - {"text": text} for instance in split for text in find_nested_strings(instance) - ] - for split_name, split in input.splits.items() - }, - ) + splits = {} + for split_name, split in input.splits.items(): + sequence_file = self.work_dir() / f"{split_name}.sqlite" + sequence_file.unlink(missing_ok=True) + sequence = SqliteSparseSequence(sequence_file) + sequence.extend( + {"text": string} + for instance in tqdm(split, desc=f"Processing split '{split_name}'") + for string in find_nested_strings(instance) + ) + splits[split_name] = sequence + + return DatasetDict(splits=splits, vocab=input.vocab, metadata=input.metadata) diff --git a/allennlp/tango/text_fields_only.py b/allennlp/tango/text_fields_only.py new file mode 100644 index 00000000000..e85a93debde --- /dev/null +++ b/allennlp/tango/text_fields_only.py @@ -0,0 +1,76 @@ +""" +*AllenNLP Tango is an experimental API and parts of it might change or disappear +every time we release a new version.* +""" + +from typing import Set, Optional, Iterable, Any + +from allennlp.tango.dataset import DatasetDict +from allennlp.tango.step import Step +from allennlp.common.sqlite_sparse_sequence import SqliteSparseSequence +from allennlp.data.fields import TextField, TransformerTextField +from allennlp.data.instance import Instance +from allennlp.tango.sqlite_format import SqliteDictFormat +from allennlp.data import Field +from tqdm import tqdm + + +@Step.register("text_fields_only") +class TextFieldsOnlyDataset(Step): + """ + This step converts a dataset into another dataset that contains only text fields from the original dataset. + + You can specify exactly which fields to keep from the original dataset (default is all of them). + You can specify a minimum length of string to keep, to filter out strings that are too short. + """ + + DETERMINISTIC = True + VERSION = "003" + FORMAT = SqliteDictFormat() + + def run( # type: ignore + self, + input: DatasetDict, + *, + fields_to_keep: Optional[Set[str]] = None, + min_length: Optional[int] = None, + ) -> DatasetDict: + """ + Turns the `input` dataset into another dataset that contains only the text fields from the + original dataset. + + * `fields_to_keep` is an optional list of field names that you want to keep in the result. + If this is `None`, all fields are kept. + * `min_length` specifies the minimum length that a string must have to be part of the + result. If this is `None`, all strings are considered. + """ + + def find_nested_fields(o: Any, prefix: str = "") -> Iterable[Field]: + if isinstance(o, list) or isinstance(o, tuple): + for i, item in enumerate(o): + new_prefix = f"{prefix}.{i}" + yield from find_nested_fields(item, new_prefix) + elif isinstance(o, dict): + for name, item in o.items(): + new_prefix = f"{prefix}.{name}" + yield from find_nested_fields(item, new_prefix) + elif isinstance(o, Instance): + yield from find_nested_fields(o.fields, prefix) + elif isinstance(o, TextField) or isinstance(o, TransformerTextField): + if fields_to_keep is None or prefix in fields_to_keep: + if min_length is None or len(o) >= min_length: + yield o + + splits = {} + for split_name, split in input.splits.items(): + sequence_file = self.work_dir() / f"{split_name}.sqlite" + sequence_file.unlink(missing_ok=True) + sequence = SqliteSparseSequence(sequence_file) + sequence.extend( + Instance({"text": text_field}) + for instance in tqdm(split, desc=f"Processing split '{split_name}'") + for text_field in find_nested_fields(instance) + ) + splits[split_name] = sequence + + return DatasetDict(splits=splits, vocab=input.vocab, metadata=input.metadata) diff --git a/tests/commands/tango_test.py b/tests/commands/tango_test.py index 8f34cd4efff..169a2397c8d 100644 --- a/tests/commands/tango_test.py +++ b/tests/commands/tango_test.py @@ -47,7 +47,7 @@ def test_dry_run(self): params_as_dict_because_mypy_is_lame = { "dataset": {"type": "hf_dataset", "dataset_name": "squad"}, "dataset_text_only": { - "type": "text_only", + "type": "text_fields_only", "input": "dataset", "fields_to_keep": ["context", "question"], }, diff --git a/tests/tango/steps_test.py b/tests/tango/steps_test.py index 3630c7b8a45..015ec3cd1ed 100644 --- a/tests/tango/steps_test.py +++ b/tests/tango/steps_test.py @@ -17,7 +17,7 @@ import logging -from allennlp.tango.text_only import TextOnlyDataset +from allennlp.tango.text_fields_only import TextFieldsOnlyDataset logging.basicConfig(level=logging.INFO) @@ -63,7 +63,7 @@ def test_make_step_graph(ordered_ascending: bool): params_as_dict_because_mypy_is_lame = { "dataset": {"type": "hf_dataset", "dataset_name": "squad"}, "dataset_text_only": { - "type": "text_only", + "type": "text_fields_only", "input": {"type": "ref", "ref": "dataset"}, "fields_to_keep": ["context", "question"], }, @@ -75,7 +75,7 @@ def test_make_step_graph(ordered_ascending: bool): step_graph = step_graph_from_params(params.pop("steps")) assert len(step_graph) == 2 assert isinstance(step_graph["dataset"], HuggingfaceDataset) - assert isinstance(step_graph["dataset_text_only"], TextOnlyDataset) + assert isinstance(step_graph["dataset_text_only"], TextFieldsOnlyDataset) assert step_graph["dataset_text_only"].kwargs["input"] == step_graph["dataset"]