Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 69 additions & 16 deletions src/quantem/core/datastructures/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,21 +452,50 @@ def crop(
axes: tuple | None = None,
modify_in_place: bool = False,
) -> Self | None:
"""
Crops Dataset
"""Select a sub-region of the dataset along specified axes

Each ``crop_widths`` entry is a ``(start, stop)`` pair defining
which elements to keep. A ``stop`` of ``0`` keeps everything from
``start`` to the end.

Parameters
----------
crop_widths:tuple
Min and max for cropping each axis specified as a tuple
axes:
Axes over which to crop. If None specified, all are cropped.
modify_in_place: bool
If True, modifies dataset
crop_widths : tuple[tuple[int, int], ...]
``(start, stop)`` indices for each axis specified in ``axes``.
axes : tuple | None
Axes to crop. If None, all axes are cropped.
modify_in_place : bool
If True, modifies this dataset in-place and frees the original
array. If False, returns a new dataset.

Returns
-------
Dataset | None
Cropped dataset if ``modify_in_place`` is False, otherwise None.

Examples
--------
Dataset (cropped) only if modify_in_place is False
Crop real-space to a 128x128 region:

>>> dset_cropped = dset.crop(
... crop_widths=((64, 192), (64, 192)),
... axes=(0, 1),
... )

Crop k-space to keep the first 180 pixels:

>>> dset_preview = dset.crop(
... crop_widths=((0, 180), (0, 180)),
... axes=(2, 3),
... )

Crop k-space in-place to free memory:

>>> dset.crop(
... crop_widths=((4, 92), (4, 92)),
... axes=(2, 3),
... modify_in_place=True,
... )
"""
if axes is None:
if len(crop_widths) != self.ndim:
Expand Down Expand Up @@ -526,27 +555,51 @@ def bin(
modify_in_place: bool = False,
reducer: str = "sum",
) -> Self | None:
"""
Bin the Dataset by integer factors along selected axes using block reduction.
"""Reduce the dataset resolution by grouping pixels into blocks

Useful for reducing diffraction pattern size to speed up
reconstruction or lower memory usage. Sampling metadata is
updated automatically.

Parameters
----------
bin_factors : int | tuple[int, ...]
Bin factors per specified axis (positive integers).
A single integer bins all axes by the same factor. A tuple
specifies a different factor per axis, e.g. ``(1, 1, 2, 2)``
to bin only the last two axes by 2x.
axes : int | tuple[int, ...] | None
Axes to bin. If None, all axes are binned.
modify_in_place : bool
If True, modifies this dataset; otherwise returns a new Dataset.
reducer : {"sum","mean"}
Reduction applied within each block. "sum" (default) preserves counts;
"mean" averages over each block (block volume = product of factors).
If True, modifies this dataset in-place. If False, returns
a new dataset.
reducer : {"sum", "mean"}
Reduction applied within each block. "sum" (default) preserves
counts; "mean" averages over each block.

Returns
-------
Dataset | None
Binned dataset if ``modify_in_place`` is False, otherwise None.

Notes
-----
- Any remainder (shape % factor) is dropped on each binned axis.
- Sampling is multiplied by the factor on each binned axis.
- Origin is shifted to the center of the first block:
origin_new = origin_old + 0.5 * (factor - 1) * sampling_old

Examples
--------
Bin diffraction space by 2x to reduce memory:

>>> dset.bin(
... bin_factors=(1, 1, 2, 2),
... modify_in_place=True,
... )

Bin all axes by 2x and return a new dataset:

>>> dset_binned = dset.bin(bin_factors=2)
"""
reducer_norm = str(reducer).lower()
if reducer_norm not in ("sum", "mean"):
Expand Down
25 changes: 19 additions & 6 deletions tests/datastructures/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,28 +241,41 @@ def test_crop(self, sample_dataset_2d):
"""Test crop method."""
# Crop 1 pixel from each side
cropped_dataset = sample_dataset_2d.crop(crop_widths=((1, 9), (1, 9)))

# Check shape
assert cropped_dataset.shape == (8, 8) # Original (10, 10) - 1 from each side

# Check that the original dataset is unchanged
assert sample_dataset_2d.shape == (10, 10)

# Test modify_in_place
sample_dataset_2d.crop(crop_widths=((1, 9), (1, 9)), modify_in_place=True)
assert sample_dataset_2d.shape == (8, 8)

def test_crop_4dstem_kspace(self):
"""Test cropping k-space axes of a 4D-STEM dataset."""
dset = Dataset.from_array(np.random.rand(8, 8, 96, 96))
cropped = dset.crop(crop_widths=((4, 92), (4, 92)), axes=(2, 3))
assert cropped.shape == (8, 8, 88, 88)
assert dset.shape == (8, 8, 96, 96)

def test_crop_4dstem_realspace_in_place(self):
"""Test in-place real-space crop of a 4D-STEM dataset."""
dset = Dataset.from_array(np.random.rand(16, 16, 32, 32))
dset.crop(crop_widths=((4, 12), (4, 12)), axes=(0, 1), modify_in_place=True)
assert dset.shape == (8, 8, 32, 32)

def test_crop_4dstem_stop_zero(self):
"""Test that stop=0 keeps all remaining elements."""
dset = Dataset.from_array(np.random.rand(8, 8, 96, 96))
cropped = dset.crop(crop_widths=((10, 0), (10, 0)), axes=(2, 3))
assert cropped.shape == (8, 8, 86, 86)

def test_bin(self, sample_dataset_2d):
"""Test bin method."""
# Bin by factor of 2
binned_dataset = sample_dataset_2d.bin(bin_factors=2)

# Check shape
assert binned_dataset.shape == (5, 5) # Original (10, 10) / 2

# Check that the original dataset is unchanged
assert sample_dataset_2d.shape == (10, 10)

# Test modify_in_place
sample_dataset_2d.bin(bin_factors=2, modify_in_place=True)
assert sample_dataset_2d.shape == (5, 5)
Expand Down