diff --git a/src/quantem/core/io/file_readers.py b/src/quantem/core/io/file_readers.py index cb36f1de..04ff0d50 100644 --- a/src/quantem/core/io/file_readers.py +++ b/src/quantem/core/io/file_readers.py @@ -1,6 +1,7 @@ import importlib from os import PathLike from pathlib import Path +from typing import Any import h5py @@ -115,7 +116,7 @@ def read_2d( if file_type is None: file_type = Path(file_path).suffix.lower().lstrip(".") - file_reader = importlib.import_module(f"rsciio.{file_type}").file_reader # type: ignore + file_reader = importlib.import_module(f"rsciio.{file_type}").file_reader imported_data = file_reader(file_path)[0] dataset = Dataset2d.from_array( @@ -160,9 +161,9 @@ def read_emdfile_to_4dstem( data_keys = ["datacube_root", "datacube", "data"] if data_keys is None else data_keys print("keys: ", data_keys) try: - data = file + data: Any = file for key in data_keys: - data = data[key] # type: ignore + data = data[key] except KeyError: raise KeyError(f"Could not find key {data_keys} in {file_path}") @@ -175,13 +176,13 @@ def read_emdfile_to_4dstem( try: calibration = file for key in calibration_keys: - calibration = calibration[key] # type: ignore + calibration = calibration[key] except KeyError: raise KeyError(f"Could not find calibration key {calibration_keys} in {file_path}") - r_pixel_size = calibration["R_pixel_size"][()] # type: ignore - q_pixel_size = calibration["Q_pixel_size"][()] # type: ignore - r_pixel_units = calibration["R_pixel_units"][()] # type: ignore - q_pixel_units = calibration["Q_pixel_units"][()] # type: ignore + r_pixel_size = calibration["R_pixel_size"][()] + q_pixel_size = calibration["Q_pixel_size"][()] + r_pixel_units = calibration["R_pixel_units"][()] + q_pixel_units = calibration["Q_pixel_units"][()] dataset = Dataset4dstem.from_array( array=data, @@ -191,3 +192,113 @@ def read_emdfile_to_4dstem( dataset.file_path = file_path return dataset + + +def read_abtem(url: str | PathLike): + """ + Read canonical abTEM Zarr file(s) into quantem Dataset(s). + + Returns + ------- + Dataset or list[Dataset] + """ + + def _open_zarr(url): + import zarr + + if url.endswith(".zip"): + store = zarr.storage.ZipStore(url, mode="r") # type: ignore + return zarr.open(store=store, mode="r") + return zarr.open(url, mode="r") + + def _validate_canonical_format(root): + if "metadata0" in root.attrs: + return + + if "kwargs0" in root.attrs: + raise ValueError( + "Legacy abTEM Zarr format detected.\n\n" + "quantem supports only canonical abTEM Zarr format.\n" + "Re-save using abtem>=1.1.0:\n\n" + " measurement = abtem.from_zarr()\n" + " measurement.to_zarr()" + ) + + raise ValueError("Unrecognized Zarr format.") + + def _iter_metadata_indices(root): + i = 0 + while f"metadata{i}" in root.attrs: + yield i + i += 1 + + def _decode_types(obj) -> Any: + if isinstance(obj, dict): + if obj.get("_type") == "tuple": + return tuple(_decode_types(v) for v in obj["_value"]) + return {k: _decode_types(v) for k, v in obj.items()} + if isinstance(obj, list): + return [_decode_types(v) for v in obj] + return obj + + def _normalize_unit(unit): + if unit is None: + return "pixels" + + unit = unit.strip() + + UNIT_MAP = { + "Å": "A", + "Ångström": "A", + "Angstrom": "A", + "1/Å": "A^-1", + "Å^-1": "A^-1", + "1/A": "A^-1", + } + + return UNIT_MAP.get(unit, unit) + + def _convert_axes(axes_dict): + sampling = [] + origin = [] + units = [] + + for key in sorted(axes_dict, key=lambda x: int(x.split("_")[1])): + axis = axes_dict[key] + + sampling.append(axis.get("sampling", 1.0)) + units.append(_normalize_unit(axis.get("units", None))) + origin.append(0.0) # deliberate design choice + + return tuple(origin), tuple(sampling), tuple(units) + + def _read_single_dataset(root, index): + metadata = _decode_types(root.attrs[f"metadata{index}"]).copy() + + axes_dict = metadata.pop("axes") + dataset_type = metadata.pop("type") + metadata.pop("data_origin", None) + + origin, sampling, units = _convert_axes(axes_dict) + + array = root[f"array{index}"] + signal_units = metadata.get("units", "arb. units") + + dataset = Dataset.from_array( + array=array, + name=dataset_type, + origin=origin, + sampling=sampling, + units=units, + signal_units=signal_units, + ) + + dataset._metadata = metadata + return dataset + + root = _open_zarr(url) + _validate_canonical_format(root) + + datasets = [_read_single_dataset(root, i) for i in _iter_metadata_indices(root)] + + return datasets[0] if len(datasets) == 1 else datasets