From a6d964bc7a5edd423d98e81f87bb85b92aa4d955 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ic=C3=ADar=20Llor=C3=A9ns=20Jover?= Date: Thu, 4 Jun 2026 18:03:34 +0200 Subject: [PATCH 1/2] add nested runs functionality to MLflow logger --- obsweatherscale/logger.py | 115 ++++++++++++++++++++++++-- test/test_loggers.py | 167 +++++++++++++++++++++++++++++++------- 2 files changed, 243 insertions(+), 39 deletions(-) diff --git a/obsweatherscale/logger.py b/obsweatherscale/logger.py index e9b5aff..cf0aa18 100644 --- a/obsweatherscale/logger.py +++ b/obsweatherscale/logger.py @@ -139,37 +139,126 @@ class MLflowLogger(Logger): exists when the logger is constructed, a new run is started automatically and ended on :meth:`close`. + Modes + ----- + Standard mode + parent_run_name=None + + Behaves in a non-nested way: + - uses the active run if one exists + - otherwise creates a run named run_name + + Nested mode + parent_run_name= + + - creates/reuses a parent run named parent_run_name + - starts a nested child run named run_name + - logs everything to the child run + Parameters ---------- experiment_name : str, optional - MLflow experiment name. If provided, + MLflow experiment name. If provided, :func:`mlflow.set_experiment` is called. run_name : str, optional Name for the MLflow run (used only when a new run is started). + parent_run_name : str, optional + Name for the parent MLflow run (used only in nested mode). """ def __init__( self, experiment_name: str | None = None, run_name: str | None = None, + parent_run_name: str | None = None, + run_tags: dict[str, str] | None = None, + parent_tags: dict[str, str] | None = None, ) -> None: try: import mlflow # pylint: disable=import-outside-toplevel except ImportError as exc: raise ImportError( "mlflow is required for MLflowLogger. " - "Install it with: pip install mlflow" + "Install it with: pip install mlflow" ) from exc self._mlflow = mlflow - self._managed_run = False + self._managed_parent = False + self._managed_child = False + # Set and get active experiment if experiment_name is not None: self._mlflow.set_experiment(experiment_name) - if self._mlflow.active_run() is None: - self._mlflow.start_run(run_name=run_name) - self._managed_run = True + active_experiment = self._mlflow.get_experiment_by_name( + experiment_name or "Default" + ) + experiment_id = ( + active_experiment.experiment_id + if active_experiment is not None else None + ) + + # Set run kwargs + run_kwargs: dict[str, Any] = {"run_name": run_name} + if run_tags is not None: + run_kwargs["tags"] = run_tags + + parent_run_kwargs: dict[str, Any] = {"run_name": parent_run_name} + if parent_tags is not None: + parent_run_kwargs["tags"] = parent_tags + + # ---- Standard mode ---- + if parent_run_name is None: + if self._mlflow.active_run() is None: + self._mlflow.start_run(**run_kwargs) + self._managed_child = True + + # ---- Nested mode ---- + else: + active = self._mlflow.active_run() + + if active is not None: + # Validate that the active run is the expected parent + active_name = active.data.tags.get("mlflow.runName") + if active_name != parent_run_name: + raise RuntimeError( + f"Active MLflow run '{active_name}' does not match " + f"requested parent run '{parent_run_name}'." + ) + parent_run_id = active.info.run_id + + else: + # Search for an existing RUNNING parent run with this name + parent_run_id = self._find_run_by_name( + parent_run_name, experiment_id + ) + + if parent_run_id is not None: + # Re-activate the parent so the child can nest under it + self._mlflow.start_run(run_id=parent_run_id) + else: + # Create a fresh parent + self._mlflow.start_run(**parent_run_kwargs) + self._managed_parent = True + + self._mlflow.start_run(nested=True, **run_kwargs) + self._managed_child = True + + def _find_run_by_name( + self, + run_name: str, + experiment_id: str | None + ) -> str | None: + client = self._mlflow.MlflowClient() + search_kwargs: dict = { + "filter_string": f"attributes.run_name = '{run_name}'", + "max_results": 1, + } + if experiment_id is not None: + search_kwargs["experiment_ids"] = [experiment_id] + + results = client.search_runs(**search_kwargs) + return results[0].info.run_id if results else None def log_params(self, params: dict[str, Any]) -> None: """Log hyperparameters to the active MLflow run.""" @@ -180,6 +269,16 @@ def log_metrics(self, metrics: dict[str, float], step: int) -> None: self._mlflow.log_metrics(metrics, step=step) def close(self) -> None: - """End the MLflow run if it was started by this logger.""" - if self._managed_run: + """End MLflow run if it was started by this logger. + + If used in standard mode, this will end the run provided it was + started by this logger. + + If used in nested mode, this will end the child run that was + started, and the parent run if it was started by this logger. + """ + if self._managed_child: # end child run first + self._mlflow.end_run() + + if self._managed_parent: self._mlflow.end_run() diff --git a/test/test_loggers.py b/test/test_loggers.py index c0d7f3d..d4d4a93 100644 --- a/test/test_loggers.py +++ b/test/test_loggers.py @@ -218,10 +218,19 @@ def test_accepts_str_path(self, tmp_path: Path) -> None: class TestMLflowLogger: """Tests for MLflowLogger using a fully mocked mlflow module.""" - def _make_mlflow_mock(self, active_run: bool = False) -> MagicMock: + def _make_mlflow_mock( + self, + active_run: bool = False, + existing_runs: list | None = None, + ) -> MagicMock: """Return a mock mlflow module.""" mock = MagicMock() mock.active_run.return_value = MagicMock() if active_run else None + + if existing_runs is None: + existing_runs = [] + mock.MlflowClient.return_value.search_runs.return_value = existing_runs + return mock def _make_logger( @@ -229,63 +238,159 @@ def _make_logger( mock_mlflow: MagicMock, experiment_name: str | None = None, run_name: str | None = None, + parent_run_name: str | None = None, ) -> MLflowLogger: - """Instantiate MLflowLogger with mlflow patched.""" + """Instantiate MLflowNestedLogger with mlflow patched.""" with patch.dict("sys.modules", {"mlflow": mock_mlflow}): - # Re-import to pick up the patched module inside __init__ - logger = MLflowLogger.__new__(MLflowLogger) - logger._mlflow = mock_mlflow # type: ignore[assignment] # pyright: ignore[reportPrivateUsage] - logger._managed_run = False # pyright: ignore[reportPrivateUsage] - if experiment_name is not None: - mock_mlflow.set_experiment(experiment_name) - if mock_mlflow.active_run() is None: - mock_mlflow.start_run(run_name=run_name) - logger._managed_run = True # pyright: ignore[reportPrivateUsage] - return logger + return MLflowLogger( + experiment_name=experiment_name, + run_name=run_name, + parent_run_name=parent_run_name, + ) + + # mock test + def test_mlflow_is_mocked(self): + mock = MagicMock() + + with patch.dict("sys.modules", {"mlflow": mock}): + logger = MLflowLogger(experiment_name=None, run_name="x") + assert logger._mlflow is mock + assert logger._mlflow is mock # should still be the mock + + # mlflow import test def test_import_error_without_mlflow(self) -> None: with patch.dict("sys.modules", {"mlflow": None}): # type: ignore[dict-item] with pytest.raises(ImportError, match="mlflow is required"): MLflowLogger() - def test_starts_run_when_no_active_run(self) -> None: - mock = self._make_mlflow_mock(active_run=False) - self._make_logger(mock) - mock.start_run.assert_called() - - def test_no_new_run_when_active_run_exists(self) -> None: - mock = self._make_mlflow_mock(active_run=True) + # log_params() and log_metrics() tests + def test_standard_mode_log_params_delegates_to_mlflow(self) -> None: + mock = self._make_mlflow_mock() logger = self._make_logger(mock) - # close() must not call end_run if the run was started externally - logger.close() - mock.end_run.assert_not_called() + logger.log_params(SAMPLE_PARAMS) + mock.log_params.assert_called_once_with(SAMPLE_PARAMS) - def test_log_params_delegates_to_mlflow(self) -> None: + def test_nested_mode_log_params_delegates_to_mlflow(self) -> None: mock = self._make_mlflow_mock() - logger = self._make_logger(mock) + logger = self._make_logger(mock, run_name="child", parent_run_name="parent") logger.log_params(SAMPLE_PARAMS) mock.log_params.assert_called_once_with(SAMPLE_PARAMS) - def test_log_metrics_delegates_to_mlflow(self) -> None: + def test_standard_mode_log_metrics_delegates_to_mlflow(self) -> None: mock = self._make_mlflow_mock() logger = self._make_logger(mock) logger.log_metrics(SAMPLE_METRICS, step=5) mock.log_metrics.assert_called_once_with(SAMPLE_METRICS, step=5) - def test_close_ends_managed_run(self) -> None: + def test_nested_mode_log_metrics_delegates_to_mlflow(self) -> None: + mock = self._make_mlflow_mock() + logger = self._make_logger(mock, run_name="child", parent_run_name="parent") + logger.log_metrics(SAMPLE_METRICS, step=5) + mock.log_metrics.assert_called_once_with(SAMPLE_METRICS, step=5) + + # experiment creation tests + def test_standard_mode_set_experiment_called_when_provided(self) -> None: + mock = self._make_mlflow_mock() + self._make_logger(mock, experiment_name="my_experiment") + mock.set_experiment.assert_called_with("my_experiment") + + def test_nested_mode_set_experiment_called_when_provided(self) -> None: + mock = self._make_mlflow_mock() + self._make_logger( + mock, + experiment_name="my_experiment", + run_name="child", + parent_run_name="parent" + ) + mock.set_experiment.assert_called_with("my_experiment") + + # Run creation tests + def test_standard_mode_starts_run_when_no_active_run(self) -> None: + mock = self._make_mlflow_mock(active_run=False) + self._make_logger(mock) + mock.start_run.assert_called() + + def test_standard_mode_starts_named_run_when_no_active_run(self) -> None: + mock = self._make_mlflow_mock(active_run=False) + self._make_logger(mock, run_name="run", parent_run_name=None) + mock.start_run.assert_called_once_with(run_name="run") + + def test_standard_mode_no_new_run_when_active_run_exists(self) -> None: + mock = self._make_mlflow_mock(active_run=True) + self._make_logger(mock) + mock.start_run.assert_not_called() + + def test_nested_mode_starts_parent_and_child_runs(self) -> None: + mock = self._make_mlflow_mock(active_run=False) + + self._make_logger(mock, run_name="child", parent_run_name="parent") + + assert mock.start_run.call_count == 2 + mock.start_run.assert_any_call(run_name="parent") + mock.start_run.assert_any_call(run_name="child", nested=True) + + def test_nested_mode_reuses_matching_parent(self) -> None: + active_run = MagicMock() + active_run.data.tags.get.return_value = "parent" + + mock = MagicMock() + mock.active_run.return_value = active_run + + self._make_logger(mock, run_name="child", parent_run_name="parent") + + mock.start_run.assert_called_once_with(run_name="child", nested=True) + + def test_nested_mode_raises_for_wrong_parent(self) -> None: + mock = self._make_mlflow_mock(active_run=True) + + with pytest.raises(RuntimeError, match="does not match"): + self._make_logger( + mock, + run_name="child", + parent_run_name="expected_parent", + ) + + # close() tests + def test_standard_mode_close_ends_managed_run(self) -> None: mock = self._make_mlflow_mock(active_run=False) logger = self._make_logger(mock) logger.close() mock.end_run.assert_called_once() - def test_close_does_not_end_external_run(self) -> None: + def test_nested_mode_close_ends_child_and_parent_runs(self) -> None: + mock = self._make_mlflow_mock(active_run=False) + + logger = self._make_logger( + mock, + parent_run_name="parent", + run_name="child", + ) + logger.close() + + assert mock.end_run.call_count == 2 + + def test_standard_mode_close_does_not_end_external_run(self) -> None: + """ close() must not call end_run if the run was started externally""" mock = self._make_mlflow_mock(active_run=True) + logger = self._make_logger(mock) logger.close() mock.end_run.assert_not_called() - def test_set_experiment_called_when_provided(self) -> None: - mock = self._make_mlflow_mock() - self._make_logger(mock, experiment_name="my_experiment") - mock.set_experiment.assert_called_with("my_experiment") + def test_nested_mode_close_ends_only_child_if_external_parent(self) -> None: + active_run = MagicMock() + active_run.data.tags.get.return_value = "parent" + + mock = MagicMock() + mock.active_run.return_value = active_run + + logger = self._make_logger( + mock, + parent_run_name="parent", + run_name="child", + ) + logger.close() + + assert mock.end_run.call_count == 1 From b44c7981d325efa73979f5e79c4c1d6be6d5e95a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ic=C3=ADar=20Llor=C3=A9ns=20Jover?= Date: Thu, 4 Jun 2026 18:05:16 +0200 Subject: [PATCH 2/2] typing, linting and renaming --- obsweatherscale/logger.py | 9 ++++++--- obsweatherscale/training/trainer.py | 20 +++++++++++--------- test/test_loggers.py | 18 +++++++++--------- 3 files changed, 26 insertions(+), 21 deletions(-) diff --git a/obsweatherscale/logger.py b/obsweatherscale/logger.py index cf0aa18..4df9ff5 100644 --- a/obsweatherscale/logger.py +++ b/obsweatherscale/logger.py @@ -129,13 +129,16 @@ def log_metrics(self, metrics: dict[str, float], step: int) -> None: writer.writerow([step, *metrics.values()]) def close(self) -> None: - """No-op: the CSV file is opened and closed within each :meth:`log_metrics` call.""" + """No-op: the CSV file is opened and closed within each + :meth:`log_metrics` call. + """ class MLflowLogger(Logger): - """Logger that records parameters and metrics to MLflow. + """Logger that records parameters and metrics to MLflow in + optionally nested runs. - Requires the optional ``mlflow`` package. If no active MLflow run + Requires the optional ``mlflow`` package. If no active MLflow run exists when the logger is constructed, a new run is started automatically and ended on :meth:`close`. diff --git a/obsweatherscale/training/trainer.py b/obsweatherscale/training/trainer.py index 6bb3a3c..07d6123 100644 --- a/obsweatherscale/training/trainer.py +++ b/obsweatherscale/training/trainer.py @@ -75,8 +75,8 @@ def __init__( self, model: ExactGP, likelihood: _GaussianLikelihoodBase, - train_loss_fct: Callable, - val_loss_fct: Callable, + train_loss_fn: Callable, + val_loss_fn: Callable, device: torch.device, optimizer: Optimizer, ) -> None: @@ -88,9 +88,9 @@ def __init__( The Gaussian Process prior model. likelihood : _GaussianLikelihoodBase The likelihood function for the model. - train_loss_fct : Callable + train_loss_fn : Callable The loss function to use for training. - val_loss_fct : Callable + val_loss_fn : Callable The loss function to use for validation. device : torch.device The device to use for training (CPU or GPU). @@ -100,8 +100,8 @@ def __init__( self.model = model self.best_model = model self.likelihood = likelihood - self.train_loss_fct = train_loss_fct - self.val_loss_fct = val_loss_fct + self.train_loss_fn = train_loss_fn + self.val_loss_fn = val_loss_fn self.device = device self.optimizer = optimizer @@ -175,7 +175,9 @@ def fit( length = len(train) val_length = len(val_context) train_progression : dict[str, list] = { - "iter": [], "train loss": [], "val loss": [], "iter time": [], "train time": [] + "iter": [], + "train loss": [], "val loss": [], + "iter time": [], "train time": [], } torch.manual_seed(seed) @@ -302,7 +304,7 @@ def _train_step( inputs=batch_x, targets=batch_y, strict=False ) distribution = self.model(batch_x) - loss = self.train_loss_fct(distribution, batch_y) + loss = self.train_loss_fn(distribution, batch_y) loss.backward() @@ -344,7 +346,7 @@ def _val_step( batch_x_context, batch_y_context, strict=False ) distribution_val = self.model(batch_x_target) - loss = self.val_loss_fct(distribution_val, batch_y_target) + loss = self.val_loss_fn(distribution_val, batch_y_target) return loss.item() diff --git a/test/test_loggers.py b/test/test_loggers.py index d4d4a93..6d8e38d 100644 --- a/test/test_loggers.py +++ b/test/test_loggers.py @@ -16,9 +16,9 @@ MLflowLogger, ) -# --------------------------------------------------------------------------- +# ---------------------------------------------------------------------- # Fixtures -# --------------------------------------------------------------------------- +# ---------------------------------------------------------------------- SAMPLE_PARAMS: dict[str, Any] = { "batch_size": 32, @@ -37,9 +37,9 @@ } -# --------------------------------------------------------------------------- +# ---------------------------------------------------------------------- # TerminalLogger -# --------------------------------------------------------------------------- +# ---------------------------------------------------------------------- class TestTerminalLogger: @@ -104,9 +104,9 @@ def test_log_level_respected(self) -> None: assert logger._logger.level == logging.WARNING -# --------------------------------------------------------------------------- +# ---------------------------------------------------------------------- # CSVLogger -# --------------------------------------------------------------------------- +# ---------------------------------------------------------------------- class TestCSVLogger: @@ -210,9 +210,9 @@ def test_accepts_str_path(self, tmp_path: Path) -> None: assert (tmp_path / "log.csv").exists() -# --------------------------------------------------------------------------- +# ---------------------------------------------------------------------- # MLflowLogger -# --------------------------------------------------------------------------- +# ---------------------------------------------------------------------- class TestMLflowLogger: @@ -250,7 +250,7 @@ def _make_logger( ) # mock test - def test_mlflow_is_mocked(self): + def test_mlflow_is_mocked(self) -> None: mock = MagicMock() with patch.dict("sys.modules", {"mlflow": mock}):