Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,4 @@ node_modules/

.mypy_cache
.ruff_cache
uv.lock
2 changes: 2 additions & 0 deletions src/spatialdata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
"deepcopy",
"sanitize_table",
"sanitize_name",
"settings",
]

from spatialdata import dataloader, datasets, models, transformations
Expand Down Expand Up @@ -70,3 +71,4 @@
from spatialdata._io.format import SpatialDataFormatType
from spatialdata._io.io_zarr import read_zarr
from spatialdata._utils import get_pyramid_levels, unpad_raster
from spatialdata.config import settings
19 changes: 18 additions & 1 deletion src/spatialdata/_core/spatialdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -1110,6 +1110,7 @@ def write(
consolidate_metadata: bool = True,
update_sdata_path: bool = True,
sdata_formats: SpatialDataFormatType | list[SpatialDataFormatType] | None = None,
shapes_geometry_encoding: Literal["WKB", "geoarrow"] | None = None,
) -> None:
"""
Write the `SpatialData` object to a Zarr store.
Expand Down Expand Up @@ -1154,6 +1155,9 @@ def write(
unspecified, the element formats will be set to the latest element format compatible with the specified
SpatialData container format. All the formats and relationships between them are defined in
`spatialdata._io.format.py`.
shapes_geometry_encoding
Whether to use the WKB or geoarrow encoding for GeoParquet. See :meth:`geopandas.GeoDataFrame.to_parquet`
for details. If None, uses the value from :attr:`spatialdata.settings.shapes_geometry_encoding`.
"""
from spatialdata._io._utils import _resolve_zarr_store
from spatialdata._io.format import _parse_formats
Expand All @@ -1179,6 +1183,7 @@ def write(
element_name=element_name,
overwrite=False,
parsed_formats=parsed,
shapes_geometry_encoding=shapes_geometry_encoding,
)

if self.path != file_path and update_sdata_path:
Expand All @@ -1195,6 +1200,7 @@ def _write_element(
element_name: str,
overwrite: bool,
parsed_formats: dict[str, SpatialDataFormatType] | None = None,
shapes_geometry_encoding: Literal["WKB", "geoarrow"] | None = None,
) -> None:
from spatialdata._io.io_zarr import _get_groups_for_element

