Skip to content

Commit ac7e93f

Browse files
committed
create custom class to close #1042, refactored to allow multiple image files to pass through
1 parent 3c2c6cd commit ac7e93f

File tree

4 files changed

+147
-81
lines changed

4 files changed

+147
-81
lines changed

docs/user_guide/01_Reading_data.md

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22

33
## The DeepForest data model
44

5-
The DeepForest data model has three components
5+
The DeepForest data model has four components:
66

7-
1. Annotations are stored as dataframes. Each row is an annotation with a single geometry and label. Each annotation dataframe must contain a 'image_path', which is the relative, not full path to the image, and a 'label' column.
7+
1. Annotations are stored as dataframes. Each row is an annotation with a single geometry and label. Each annotation dataframe must contain a 'image_path', which is the basename, not full path to the image, and a 'label' column.
88
2. Annotation geometry is stored as a shapely object, allowing the easy movement among Point, Polygon and Box representations.
99
3. Annotations are expressed in image coordinates, not geographic coordinates. There are utilities to convert geospatial data (.shp, .gpkg) to DeepForest data formats.
10+
4. A root_dir attribute that specifies where the images are stored. A Dee
1011

1112
## The read_file function
1213
DeepForest has collated many use cases into a single `read_file` function that will read many common data formats, both projected and unprojected, and create a dataframe ready for DeepForest functions that fits the DeepForest data model.
@@ -16,18 +17,19 @@ DeepForest has collated many use cases into a single `read_file` function that w
1617
```
1718
from deepforest import utilities
1819
19-
df = utilities.read_file("annotations.csv", image_path="<full path to the image>", label="Tree")
20+
df = utilities.read_file("annotations.csv", root_dir="directory containing images", image_path="relative path to the image>", label="Tree")
2021
```
2122

22-
For files that lack an `image_path` or `label` column, pass the `image_path` or `label` argument.
23+
For files that lack an `image_path` or `label` column, pass the `image_path` or `label` argument. This applies the same image_path and label for the entire file, and is not appropriate for multi-image files.
2324

2425
```python
2526
from deepforest import utilities
2627

2728
gdf = utilities.read_file(
2829
input="/path/to/annotations.shp",
29-
image_path="/path/to/OSBS_029.tif", # required if no image_path column
30-
label="Tree" # optional: used if no 'label' column in the shapefile
30+
image_path="OSBS_029.tif", # required if no image_path column
31+
root_dir="path/to/images/" # required is image_path argument is used
32+
label="Tree" # optional: used if no 'label' column in the shapefile
3133
)
3234
```
3335

src/deepforest/main.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1031,7 +1031,11 @@ def __evaluation_logs__(self, results):
10311031
"""Log metrics from evaluation results."""
10321032
# Log metrics
10331033
for key, value in results.items():
1034-
if type(value) in [pd.DataFrame, gpd.GeoDataFrame]:
1034+
if type(value) in [
1035+
pd.DataFrame,
1036+
gpd.GeoDataFrame,
1037+
utilities.DeepForest_DataFrame,
1038+
]:
10351039
pass
10361040
elif value is None:
10371041
pass

src/deepforest/utilities.py

Lines changed: 113 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,22 @@ def update_to(self, b=1, bsize=1, tsize=None):
7979
self.update(b * bsize - self.n)
8080

8181

