Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 119 additions & 8 deletions src/quantem/core/io/file_readers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import importlib
from os import PathLike
from pathlib import Path
from typing import Any

import h5py

Expand Down Expand Up @@ -115,7 +116,7 @@ def read_2d(
if file_type is None:
file_type = Path(file_path).suffix.lower().lstrip(".")

file_reader = importlib.import_module(f"rsciio.{file_type}").file_reader # type: ignore
file_reader = importlib.import_module(f"rsciio.{file_type}").file_reader
imported_data = file_reader(file_path)[0]

dataset = Dataset2d.from_array(
Expand Down Expand Up @@ -160,9 +161,9 @@ def read_emdfile_to_4dstem(
data_keys = ["datacube_root", "datacube", "data"] if data_keys is None else data_keys
print("keys: ", data_keys)
try:
data = file
data: Any = file
for key in data_keys:
data = data[key] # type: ignore
data = data[key]
except KeyError:
raise KeyError(f"Could not find key {data_keys} in {file_path}")

Expand All @@ -175,13 +176,13 @@ def read_emdfile_to_4dstem(
try:
calibration = file
for key in calibration_keys:
calibration = calibration[key] # type: ignore
calibration = calibration[key]
except KeyError:
raise KeyError(f"Could not find calibration key {calibration_keys} in {file_path}")
r_pixel_size = calibration["R_pixel_size"][()] # type: ignore
q_pixel_size = calibration["Q_pixel_size"][()] # type: ignore
r_pixel_units = calibration["R_pixel_units"][()] # type: ignore
q_pixel_units = calibration["Q_pixel_units"][()] # type: ignore
r_pixel_size = calibration["R_pixel_size"][()]
q_pixel_size = calibration["Q_pixel_size"][()]
r_pixel_units = calibration["R_pixel_units"][()]
q_pixel_units = calibration["Q_pixel_units"][()]

dataset = Dataset4dstem.from_array(
array=data,
Expand All @@ -191,3 +192,113 @@ def read_emdfile_to_4dstem(
dataset.file_path = file_path

return dataset


def read_abtem(url: str | PathLike):
"""
Read canonical abTEM Zarr file(s) into quantem Dataset(s).

Returns
-------
Dataset or list[Dataset]
"""

def _open_zarr(url):
import zarr

if url.endswith(".zip"):
store = zarr.storage.ZipStore(url, mode="r") # type: ignore
return zarr.open(store=store, mode="r")
return zarr.open(url, mode="r")

def _validate_canonical_format(root):
if "metadata0" in root.attrs:
return

if "kwargs0" in root.attrs:
raise ValueError(
"Legacy abTEM Zarr format detected.\n\n"
"quantem supports only canonical abTEM Zarr format.\n"
"Re-save using abtem>=1.1.0:\n\n"
" measurement = abtem.from_zarr(<legacy_path>)\n"
" measurement.to_zarr(<new_path>)"
)

raise ValueError("Unrecognized Zarr format.")

def _iter_metadata_indices(root):
i = 0
while f"metadata{i}" in root.attrs:
yield i
i += 1

def _decode_types(obj) -> Any:
if isinstance(obj, dict):
if obj.get("_type") == "tuple":
return tuple(_decode_types(v) for v in obj["_value"])
return {k: _decode_types(v) for k, v in obj.items()}
if isinstance(obj, list):
return [_decode_types(v) for v in obj]
return obj

def _normalize_unit(unit):
if unit is None:
return "pixels"

unit = unit.strip()

UNIT_MAP = {
"Å": "A",
"Ångström": "A",
"Angstrom": "A",
"1/Å": "A^-1",
"Å^-1": "A^-1",
"1/A": "A^-1",
}

return UNIT_MAP.get(unit, unit)

def _convert_axes(axes_dict):
sampling = []
origin = []
units = []

for key in sorted(axes_dict, key=lambda x: int(x.split("_")[1])):
axis = axes_dict[key]

sampling.append(axis.get("sampling", 1.0))
units.append(_normalize_unit(axis.get("units", None)))
origin.append(0.0) # deliberate design choice

return tuple(origin), tuple(sampling), tuple(units)

def _read_single_dataset(root, index):
metadata = _decode_types(root.attrs[f"metadata{index}"]).copy()

axes_dict = metadata.pop("axes")
dataset_type = metadata.pop("type")
metadata.pop("data_origin", None)

origin, sampling, units = _convert_axes(axes_dict)

array = root[f"array{index}"]
signal_units = metadata.get("units", "arb. units")

dataset = Dataset.from_array(
array=array,
name=dataset_type,
origin=origin,
sampling=sampling,
units=units,
signal_units=signal_units,
)

dataset._metadata = metadata
return dataset

root = _open_zarr(url)
_validate_canonical_format(root)

datasets = [_read_single_dataset(root, i) for i in _iter_metadata_indices(root)]

return datasets[0] if len(datasets) == 1 else datasets