Expand Down Expand Up @@ -1247,6 +1253,7 @@ def _write_element(
shapes=element,
group=element_group,
element_format=parsed_formats["shapes"],
geometry_encoding=shapes_geometry_encoding,
)
elif element_type == "tables":
write_table(
Expand All @@ -1263,6 +1270,7 @@ def write_element(
element_name: str | list[str],
overwrite: bool = False,
sdata_formats: SpatialDataFormatType | list[SpatialDataFormatType] | None = None,
shapes_geometry_encoding: Literal["WKB", "geoarrow"] | None = None,
) -> None:
"""
Write a single element, or a list of elements, to the Zarr store used for backing.
Expand All @@ -1278,6 +1286,9 @@ def write_element(
sdata_formats
It is recommended to leave this parameter equal to `None`. See more details in the documentation of
`SpatialData.write()`.
shapes_geometry_encoding
Whether to use the WKB or geoarrow encoding for GeoParquet. See :meth:`geopandas.GeoDataFrame.to_parquet`
for details. If None, uses the value from :attr:`spatialdata.settings.shapes_geometry_encoding`.

Notes
-----
Expand All @@ -1291,7 +1302,12 @@ def write_element(
if isinstance(element_name, list):
for name in element_name:
assert isinstance(name, str)
self.write_element(name, overwrite=overwrite, sdata_formats=sdata_formats)
self.write_element(
name,
overwrite=overwrite,
sdata_formats=sdata_formats,
shapes_geometry_encoding=shapes_geometry_encoding,
)
return

check_valid_name(element_name)
Expand Down Expand Up @@ -1325,6 +1341,7 @@ def write_element(
element_name=element_name,
overwrite=overwrite,
parsed_formats=parsed_formats,
shapes_geometry_encoding=shapes_geometry_encoding,
)
# After every write, metadata should be consolidated, otherwise this can lead to IO problems like when deleting.
if self.has_consolidated_metadata():
Expand Down
22 changes: 18 additions & 4 deletions src/spatialdata/_io/io_shapes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Any
from typing import Any, Literal

import numpy as np
import zarr
Expand Down Expand Up @@ -70,6 +70,7 @@ def write_shapes(
group: zarr.Group,
group_type: str = "ngff:shapes",
element_format: Format = CurrentShapesFormat(),
geometry_encoding: Literal["WKB", "geoarrow"] | None = None,
) -> None:
"""Write shapes to spatialdata zarr store.

Expand All @@ -86,15 +87,23 @@ def write_shapes(
The type of the element.
element_format
The format of the shapes element used to store it.
geometry_encoding
Whether to use the WKB or geoarrow encoding for GeoParquet. See :meth:`geopandas.GeoDataFrame.to_parquet` for
details. If None, uses the value from :attr:`spatialdata.settings.shapes_geometry_encoding`.
"""
from spatialdata.config import settings

if geometry_encoding is None:
geometry_encoding = settings.shapes_geometry_encoding

axes = get_axes_names(shapes)
transformations = _get_transformations(shapes)
if transformations is None:
raise ValueError(f"{group.basename} does not have any transformations and can therefore not be written.")
if isinstance(element_format, ShapesFormatV01):
attrs = _write_shapes_v01(shapes, group, element_format)
elif isinstance(element_format, ShapesFormatV02 | ShapesFormatV03):
attrs = _write_shapes_v02_v03(shapes, group, element_format)
attrs = _write_shapes_v02_v03(shapes, group, element_format, geometry_encoding=geometry_encoding)
else:
raise ValueError(f"Unsupported format version {element_format.version}. Please update the spatialdata library.")

Expand Down Expand Up @@ -139,7 +148,9 @@ def _write_shapes_v01(shapes: GeoDataFrame, group: zarr.Group, element_format: F
return attrs


def _write_shapes_v02_v03(shapes: GeoDataFrame, group: zarr.Group, element_format: Format) -> Any:
def _write_shapes_v02_v03(
shapes: GeoDataFrame, group: zarr.Group, element_format: Format, geometry_encoding: Literal["WKB", "geoarrow"]
) -> Any:
"""Write shapes to spatialdata zarr store using format ShapesFormatV02 or ShapesFormatV03.

Parameters
Expand All @@ -150,6 +161,9 @@ def _write_shapes_v02_v03(shapes: GeoDataFrame, group: zarr.Group, element_forma
The zarr group in the 'shapes' zarr group to write the shapes element to.
element_format
The format of the shapes element used to store it.
geometry_encoding
Whether to use the WKB or geoarrow encoding for GeoParquet. See :meth:`geopandas.GeoDataFrame.to_parquet` for
details.
"""
from spatialdata.models._utils import TRANSFORM_KEY

Expand All @@ -159,7 +173,7 @@ def _write_shapes_v02_v03(shapes: GeoDataFrame, group: zarr.Group, element_forma
# Temporarily remove transformations from attrs to avoid serialization issues
transforms = shapes.attrs[TRANSFORM_KEY]
del shapes.attrs[TRANSFORM_KEY]
shapes.to_parquet(path)
shapes.to_parquet(path, geometry_encoding=geometry_encoding)
shapes.attrs[TRANSFORM_KEY] = transforms

attrs = element_format.attrs_to_dict(shapes.attrs)
Expand Down
32 changes: 28 additions & 4 deletions src/spatialdata/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,28 @@
# chunk sizes bigger than this value (bytes) can trigger a compression error
# https://github.com/scverse/spatialdata/issues/812#issuecomment-2559380276
# so if we detect this during parsing/validation we raise a warning
LARGE_CHUNK_THRESHOLD_BYTES = 2147483647
from dataclasses import dataclass
from typing import Literal


@dataclass
class Settings:
"""Global settings for spatialdata.

Attributes
----------
shapes_geometry_encoding
Default geometry encoding for GeoParquet files when writing shapes.
Can be "WKB" (Well-Known Binary) or "geoarrow".
See :meth:`geopandas.GeoDataFrame.to_parquet` for details.
large_chunk_threshold_bytes
Chunk sizes bigger than this value (bytes) can trigger a compression error.
See https://github.com/scverse/spatialdata/issues/812#issuecomment-2559380276
If detected during parsing/validation, a warning is raised.
"""

shapes_geometry_encoding: Literal["WKB", "geoarrow"] = "WKB"
large_chunk_threshold_bytes: int = 2147483647


settings = Settings()

# Backwards compatibility alias
LARGE_CHUNK_THRESHOLD_BYTES = settings.large_chunk_threshold_bytes
8 changes: 4 additions & 4 deletions src/spatialdata/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from spatialdata._logging import logger
from spatialdata._types import ArrayLike
from spatialdata._utils import _check_match_length_channels_c_dim
from spatialdata.config import LARGE_CHUNK_THRESHOLD_BYTES
from spatialdata.config import settings
from spatialdata.models import C, X, Y, Z, get_axes_names
from spatialdata.models._utils import (
DEFAULT_COORDINATE_SYSTEM,
Expand Down Expand Up @@ -315,9 +315,9 @@ def _check_chunk_size_not_too_large(self, data: DataArray | DataTree) -> None:
return
n_elems = np.array(list(max_per_dimension.values())).prod().item()
usage = n_elems * data.dtype.itemsize
if usage > LARGE_CHUNK_THRESHOLD_BYTES:
if usage > settings.large_chunk_threshold_bytes:
warnings.warn(
f"Detected chunks larger than: {usage} > {LARGE_CHUNK_THRESHOLD_BYTES} bytes. "
f"Detected chunks larger than: {usage} > {settings.large_chunk_threshold_bytes} bytes. "
"This can lead to low "
"performance and memory issues downstream, and sometimes cause compression errors when writing "
"(https://github.com/scverse/spatialdata/issues/812#issuecomment-2575983527). Please consider using"
Expand All @@ -327,7 +327,7 @@ def _check_chunk_size_not_too_large(self, data: DataArray | DataTree) -> None:
"2) Multiscale representations can be achieved by using the `scale_factors` argument in the "
"`parse()` function.\n"
"You can suppress this warning by increasing the value of "
"`spatialdata.config.LARGE_CHUNK_THRESHOLD_BYTES`.",
"`spatialdata.settings.large_chunk_threshold_bytes`.",
UserWarning,
stacklevel=2,
)
Expand Down
80 changes: 77 additions & 3 deletions tests/io/test_readwrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,21 @@
import tempfile
from collections.abc import Callable
from pathlib import Path
from typing import Any
from typing import Any, Literal

import dask.dataframe as dd
import numpy as np
import pandas as pd
import pyarrow.parquet as pq
import pytest
import zarr
from anndata import AnnData
from numpy.random import default_rng
from shapely import MultiPolygon, Polygon
from upath import UPath
from zarr.errors import GroupNotFoundError

import spatialdata.config
from spatialdata import SpatialData, deepcopy, read_zarr
from spatialdata._core.validation import ValidationError
from spatialdata._io._utils import _are_directories_identical, get_dask_backing_files
Expand Down Expand Up @@ -74,20 +78,90 @@ def test_labels(
sdata = SpatialData.read(tmpdir)
assert_spatial_data_objects_are_identical(labels, sdata)

@pytest.mark.parametrize("geometry_encoding", ["WKB", "geoarrow"])
def test_shapes(
self,
tmp_path: str,
shapes: SpatialData,
sdata_container_format: SpatialDataContainerFormatType,
geometry_encoding: Literal["WKB", "geoarrow"],
) -> None:
tmpdir = Path(tmp_path) / "tmp.zarr"

# check the index is correctly written and then read
shapes["circles"].index = np.arange(1, len(shapes["circles"]) + 1)

shapes.write(tmpdir, sdata_formats=sdata_container_format)
# add a mixed Polygon + MultiPolygon element
shapes["mixed"] = pd.concat([shapes["poly"], shapes["multipoly"]])

shapes.write(tmpdir, sdata_formats=sdata_container_format, shapes_geometry_encoding=geometry_encoding)
sdata = SpatialData.read(tmpdir)
assert_spatial_data_objects_are_identical(shapes, sdata)

if geometry_encoding == "WKB":
assert_spatial_data_objects_are_identical(shapes, sdata)
else:
# convert each Polygon to a MultiPolygon
mixed_multipolygon = shapes["mixed"].assign(
geometry=lambda df: df.geometry.apply(lambda g: MultiPolygon([g]) if isinstance(g, Polygon) else g)
)
assert sdata["mixed"].equals(mixed_multipolygon)
assert not sdata["mixed"].equals(shapes["mixed"])

del shapes["mixed"]
del sdata["mixed"]
assert_spatial_data_objects_are_identical(shapes, sdata)

@pytest.mark.parametrize("geometry_encoding", ["WKB", "geoarrow"])
def test_shapes_geometry_encoding_write_element(
self,
tmp_path: str,
shapes: SpatialData,
sdata_container_format: SpatialDataContainerFormatType,
geometry_encoding: Literal["WKB", "geoarrow"],
) -> None:
"""Test shapes geometry encoding with write_element() and global settings."""
tmpdir = Path(tmp_path) / "tmp.zarr"

# First write an empty SpatialData to create the zarr store
empty_sdata = SpatialData()
empty_sdata.write(tmpdir, sdata_formats=sdata_container_format)

shapes["mixed"] = pd.concat([shapes["poly"], shapes["multipoly"]])

# Add shapes to the empty sdata
for shape_name in shapes.shapes:
empty_sdata[shape_name] = shapes[shape_name]

# Store original setting and set global encoding
original_encoding = spatialdata.config.settings.shapes_geometry_encoding
try:
spatialdata.config.settings.shapes_geometry_encoding = geometry_encoding

# Write each shape element - should use global setting
for shape_name in shapes.shapes:
empty_sdata.write_element(shape_name, sdata_formats=sdata_container_format)

# Verify the encoding metadata in the parquet file
parquet_file = tmpdir / "shapes" / shape_name / "shapes.parquet"
with pq.ParquetFile(parquet_file) as pf:
md = pf.metadata
d = json.loads(md.metadata[b"geo"].decode("utf-8"))
found_encoding = d["columns"]["geometry"]["encoding"]
if geometry_encoding == "WKB":
expected_encoding = "WKB"
elif shape_name == "circles":
expected_encoding = "point"
elif shape_name == "poly":
expected_encoding = "polygon"
elif shape_name in ["multipoly", "mixed"]:
expected_encoding = "multipolygon"
else:
raise ValueError(
f"Uncovered case for shape_name: {shape_name}, found encoding: {found_encoding}."
)
assert found_encoding == expected_encoding
finally:
spatialdata.config.settings.shapes_geometry_encoding = original_encoding

def test_points(
self,
Expand Down
Loading