diff --git a/nion/data/Core.py b/nion/data/Core.py index 41fa9c3..7857edc 100755 --- a/nion/data/Core.py +++ b/nion/data/Core.py @@ -1801,30 +1801,48 @@ def calculate_data() -> _ImageDataType: def function_warp(data_and_metadata_in: _DataAndMetadataLike, coordinates_in: typing.Sequence[_DataAndMetadataLike], order: int = 1) -> DataAndMetadata.DataAndMetadata: + """Warp or unwarp input data using an N-dimensional warp map. + + The warp map is applied along N axes and broadcast over any additional + dimensions in the input, allowing a single warp map to be used for + higher-dimensional data (e.g., image sequences). For multichannel data + such as RGB/RGBA, the warp is applied uniformly to all channels. + """ data_and_metadata = DataAndMetadata.promote_ndarray(data_and_metadata_in) coordinates = [DataAndMetadata.promote_ndarray(c) for c in coordinates_in] - coords = numpy.moveaxis(numpy.dstack([coordinate.data for coordinate in coordinates]), -1, 0) + coords = numpy.stack([c.data.astype(float) for c in coordinates], axis=0) data = data_and_metadata._data_ex - if data_and_metadata.is_data_rgb: - rgb: numpy.typing.NDArray[numpy.uint8] = numpy.zeros(tuple(data_and_metadata.dimensional_shape) + (3,), numpy.uint8) - rgb[..., 0] = scipy.ndimage.map_coordinates(data[..., 0], coords, order=order) - rgb[..., 1] = scipy.ndimage.map_coordinates(data[..., 1], coords, order=order) - rgb[..., 2] = scipy.ndimage.map_coordinates(data[..., 2], coords, order=order) - return DataAndMetadata.new_data_and_metadata(data=rgb, - dimensional_calibrations=data_and_metadata.dimensional_calibrations, - intensity_calibration=data_and_metadata.intensity_calibration) - elif data_and_metadata.is_data_rgba: - rgba: numpy.typing.NDArray[numpy.uint8] = numpy.zeros(tuple(data_and_metadata.dimensional_shape) + (4,), numpy.uint8) - rgba[..., 0] = scipy.ndimage.map_coordinates(data[..., 0], coords, order=order) - rgba[..., 1] = scipy.ndimage.map_coordinates(data[..., 1], coords, order=order) - rgba[..., 2] = scipy.ndimage.map_coordinates(data[..., 2], coords, order=order) - rgba[..., 3] = scipy.ndimage.map_coordinates(data[..., 3], coords, order=order) - return DataAndMetadata.new_data_and_metadata(data=rgba, + num_frame_dims = coords.shape[0] + + if data_and_metadata.is_data_rgb_type: + # Last dimension is channels + leading_shape = data.shape[:-num_frame_dims - 1] + output_shape = leading_shape + coords.shape[1:] + channels = 3 if data_and_metadata.is_data_rgb else 4 + output = numpy.zeros(tuple(output_shape) + (channels,), numpy.uint8) + + # scipy map_coordinates does not broadcast by default, so need to loop + for index in numpy.ndindex(leading_shape): + for chan in range(channels): + output[index + (..., chan)] = scipy.ndimage.map_coordinates( + data[index + (..., chan)], + coords, + order=order) + + return DataAndMetadata.new_data_and_metadata(data=output, dimensional_calibrations=data_and_metadata.dimensional_calibrations, intensity_calibration=data_and_metadata.intensity_calibration) else: + leading_shape = data.shape[:-num_frame_dims] + output_shape = leading_shape + coords.shape[1:] + output = numpy.zeros(output_shape, dtype=data.dtype) + + # scipy map_coordinates does not broadcast by default, so need to loop + for index in numpy.ndindex(leading_shape): + output[index] = scipy.ndimage.map_coordinates(data[index], coords, order=order) + return DataAndMetadata.new_data_and_metadata( - data=scipy.ndimage.map_coordinates(data, coords, order=order), + data=output, dimensional_calibrations=data_and_metadata.dimensional_calibrations, intensity_calibration=data_and_metadata.intensity_calibration) diff --git a/nion/data/test/Core_test.py b/nion/data/test/Core_test.py index 5106973..0bce735 100755 --- a/nion/data/test/Core_test.py +++ b/nion/data/test/Core_test.py @@ -1336,6 +1336,101 @@ def test_fft_zero_component_calibration(self) -> None: result4 = Core.function_fft(xdata4) self.assertAlmostEqual(0.0, result4.dimensional_calibrations[0].convert_to_calibrated_value(7.5)) + ## WARP TESTS + # Helper func + def _create_warp_test_data(self, + input_shape: tuple[int,...], + output_shape: tuple[int, ...] | None = None, + identity: bool = False, + mode: str = "greyscale") -> tuple[DataAndMetadata.DataAndMetadata, list[numpy.ndarray]]: + # Determine data type and channels based on mode + dtype: numpy.typing.DTypeLike + if mode == "greyscale": + dtype = float + channels = None + elif mode == "rgb": + dtype = numpy.uint8 + channels = 3 + elif mode == "rgba": + dtype = numpy.uint8 + channels = 4 + else: + raise ValueError(f"Invalid mode: {mode}. Choose 'greyscale', 'rgb', or 'rgba'.") + + # Prepare input shape for data array + if channels is None: + full_shape = input_shape + else: + full_shape = input_shape + (channels,) + + # Input data: sequential numbers for easy validation + data = numpy.arange(numpy.prod(full_shape), dtype=dtype).reshape(full_shape) + src = DataAndMetadata.new_data_and_metadata(data=data) + + # Determine output grid shape + if output_shape is None: + height, width = input_shape[-2:] + else: + height, width = output_shape[-2:] + + # Create warp coordinates + if identity: + # Identity warp: map output coordinates to same as input indices + warp_y, warp_x = numpy.meshgrid( + numpy.arange(input_shape[-2]), + numpy.arange(input_shape[-1]), + indexing="ij" + ) + else: + # Resampling / scaling: map output grid into input index space + input_height, input_width = input_shape[-2:] + y = numpy.arange(0, input_height, input_height / height) + x = numpy.arange(0, input_width, input_width / width) + warp_y, warp_x = numpy.meshgrid(y, x, indexing="ij") + + return src, [warp_y, warp_x] + + def _validate_warp_shape(self, src: DataAndMetadata.DataAndMetadata, dst: DataAndMetadata.DataAndMetadata, coords: list[numpy.ndarray], is_channel_data: bool = False) -> None: + n_dims = len(coords) # number of warped dimensions + output_shape = coords[0].shape # shape of warp grid + expected_shape = src.data_shape[:-n_dims] + output_shape + + if is_channel_data: + expected_shape = src.data_shape[:-n_dims-1] + output_shape + (src.data_shape[-1],) + + assert dst.data_shape == expected_shape, f"Output shape mismatch: {dst.data_shape} != {expected_shape}" + + def test_warp_identity(self) -> None: + src, coords = self._create_warp_test_data(input_shape=(4, 4), identity=True) + dst = Core.function_warp(src, coords) + self._validate_warp_shape(src, dst, coords) + + def test_warp_sequence(self) -> None: + src, coords = self._create_warp_test_data(input_shape=(6, 4, 4), output_shape=(4, 4)) + dst = Core.function_warp(src, coords) + self._validate_warp_shape(src, dst, coords) + + def test_warp_upscale(self) -> None: + # Input 4x4, warp to 8x8 + src, coords = self._create_warp_test_data(input_shape=(4, 4), output_shape=(8, 8)) + dst = Core.function_warp(src, coords) + self._validate_warp_shape(src, dst, coords) + + def test_warp_sequence_upscale(self) -> None: + src, coords = self._create_warp_test_data(input_shape=(6, 4, 4), output_shape=(6, 8, 8)) + dst = Core.function_warp(src, coords) + self._validate_warp_shape(src, dst, coords) + + def test_warp_rgb(self) -> None: + src, coords = self._create_warp_test_data(input_shape=(6, 4, 4), output_shape=(4, 4), mode="rgb") + dst = Core.function_warp(src, coords) + self._validate_warp_shape(src, dst, coords, is_channel_data=True) + + def test_warp_rgba(self) -> None: + src, coords = self._create_warp_test_data(input_shape=(6, 4, 4), output_shape=(4, 4), mode="rgba") + dst = Core.function_warp(src, coords) + self._validate_warp_shape(src, dst, coords, is_channel_data=True) + if __name__ == '__main__': logging.getLogger().setLevel(logging.DEBUG)