Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions src/lighteval/utils/cache_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,13 +414,21 @@ def wrapper(self, docs: Union[Doc, List[Doc]], *args, **kwargs): # noqa C901
)
new_results = func(self, docs_not_cached, *args, **kwargs)

# Store new results in file cache
cache.cache_samples(
docs=docs_not_cached,
results=new_results,
task_ids=task_ids,
sampling_method=sampling_method,
)
# Store new results in file cache. Under a data-parallel launch (e.g. accelerate with
# several processes), every rank holds the full, gathered results, so only the main
# process writes the cache file. Letting every rank write the same parquet concurrently
# corrupts it and makes subsequent loads fail. Other ranks wait at the barrier below
# before reading. See https://github.com/huggingface/lighteval/issues/1102.
accelerator = getattr(self, "accelerator", None)
if accelerator is None or accelerator.is_main_process:
cache.cache_samples(
docs=docs_not_cached,
results=new_results,
task_ids=task_ids,
sampling_method=sampling_method,
)
if accelerator is not None:
accelerator.wait_for_everyone()

# 3) Create final results by pulling from newly saved file cache
final_cached_results = cache.get_samples_from_cache(docs, task_ids, sampling_method)
Expand Down
38 changes: 38 additions & 0 deletions tests/unit/utils/test_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,44 @@ def test_cache_transformers(self, mock_create_model, mock_accelerator, mock_gree
],
)

@patch("lighteval.models.transformers.transformers_model.TransformersModel._padded_greedy_until")
@patch("lighteval.models.transformers.transformers_model.Accelerator")
@patch("lighteval.models.transformers.transformers_model.TransformersModel._create_auto_model")
def test_cache_only_main_process_writes(self, mock_create_model, mock_accelerator, mock_greedy_until):
"""Regression test for #1102. Under a data-parallel (accelerate) launch every rank holds the same
gathered results and would write the same parquet concurrently, corrupting the cache. Only the main
process must write; other ranks must wait at a barrier (so the file exists before they read it)."""
from lighteval.models.transformers.transformers_model import TransformersModel, TransformersModelConfig

mock_create_model = Mock() # noqa F841
mock_accelerator_instance = Mock()
mock_accelerator_instance.device = torch.device("cpu")
mock_accelerator.return_value = mock_accelerator_instance
mock_greedy_until.return_value = self.model_responses

with tempfile.TemporaryDirectory() as temp_dir:
config = TransformersModelConfig(model_name="Qwen/Qwen3-0.6B", cache_dir=temp_dir)
model = TransformersModel(config)
cache: SampleCache = model._cache
task_id = cache.get_task_id(self.task_name, SamplingMethod.GENERATIVE)
cache_file = cache.get_cache_path(task_id)

# Non-main process: must NOT write the cache file, but must hit the barrier and still return
# results (in a real run the main process has written the file by the time the barrier clears,
# which we emulate by patching the cache read).
mock_accelerator_instance.is_main_process = False
mock_accelerator_instance.wait_for_everyone.reset_mock()
with patch.object(cache, "get_samples_from_cache", return_value=self.model_responses):
results = model.greedy_until(self.docs)
self.assertFalse(cache_file.exists(), "Non-main process must not write the cache file (#1102)")
mock_accelerator_instance.wait_for_everyone.assert_called()
self.assertEqual(len(results), len(self.docs))

# Main process: must write the cache file.
mock_accelerator_instance.is_main_process = True
model.greedy_until(self.docs)
self.assertTrue(cache_file.exists(), "Main process must write the cache file")

@patch("lighteval.models.vllm.vllm_model.VLLMModel._loglikelihood_tokens")
@patch("lighteval.models.vllm.vllm_model.VLLMModel._greedy_until")
@patch("lighteval.models.vllm.vllm_model.VLLMModel._create_auto_model")
Expand Down
Loading