diff --git a/src/lighteval/utils/cache_management.py b/src/lighteval/utils/cache_management.py index 354a72f52..178154613 100644 --- a/src/lighteval/utils/cache_management.py +++ b/src/lighteval/utils/cache_management.py @@ -414,13 +414,21 @@ def wrapper(self, docs: Union[Doc, List[Doc]], *args, **kwargs): # noqa C901 ) new_results = func(self, docs_not_cached, *args, **kwargs) - # Store new results in file cache - cache.cache_samples( - docs=docs_not_cached, - results=new_results, - task_ids=task_ids, - sampling_method=sampling_method, - ) + # Store new results in file cache. Under a data-parallel launch (e.g. accelerate with + # several processes), every rank holds the full, gathered results, so only the main + # process writes the cache file. Letting every rank write the same parquet concurrently + # corrupts it and makes subsequent loads fail. Other ranks wait at the barrier below + # before reading. See https://github.com/huggingface/lighteval/issues/1102. + accelerator = getattr(self, "accelerator", None) + if accelerator is None or accelerator.is_main_process: + cache.cache_samples( + docs=docs_not_cached, + results=new_results, + task_ids=task_ids, + sampling_method=sampling_method, + ) + if accelerator is not None: + accelerator.wait_for_everyone() # 3) Create final results by pulling from newly saved file cache final_cached_results = cache.get_samples_from_cache(docs, task_ids, sampling_method) diff --git a/tests/unit/utils/test_caching.py b/tests/unit/utils/test_caching.py index 794feb399..f5106c601 100644 --- a/tests/unit/utils/test_caching.py +++ b/tests/unit/utils/test_caching.py @@ -214,6 +214,44 @@ def test_cache_transformers(self, mock_create_model, mock_accelerator, mock_gree ], ) + @patch("lighteval.models.transformers.transformers_model.TransformersModel._padded_greedy_until") + @patch("lighteval.models.transformers.transformers_model.Accelerator") + @patch("lighteval.models.transformers.transformers_model.TransformersModel._create_auto_model") + def test_cache_only_main_process_writes(self, mock_create_model, mock_accelerator, mock_greedy_until): + """Regression test for #1102. Under a data-parallel (accelerate) launch every rank holds the same + gathered results and would write the same parquet concurrently, corrupting the cache. Only the main + process must write; other ranks must wait at a barrier (so the file exists before they read it).""" + from lighteval.models.transformers.transformers_model import TransformersModel, TransformersModelConfig + + mock_create_model = Mock() # noqa F841 + mock_accelerator_instance = Mock() + mock_accelerator_instance.device = torch.device("cpu") + mock_accelerator.return_value = mock_accelerator_instance + mock_greedy_until.return_value = self.model_responses + + with tempfile.TemporaryDirectory() as temp_dir: + config = TransformersModelConfig(model_name="Qwen/Qwen3-0.6B", cache_dir=temp_dir) + model = TransformersModel(config) + cache: SampleCache = model._cache + task_id = cache.get_task_id(self.task_name, SamplingMethod.GENERATIVE) + cache_file = cache.get_cache_path(task_id) + + # Non-main process: must NOT write the cache file, but must hit the barrier and still return + # results (in a real run the main process has written the file by the time the barrier clears, + # which we emulate by patching the cache read). + mock_accelerator_instance.is_main_process = False + mock_accelerator_instance.wait_for_everyone.reset_mock() + with patch.object(cache, "get_samples_from_cache", return_value=self.model_responses): + results = model.greedy_until(self.docs) + self.assertFalse(cache_file.exists(), "Non-main process must not write the cache file (#1102)") + mock_accelerator_instance.wait_for_everyone.assert_called() + self.assertEqual(len(results), len(self.docs)) + + # Main process: must write the cache file. + mock_accelerator_instance.is_main_process = True + model.greedy_until(self.docs) + self.assertTrue(cache_file.exists(), "Main process must write the cache file") + @patch("lighteval.models.vllm.vllm_model.VLLMModel._loglikelihood_tokens") @patch("lighteval.models.vllm.vllm_model.VLLMModel._greedy_until") @patch("lighteval.models.vllm.vllm_model.VLLMModel._create_auto_model")