diff --git a/ophyd/status.py b/ophyd/status.py index 2e7c65fa5..e9664f210 100644 --- a/ophyd/status.py +++ b/ophyd/status.py @@ -1,9 +1,13 @@ +from __future__ import annotations + import json +import operator import threading import time from collections import deque from functools import partial from logging import LoggerAdapter +from typing import TYPE_CHECKING, Literal from warnings import warn import numpy as np @@ -18,6 +22,10 @@ adapt_old_callback_signature, ) +if TYPE_CHECKING: + from ophyd.signal import Signal + + tracer = trace.get_tracer(__name__) _TRACE_PREFIX = "Ophyd Status" @@ -1138,3 +1146,213 @@ def wait(status, timeout=None, *, poll_rate="DEPRECATED"): from ``WaitTimeoutError`` above. """ return status.wait(timeout) + + +class CompareStatus(SubscriptionStatus): + """ + Status class to compare a signal value against a given value. + The comparison is done using the specified operation, which can be one of + '==', '!=', '<', '<=', '>', '>='. If the value is a string, only '==' and '!=' are allowed. + One may also define a value or list of values that will result in an exception if encountered. + The status is finished when the comparison is either true or an exception is raised. + + Parameters + ---------- + signal: Signal + The device signal to compare. + value: float | int | str + The value to compare against. + failure_value: float | int | str | list[float | int | str] | None, optional + A value or list of values that will raise an exception if encountered. Defaults to None. + operation_success: Literal["==", "!=", "<", "<=", ">", ">="], optional + The operation_success to use for comparison. Defaults to '=='. + event_type: Optional[Type[Event]] + The type of event to trigger on comparison. Defaults to None (default sub). + timeout: float | None, optional + The timeout for the status. Defaults to None (indefinite). + settle_time: float, optional + The time to wait for the signal to settle before comparison. Defaults to 0. + run: bool, optional + Whether to run the status callback on creation or not. Defaults to True. + """ + + OP_MAP = { + "==": operator.eq, + "!=": operator.ne, + "<": operator.lt, + "<=": operator.le, + ">": operator.gt, + ">=": operator.ge, + } + + def __init__( + self, + signal: Signal, + value: float | int | str, + *, + operation_success: Literal["==", "!=", "<", "<=", ">", ">="] = "==", + failure_value: float | int | str | list[float | int | str] | None = None, + operation_failure: Literal["==", "!=", "<", "<=", ">", ">="] = "==", + event_type=None, + timeout: float = None, + settle_time: float = 0, + run: bool = True, + ): + if isinstance(value, str): + if operation_success not in ("==", "!=") and operation_failure not in ( + "==", + "!=", + ): + raise ValueError( + f"Invalid operation_success: {operation_success} for string comparison. Must be '==' or '!='." + ) + if operation_success not in ("==", "!=", "<", "<=", ">", ">="): + raise ValueError( + f"Invalid operation_success: {operation_success}. Must be one of '==', '!=', '<', '<=', '>', '>='." + ) + self._signal = signal + self._value = value + self._operation_success = operation_success + self._operation_failure = operation_failure + if failure_value is None: + self._failure_values = [] + elif isinstance(failure_value, (float, int, str)): + self._failure_values = [failure_value] + elif isinstance(failure_value, list): + self._failure_values = failure_value + else: + raise ValueError( + f"failure_value must be a float, int, str, list or None. Received: {failure_value}" + ) + super().__init__( + device=signal, + callback=self._compare_callback, + timeout=timeout, + settle_time=settle_time, + event_type=event_type, + run=run, + ) + + def _compare_callback(self, value, **kwargs) -> bool: + """Callback for subscription status""" + try: + if isinstance(value, list): + # List values are not supported + self.set_exception( + ValueError( + f"List values are not supported. Received value: {value}" + ) + ) + return False + if any( + self.OP_MAP[self._operation_failure](value, failure_value) + for failure_value in self._failure_values + ): + self.set_exception( + ValueError( + f"CompareStatus for signal {self._signal.name} " + f"did not reach the desired state {self._operation_success} {self._value}. " + f"But instead reached {value}, which is in list of failure values: {self._failure_values}" + ) + ) + return False + return self.OP_MAP[self._operation_success](value, self._value) + except Exception as e: + # Catch any exception if the value comparison fails + # This can be the case if value is None or of an unexpected type + # For example a numpy array + self.log.error(f"Error in CompareStatus callback: {e}") + self.set_exception(e) + return False + + +class TransitionStatus(SubscriptionStatus): + """ + Status class to monitor transitions of a signal value through a list of specified transitions. + The status is finished when all transitions have been observed in order. The keyword argument + `strict` determines whether the transitions must occur in strict order or not. The strict option + only becomes relevant once the first transition has been observed. + If `failure_states` is provided, the status will raise an exception if the signal value matches + any of the values in `failure_states`. + + Parameters + ---------- + signal: Signal + The device signal to monitor. + transitions: list + A list of values to transition through. + strict: bool, optional + Whether the transitions must occur in strict order. Defaults to True. + failure_states: list, optional + A list of values that will raise an exception if encountered. Defaults to None. + run: bool, optional + Whether to run the status callback on creation or not. Defaults to True. + event_type: optional + The type of event to trigger on transition. Defaults to None (default sub). + timeout: float | None, optional + The timeout for the status. Defaults to None (indefinite). + settle_time: float, optional + The time to wait for the signal to settle before checking transitions. Defaults to 0. + Notes + ----- + The 'strict' option does not raise if transitions are observed which are out of order. + It only determines whether a transition is accepted if it is observed from the + previous value in the list of transitions to the next value. + For example, with strict=True and transitions=[1, 2, 3], the sequence + 0 -> 1 -> 2 -> 3 is accepted, but 0 -> 1 -> 3 -> 2 -> 3 is not and the status + will not complete. With strict=False, both sequences are accepted. + However, with strict=True, the sequence 0 -> 1 -> 3 -> 1 -> 2 -> 3 is accepted. + To raise an exception if an out-of-order transition is observed, use the + `failure_states` keyword argument. + """ + + def __init__( + self, + signal: Signal, + transitions: list[float | int | str], + *, + strict: bool = True, + failure_states: list[float | int | str] | None = None, + run: bool = True, + event_type=None, + timeout: float = None, + settle_time: float = 0, + ): + self._signal = signal + self._transitions = tuple(transitions) + self._index = 0 + self._strict = strict + self._failure_states = failure_states if failure_states else [] + super().__init__( + device=signal, + callback=self._compare_callback, + timeout=timeout, + settle_time=settle_time, + event_type=event_type, + run=run, + ) + + def _compare_callback(self, old_value, value, **kwargs) -> bool: + """Callback for subscription Status""" + if value in self._failure_states: + self.set_exception( + ValueError( + f"Transition Status for {self._signal.name} resulted in a value: {value}. " + f"marked to raise {self._failure_states}. Expected transitions: {self._transitions}." + ) + ) + return False + if self._index == 0: + if value == self._transitions[0]: + self._index += 1 + else: + if self._strict: + if ( + old_value == self._transitions[self._index - 1] + and value == self._transitions[self._index] + ): + self._index += 1 + else: + if value == self._transitions[self._index]: + self._index += 1 + return self._index >= len(self._transitions) diff --git a/ophyd/tests/test_status.py b/ophyd/tests/test_status.py index ec0ba54d3..9de7fcf51 100644 --- a/ophyd/tests/test_status.py +++ b/ophyd/tests/test_status.py @@ -4,13 +4,15 @@ import pytest from ophyd import Device -from ophyd.signal import EpicsSignalRO +from ophyd.signal import EpicsSignalRO, Signal from ophyd.status import ( + CompareStatus, DeviceStatus, MoveStatus, StableSubscriptionStatus, StatusBase, SubscriptionStatus, + TransitionStatus, UseNewProperty, ) from ophyd.utils import ( @@ -607,3 +609,199 @@ def _handle_failure(self): st.wait(1) time.sleep(0.1) # Wait for callbacks to run. assert state + + +def test_compare_status_number(): + """Test CompareStatus with different operations.""" + sig = Signal(name="test_signal", value=0) + status = CompareStatus(signal=sig, value=5, operation_success="==") + assert status.done is False + sig.put(1) + assert status.done is False + sig.put(5) + status.wait(timeout=5) + assert status.done is True + + sig.put(5) + # Test with different operations + status = CompareStatus(signal=sig, value=5, operation_success="!=") + assert status.done is False + sig.put(5) + assert status.done is False + sig.put(6) + assert status.done is True + assert status.success is True + assert status.exception() is None + + sig.put(0) + status = CompareStatus(signal=sig, value=5, operation_success=">") + assert status.done is False + sig.put(5) + assert status.done is False + sig.put(10) + assert status.done is True + assert status.success is True + assert status.exception() is None + + # Should raise + sig.put(0) + status = CompareStatus( + signal=sig, value=5, operation_success="==", failure_value=[10] + ) + with pytest.raises(ValueError): + sig.put(10) + status.wait() + assert status.done is True + assert status.success is False + assert isinstance(status.exception(), ValueError) + + # failure_operation + sig.put(0) + status = CompareStatus( + signal=sig, + value=5, + operation_success="==", + failure_value=10, + operation_failure=">", + ) + sig.put(10) + assert status.done is False + assert status.success is False + sig.put(11) + with pytest.raises(ValueError): + status.wait() + assert status.done is True + assert status.success is False + + # raise if array is returned + sig.put(0) + status = CompareStatus(signal=sig, value=5, operation_success="==") + with pytest.raises(ValueError): + sig.put([1, 2, 3]) + status.wait(timeout=2) + assert status.done is True + assert status.success is False + + +def test_compare_status_string(): + """Test CompareStatus with string values""" + sig = Signal(name="test_signal", value="test") + status = CompareStatus(signal=sig, value="test", operation_success="==") + assert status.done is False + sig.put("test1") + assert status.done is False + sig.put("test") + assert status.done is True + + sig.put("test") + # Test with different operations + status = CompareStatus(signal=sig, value="test", operation_success="!=") + assert status.done is False + sig.put("test") + assert status.done is False + sig.put("test1") + assert status.done is True + assert status.success is True + assert status.exception() is None + + +def test_transition_status(): + """Test TransitionStatus""" + sig = Signal(name="test_signal", value=0) + + # Test strict=True, without intermediate transitions + sig.put(0) + status = TransitionStatus(signal=sig, transitions=[1, 2, 3], strict=True) + + assert status.done is False + sig.put(1) + assert status.done is False + sig.put(2) + assert status.done is False + sig.put(3) + assert status.done is True + assert status.success is True + assert status.exception() is None + + # Test strict=True, failure_states + sig.put(1) + status = TransitionStatus( + signal=sig, transitions=[1, 2, 3], strict=True, failure_states=[4] + ) + assert status.done is False + sig.put(4) + with pytest.raises(ValueError): + status.wait() + + assert status.done is True + assert status.success is False + assert isinstance(status.exception(), ValueError) + + # Test strict=False, with intermediate transitions + sig.put(0) + status = TransitionStatus(signal=sig, transitions=[1, 2, 3], strict=False) + + assert status.done is False + sig.put(1) # entering first transition + sig.put(3) + sig.put(2) # transision + assert status.done is False + sig.put(4) + sig.put(2) + sig.put(3) # last transition + assert status.done is True + assert status.success is True + assert status.exception() is None + + +def test_transition_status_strings(): + """Test TransitionStatus with string values""" + sig = Signal(name="test_signal", value="a") + + # Test strict=True, without intermediate transitions + sig.put("a") + status = TransitionStatus(signal=sig, transitions=["b", "c", "d"], strict=True) + + assert status.done is False + sig.put("b") + assert status.done is False + sig.put("c") + assert status.done is False + sig.put("d") + assert status.done is True + assert status.success is True + assert status.exception() is None + + # Test strict=True with additional intermediate transition + + sig.put("a") + status = TransitionStatus(signal=sig, transitions=["b", "c", "d"], strict=True) + + assert status.done is False + sig.put("b") # first transition + sig.put("e") + sig.put("b") + sig.put("c") # transision + assert status.done is False + sig.put("f") + sig.put("b") + sig.put("c") + sig.put("d") # transision + assert status.done is True + assert status.success is True + assert status.exception() is None + + # Test strict=False, with intermediate transitions + sig.put("a") + status = TransitionStatus(signal=sig, transitions=["b", "c", "d"], strict=False) + + assert status.done is False + sig.put("b") # entering first transition + sig.put("d") + sig.put("c") # transision + assert status.done is False + sig.put("e") + sig.put("c") + sig.put("d") # last transition + assert status.done is True + assert status.success is True