From e7a73b9dd0dd370a0d1512d6886f6c4783d9ef07 Mon Sep 17 00:00:00 2001 From: appel_c Date: Mon, 22 Sep 2025 16:52:10 -0500 Subject: [PATCH 1/4] feat(status): add CompareStatus, TransitionStatus, AllAndStatus and OrAnyStatus --- ophyd/status.py | 288 +++++++++++++++++++++++++++++++++++++ ophyd/tests/test_status.py | 233 +++++++++++++++++++++++++++++- 2 files changed, 520 insertions(+), 1 deletion(-) diff --git a/ophyd/status.py b/ophyd/status.py index 2e7c65fa5..500a76ecf 100644 --- a/ophyd/status.py +++ b/ophyd/status.py @@ -1,14 +1,21 @@ +from __future__ import annotations + import json +import operator import threading import time from collections import deque from functools import partial from logging import LoggerAdapter +from typing import Literal from warnings import warn +from typing import TYPE_CHECKING + import numpy as np from opentelemetry import trace + from .log import logger from .utils import ( InvalidState, @@ -18,6 +25,11 @@ adapt_old_callback_signature, ) +if TYPE_CHECKING: + from ophyd.device import Device + from ophyd.signal import Signal + + tracer = trace.get_tracer(__name__) _TRACE_PREFIX = "Ophyd Status" @@ -1138,3 +1150,279 @@ def wait(status, timeout=None, *, poll_rate="DEPRECATED"): from ``WaitTimeoutError`` above. """ return status.wait(timeout) + + +class CompareStatus(SubscriptionStatus): + """ + Status class to compare a signal value against a given value. + The comparison is done using the specified operation, which can be one of + '==', '!=', '<', '<=', '>', '>='. If the value is a string, only '==' and '!=' are allowed. + One may also define a value or list of values that will result in an exception if encountered. + The status is finished when the comparison is either true or an exception is raised. + + Parameters + ---------- + signal: The device signal to compare. + value: The value to compare against. + raise_exc_value: A value or list of values that will raise an exception if encountered. Defaults to None. + operation: The operation to use for comparison. Defaults to '=='. + event_type: The type of event to trigger on comparison. Defaults to None (default sub). + timeout: The timeout for the status. Defaults to None (indefinite). + settle_time: The time to wait for the signal to settle before comparison. Defaults to 0. + run: Whether to run the status callback on creation or not. Defaults to True. + """ + + OP_MAP = { + "==": operator.eq, + "!=": operator.ne, + "<": operator.lt, + "<=": operator.le, + ">": operator.gt, + ">=": operator.ge, + } + + def __init__( + self, + signal: Signal, + value: float | int | str, + *, + operation: Literal["==", "!=", "<", "<=", ">", ">="] = "==", + raise_exc_value: float | int | str | list[float | int | str] | None = None, + event_type=None, + timeout: float = None, + settle_time: float = 0, + run: bool = True, + ): + if isinstance(value, str): + if operation not in ("==", "!="): + raise ValueError(f"Invalid operation: {operation} for string comparison. Must be '==' or '!='.") + if operation not in ("==", "!=", "<", "<=", ">", ">="): + raise ValueError(f"Invalid operation: {operation}. Must be one of '==', '!=', '<', '<=', '>', '>='.") + self._signal = signal + self._value = value + self._operation = operation + if raise_exc_value is None: + self._raise_exc_values = [] + elif isinstance(raise_exc_value, (float, int, str)): + self._raise_exc_values = [raise_exc_value] + elif isinstance(raise_exc_value, list): + self._raise_exc_values = raise_exc_value + else: + raise ValueError( + f"raise_exc_value must be a float, int, str, list or None. Received: {raise_exc_value}" + ) + super().__init__( + device=signal, + callback=self._compare_callback, + timeout=timeout, + settle_time=settle_time, + event_type=event_type, + run=run, + ) + + def _compare_callback(self, value, **kwargs) -> bool: + """Callback for subscription status""" + if value in self._raise_exc_values: + self.set_exception( + ValueError( + f"CompareStatus for signal {self._signal.name} " + f"did not reach the desired state {self._operation} {self._value}. " + f"But instead reached {value} in {self._raise_exc_values}, which is set to raise an exception." + ) + ) + return False + return self.OP_MAP[self._operation](value, self._value) + + +class TransitionStatus(SubscriptionStatus): + """ + Status class to monitor transitions of a signal value through a list of specified transitions. + The status is finished when all transitions have been observed in order. The keyword argument + `strict` determines whether the transitions must occur in strict order or not. + If `raise_states` is provided, the status will raise an exception if the signal value matches + any of the values in `raise_states`. + + Parameters + ---------- + signal: The device signal to monitor. + transitions: A list of values to transition through. + strict: Whether the transitions must occur in strict order. Defaults to True. + raise_states: A list of values that will raise an exception if encountered. Defaults to None. + run: Whether to run the status callback on creation or not. Defaults to True. + event_type: The type of event to trigger on transition. Defaults to None (default sub). + timeout: The timeout for the status. Defaults to None (indefinite). + settle_time: The time to wait for the signal to settle before comparison. Defaults to 0. + """ + + def __init__( + self, + signal: Signal, + transitions: list[float | int | str], + *, + strict: bool = True, + raise_states: list[float | int | str] | None = None, + run: bool = True, + event_type=None, + timeout: float = None, + settle_time: float = 0, + ): + self._signal = signal + if not isinstance(transitions, list): + raise ValueError(f"Transitions must be a list of values. Received: {transitions}") + self._transitions = transitions + self._index = 0 + self._strict = strict + self._raise_states = raise_states if raise_states else [] + super().__init__( + device=signal, + callback=self._compare_callback, + timeout=timeout, + settle_time=settle_time, + event_type=event_type, + run=run, + ) + + def _compare_callback(self, old_value, value, **kwargs) -> bool: + """Callback for subscription Status""" + if value in self._raise_states: + self.set_exception( + ValueError( + f"Transition Status for {self._signal.name} resulted in a value: {value}. " + f"marked to raise {self._raise_states}. Expected transitions: {self._transitions}." + ) + ) + return False + if self._index == 0: + if value == self._transitions[0]: + self._index += 1 + else: + if self._strict: + if old_value == self._transitions[self._index - 1] and value == self._transitions[self._index]: + self._index += 1 + else: + if value == self._transitions[self._index]: + self._index += 1 + return self._is_finished() + + def _is_finished(self) -> bool: + """Check if the status is finished""" + return self._index >= len(self._transitions) + + +class AndAllStatus(DeviceStatus): + """ + A status that combines mutiple status objects in a list using logical and. + The status is finished when all status objects in the list are finished. + If any status object fails, the combined status will also fail and + set the exception from the first failed status on all sub-statuses. + + Parameters + ---------- + device: Device + status_list: A list of StatusBase or DeviceStatus objects to combine. + """ + + def __init__(self, device: Device, status_list: list[StatusBase | DeviceStatus], **kwargs): + self.status_list = status_list + super().__init__(device=device, **kwargs) + self._trace_attributes["all"] = [st._trace_attributes for st in self.status_list] + + def inner(status): + with self._lock: + if self._externally_initiated_completion: + return + if self.done: # Return if status is already done.. It must be resolved already + return + + for st in self.status_list: + with st._lock: + if st.done and not st.success: + self.set_exception(st.exception()) # st._exception + return + + if all(st.done for st in self.status_list) and all(st.success for st in self.status_list): + self.set_finished() + + for st in self.status_list: + with st._lock: + st.add_callback(inner) + + def set_exception(self, exc): + with self._lock: + if self._externally_initiated_completion or self.done: + return + super().set_exception(exc) + + def __repr__(self): + status_reprs = ", ".join(repr(s) for s in self.status_list) + return f"{self.__class__.__name__}([{status_reprs}])" + + def __str__(self): + status_strs = ", ".join(str(s) for s in self.status_list) + return f"AndAllStatus with {len(self.status_list)} statuses: [{status_strs}]" + + def __contains__(self, item): + return item in self.status_list + + +class OrAnyStatus(DeviceStatus): + """ + A status that combines multiple status objects in a list using logical OR. + The status is finished when any status object in the list finishes successfully. + If all status objects finish and none succeed, the combined status will fail + with the exception of the first failure. + + Parameters + ---------- + device: Device + status_list: A list of StatusBase or DeviceStatus objects to combine. + """ + + def __init__(self, device: Device, status_list: list[StatusBase | DeviceStatus], **kwargs): + self.status_list = status_list + super().__init__(device=device, **kwargs) + self._trace_attributes["all"] = [st._trace_attributes for st in self.status_list] + + def inner(status): + with self._lock: + if self._externally_initiated_completion: + return + if self.done: + return + + if status.done and status.success: + self.set_finished() + return + + if all(st.done for st in self.status_list): + exceptions = [ + st.exception() + for st in self.status_list + if st.done and not st.success and st.exception() is not None + ] + combined_exceptions = RuntimeError( + "; ".join(f"{type(exc).__name__}: {exc}" for exc in exceptions) + ) + self.set_exception(combined_exceptions) + + for st in self.status_list: + with st._lock: + st.add_callback(inner) + + def set_exception(self, exc): + with self._lock: + if self._externally_initiated_completion or self.done: + return + super().set_exception(exc) + + def __repr__(self): + status_reprs = ", ".join(repr(s) for s in self.status_list) + return f"{self.__class__.__name__}([{status_reprs}])" + + def __str__(self): + status_strs = ", ".join(str(s) for s in self.status_list) + return f"OrAnyStatus with {len(self.status_list)} statuses: [{status_strs}]" + + def __contains__(self, item): + return item in self.status_list + diff --git a/ophyd/tests/test_status.py b/ophyd/tests/test_status.py index ec0ba54d3..c38fd836b 100644 --- a/ophyd/tests/test_status.py +++ b/ophyd/tests/test_status.py @@ -4,13 +4,16 @@ import pytest from ophyd import Device -from ophyd.signal import EpicsSignalRO +from ophyd.signal import EpicsSignalRO, Signal +from ophyd.sim import FakeEpicsSignalRO from ophyd.status import ( DeviceStatus, MoveStatus, + OrAnyStatus, StableSubscriptionStatus, StatusBase, SubscriptionStatus, + TransitionStatus, UseNewProperty, ) from ophyd.utils import ( @@ -607,3 +610,231 @@ def _handle_failure(self): st.wait(1) time.sleep(0.1) # Wait for callbacks to run. assert state + +def test_compare_status_number(): + """Test CompareStatus with different operations.""" + sig = Signal(name="test_signal", value=0) + status = CompareStatus(signal=sig, value=5, operation="==") + assert status.done is False + sig.put(1) + assert status.done is False + sig.put(5) + assert status.done is True + + sig.put(5) + # Test with different operations + status = CompareStatus(signal=sig, value=5, operation="!=") + assert status.done is False + sig.put(5) + assert status.done is False + sig.put(6) + assert status.done is True + assert status.success is True + assert status.exception() is None + + sig.put(0) + status = CompareStatus(signal=sig, value=5, operation=">") + assert status.done is False + sig.put(5) + assert status.done is False + sig.put(10) + assert status.done is True + assert status.success is True + assert status.exception() is None + + +def test_compare_status_string(): + """Test CompareStatus with string values""" + sig = Signal(name="test_signal", value="test") + status = CompareStatus(signal=sig, value="test", operation="==") + assert status.done is False + sig.put("test1") + assert status.done is False + sig.put("test") + assert status.done is True + + sig.put("test") + # Test with different operations + status = CompareStatus(signal=sig, value="test", operation="!=") + assert status.done is False + sig.put("test") + assert status.done is False + sig.put("test1") + assert status.done is True + assert status.success is True + assert status.exception() is None + + # Test with greater than operation + # Raises ValueError for strings + sig.put("a") + with pytest.raises(ValueError): + status = CompareStatus(signal=sig, value="b", operation=">") + + +def test_transition_status(): + """Test TransitionStatus""" + sig = Signal(name="test_signal", value=0) + + # Test strict=True, without intermediate transitions + sig.put(0) + status = TransitionStatus(signal=sig, transitions=[1, 2, 3], strict=True) + + assert status.done is False + sig.put(1) + assert status.done is False + sig.put(2) + assert status.done is False + sig.put(3) + assert status.done is True + assert status.success is True + assert status.exception() is None + + # Test strict=True, ra + sig.put(1) + status = TransitionStatus(signal=sig, transitions=[1, 2, 3], strict=True, raise_states=[4]) + assert status.done is False + sig.put(4) + with pytest.raises(ValueError): + status.wait() + + assert status.done is True + assert status.success is False + assert isinstance(status.exception(), ValueError) + + # Test strict=False, with intermediate transitions + sig.put(0) + status = TransitionStatus(signal=sig, transitions=[1, 2, 3], strict=False) + + assert status.done is False + sig.put(1) # entering first transition + sig.put(3) + sig.put(2) # transision + assert status.done is False + sig.put(4) + sig.put(2) + sig.put(3) # last transition + assert status.done is True + assert status.success is True + assert status.exception() is None + + +def test_transition_status_strings(): + """Test TransitionStatus with string values""" + sig = Signal(name="test_signal", value="a") + + # Test strict=True, without intermediate transitions + sig.put("a") + status = TransitionStatus(signal=sig, transitions=["b", "c", "d"], strict=True) + + assert status.done is False + sig.put("b") + assert status.done is False + sig.put("c") + assert status.done is False + sig.put("d") + assert status.done is True + assert status.success is True + assert status.exception() is None + + # Test strict=True with additional intermediate transition + + sig.put("a") + status = TransitionStatus(signal=sig, transitions=["b", "c", "d"], strict=True) + + assert status.done is False + sig.put("b") # first transition + sig.put("e") + sig.put("b") + sig.put("c") # transision + assert status.done is False + sig.put("f") + sig.put("b") + sig.put("c") + sig.put("d") # transision + assert status.done is True + assert status.success is True + assert status.exception() is None + + # Test strict=False, with intermediate transitions + sig.put("a") + status = TransitionStatus(signal=sig, transitions=["b", "c", "d"], strict=False) + + assert status.done is False + sig.put("b") # entering first transition + sig.put("d") + sig.put("c") # transision + assert status.done is False + sig.put("e") + sig.put("c") + sig.put("d") # last transition + assert status.done is True + assert status.success is True + +def test_and_all_status(): + """ Test AndAllStatus """ + dev = Device("Tst:Prefix", name="test") + st1 = StatusBase() + st2 = StatusBase() + st3 = DeviceStatus(dev) + and_status = AndAllStatus(dev, [st1, st2, st3]) + + # Finish in success + assert and_status.done is False + st1.set_finished() + assert and_status.done is False + st2.set_finished() + assert and_status.done is False + st3.set_finished() + assert and_status.done is True + assert and_status.success is True + + # Failure + st1 = StatusBase() + st2 = StatusBase() + st3 = DeviceStatus(dev) + and_status = AndAllStatus(dev, [st1, st2, st3]) + + assert and_status.done is False + st1.set_finished() + assert and_status.done is False + st2.set_exception(Exception("Test exception")) + assert and_status.done is True + assert and_status.success is False + assert st2.success is False + assert st3.success is False + assert st3.done is False # Not resolved before failure + assert st1.success is True # Already resolved before failure + # Exception is propagated to all unresolved statuses + +def test_or_any_status(): + """ Test OrAnyStatus """ + dev = Device("Tst:Prefix", name="test") + st1 = StatusBase() + st2 = StatusBase() + st3 = DeviceStatus(dev) + or_status = OrAnyStatus(dev, [st1, st2, st3]) + + # Finish in success + assert or_status.done is False + st1.set_finished() + assert or_status.done is True + assert or_status.success is True + + st1 = StatusBase() + or_status = OrAnyStatus(dev, [st1, st2, st3]) + assert or_status.done is False + assert or_status.success is False + st1.set_exception(Exception("Test exception")) + assert or_status.done is False + assert or_status.success is False + st2.set_exception(RuntimeError("Test exception 2")) + assert or_status.done is False + assert or_status.success is False + st3.set_exception(ValueError("Test exception 3")) + assert or_status.done is True + assert or_status.success is False + assert isinstance(or_status.exception(), RuntimeError) + assert str(or_status.exception()) == "Exception: Test exception; RuntimeError: Test exception 2; ValueError: Test exception 3" + + + From 7b8f94e53dd563bc8ed08a9b4826c895d7277bf9 Mon Sep 17 00:00:00 2001 From: appel_c Date: Thu, 25 Sep 2025 09:52:19 -0500 Subject: [PATCH 2/4] refactor: cleanup and fix linter comments --- ophyd/status.py | 49 +++++++++++++++++++++++++------------- ophyd/tests/test_status.py | 31 ++++++++++++++---------- 2 files changed, 52 insertions(+), 28 deletions(-) diff --git a/ophyd/status.py b/ophyd/status.py index 500a76ecf..1e2d9100d 100644 --- a/ophyd/status.py +++ b/ophyd/status.py @@ -7,15 +7,12 @@ from collections import deque from functools import partial from logging import LoggerAdapter -from typing import Literal +from typing import TYPE_CHECKING, Literal from warnings import warn -from typing import TYPE_CHECKING - import numpy as np from opentelemetry import trace - from .log import logger from .utils import ( InvalidState, @@ -1195,9 +1192,13 @@ def __init__( ): if isinstance(value, str): if operation not in ("==", "!="): - raise ValueError(f"Invalid operation: {operation} for string comparison. Must be '==' or '!='.") + raise ValueError( + f"Invalid operation: {operation} for string comparison. Must be '==' or '!='." + ) if operation not in ("==", "!=", "<", "<=", ">", ">="): - raise ValueError(f"Invalid operation: {operation}. Must be one of '==', '!=', '<', '<=', '>', '>='.") + raise ValueError( + f"Invalid operation: {operation}. Must be one of '==', '!=', '<', '<=', '>', '>='." + ) self._signal = signal self._value = value self._operation = operation @@ -1268,7 +1269,9 @@ def __init__( ): self._signal = signal if not isinstance(transitions, list): - raise ValueError(f"Transitions must be a list of values. Received: {transitions}") + raise ValueError( + f"Transitions must be a list of values. Received: {transitions}" + ) self._transitions = transitions self._index = 0 self._strict = strict @@ -1297,7 +1300,10 @@ def _compare_callback(self, old_value, value, **kwargs) -> bool: self._index += 1 else: if self._strict: - if old_value == self._transitions[self._index - 1] and value == self._transitions[self._index]: + if ( + old_value == self._transitions[self._index - 1] + and value == self._transitions[self._index] + ): self._index += 1 else: if value == self._transitions[self._index]: @@ -1322,16 +1328,22 @@ class AndAllStatus(DeviceStatus): status_list: A list of StatusBase or DeviceStatus objects to combine. """ - def __init__(self, device: Device, status_list: list[StatusBase | DeviceStatus], **kwargs): + def __init__( + self, device: Device, status_list: list[StatusBase | DeviceStatus], **kwargs + ): self.status_list = status_list super().__init__(device=device, **kwargs) - self._trace_attributes["all"] = [st._trace_attributes for st in self.status_list] + self._trace_attributes["all"] = [ + st._trace_attributes for st in self.status_list + ] def inner(status): with self._lock: if self._externally_initiated_completion: return - if self.done: # Return if status is already done.. It must be resolved already + + # Return if status is already done.. + if self.done: return for st in self.status_list: @@ -1340,7 +1352,9 @@ def inner(status): self.set_exception(st.exception()) # st._exception return - if all(st.done for st in self.status_list) and all(st.success for st in self.status_list): + if all(st.done for st in self.status_list) and all( + st.success for st in self.status_list + ): self.set_finished() for st in self.status_list: @@ -1363,7 +1377,7 @@ def __str__(self): def __contains__(self, item): return item in self.status_list - + class OrAnyStatus(DeviceStatus): """ @@ -1378,10 +1392,14 @@ class OrAnyStatus(DeviceStatus): status_list: A list of StatusBase or DeviceStatus objects to combine. """ - def __init__(self, device: Device, status_list: list[StatusBase | DeviceStatus], **kwargs): + def __init__( + self, device: Device, status_list: list[StatusBase | DeviceStatus], **kwargs + ): self.status_list = status_list super().__init__(device=device, **kwargs) - self._trace_attributes["all"] = [st._trace_attributes for st in self.status_list] + self._trace_attributes["all"] = [ + st._trace_attributes for st in self.status_list + ] def inner(status): with self._lock: @@ -1425,4 +1443,3 @@ def __str__(self): def __contains__(self, item): return item in self.status_list - diff --git a/ophyd/tests/test_status.py b/ophyd/tests/test_status.py index c38fd836b..71a1eab9a 100644 --- a/ophyd/tests/test_status.py +++ b/ophyd/tests/test_status.py @@ -5,7 +5,6 @@ from ophyd import Device from ophyd.signal import EpicsSignalRO, Signal -from ophyd.sim import FakeEpicsSignalRO from ophyd.status import ( DeviceStatus, MoveStatus, @@ -611,6 +610,7 @@ def _handle_failure(self): time.sleep(0.1) # Wait for callbacks to run. assert state + def test_compare_status_number(): """Test CompareStatus with different operations.""" sig = Signal(name="test_signal", value=0) @@ -689,9 +689,11 @@ def test_transition_status(): assert status.success is True assert status.exception() is None - # Test strict=True, ra + # Test strict=True, raise_states sig.put(1) - status = TransitionStatus(signal=sig, transitions=[1, 2, 3], strict=True, raise_states=[4]) + status = TransitionStatus( + signal=sig, transitions=[1, 2, 3], strict=True, raise_states=[4] + ) assert status.done is False sig.put(4) with pytest.raises(ValueError): @@ -770,8 +772,9 @@ def test_transition_status_strings(): assert status.done is True assert status.success is True + def test_and_all_status(): - """ Test AndAllStatus """ + """Test AndAllStatus""" dev = Device("Tst:Prefix", name="test") st1 = StatusBase() st2 = StatusBase() @@ -802,12 +805,16 @@ def test_and_all_status(): assert and_status.success is False assert st2.success is False assert st3.success is False - assert st3.done is False # Not resolved before failure - assert st1.success is True # Already resolved before failure - # Exception is propagated to all unresolved statuses + + # Not resolved before failure + assert st3.done is False + + # Already resolved before failure + assert st1.success is True + def test_or_any_status(): - """ Test OrAnyStatus """ + """Test OrAnyStatus""" dev = Device("Tst:Prefix", name="test") st1 = StatusBase() st2 = StatusBase() @@ -834,7 +841,7 @@ def test_or_any_status(): assert or_status.done is True assert or_status.success is False assert isinstance(or_status.exception(), RuntimeError) - assert str(or_status.exception()) == "Exception: Test exception; RuntimeError: Test exception 2; ValueError: Test exception 3" - - - + assert ( + str(or_status.exception()) + == "Exception: Test exception; RuntimeError: Test exception 2; ValueError: Test exception 3" + ) From 61cb5d0a700ede09da6a4f4b7a90f0d39103a077 Mon Sep 17 00:00:00 2001 From: appel_c Date: Fri, 26 Sep 2025 11:58:23 -0500 Subject: [PATCH 3/4] refactor: cleanup, make fixes following PR review --- ophyd/status.py | 183 +++++++++++++++++++++---------------- ophyd/tests/test_status.py | 60 +++++++++--- 2 files changed, 153 insertions(+), 90 deletions(-) diff --git a/ophyd/status.py b/ophyd/status.py index 1e2d9100d..215b0149e 100644 --- a/ophyd/status.py +++ b/ophyd/status.py @@ -1159,14 +1159,22 @@ class CompareStatus(SubscriptionStatus): Parameters ---------- - signal: The device signal to compare. - value: The value to compare against. - raise_exc_value: A value or list of values that will raise an exception if encountered. Defaults to None. - operation: The operation to use for comparison. Defaults to '=='. - event_type: The type of event to trigger on comparison. Defaults to None (default sub). - timeout: The timeout for the status. Defaults to None (indefinite). - settle_time: The time to wait for the signal to settle before comparison. Defaults to 0. - run: Whether to run the status callback on creation or not. Defaults to True. + signal: Signal + The device signal to compare. + value: float | int | str + The value to compare against. + failure_value: float | int | str | list[float | int | str] | None, optional + A value or list of values that will raise an exception if encountered. Defaults to None. + operation_success: Literal["==", "!=", "<", "<=", ">", ">="], optional + The operation_success to use for comparison. Defaults to '=='. + event_type: Optional[Type[Event]] + The type of event to trigger on comparison. Defaults to None (default sub). + timeout: float | None, optional + The timeout for the status. Defaults to None (indefinite). + settle_time: float, optional + The time to wait for the signal to settle before comparison. Defaults to 0. + run: bool, optional + Whether to run the status callback on creation or not. Defaults to True. """ OP_MAP = { @@ -1183,34 +1191,39 @@ def __init__( signal: Signal, value: float | int | str, *, - operation: Literal["==", "!=", "<", "<=", ">", ">="] = "==", - raise_exc_value: float | int | str | list[float | int | str] | None = None, + operation_success: Literal["==", "!=", "<", "<=", ">", ">="] = "==", + failure_value: float | int | str | list[float | int | str] | None = None, + operation_failure: Literal["==", "!=", "<", "<=", ">", ">="] = "==", event_type=None, timeout: float = None, settle_time: float = 0, run: bool = True, ): if isinstance(value, str): - if operation not in ("==", "!="): + if operation_success not in ("==", "!=") and operation_failure not in ( + "==", + "!=", + ): raise ValueError( - f"Invalid operation: {operation} for string comparison. Must be '==' or '!='." + f"Invalid operation_success: {operation_success} for string comparison. Must be '==' or '!='." ) - if operation not in ("==", "!=", "<", "<=", ">", ">="): + if operation_success not in ("==", "!=", "<", "<=", ">", ">="): raise ValueError( - f"Invalid operation: {operation}. Must be one of '==', '!=', '<', '<=', '>', '>='." + f"Invalid operation_success: {operation_success}. Must be one of '==', '!=', '<', '<=', '>', '>='." ) self._signal = signal self._value = value - self._operation = operation - if raise_exc_value is None: - self._raise_exc_values = [] - elif isinstance(raise_exc_value, (float, int, str)): - self._raise_exc_values = [raise_exc_value] - elif isinstance(raise_exc_value, list): - self._raise_exc_values = raise_exc_value + self._operation_success = operation_success + self._operation_failure = operation_failure + if failure_value is None: + self._failure_values = [] + elif isinstance(failure_value, (float, int, str)): + self._failure_values = [failure_value] + elif isinstance(failure_value, list): + self._failure_values = failure_value else: raise ValueError( - f"raise_exc_value must be a float, int, str, list or None. Received: {raise_exc_value}" + f"failure_value must be a float, int, str, list or None. Received: {failure_value}" ) super().__init__( device=signal, @@ -1223,36 +1236,75 @@ def __init__( def _compare_callback(self, value, **kwargs) -> bool: """Callback for subscription status""" - if value in self._raise_exc_values: - self.set_exception( - ValueError( - f"CompareStatus for signal {self._signal.name} " - f"did not reach the desired state {self._operation} {self._value}. " - f"But instead reached {value} in {self._raise_exc_values}, which is set to raise an exception." + try: + if isinstance(value, list): + # List values are not supported + self.set_exception( + ValueError( + f"List values are not supported. Received value: {value}" + ) ) - ) + return False + if any( + self.OP_MAP[self._operation_failure](value, failure_value) + for failure_value in self._failure_values + ): + self.set_exception( + ValueError( + f"CompareStatus for signal {self._signal.name} " + f"did not reach the desired state {self._operation_success} {self._value}. " + f"But instead reached {value}, which is in list of failure values: {self._failure_values}" + ) + ) + return False + return self.OP_MAP[self._operation_success](value, self._value) + except Exception as e: + # Catch any exception if the value comparison fails + # This can be the case if value is None or of an unexpected type + # For example a numpy array + self.log.error(f"Error in CompareStatus callback: {e}") + self.set_exception(e) return False - return self.OP_MAP[self._operation](value, self._value) class TransitionStatus(SubscriptionStatus): """ Status class to monitor transitions of a signal value through a list of specified transitions. The status is finished when all transitions have been observed in order. The keyword argument - `strict` determines whether the transitions must occur in strict order or not. - If `raise_states` is provided, the status will raise an exception if the signal value matches - any of the values in `raise_states`. + `strict` determines whether the transitions must occur in strict order or not. The strict option + only becomes relevant once the first transition has been observed. + If `failure_states` is provided, the status will raise an exception if the signal value matches + any of the values in `failure_states`. Parameters ---------- - signal: The device signal to monitor. - transitions: A list of values to transition through. - strict: Whether the transitions must occur in strict order. Defaults to True. - raise_states: A list of values that will raise an exception if encountered. Defaults to None. - run: Whether to run the status callback on creation or not. Defaults to True. - event_type: The type of event to trigger on transition. Defaults to None (default sub). - timeout: The timeout for the status. Defaults to None (indefinite). - settle_time: The time to wait for the signal to settle before comparison. Defaults to 0. + signal: Signal + The device signal to monitor. + transitions: list + A list of values to transition through. + strict: bool, optional + Whether the transitions must occur in strict order. Defaults to True. + failure_states: list, optional + A list of values that will raise an exception if encountered. Defaults to None. + run: bool, optional + Whether to run the status callback on creation or not. Defaults to True. + event_type: optional + The type of event to trigger on transition. Defaults to None (default sub). + timeout: float | None, optional + The timeout for the status. Defaults to None (indefinite). + settle_time: float, optional + The time to wait for the signal to settle before checking transitions. Defaults to 0. + Notes + ----- + The 'strict' option does not raise if transitions are observed which are out of order. + It only determines whether a transition is accepted if it is observed from the + previous value in the list of transitions to the next value. + For example, with strict=True and transitions=[1, 2, 3], the sequence + 0 -> 1 -> 2 -> 3 is accepted, but 0 -> 2 -> 1 -> 3 is not and the status will not complete. + With strict=False, both sequences are accepted. + However, with strict=True, the sequence 0 -> 1 -> 3 -> 1 -> 2 -> 3 is accepted. + To raise an exception if an out-of-order transition is observed, use the + `failure_states` keyword argument. """ def __init__( @@ -1261,21 +1313,17 @@ def __init__( transitions: list[float | int | str], *, strict: bool = True, - raise_states: list[float | int | str] | None = None, + failure_states: list[float | int | str] | None = None, run: bool = True, event_type=None, timeout: float = None, settle_time: float = 0, ): self._signal = signal - if not isinstance(transitions, list): - raise ValueError( - f"Transitions must be a list of values. Received: {transitions}" - ) - self._transitions = transitions + self._transitions = tuple(transitions) self._index = 0 self._strict = strict - self._raise_states = raise_states if raise_states else [] + self._failure_states = failure_states if failure_states else [] super().__init__( device=signal, callback=self._compare_callback, @@ -1287,11 +1335,11 @@ def __init__( def _compare_callback(self, old_value, value, **kwargs) -> bool: """Callback for subscription Status""" - if value in self._raise_states: + if value in self._failure_states: self.set_exception( ValueError( f"Transition Status for {self._signal.name} resulted in a value: {value}. " - f"marked to raise {self._raise_states}. Expected transitions: {self._transitions}." + f"marked to raise {self._failure_states}. Expected transitions: {self._transitions}." ) ) return False @@ -1308,10 +1356,6 @@ def _compare_callback(self, old_value, value, **kwargs) -> bool: else: if value == self._transitions[self._index]: self._index += 1 - return self._is_finished() - - def _is_finished(self) -> bool: - """Check if the status is finished""" return self._index >= len(self._transitions) @@ -1319,18 +1363,16 @@ class AndAllStatus(DeviceStatus): """ A status that combines mutiple status objects in a list using logical and. The status is finished when all status objects in the list are finished. - If any status object fails, the combined status will also fail and - set the exception from the first failed status on all sub-statuses. Parameters ---------- device: Device - status_list: A list of StatusBase or DeviceStatus objects to combine. + The parent device for this status + status_list: list[StatusBase] + A list of StatusBase objects to combine. """ - def __init__( - self, device: Device, status_list: list[StatusBase | DeviceStatus], **kwargs - ): + def __init__(self, device: Device, status_list: list[StatusBase], **kwargs): self.status_list = status_list super().__init__(device=device, **kwargs) self._trace_attributes["all"] = [ @@ -1346,11 +1388,10 @@ def inner(status): if self.done: return - for st in self.status_list: - with st._lock: - if st.done and not st.success: - self.set_exception(st.exception()) # st._exception - return + with status._lock: + if status.done and not status.success: + self.set_exception(status.exception()) # st._exception + return if all(st.done for st in self.status_list) and all( st.success for st in self.status_list @@ -1361,12 +1402,6 @@ def inner(status): with st._lock: st.add_callback(inner) - def set_exception(self, exc): - with self._lock: - if self._externally_initiated_completion or self.done: - return - super().set_exception(exc) - def __repr__(self): status_reprs = ", ".join(repr(s) for s in self.status_list) return f"{self.__class__.__name__}([{status_reprs}])" @@ -1427,12 +1462,6 @@ def inner(status): with st._lock: st.add_callback(inner) - def set_exception(self, exc): - with self._lock: - if self._externally_initiated_completion or self.done: - return - super().set_exception(exc) - def __repr__(self): status_reprs = ", ".join(repr(s) for s in self.status_list) return f"{self.__class__.__name__}([{status_reprs}])" diff --git a/ophyd/tests/test_status.py b/ophyd/tests/test_status.py index 71a1eab9a..7e97b0cb9 100644 --- a/ophyd/tests/test_status.py +++ b/ophyd/tests/test_status.py @@ -614,16 +614,17 @@ def _handle_failure(self): def test_compare_status_number(): """Test CompareStatus with different operations.""" sig = Signal(name="test_signal", value=0) - status = CompareStatus(signal=sig, value=5, operation="==") + status = CompareStatus(signal=sig, value=5, operation_success="==") assert status.done is False sig.put(1) assert status.done is False sig.put(5) + status.wait(timeout=5) assert status.done is True sig.put(5) # Test with different operations - status = CompareStatus(signal=sig, value=5, operation="!=") + status = CompareStatus(signal=sig, value=5, operation_success="!=") assert status.done is False sig.put(5) assert status.done is False @@ -633,7 +634,7 @@ def test_compare_status_number(): assert status.exception() is None sig.put(0) - status = CompareStatus(signal=sig, value=5, operation=">") + status = CompareStatus(signal=sig, value=5, operation_success=">") assert status.done is False sig.put(5) assert status.done is False @@ -642,11 +643,50 @@ def test_compare_status_number(): assert status.success is True assert status.exception() is None + # Should raise + sig.put(0) + status = CompareStatus( + signal=sig, value=5, operation_success="==", failure_value=[10] + ) + with pytest.raises(ValueError): + sig.put(10) + status.wait() + assert status.done is True + assert status.success is False + assert isinstance(status.exception(), ValueError) + + # failure_operation + sig.put(0) + status = CompareStatus( + signal=sig, + value=5, + operation_success="==", + failure_value=10, + operation_failure=">", + ) + sig.put(10) + assert status.done is False + assert status.success is False + sig.put(11) + with pytest.raises(ValueError): + status.wait() + assert status.done is True + assert status.success is False + + # raise if array is returned + sig.put(0) + status = CompareStatus(signal=sig, value=5, operation_success="==") + with pytest.raises(ValueError): + sig.put([1, 2, 3]) + status.wait(timeout=2) + assert status.done is True + assert status.success is False + def test_compare_status_string(): """Test CompareStatus with string values""" sig = Signal(name="test_signal", value="test") - status = CompareStatus(signal=sig, value="test", operation="==") + status = CompareStatus(signal=sig, value="test", operation_success="==") assert status.done is False sig.put("test1") assert status.done is False @@ -655,7 +695,7 @@ def test_compare_status_string(): sig.put("test") # Test with different operations - status = CompareStatus(signal=sig, value="test", operation="!=") + status = CompareStatus(signal=sig, value="test", operation_success="!=") assert status.done is False sig.put("test") assert status.done is False @@ -664,12 +704,6 @@ def test_compare_status_string(): assert status.success is True assert status.exception() is None - # Test with greater than operation - # Raises ValueError for strings - sig.put("a") - with pytest.raises(ValueError): - status = CompareStatus(signal=sig, value="b", operation=">") - def test_transition_status(): """Test TransitionStatus""" @@ -689,10 +723,10 @@ def test_transition_status(): assert status.success is True assert status.exception() is None - # Test strict=True, raise_states + # Test strict=True, failure_states sig.put(1) status = TransitionStatus( - signal=sig, transitions=[1, 2, 3], strict=True, raise_states=[4] + signal=sig, transitions=[1, 2, 3], strict=True, failure_states=[4] ) assert status.done is False sig.put(4) From 4348dcd9f2a64e69f55f70f7127e75c91cd29172 Mon Sep 17 00:00:00 2001 From: appel_c Date: Fri, 26 Sep 2025 15:34:20 -0500 Subject: [PATCH 4/4] refactor: remove AndAllStatus and OrAnyStatus --- ophyd/status.py | 120 +------------------------------------ ophyd/tests/test_status.py | 76 +---------------------- 2 files changed, 3 insertions(+), 193 deletions(-) diff --git a/ophyd/status.py b/ophyd/status.py index 215b0149e..e9664f210 100644 --- a/ophyd/status.py +++ b/ophyd/status.py @@ -23,7 +23,6 @@ ) if TYPE_CHECKING: - from ophyd.device import Device from ophyd.signal import Signal @@ -1300,8 +1299,8 @@ class TransitionStatus(SubscriptionStatus): It only determines whether a transition is accepted if it is observed from the previous value in the list of transitions to the next value. For example, with strict=True and transitions=[1, 2, 3], the sequence - 0 -> 1 -> 2 -> 3 is accepted, but 0 -> 2 -> 1 -> 3 is not and the status will not complete. - With strict=False, both sequences are accepted. + 0 -> 1 -> 2 -> 3 is accepted, but 0 -> 1 -> 3 -> 2 -> 3 is not and the status + will not complete. With strict=False, both sequences are accepted. However, with strict=True, the sequence 0 -> 1 -> 3 -> 1 -> 2 -> 3 is accepted. To raise an exception if an out-of-order transition is observed, use the `failure_states` keyword argument. @@ -1357,118 +1356,3 @@ def _compare_callback(self, old_value, value, **kwargs) -> bool: if value == self._transitions[self._index]: self._index += 1 return self._index >= len(self._transitions) - - -class AndAllStatus(DeviceStatus): - """ - A status that combines mutiple status objects in a list using logical and. - The status is finished when all status objects in the list are finished. - - Parameters - ---------- - device: Device - The parent device for this status - status_list: list[StatusBase] - A list of StatusBase objects to combine. - """ - - def __init__(self, device: Device, status_list: list[StatusBase], **kwargs): - self.status_list = status_list - super().__init__(device=device, **kwargs) - self._trace_attributes["all"] = [ - st._trace_attributes for st in self.status_list - ] - - def inner(status): - with self._lock: - if self._externally_initiated_completion: - return - - # Return if status is already done.. - if self.done: - return - - with status._lock: - if status.done and not status.success: - self.set_exception(status.exception()) # st._exception - return - - if all(st.done for st in self.status_list) and all( - st.success for st in self.status_list - ): - self.set_finished() - - for st in self.status_list: - with st._lock: - st.add_callback(inner) - - def __repr__(self): - status_reprs = ", ".join(repr(s) for s in self.status_list) - return f"{self.__class__.__name__}([{status_reprs}])" - - def __str__(self): - status_strs = ", ".join(str(s) for s in self.status_list) - return f"AndAllStatus with {len(self.status_list)} statuses: [{status_strs}]" - - def __contains__(self, item): - return item in self.status_list - - -class OrAnyStatus(DeviceStatus): - """ - A status that combines multiple status objects in a list using logical OR. - The status is finished when any status object in the list finishes successfully. - If all status objects finish and none succeed, the combined status will fail - with the exception of the first failure. - - Parameters - ---------- - device: Device - status_list: A list of StatusBase or DeviceStatus objects to combine. - """ - - def __init__( - self, device: Device, status_list: list[StatusBase | DeviceStatus], **kwargs - ): - self.status_list = status_list - super().__init__(device=device, **kwargs) - self._trace_attributes["all"] = [ - st._trace_attributes for st in self.status_list - ] - - def inner(status): - with self._lock: - if self._externally_initiated_completion: - return - if self.done: - return - - if status.done and status.success: - self.set_finished() - return - - if all(st.done for st in self.status_list): - exceptions = [ - st.exception() - for st in self.status_list - if st.done and not st.success and st.exception() is not None - ] - combined_exceptions = RuntimeError( - "; ".join(f"{type(exc).__name__}: {exc}" for exc in exceptions) - ) - self.set_exception(combined_exceptions) - - for st in self.status_list: - with st._lock: - st.add_callback(inner) - - def __repr__(self): - status_reprs = ", ".join(repr(s) for s in self.status_list) - return f"{self.__class__.__name__}([{status_reprs}])" - - def __str__(self): - status_strs = ", ".join(str(s) for s in self.status_list) - return f"OrAnyStatus with {len(self.status_list)} statuses: [{status_strs}]" - - def __contains__(self, item): - return item in self.status_list diff --git a/ophyd/tests/test_status.py b/ophyd/tests/test_status.py index 7e97b0cb9..9de7fcf51 100644 --- a/ophyd/tests/test_status.py +++ b/ophyd/tests/test_status.py @@ -6,9 +6,9 @@ from ophyd import Device from ophyd.signal import EpicsSignalRO, Signal from ophyd.status import ( + CompareStatus, DeviceStatus, MoveStatus, - OrAnyStatus, StableSubscriptionStatus, StatusBase, SubscriptionStatus, @@ -805,77 +805,3 @@ def test_transition_status_strings(): sig.put("d") # last transition assert status.done is True assert status.success is True - - -def test_and_all_status(): - """Test AndAllStatus""" - dev = Device("Tst:Prefix", name="test") - st1 = StatusBase() - st2 = StatusBase() - st3 = DeviceStatus(dev) - and_status = AndAllStatus(dev, [st1, st2, st3]) - - # Finish in success - assert and_status.done is False - st1.set_finished() - assert and_status.done is False - st2.set_finished() - assert and_status.done is False - st3.set_finished() - assert and_status.done is True - assert and_status.success is True - - # Failure - st1 = StatusBase() - st2 = StatusBase() - st3 = DeviceStatus(dev) - and_status = AndAllStatus(dev, [st1, st2, st3]) - - assert and_status.done is False - st1.set_finished() - assert and_status.done is False - st2.set_exception(Exception("Test exception")) - assert and_status.done is True - assert and_status.success is False - assert st2.success is False - assert st3.success is False - - # Not resolved before failure - assert st3.done is False - - # Already resolved before failure - assert st1.success is True - - -def test_or_any_status(): - """Test OrAnyStatus""" - dev = Device("Tst:Prefix", name="test") - st1 = StatusBase() - st2 = StatusBase() - st3 = DeviceStatus(dev) - or_status = OrAnyStatus(dev, [st1, st2, st3]) - - # Finish in success - assert or_status.done is False - st1.set_finished() - assert or_status.done is True - assert or_status.success is True - - st1 = StatusBase() - or_status = OrAnyStatus(dev, [st1, st2, st3]) - assert or_status.done is False - assert or_status.success is False - st1.set_exception(Exception("Test exception")) - assert or_status.done is False - assert or_status.success is False - st2.set_exception(RuntimeError("Test exception 2")) - assert or_status.done is False - assert or_status.success is False - st3.set_exception(ValueError("Test exception 3")) - assert or_status.done is True - assert or_status.success is False - assert isinstance(or_status.exception(), RuntimeError) - assert ( - str(or_status.exception()) - == "Exception: Test exception; RuntimeError: Test exception 2; ValueError: Test exception 3" - )