82+
class DeepForest_DataFrame(gpd.GeoDataFrame):
83+
"""Custom GeoDataFrame that preserves a root_dir attribute if present."""
84+
85+
_metadata = ["root_dir"]
86+
87+
def __init__(self, *args, **kwargs):
88+
root_dir = getattr(args[0], "root_dir", None) if args else None
89+
super().__init__(*args, **kwargs)
90+
if root_dir is not None:
91+
self.root_dir = root_dir
92+
93+
@property
94+
def _constructor(self):
95+
return DeepForest_DataFrame
96+
97+
8298
def read_pascal_voc(xml_path):
8399
"""Load annotations from xml format (e.g. RectLabel editor) and convert
84100
them into retinanet annotations format.
@@ -174,8 +190,8 @@ def convert_point_to_bbox(gdf: gpd.GeoDataFrame, buffer_size: float) -> gpd.GeoD
174190

175191
def shapefile_to_annotations(
176192
shapefile: str | gpd.GeoDataFrame,
177-
root_dir: str | None = None,
178193
rgb: str | None = None,
194+
root_dir: str | None = None,
179195
buffer_size: float | None = None,
180196
convert_point: bool = False,
181197
label: str | None = None,
@@ -189,50 +205,49 @@ def shapefile_to_annotations(
189205
"buffer_size argument is deprecated, use convert_point_to_bbox instead"
190206
)
191207

192-
image_path = root_dir + rgb
193-
return __shapefile_to_annotations__(shapefile, image_path, label)
208+
return __shapefile_to_annotations__(shapefile)
194209

195210

196-
def __check_image_path__(
197-
df: pd.DataFrame | gpd.GeoDataFrame,
198-
image_path: str | None = None,
199-
root_dir: str | None = None,
200-
):
211+
def __assign_image_path__(gdf, image_path: str) -> str:
201212
if image_path is None:
202-
if "image_path" not in df.columns:
213+
if "image_path" not in gdf.columns:
203214
raise ValueError(
204-
"No image_path column found in dataframe and image_path argument not specified, please specify full path to image file in image_path argument: read_file(input=df, image_path='/path/to/image.tif', ...)"
215+
"No image_path column found in GeoDataframe and image_path argument not specified, please specify the root_dir and image_path arguements: read_file(input=df, root_dir='path/to/images/', image_path='image.tif', ...)"
205216
)
206217
else:
207-
full_image_path = os.path.join(root_dir, df["image_path"].unique()[0])
218+
# Image Path columns exists, leave it unchanged.
219+
pass
208220
else:
209-
full_image_path = image_path
210-
211-
if not os.path.exists(full_image_path):
212-
raise FileNotFoundError(
213-
f"Image file {full_image_path} not found, please check the image_path argument, it should be the full path: read_file(input=df, image_path='/path/to/image.tif', ...)"
214-
)
221+
if "image_path" in gdf.columns:
222+
existing_image_path = gdf.image_path.unique()[0]
223+
if len(existing_image_path) > 1:
224+
warnings.warn(
225+
f"Multiple image_paths found in dataframe: {existing_image_path}, overriding and assigning {image_path} to all rows!",
226+
stacklevel=2,
227+
)
228+
if existing_image_path != image_path:
229+
warnings.warn(
230+
f"Image path {existing_image_path} found in dataframe, overriding and assigning {image_path} to all rows!",
231+
stacklevel=2,
232+
)
233+
gdf["image_path"] = image_path
234+
else:
235+
gdf["image_path"] = image_path
215236

216-
return full_image_path
237+
return gdf
217238

218239

219240
def __shapefile_to_annotations__(
220-
gdf: str | gpd.GeoDataFrame,
221-
root_dir: str | None = None,
222-
image_path: str | None = None,
241+
gdf: gpd.GeoDataFrame,
223242
) -> gpd.GeoDataFrame:
224243
"""Convert geospatial annotations to DeepForest format.
225244
226245
Args:
227-
gdf: A GeoDataFrame with a geometry column and an image_path column. If the image_path column is not present, it will be added using the image_path argument.
228-
image_path: Full path to the image file.
229-
root_dir: Root directory of the image files. If not provided, it will be inferred from the image_path column.
246+
gdf: A GeoDataFrame with a geometry column and an image_path column.
230247
231248
Returns:
232249
GeoDataFrame with annotations in DeepForest format.
233250
"""
234-
image_path = __check_image_path__(gdf, image_path=image_path, root_dir=root_dir)
235-
236251
# Determine geometry type and report to user
237252
if gdf.geometry.type.unique().shape[0] > 1:
238253
raise ValueError(
@@ -244,7 +259,8 @@ def __shapefile_to_annotations__(
244259
print(f"Geometry type of shapefile is {geometry_type}")
245260

246261
# raster bounds
247-
with rasterio.open(image_path) as src:
262+
full_image_path = os.path.join(gdf.root_dir, gdf.image_path.unique()[0])
263+
with rasterio.open(full_image_path) as src:
248264
raster_crs = src.crs
249265

250266
if gdf.crs:
@@ -270,9 +286,6 @@ def __shapefile_to_annotations__(
270286
print(f"CRS of image is {raster_crs}")
271287
gdf = geo_to_image_coordinates(gdf, src.bounds, src.res[0])
272288

273-
# add filename
274-
gdf["image_path"] = os.path.basename(image_path)
275-
276289
return gdf
277290

278291

@@ -283,7 +296,7 @@ def determine_geometry_type(df):
283296
Returns:
284297
geometry_type: a string of the geometry type
285298
"""
286-
if type(df) in [pd.DataFrame, gpd.GeoDataFrame]:
299+
if type(df) in [pd.DataFrame, gpd.GeoDataFrame, DeepForest_DataFrame]:
287300
columns = df.columns
288301
if "geometry" in columns:
289302
df = gpd.GeoDataFrame(geometry=df["geometry"])
@@ -447,43 +460,64 @@ def __pandas_to_geodataframe__(df: pd.DataFrame):
447460
]
448461
)
449462
gdf = gpd.GeoDataFrame(df, geometry="geometry")
463+
gdf = DeepForest_DataFrame(gdf)
450464

