diff --git a/src/quantem/core/datastructures/dataset.py b/src/quantem/core/datastructures/dataset.py index f01e68f6..3dc3b925 100644 --- a/src/quantem/core/datastructures/dataset.py +++ b/src/quantem/core/datastructures/dataset.py @@ -452,21 +452,50 @@ def crop( axes: tuple | None = None, modify_in_place: bool = False, ) -> Self | None: - """ - Crops Dataset + """Select a sub-region of the dataset along specified axes + + Each ``crop_widths`` entry is a ``(start, stop)`` pair defining + which elements to keep. A ``stop`` of ``0`` keeps everything from + ``start`` to the end. Parameters ---------- - crop_widths:tuple - Min and max for cropping each axis specified as a tuple - axes: - Axes over which to crop. If None specified, all are cropped. - modify_in_place: bool - If True, modifies dataset + crop_widths : tuple[tuple[int, int], ...] + ``(start, stop)`` indices for each axis specified in ``axes``. + axes : tuple | None + Axes to crop. If None, all axes are cropped. + modify_in_place : bool + If True, modifies this dataset in-place and frees the original + array. If False, returns a new dataset. Returns + ------- + Dataset | None + Cropped dataset if ``modify_in_place`` is False, otherwise None. + + Examples -------- - Dataset (cropped) only if modify_in_place is False + Crop real-space to a 128x128 region: + + >>> dset_cropped = dset.crop( + ... crop_widths=((64, 192), (64, 192)), + ... axes=(0, 1), + ... ) + + Crop k-space to keep the first 180 pixels: + + >>> dset_preview = dset.crop( + ... crop_widths=((0, 180), (0, 180)), + ... axes=(2, 3), + ... ) + + Crop k-space in-place to free memory: + + >>> dset.crop( + ... crop_widths=((4, 92), (4, 92)), + ... axes=(2, 3), + ... modify_in_place=True, + ... ) """ if axes is None: if len(crop_widths) != self.ndim: @@ -526,20 +555,31 @@ def bin( modify_in_place: bool = False, reducer: str = "sum", ) -> Self | None: - """ - Bin the Dataset by integer factors along selected axes using block reduction. + """Reduce the dataset resolution by grouping pixels into blocks + + Useful for reducing diffraction pattern size to speed up + reconstruction or lower memory usage. Sampling metadata is + updated automatically. Parameters ---------- bin_factors : int | tuple[int, ...] - Bin factors per specified axis (positive integers). + A single integer bins all axes by the same factor. A tuple + specifies a different factor per axis, e.g. ``(1, 1, 2, 2)`` + to bin only the last two axes by 2x. axes : int | tuple[int, ...] | None Axes to bin. If None, all axes are binned. modify_in_place : bool - If True, modifies this dataset; otherwise returns a new Dataset. - reducer : {"sum","mean"} - Reduction applied within each block. "sum" (default) preserves counts; - "mean" averages over each block (block volume = product of factors). + If True, modifies this dataset in-place. If False, returns + a new dataset. + reducer : {"sum", "mean"} + Reduction applied within each block. "sum" (default) preserves + counts; "mean" averages over each block. + + Returns + ------- + Dataset | None + Binned dataset if ``modify_in_place`` is False, otherwise None. Notes ----- @@ -547,6 +587,19 @@ def bin( - Sampling is multiplied by the factor on each binned axis. - Origin is shifted to the center of the first block: origin_new = origin_old + 0.5 * (factor - 1) * sampling_old + + Examples + -------- + Bin diffraction space by 2x to reduce memory: + + >>> dset.bin( + ... bin_factors=(1, 1, 2, 2), + ... modify_in_place=True, + ... ) + + Bin all axes by 2x and return a new dataset: + + >>> dset_binned = dset.bin(bin_factors=2) """ reducer_norm = str(reducer).lower() if reducer_norm not in ("sum", "mean"): diff --git a/tests/datastructures/test_dataset.py b/tests/datastructures/test_dataset.py index 93449aa2..fc5c108d 100644 --- a/tests/datastructures/test_dataset.py +++ b/tests/datastructures/test_dataset.py @@ -241,28 +241,41 @@ def test_crop(self, sample_dataset_2d): """Test crop method.""" # Crop 1 pixel from each side cropped_dataset = sample_dataset_2d.crop(crop_widths=((1, 9), (1, 9))) - # Check shape assert cropped_dataset.shape == (8, 8) # Original (10, 10) - 1 from each side - # Check that the original dataset is unchanged assert sample_dataset_2d.shape == (10, 10) - # Test modify_in_place sample_dataset_2d.crop(crop_widths=((1, 9), (1, 9)), modify_in_place=True) assert sample_dataset_2d.shape == (8, 8) + def test_crop_4dstem_kspace(self): + """Test cropping k-space axes of a 4D-STEM dataset.""" + dset = Dataset.from_array(np.random.rand(8, 8, 96, 96)) + cropped = dset.crop(crop_widths=((4, 92), (4, 92)), axes=(2, 3)) + assert cropped.shape == (8, 8, 88, 88) + assert dset.shape == (8, 8, 96, 96) + + def test_crop_4dstem_realspace_in_place(self): + """Test in-place real-space crop of a 4D-STEM dataset.""" + dset = Dataset.from_array(np.random.rand(16, 16, 32, 32)) + dset.crop(crop_widths=((4, 12), (4, 12)), axes=(0, 1), modify_in_place=True) + assert dset.shape == (8, 8, 32, 32) + + def test_crop_4dstem_stop_zero(self): + """Test that stop=0 keeps all remaining elements.""" + dset = Dataset.from_array(np.random.rand(8, 8, 96, 96)) + cropped = dset.crop(crop_widths=((10, 0), (10, 0)), axes=(2, 3)) + assert cropped.shape == (8, 8, 86, 86) + def test_bin(self, sample_dataset_2d): """Test bin method.""" # Bin by factor of 2 binned_dataset = sample_dataset_2d.bin(bin_factors=2) - # Check shape assert binned_dataset.shape == (5, 5) # Original (10, 10) / 2 - # Check that the original dataset is unchanged assert sample_dataset_2d.shape == (10, 10) - # Test modify_in_place sample_dataset_2d.bin(bin_factors=2, modify_in_place=True) assert sample_dataset_2d.shape == (5, 5)