diff --git a/tiler/merger.py b/tiler/merger.py index c8e24ea..01f8677 100644 --- a/tiler/merger.py +++ b/tiler/merger.py @@ -58,65 +58,85 @@ class Merger: def __init__( self, tiler: Tiler, - window: Union[None, str, np.ndarray] = None, - logits: int = 0, - save_visits: bool = True, + window: Union[str, np.ndarray] = "boxcar", + ignore_channels: bool = False, + logits_n: Optional[int] = None, + logits_dim: int = 0, + visits_buffer: bool = True, data_dtype: npt.DTypeLike = np.float32, weights_dtype: npt.DTypeLike = np.float32, + visits_dtype: npt.DTypeLike = np.uint32, ): """Merger holds cumulative result buffers for merging tiles created by a given Tiler - and the window function that is applied to added tiles. + and the window function that is applied to the added tiles. - There are two required np.float64 buffers: `self.data` and `self.weights_sum` - and one optional np.uint32 `self.data_visits` (see below `save_visits` argument). - - TODO: - - generate window depending on tile border type - # some reference for the future borders generation - # 1d = 3 types of tiles: 2 corners and middle - # 2d = 9 types of tiles: 4 corners, 4 tiles with 1 edge and middle - # 3d = 25 types of tiles: 8 corners, 12 tiles with 2 edges, 6 tiles with one edge and middle - # corners: 2^ndim - # tiles: 2*ndim*nedges + There are two required buffers: `self.data` and `self.weights` + and one optional `self.visits` (enabled by the keyword `visits_buffer`). Args: tiler (Tiler): Tiler with which the tiles were originally created. - window (None, str or np.ndarray): Specifies which window function to use for tile merging. + window (str or np.ndarray): Specifies which window function to use for tile merging. Must be one of `Merger.SUPPORTED_WINDOWS` or a numpy array with the same size as the tile. - Default is None which creates a boxcar window (constant 1s). + Default is "boxcar" which creates a boxcar window (constant 1s). + + ignore_channels (bool): If True, ignores channel dimension set in Tiler and makes Merger expect tiles + without channel dimensions. Default is `False`. + + logits_n (int, optional): Specifies the number of classes Merger is expected to hold, i.e. whether + Moreover, if set, makes Merger ignore Tiler's channel dimension. + Useful for merging multi-class segmentation predictions. Works in combination with `logits_dim`. + Default is `None`. - logits (int): Specify whether to add logits dimensions in front of the data array. Default is `0`. + logits_dim (int): If `logits_n` is set, specifies in which dimension logits should be expected. + Supports negative indexing. Default is `0`. - save_visits (bool): Specify whether to save which elements has been modified and how many times in - `self.data_visits`. Can be disabled to save some memory. Default is `True`. + visits_buffer (bool): Specifies whether to enable visits buffer which specifies how many times + each element has been modified. Can be disabled to save memory. Default is `True`. - data_dtype (np.dtype): Specify data type for data buffer that stores cumulative result. + data_dtype (np.dtype): Data type for data buffer that stores cumulative result. Default is `np.float32`. - weights_dtype (np.dtype): Specify data type for weights buffer that stores cumulative weights and window array. - If you don't need precision but would rather save memory you can use `np.float16`. - Likewise, on the opposite, you can use `np.float64`. + weights_dtype (np.dtype): Data type for window array and weights buffer that stores cumulative weights. + Can be used for precision-memory tradeoff, i.e. `np.float16` can save some memory but less precise. + On the opposite, you can use `np.float64` if you require high fpoint precision. Default is `np.float32`. + visits_dtype (np.dtype): Data type for visits buffer. Used only if `visits_buffer` is True. + Since visits are discrete, uint data types are recommended. + Default is `np.uint32`, "ought to be enough for anybody". """ self.tiler = tiler + self.ignore_channels = ignore_channels # Logits support - if not isinstance(logits, int) or logits < 0: - raise ValueError( - f"Logits must be an integer 0 or a positive number ({logits})." - ) - self.logits = int(logits) - - # Generate data and normalization arrays - self.data = self.data_visits = self.weights_sum = None + if logits_n is not None: + if not isinstance(logits_n, int) or logits_n <= 0: + raise ValueError(f"Number of logits must be a positive integer") + if not isinstance(logits_dim, int): + raise ValueError(f"Logits dimension must be an integer") + + # support negative indexing for logits dimensions + n_dim = self.tiler._n_dim + if logits_dim >= n_dim or logits_dim < -n_dim: + raise ValueError(f"Logits dimension must be from {-n_dim} to {n_dim-1}") + if logits_dim < 0: + logits_dim = n_dim + logits_dim + + self.logits_n = logits_n + self.logits_dim = logits_dim + + # Data, visits and weights buffers + self.data = self.weights = self.visits = None self.data_dtype = data_dtype self.weights_dtype = weights_dtype - self.reset(save_visits) + self.visits_dtype = visits_dtype + self.visits_buffer = visits_buffer + self._expected_tile_shape = None + self.reset() - # Generate window function + # Window function self.window = None self.set_window(window) @@ -136,16 +156,12 @@ def _generate_window(self, window: str, shape: Union[Tuple, List]) -> np.ndarray w = np.ones(shape, dtype=self.weights_dtype) overlap = self.tiler._tile_overlap for axis, length in enumerate(shape): - if axis == self.tiler.channel_dimension: - # channel dimension should have weight of 1 everywhere - win = get_window("boxcar", length) + if window == "overlap-tile": + axis_overlap = overlap[axis] // 2 + win = np.zeros(length) + win[axis_overlap:-axis_overlap] = 1 else: - if window == "overlap-tile": - axis_overlap = overlap[axis] // 2 - win = np.zeros(length) - win[axis_overlap:-axis_overlap] = 1 - else: - win = get_window(window, length) + win = get_window(window, length) for i in range(len(shape)): if i == axis: @@ -157,33 +173,32 @@ def _generate_window(self, window: str, shape: Union[Tuple, List]) -> np.ndarray return w - def set_window(self, window: Union[None, str, np.ndarray] = None) -> None: + def set_window(self, window: Union[str, np.ndarray] = "boxcar") -> None: """Sets window function depending on the given window function. Args: - window (None, str or np.ndarray): Specifies which window function to use for tile merging. + window (str or np.ndarray): Specifies which window function to use for tile merging. Must be one of `Merger.SUPPORTED_WINDOWS` or a numpy array with the same size as the tile. - If passed None sets a boxcar window (constant 1s). + Default is "boxcar" (constant 1s). Returns: None """ - # Warn user that changing window type after some elements were already visited is a bad idea. - if np.count_nonzero(self.data_visits): + # Warn user that changing window after some elements were already visited is a bad idea. + if np.count_nonzero(self.visits): warnings.warn( "You are setting window type after some elements were already added." ) - # Default window is boxcar - if window is None: - window = "boxcar" - # Generate or set a window function if isinstance(window, str): if window not in self.SUPPORTED_WINDOWS: - raise ValueError("Unsupported window, please check docs") - self.window = self._generate_window(window, self.tiler.tile_shape) + raise ValueError(f"Unsupported window {window}, please check docs.") + + self.window = self._generate_window( + window, self.tiler.tile_shape_wo_channel + ) elif isinstance(window, np.ndarray): if not np.array_equal(window.shape, self.tiler.tile_shape): raise ValueError( @@ -195,45 +210,77 @@ def set_window(self, window: Union[None, str, np.ndarray] = None) -> None: f"Unsupported type for window function ({type(window)}), expected str or np.ndarray." ) - def reset(self, save_visits: bool = True) -> None: - """Reset data, weights and optional data_visits buffers. - - Should be done after finishing merging full tile set and before starting processing the next tile set. - - Args: - save_visits (bool): Specify whether to save which elements has been modified and how many times in - `self.data_visits`. Can be disabled to save some memory. Default is `True`. + def reset(self) -> None: + """Resets data, weights and optional data_visits buffers, and recalculates expected tile shape. + Should be called if you want to reuse the same Merger for another image. Returns: None """ - padded_data_shape = self.tiler._new_shape + # Data buffer holds sum of all processed tiles multiplied by the window + # Weights buffer holds sum of window weights coefficients per element + # Optional visits buffer holds number of times each element was added to Merger + # Also, calculate expected tile shape + data_shape = self.tiler._new_shape + if self.logits_n and self.ignore_channels: + ds_wo_channels = data_shape[ + np.arange(self.tiler._n_dim) != self.tiler.channel_dimension + ] + + ds_wo_channels_w_logits = np.insert( + ds_wo_channels, + self.logits_dim, + self.logits_n, + ) + + self._ds_shape = ds_wo_channels_w_logits + + self.data = np.zeros(self._ds_shape, dtype=self.data_dtype) + self.weights = np.zeros(self._ds_shape, dtype=self.weights_dtype) - # Image holds sum of all processed tiles multiplied by the window - if self.logits: - self.data = np.zeros( - (self.logits, *padded_data_shape), dtype=self.data_dtype + self._expected_tile_shape = np.insert( + self.tiler.tile_shape_wo_channel, + self.logits_dim, + self.logits_n, ) + + if self.visits_buffer: + self.visits = np.zeros(self._ds_shape, dtype=self.weights_dtype) + + elif self.logits_n and not self.ignore_channels: + ds_w_logits = np.insert(data_shape, self.logits_dim, self.logits_n) + self._ds_shape = ds_w_logits + + self.data = np.zeros(self._ds_shape, dtype=self.data_dtype) + self.weights = np.zeros(self._ds_shape, dtype=self.weights_dtype) + + self._expected_tile_shape = np.insert( + self.tiler.tile_shape, + self.logits_dim, + self.logits_n, + ) + + if self.visits_buffer: + self.visits = np.zeros(self._ds_shape, dtype=self.weights_dtype) + else: - self.data = np.zeros(padded_data_shape, dtype=self.data_dtype) + self._ds_shape = data_shape + + self.data = np.zeros(self._ds_shape, dtype=self.data_dtype) + self.weights = np.zeros(self._ds_shape, dtype=self.weights_dtype) - # Data visits holds the number of times each element was assigned - if save_visits: - self.data_visits = np.zeros( - padded_data_shape, dtype=np.uint32 - ) # uint32 ought to be enough for anyone :) + self._expected_tile_shape = self.tiler.tile_shape - # Total data window (weight) coefficients - self.weights_sum = np.zeros(padded_data_shape, dtype=self.weights_dtype) + if self.visits_buffer: + self.visits = np.zeros(self._ds_shape, dtype=self.weights_dtype) def add(self, tile_id: int, data: np.ndarray) -> None: - """Adds `tile_id`-th tile into Merger. + """Adds `tile_id`-th tile into Merger buffers. Args: - tile_id (int): Specifies which tile it is. - - data (np.ndarray): Specifies tile data. + tile_id (int): Tile id + data (np.ndarray): Tile data Returns: None @@ -245,46 +292,54 @@ def add(self, tile_id: int, data: np.ndarray) -> None: ) data_shape = np.array(data.shape) - expected_tile_shape = ( - ((self.logits,) + tuple(self.tiler.tile_shape)) - if self.logits > 0 - else tuple(self.tiler.tile_shape) - ) if self.tiler.mode != "irregular": - if not np.all(np.equal(data_shape, expected_tile_shape)): + if not np.all(np.equal(data_shape, self._expected_tile_shape)): raise ValueError( f"Passed data shape ({data_shape}) " - f"does not fit expected tile shape ({expected_tile_shape})." + f"does not fit expected tile shape ({self._expected_tile_shape})." ) else: - if not np.all(np.less_equal(data_shape, expected_tile_shape)): + if not np.all(np.less_equal(data_shape, self._expected_tile_shape)): raise ValueError( f"Passed data shape ({data_shape}) " - f"must be less or equal than tile shape ({expected_tile_shape})." + f"must be less or equal than tile shape ({self._expected_tile_shape})." ) - # Select coordinates for data - shape_diff = expected_tile_shape - data_shape - a, b = self.tiler.get_tile_bbox(tile_id, with_channel_dim=True) + # Find difference between expected tile shape and provided tile shape + shape_diff = self._expected_tile_shape - data_shape + # Get tile bbox data buffer coordinates + a, b = self.tiler.get_tile_bbox( + tile_id, with_channel_dim=not self.ignore_channels + ) + + # Generate slicing that puts the provided tile into the data buffer sl = [slice(x, y - shape_diff[i]) for i, (x, y) in enumerate(zip(a, b))] + if self.logits_n: + # add whole axis for logits dimension + sl.insert(self.logits_dim, slice(None, None, None)) + + # Generate window slicing win_sl = [ slice(None, -diff) if (diff > 0) else slice(None, None) - for diff in shape_diff + for i, diff in enumerate(shape_diff) + if i != self.tiler.channel_dimension ] - if self.logits > 0: - self.data[tuple([slice(None, None, None)] + sl)] += ( - data * self.window[tuple(win_sl[1:])] - ) - self.weights_sum[tuple(sl)] += self.window[tuple(win_sl[1:])] - else: - self.data[tuple(sl)] += data * self.window[tuple(win_sl)] - self.weights_sum[tuple(sl)] += self.window[tuple(win_sl)] + # expand dimensions for correct broadcasting + if self.logits_n: + win_sl.insert(self.logits_dim, np.newaxis) + if not self.ignore_channels: + win_sl.insert(self.tiler.channel_dimension, np.newaxis) + + # Add to data and weights buffers + self.data[tuple(sl)] += data * self.window[tuple(win_sl)] + self.weights[tuple(sl)] += self.window[tuple(win_sl)] - if self.data_visits is not None: - self.data_visits[tuple(sl)] += 1 + # Add to visits buffer + if self.visits_buffer: + self.visits[tuple(sl)] += 1 def add_batch(self, batch_id: int, batch_size: int, data: np.ndarray) -> None: """Adds `batch_id`-th batch of `batch_size` tiles into Merger. @@ -318,7 +373,9 @@ def add_batch(self, batch_id: int, batch_size: int, data: np.ndarray) -> None: self.add(tile_i, data[data_i]) def _unpad( - self, data: np.ndarray, extra_padding: Optional[List[Tuple[int, int]]] = None + self, + data: np.ndarray, + extra_padding: Optional[List[Tuple[int, int]]] = None, ): """Slices/unpads data according to merger and tiler settings, as well as additional padding. @@ -330,22 +387,14 @@ def _unpad( ((before_1, after_1), … (before_N, after_N)) unique pad widths for each axis. Default is None. """ + if extra_padding: sl = [ slice(pad_from, shape - pad_to) - for shape, (pad_from, pad_to) in zip( - self.tiler.data_shape, extra_padding - ) + for shape, (pad_from, pad_to) in zip(self._ds_shape, extra_padding) ] else: - sl = [ - slice(None, self.tiler.data_shape[i]) - for i in range(len(self.tiler.data_shape)) - ] - - # if merger has logits dimension, add another slicing in front - if self.logits: - sl = [slice(None, None, None)] + sl + sl = [slice(None, i) for i in self._ds_shape] return data[tuple(sl)] @@ -353,7 +402,7 @@ def merge( self, unpad: bool = True, extra_padding: Optional[List[Tuple[int, int]]] = None, - argmax: bool = False, + argmax: Optional[int] = None, normalize_by_weights: bool = True, dtype: Optional[npt.DTypeLike] = None, ) -> np.ndarray: @@ -367,9 +416,9 @@ def merge( ((before_1, after_1), … (before_N, after_N)) unique pad widths for each axis. Default is None. - argmax (bool): If argmax is True, the first dimension will be argmaxed. - Useful when merger is initialized with `logits=True`. - Default is False. + argmax (int, optional): If set, specifies dimension to be argmaxed. + Useful in combination with `Merger.logits_n` and `Merger.logits_dim`. + Default is None normalize_by_weights (bool): If normalize is True, the accumulated data will be divided by weights. Default is True. @@ -392,13 +441,13 @@ def merge( # ignoring should be more precise without atol # but can hide other errors with np.errstate(divide="ignore", invalid="ignore"): - data = np.nan_to_num(data / self.weights_sum) + data = np.nan_to_num(data / self.weights) if unpad: data = self._unpad(data, extra_padding) if argmax: - data = np.argmax(data, 0) + data = np.argmax(data, argmax) if dtype is not None: return data.astype(dtype) diff --git a/tiler/tiler.py b/tiler/tiler.py index c1d1272..1b5f5be 100644 --- a/tiler/tiler.py +++ b/tiler/tiler.py @@ -58,7 +58,7 @@ def __init__( channel_dimension (int, optional): Specifies which axis is the channel dimension that will not be tiled. Usually it is the last or the first dimension of the array. - Negative indexing (`-len(data_shape)` to `-1` inclusive) is allowed. + Negative indexing (i.e., `-len(data_shape)` to `-1` inclusive) is allowed. Default is `None`, no channel dimension in the data. mode (str): Defines how the data will be tiled. @@ -147,6 +147,10 @@ def recalculate( # negative indexing self.channel_dimension = self._n_dim + self.channel_dimension + self.tile_shape_wo_channel = self.tile_shape[ + np.arange(self._n_dim) != self.channel_dimension + ] + # Overlap and step if overlap is not None: self.overlap = overlap @@ -163,12 +167,9 @@ def recalculate( self._tile_overlap[self.channel_dimension] = 0 elif isinstance(self.overlap, int): - tile_shape_without_channel = self.tile_shape[ - np.arange(self._n_dim) != self.channel_dimension - ] - if self.overlap < 0 or np.any(self.overlap >= tile_shape_without_channel): + if self.overlap < 0 or np.any(self.overlap >= self.tile_shape_wo_channel): raise ValueError( - f"Integer overlap must be in range of 0 to {np.max(tile_shape_without_channel)}" + f"Integer overlap must be in range of 0 to {np.max(self.tile_shape_wo_channel)}" ) self._tile_overlap: np.ndarray = np.array(