451465
return gdf
452466

453467

454-
def __check_label__(df: pd.DataFrame | gpd.GeoDataFrame, label: str | None = None):
468+
def __check_and_assign_label__(
469+
df: pd.DataFrame | gpd.GeoDataFrame, label: str | None = None
470+
):
455471
if label is None:
456472
if "label" not in df.columns:
457473
raise ValueError(
458474
"No label specified and no label column found in dataframe, please specify label in label argument: read_file(input=df, label='YourLabel', ...)"
459475
)
460476
else:
461-
df["label"] = label
477+
if "label" in df.columns:
478+
existing_labels = df.label.unique()
479+
if len(existing_labels) > 1:
480+
warnings.warn(
481+
f"Multiple labels found in dataframe: {existing_labels}, the label argument in read_file will override these labels!",
482+
stacklevel=2,
483+
)
484+
if existing_labels[0] != label:
485+
warnings.warn(
486+
f"Label {existing_labels[0]} found in dataframe, overriding and assigning {label} to all rows!",
487+
stacklevel=2,
488+
)
489+
else:
490+
df["label"] = label
462491

463492
return df
464493

465494

466495
def __assign_root_dir__(
467496
input,
468497
gdf: gpd.GeoDataFrame,
469-
image_path: str | None = None,
470498
root_dir: str | None = None,
471499
):
472500
if root_dir is not None:
473501
gdf.root_dir = root_dir
474502
else:
475-
if image_path is not None:
476-
gdf.root_dir = os.path.dirname(image_path)
477-
elif isinstance(input, str):
478-
warnings.warn(
479-
f"root_dir argument not specified, defaulting the images root_dir to the same directory as the input file: {os.path.dirname(input)}",
480-
stacklevel=2,
481-
)
503+
# If the user specified a path to file, use that root_dir as default.
504+
if isinstance(input, str):
482505
gdf.root_dir = os.path.dirname(input)
483506
else:
484507
raise ValueError(
485508
"root_dir argument not specified and input is a dataframe, where are the images stored?"
486509
)
510+
511+
return gdf
512+
513+
514+
def _pandas_to_deepforest_format__(input, df, image_path, root_dir, label):
515+
df = __check_and_assign_label__(df, label=label)
516+
gdf = __pandas_to_geodataframe__(df)
517+
gdf = __assign_image_path__(gdf, image_path=image_path)
518+
gdf = __assign_root_dir__(input, gdf, root_dir=root_dir)
519+
gdf = DeepForest_DataFrame(gdf)
520+
487521
return gdf
488522

489523

