Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ classifiers = [
]
requires-python = '>=3.11'
dependencies = [
'easyscience @ git+https://github.com/easyscience/corelib.git@bayesian',
'easyscience @ git+https://github.com/easyscience/corelib.git@bayesian_mp',
# 'easyscience',
'scipp',
'refnx',
Expand Down
32 changes: 32 additions & 0 deletions src/easyreflectometry/calculators/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,38 @@
"""Init function."""
super().__init__(interface_list=CalculatorBase._calculators)

def __reduce__(self):

Check warning on line 17 in src/easyreflectometry/calculators/factory.py

View check run for this annotation

Codecov / codecov/patch

src/easyreflectometry/calculators/factory.py#L17

Added line #L17 was not covered by tests
"""Serialize the active calculator state for worker processes."""
wrapper = getattr(self(), '_wrapper', None)
wrapper_state = None
if wrapper is not None:
wrapper_state = {

Check warning on line 22 in src/easyreflectometry/calculators/factory.py

View check run for this annotation

Codecov / codecov/patch

src/easyreflectometry/calculators/factory.py#L19-L22

Added lines #L19 - L22 were not covered by tests
'storage': wrapper.storage,
'resolution_function': wrapper._resolution_function,
'magnetism': wrapper._magnetism,
}
return (

Check warning on line 27 in src/easyreflectometry/calculators/factory.py

View check run for this annotation

Codecov / codecov/patch

src/easyreflectometry/calculators/factory.py#L27

Added line #L27 was not covered by tests
self.__state_restore__,
(
self.__class__,
self.current_interface_name,
wrapper_state,
),
)

@staticmethod
def __state_restore__(cls, interface_str, wrapper_state):

Check warning on line 37 in src/easyreflectometry/calculators/factory.py

View check run for this annotation

Codecov / codecov/patch

src/easyreflectometry/calculators/factory.py#L36-L37

Added lines #L36 - L37 were not covered by tests
"""Restore a calculator factory with its active wrapper state."""
obj = cls()
if interface_str in obj.available_interfaces:
obj.switch(interface_str)
wrapper = getattr(obj(), '_wrapper', None)
if wrapper is not None and wrapper_state is not None:
wrapper.storage = wrapper_state['storage']
wrapper._resolution_function = wrapper_state['resolution_function']
wrapper._magnetism = wrapper_state['magnetism']
return obj

Check warning on line 47 in src/easyreflectometry/calculators/factory.py

View check run for this annotation

Codecov / codecov/patch

src/easyreflectometry/calculators/factory.py#L39-L47

Added lines #L39 - L47 were not covered by tests

def reset_storage(self) -> None:
"""Reset storage."""
return self().reset_storage()
Expand Down
9 changes: 9 additions & 0 deletions src/easyreflectometry/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,7 @@ def sample(
seed: int | None = None,
objective: str | None = None,
initializer: str | None = None,
n_workers: int | None = None,
progress_callback=None,
abort_test=None,
) -> dict:
Expand All @@ -383,8 +384,15 @@ def sample(
:param initializer: DREAM population initializer. One of ``'eps'``,
``'cov'``, ``'lhs'``, or ``'random'``. By default, None (BUMPS
uses ``'eps'``).
:param n_workers: Number of worker processes for parallel DREAM
population evaluation. ``None`` (default) and ``1`` use
sequential evaluation. Values greater than ``1`` enable
multiprocessing; the effective pool size is capped at
``min(n_workers, population)``.
:param progress_callback: Optional callback for progress updates during
sampling. Forwarded to the core MultiFitter.
:param abort_test: Optional callback that returns ``True`` to signal
that sampling should be aborted.
:return: Dictionary with keys ``'draws'``, ``'param_names'``, ``'state'``,
and ``'logp'``.
:raises RuntimeError: If the current minimizer is not a BUMPS instance.
Expand Down Expand Up @@ -428,6 +436,7 @@ def sample(
population=population,
seed=seed,
sampler_kwargs=sampler_kwargs or None,
n_workers=n_workers,
progress_callback=progress_callback,
abort_test=abort_test,
)
Expand Down
48 changes: 48 additions & 0 deletions tests/calculators/test_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# SPDX-FileCopyrightText: 2026 EasyScience contributors <https://github.com/easyscience>
# SPDX-License-Identifier: BSD-3-Clause

"""Tests for CalculatorFactory serialization."""

import pickle # noqa: S403

import numpy as np
from numpy.testing import assert_allclose

from easyreflectometry.calculators import CalculatorFactory
from easyreflectometry.model import Model
from easyreflectometry.model import PercentageFwhm
from easyreflectometry.sample import Layer
from easyreflectometry.sample import Material
from easyreflectometry.sample import Multilayer
from easyreflectometry.sample import Sample


def test_calculator_factory_pickle_preserves_active_wrapper_storage():
"""Pickled calculator factories retain model storage for worker processes."""
si = Material(sld=2.07, isld=0.0, name='Si')
film = Material(sld=2.0, isld=0.0, name='Film')
d2o = Material(sld=6.36, isld=0.0, name='D2O')

sample = Sample(
Multilayer(Layer(material=si, thickness=0.0, roughness=3.0, name='Si')),
Multilayer(Layer(material=film, thickness=250.0, roughness=3.0, name='Film')),
Multilayer(Layer(material=d2o, thickness=0.0, roughness=3.0, name='D2O')),
)
model = Model(
sample=sample,
scale=1.0,
background=1e-6,
resolution_function=PercentageFwhm(0.02),
)
interface = CalculatorFactory()
interface.switch('refnx')
model.interface = interface

restored = pickle.loads(pickle.dumps(interface)) # noqa: S301

assert model.unique_name in restored()._wrapper.storage['model']
q = np.linspace(0.01, 0.3, 10)
assert_allclose(
restored.fit_func(q, model.unique_name),
interface.fit_func(q, model.unique_name),
)
141 changes: 141 additions & 0 deletions tests/test_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,3 +1034,144 @@ def _fake_sample(*, x, y, weights, **kwargs):
fitter.sample(data, samples=100, burn=20, thin=2, objective='hybrid')

assert len(captured['x'][0]) == 10 # all points kept (Mighell-substituted)


class TestSampleWorkers:
"""n_workers parameter forwarding in sample()."""

def test_default_is_none(self):
"""When n_workers is not passed, it defaults to None (sequential)."""
model = Model()
model.interface = CalculatorFactory()
fitter = MultiFitter(model)

captured = {}

def _fake_sample(*, n_workers, **kwargs):
captured['n_workers'] = n_workers
return {'draws': np.ones((10, 2)), 'param_names': ['a', 'b'], 'state': None, 'logp': None}

fitter.easy_science_multi_fitter = MagicMock()
fitter.easy_science_multi_fitter.sample = MagicMock(side_effect=_fake_sample)

data = sc.DataGroup({
'coords': {'Qz_0': sc.array(dims=['Qz_0'], values=np.linspace(0.01, 0.3, 10))},
'data': {'R_0': sc.array(dims=['Qz_0'], values=np.ones(10), variances=np.ones(10) * 0.01)},
})

fitter.sample(data, samples=100, burn=20, thin=2)
assert captured['n_workers'] is None

def test_explicit_none(self):
"""Explicit n_workers=None is forwarded as None."""
model = Model()
model.interface = CalculatorFactory()
fitter = MultiFitter(model)

captured = {}

def _fake_sample(*, n_workers, **kwargs):
captured['n_workers'] = n_workers
return {'draws': np.ones((10, 2)), 'param_names': ['a', 'b'], 'state': None, 'logp': None}

fitter.easy_science_multi_fitter = MagicMock()
fitter.easy_science_multi_fitter.sample = MagicMock(side_effect=_fake_sample)

data = sc.DataGroup({
'coords': {'Qz_0': sc.array(dims=['Qz_0'], values=np.linspace(0.01, 0.3, 10))},
'data': {'R_0': sc.array(dims=['Qz_0'], values=np.ones(10), variances=np.ones(10) * 0.01)},
})

fitter.sample(data, samples=100, burn=20, thin=2, n_workers=None)
assert captured['n_workers'] is None

def test_explicit_one(self):
"""n_workers=1 is forwarded (sequential, same as None)."""
model = Model()
model.interface = CalculatorFactory()
fitter = MultiFitter(model)

captured = {}

def _fake_sample(*, n_workers, **kwargs):
captured['n_workers'] = n_workers
return {'draws': np.ones((10, 2)), 'param_names': ['a', 'b'], 'state': None, 'logp': None}

fitter.easy_science_multi_fitter = MagicMock()
fitter.easy_science_multi_fitter.sample = MagicMock(side_effect=_fake_sample)

data = sc.DataGroup({
'coords': {'Qz_0': sc.array(dims=['Qz_0'], values=np.linspace(0.01, 0.3, 10))},
'data': {'R_0': sc.array(dims=['Qz_0'], values=np.ones(10), variances=np.ones(10) * 0.01)},
})

fitter.sample(data, samples=100, burn=20, thin=2, n_workers=1)
assert captured['n_workers'] == 1

@pytest.mark.parametrize('workers', [2, 4, 8])
def test_multiple_workers_forwarded(self, workers):
"""n_workers values greater than 1 are forwarded to core."""
model = Model()
model.interface = CalculatorFactory()
fitter = MultiFitter(model)

captured = {}

def _fake_sample(*, n_workers, **kwargs):
captured['n_workers'] = n_workers
return {'draws': np.ones((10, 2)), 'param_names': ['a', 'b'], 'state': None, 'logp': None}

fitter.easy_science_multi_fitter = MagicMock()
fitter.easy_science_multi_fitter.sample = MagicMock(side_effect=_fake_sample)

data = sc.DataGroup({
'coords': {'Qz_0': sc.array(dims=['Qz_0'], values=np.linspace(0.01, 0.3, 10))},
'data': {'R_0': sc.array(dims=['Qz_0'], values=np.ones(10), variances=np.ones(10) * 0.01)},
})

fitter.sample(data, samples=100, burn=20, thin=2, n_workers=workers)
assert captured['n_workers'] == workers

def test_with_other_params_combined(self):
"""n_workers can be combined with all other sample() parameters."""
model = Model()
model.interface = CalculatorFactory()
fitter = MultiFitter(model)

captured = {}

def _fake_sample(*, samples, burn, thin, population, seed, n_workers, sampler_kwargs, **kwargs):
captured['samples'] = samples
captured['burn'] = burn
captured['thin'] = thin
captured['population'] = population
captured['seed'] = seed
captured['n_workers'] = n_workers
captured['sampler_kwargs'] = sampler_kwargs
return {'draws': np.ones((10, 2)), 'param_names': ['a', 'b'], 'state': None, 'logp': None}

fitter.easy_science_multi_fitter = MagicMock()
fitter.easy_science_multi_fitter.sample = MagicMock(side_effect=_fake_sample)

data = sc.DataGroup({
'coords': {'Qz_0': sc.array(dims=['Qz_0'], values=np.linspace(0.01, 0.3, 10))},
'data': {'R_0': sc.array(dims=['Qz_0'], values=np.ones(10), variances=np.ones(10) * 0.01)},
})

fitter.sample(
data,
samples=500,
burn=100,
thin=5,
population=8,
seed=42,
initializer='cov',
n_workers=4,
)
assert captured['samples'] == 500
assert captured['burn'] == 100
assert captured['thin'] == 5
assert captured['population'] == 8
assert captured['seed'] == 42
assert captured['n_workers'] == 4
assert captured['sampler_kwargs'] == {'init': 'cov'}
Loading