@@ -498,63 +532,72 @@ def read_file(
498532
Args:
499533
input: Path to file, DataFrame, or GeoDataFrame
500534
root_dir: Root directory for image files
501-
image_path: Assign image_path column to all rows. The full path to the image file.
502-
label: Assign a single label column to all rows.
535+
image_path: Path relative to root_dir to a single image that will be assigned as the image_path column for all annotations. The full path will be constructed by joining the root_dir and the image_path. Overrides any image_path column in input.
536+
label: Single label to be assigned as the label for all annotations. Overrides any label column in input.
503537
504-
Notes:
505-
The image_path and label arguments are applied to all rows in the dataframe or shapefile and therefore should only be used in cases where all rows have the same image_path and label.
506538
Returns:
507539
GeoDataFrame with geometry, image_path, and label columns
508540
"""
541+
# Check arguments
542+
if image_path is not None and root_dir is None:
543+
raise ValueError(
544+
"root_dir argument must be specified if image_path argument is used"
545+
)
546+
509547
# read file
510548
if isinstance(input, str):
511549
if input.endswith(".csv"):
512550
df = pd.read_csv(input)
513-
df = __check_label__(df, label=label)
514-
gdf = __pandas_to_geodataframe__(df)
551+
gdf = _pandas_to_deepforest_format__(input, df, image_path, root_dir, label)
515552
elif input.endswith(".json"):
516553
df = read_coco(input)
517-
df = __check_label__(df, label=label)
518-
gdf = __pandas_to_geodataframe__(df)
519-
elif input.endswith((".shp", ".gpkg")):
520-
gdf = gpd.read_file(input)
521-
gdf = __check_label__(gdf, label=label)
522-
gdf = __shapefile_to_annotations__(
523-
gdf,
524-
root_dir=root_dir,
525-
image_path=image_path,
526-
)
554+
gdf = _pandas_to_deepforest_format__(input, df, image_path, root_dir, label)
527555
elif input.endswith(".xml"):
528556
df = read_pascal_voc(input)
529-
df = __check_label__(df, label=label)
530-
gdf = __pandas_to_geodataframe__(df)
557+
gdf = _pandas_to_deepforest_format__(input, df, image_path, root_dir, label)
558+
elif input.endswith((".shp", ".gpkg")):
559+
gdf = gpd.read_file(input)
560+
gdf = DeepForest_DataFrame(gdf)
561+
gdf = __assign_image_path__(gdf, image_path=image_path)
562+
gdf = __check_and_assign_label__(gdf, label=label)
563+
gdf = __assign_root_dir__(input=input, gdf=gdf, root_dir=root_dir)
564+
gdf = __shapefile_to_annotations__(gdf)
531565
else:
532566
raise ValueError(
533567
f"File type {input} not supported. "
534568
"DeepForest currently supports .csv, .shp, .gpkg, .xml, and .json files. "
535569
"See https://deepforest.readthedocs.io/en/latest/annotation.html "
536570
)
537571
elif isinstance(input, gpd.GeoDataFrame):
538-
input = __check_label__(input, label=label)
539-
gdf = __shapefile_to_annotations__(
540-
input,
541-
image_path=image_path,
542-
root_dir=root_dir,
543-
)
572+
gdf = input
573+
gdf = __assign_image_path__(gdf, image_path=image_path)
574+
gdf = __assign_root_dir__(input, gdf, root_dir=root_dir)
575+
gdf = DeepForest_DataFrame(gdf)
576+
gdf_list = []
577+
for image_path in gdf.image_path.unique():
578+
image_annotations = gdf[gdf.image_path == image_path]
579+
gdf = __shapefile_to_annotations__(image_annotations)
580+
gdf_list.append(gdf)
581+
582+
# When concat, need to reform GeoPandas GeoDataFrame
583+
gdf = pd.concat(gdf_list)
584+
gdf = gpd.GeoDataFrame(gdf)
585+
gdf = DeepForest_DataFrame(gdf)
586+
gdf = __check_and_assign_label__(gdf, label=label)
587+
544588
elif isinstance(input, pd.DataFrame):
545589
if input.empty:
546590
raise ValueError("No annotations in dataframe")
547-
gdf = __check_label__(input, label=label)
548591
gdf = __pandas_to_geodataframe__(input)
549-
image_path = __check_image_path__(gdf, image_path=image_path, root_dir=root_dir)
550-
gdf["image_path"] = os.path.basename(image_path)
592+
gdf = __assign_image_path__(gdf, image_path=image_path)
593+
gdf = __assign_root_dir__(input, gdf, root_dir=root_dir)
594+
gdf = __check_and_assign_label__(gdf, label=label)
595+
gdf = DeepForest_DataFrame(gdf)
551596
else:
552597
raise ValueError(
553598
"Input must be a path to a file, geopandas or a pandas dataframe"
554599
)
555600

556-
gdf = __assign_root_dir__(input, gdf, root_dir=root_dir, image_path=image_path)
557-
558601
return gdf
559602

560603

0 commit comments

Comments
 (0)