From 906d7d5abced7d588a46e996c114e1f9b5f28f07 Mon Sep 17 00:00:00 2001 From: Christopher Kotthoff Date: Wed, 16 Jul 2025 18:10:16 +0100 Subject: [PATCH 01/17] rgb contrast and unsafe access fix --- detectree2/preprocessing/tiling.py | 54 +++++++++++++++++++----------- 1 file changed, 34 insertions(+), 20 deletions(-) diff --git a/detectree2/preprocessing/tiling.py b/detectree2/preprocessing/tiling.py index 54ee7202..a376f75d 100644 --- a/detectree2/preprocessing/tiling.py +++ b/detectree2/preprocessing/tiling.py @@ -134,7 +134,8 @@ def process_tile(img_path: str, additional_nodata: List[Any] = [], image_statistics: List[Dict[str, float]] = None, ignore_bands_indices: List[int] = [], - use_convex_mask: bool = True): + use_convex_mask: bool = True, + enhance_rgb_contrast: bool = True): """Process a single tile for making predictions. Args: @@ -199,7 +200,7 @@ def process_tile(img_path: str, unioned_crowns = overlapping_crowns.union_all() else: unioned_crowns = overlapping_crowns.unary_union - convex_mask_tif = rasterio.features.geometry_mask([unioned_crowns.convex_hull.buffer(5)], + convex_mask_tif = rasterio.features.geometry_mask([unioned_crowns.convex_hull.buffer(3)], transform=out_transform, invert=True, out_shape=(out_img.shape[1], out_img.shape[2])) @@ -224,9 +225,19 @@ def process_tile(img_path: str, ) return None + if enhance_rgb_contrast: + # rescale image to 1-255 (0 is reserved for nodata) + min_vals, max_vals = np.percentile( + out_img.reshape(3, -1)[:, ~nan_mask.reshape(-1).astype(bool)], [0.2, 99.8]) + + out_img = (out_img - min_vals) / (max_vals - min_vals) * 254 + 1 + # Apply nan mask out_img[np.broadcast_to((nan_mask == 1)[None, :, :], out_img.shape)] = 0 + if enhance_rgb_contrast: + out_img = np.clip(out_img, 0, 255) + dtype, nodata = dtype_map.get(out_img.dtype, (None, None)) if dtype is None: logger.exception(f"Unsupported dtype: {out_img.dtype}") @@ -249,20 +260,20 @@ def process_tile(img_path: str, r, g, b = out_img[0], out_img[1], out_img[2] rgb = np.dstack((b, g, r)) # Reorder for cv2 (BGRA) - # Rescale to 0-255 if necessary - if np.nanmax(g) > 255: - rgb_rescaled = rgb / 65535 * 255 - else: - rgb_rescaled = rgb - - np.clip(rgb_rescaled, 0, 255, out=rgb_rescaled) + if not enhance_rgb_contrast: + # If not enhancing contrast, ensure the dtype is uint8 + if dtype_bool: + rgb = rgb.astype(np.uint8) + else: + rgb = rgb.astype(np.float32) + np.clip(rgb, 0, 255, out=rgb) - cv2.imwrite(str(out_path_root.with_suffix(".png").resolve()), rgb_rescaled.astype(np.uint8)) + cv2.imwrite(str(out_path_root.with_suffix(".png").resolve()), rgb.astype(np.uint8)) if overlapping_crowns is not None: - return data, out_path_root, overlapping_crowns, minx, miny, buffer + return out_transform, out_path_root, overlapping_crowns, minx, miny, buffer - return data, out_path_root, None, minx, miny, buffer + return out_transform, out_path_root, None, minx, miny, buffer except RasterioIOError as e: logger.error(f"RasterioIOError while applying mask {coords}: {e}") @@ -421,9 +432,9 @@ def process_tile_ms(img_path: str, # cv2.imwrite(str(out_path_root.with_suffix(".png").resolve()), rgb) if overlapping_crowns is not None: - return data, out_path_root, overlapping_crowns, minx, miny, buffer + return out_transform, out_path_root, overlapping_crowns, minx, miny, buffer - return data, out_path_root, None, minx, miny, buffer + return out_transform, out_path_root, None, minx, miny, buffer except RasterioIOError as e: logger.error(f"RasterioIOError while applying mask {coords}: {e}") @@ -453,7 +464,9 @@ def process_tile_train( additional_nodata: List[Any] = [], image_statistics: List[Dict[str, float]] = None, ignore_bands_indices: List[int] = [], - use_convex_mask: bool = True) -> None: + use_convex_mask: bool = True, + enhance_rgb_contrast: bool = True + ) -> None: """Process a single tile for training data. Args: @@ -477,7 +490,7 @@ def process_tile_train( if mode == "rgb": result = process_tile(img_path, out_dir, buffer, tile_width, tile_height, dtype_bool, minx, miny, crs, tilename, crowns, threshold, nan_threshold, mask_gdf, additional_nodata, image_statistics, - ignore_bands_indices, use_convex_mask) + ignore_bands_indices, use_convex_mask, enhance_rgb_contrast) elif mode == "ms": result = process_tile_ms(img_path, out_dir, buffer, tile_width, tile_height, dtype_bool, minx, miny, crs, tilename, crowns, threshold, nan_threshold, mask_gdf, additional_nodata, @@ -487,13 +500,13 @@ def process_tile_train( # logger.warning(f"Skipping tile at ({minx}, {miny}) due to insufficient data.") return - data, out_path_root, overlapping_crowns, minx, miny, buffer = result + out_transform, out_path_root, overlapping_crowns, minx, miny, buffer = result if overlapping_crowns is not None and not overlapping_crowns.empty: overlapping_crowns = overlapping_crowns.explode(index_parts=True) moved = overlapping_crowns.translate(-minx + buffer, -miny + buffer) - scalingx = 1 / (data.transform[0]) - scalingy = -1 / (data.transform[4]) + scalingx = 1 / (out_transform[0]) + scalingy = -1 / (out_transform[4]) moved_scaled = moved.scale(scalingx, scalingy, origin=(0, 0)) if mode == "rgb": @@ -766,6 +779,7 @@ def tile_data( overlapping_tiles: bool = False, ignore_bands_indices: List[int] = [], use_convex_mask: bool = True, + enhance_rgb_contrast: bool = True, ) -> None: """Tiles up orthomosaic and corresponding crowns (if supplied) into training/prediction tiles. @@ -813,7 +827,7 @@ def tile_data( tile_args = [ (img_path, out_dir, buffer, tile_width, tile_height, dtype_bool, minx, miny, crs, tilename, crowns, threshold, nan_threshold, mode, class_column, mask_gdf, additional_nodata, image_statistics, ignore_bands_indices, - use_convex_mask) for minx, miny in tile_coordinates + use_convex_mask, enhance_rgb_contrast) for minx, miny in tile_coordinates if mask_path is None or (mask_path is not None and mask_gdf.intersects( box(minx, miny, minx + tile_width, miny + tile_height) #TODO maybe add to_crs here ).any()) From 42752d04cffedcf069461289029ce7b9fb3260e3 Mon Sep 17 00:00:00 2001 From: Christopher Kotthoff Date: Sun, 20 Jul 2025 19:22:18 +0100 Subject: [PATCH 02/17] experimental numpy fix --- .github/workflows/python-ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/python-ci.yml b/.github/workflows/python-ci.yml index a8fec3ca..3f5f7f4d 100644 --- a/.github/workflows/python-ci.yml +++ b/.github/workflows/python-ci.yml @@ -59,6 +59,7 @@ jobs: flake8 detectree2 --count --exit-zero --max-complexity=10 --statistics - name: pytest checks run: | + pip install --upgrade --force-reinstall "numpy>=1.20,<2.0" pip install pytest-order pytest - name: mypy checks From 1aad2754d115f9b2f5d61a5b91ae3623b41c4484 Mon Sep 17 00:00:00 2001 From: Christopher Kotthoff Date: Sun, 20 Jul 2025 19:45:18 +0100 Subject: [PATCH 03/17] fixes a lot of accumulated mypy problem --- detectree2/models/predict.py | 7 +-- detectree2/models/train.py | 14 ++++-- detectree2/preprocessing/tiling.py | 68 +++++++++++++++--------------- 3 files changed, 49 insertions(+), 40 deletions(-) diff --git a/detectree2/models/predict.py b/detectree2/models/predict.py index 18bd01ad..b87e4422 100644 --- a/detectree2/models/predict.py +++ b/detectree2/models/predict.py @@ -66,16 +66,17 @@ def predict_on_data( file_ext = os.path.splitext(file_name)[1].lower() if file_ext == ".png": # RGB image, read with cv2 - img = cv2.imread(file_name) - if img is None: + cv_img = cv2.imread(file_name) + if cv_img is None: print(f"Failed to read image {file_name} with cv2.") continue + img = np.array(cv_img) # Explicitly convert to numpy array elif file_ext == ".tif": # Multispectral image, read with rasterio with rasterio.open(file_name) as src: img = src.read() # Transpose to match expected format (H, W, C) - img = np.transpose(img, (1, 2, 0)) + img = img.transpose(1, 2, 0) else: print(f"Unsupported file extension {file_ext} for file {file_name}") continue diff --git a/detectree2/models/train.py b/detectree2/models/train.py index c39c9bf8..66b8e519 100644 --- a/detectree2/models/train.py +++ b/detectree2/models/train.py @@ -725,7 +725,10 @@ def get_tree_dicts(directory: str, class_mapping: Optional[Dict[str, int]] = Non # Make sure we have the correct height and width # If image path ends in .png use cv2 to get height and width else if image path ends in .tif use rasterio if filename.endswith(".png"): - height, width = cv2.imread(filename).shape[:2] + img = cv2.imread(filename) + if img is None: + continue + height, width = img.shape[:2] elif filename.endswith(".tif"): with rasterio.open(filename) as src: height, width = src.shape @@ -744,7 +747,7 @@ def get_tree_dicts(directory: str, class_mapping: Optional[Dict[str, int]] = Non print("Skipping annotation of type", anno["type"], "in file", filename) continue px = [a[0] for a in anno["coordinates"][0]] - py = [np.array(height) - a[1] for a in anno["coordinates"][0]] + py = [height - a[1] for a in anno["coordinates"][0]] poly = [(x, y) for x, y in zip(px, py)] poly = [p for x in poly for p in x] @@ -755,7 +758,12 @@ def get_tree_dicts(directory: str, class_mapping: Optional[Dict[str, int]] = Non category_id = 0 # Default to "tree" if no class mapping is provided obj = { - "bbox": [np.min(px), np.min(py), np.max(px), np.max(py)], + "bbox": [ + np.min(np.array(px)), + np.min(np.array(py)), + np.max(np.array(px)), + np.max(np.array(py)), + ], "bbox_mode": BoxMode.XYXY_ABS, "segmentation": [poly], "category_id": category_id, diff --git a/detectree2/preprocessing/tiling.py b/detectree2/preprocessing/tiling.py index a376f75d..2fb63974 100644 --- a/detectree2/preprocessing/tiling.py +++ b/detectree2/preprocessing/tiling.py @@ -233,7 +233,7 @@ def process_tile(img_path: str, out_img = (out_img - min_vals) / (max_vals - min_vals) * 254 + 1 # Apply nan mask - out_img[np.broadcast_to((nan_mask == 1)[None, :, :], out_img.shape)] = 0 + out_img[np.broadcast_to((nan_mask == 1)[None, :, :], out_img.shape)] = 0 # type: ignore[attr-defined] if enhance_rgb_contrast: out_img = np.clip(out_img, 0, 255) @@ -258,7 +258,7 @@ def process_tile(img_path: str, dest.write(out_img) r, g, b = out_img[0], out_img[1], out_img[2] - rgb = np.dstack((b, g, r)) # Reorder for cv2 (BGRA) + rgb = np.dstack((b, g, r)) # type: ignore[attr-defined] # Reorder for cv2 (BGRA) if not enhance_rgb_contrast: # If not enhancing contrast, ensure the dtype is uint8 @@ -266,7 +266,7 @@ def process_tile(img_path: str, rgb = rgb.astype(np.uint8) else: rgb = rgb.astype(np.float32) - np.clip(rgb, 0, 255, out=rgb) + np.clip(rgb, 0, 255, out=rgb) # type: ignore[call-arg] cv2.imwrite(str(out_path_root.with_suffix(".png").resolve()), rgb.astype(np.uint8)) @@ -405,7 +405,7 @@ def process_tile_ms(img_path: str, out_img = np.clip(out_img.astype(np.float32), 1.0, 255.0) # Apply nan mask - out_img[np.broadcast_to((nan_mask == 1)[None, :, :], out_img.shape)] = 0.0 + out_img[np.broadcast_to((nan_mask == 1)[None, :, :], out_img.shape)] = 0.0 # type: ignore[attr-defined] dtype, nodata = dtype_map.get(out_img.dtype, (None, None)) if dtype is None: @@ -558,21 +558,22 @@ def _calculate_tile_placements( overlapping_tiles: bool = False, ) -> List[Tuple[int, int]]: """Internal method for calculating the placement of tiles""" - + coordinates: List[Tuple[int, int]] = [] if tile_placement == "grid": with rasterio.open(img_path) as data: - coordinates = [ - (minx, miny) for minx in np.arange( - math.ceil(data.bounds[0]) + buffer, data.bounds[2] - tile_width - buffer, tile_width, int) + grid_coords = [ + (int(minx), int(miny)) for minx in np.arange( + int(math.ceil(data.bounds[0])) + buffer, int(data.bounds[2] - tile_width - buffer), tile_width) for miny in np.arange( - math.ceil(data.bounds[1]) + buffer, data.bounds[3] - tile_height - buffer, tile_height, int) + int(math.ceil(data.bounds[1])) + buffer, int(data.bounds[3] - tile_height - buffer), tile_height) ] if overlapping_tiles: - coordinates.extend([(minx, miny) for minx in np.arange( - math.ceil(data.bounds[0]) + buffer + tile_width // 2, data.bounds[2] - tile_width - buffer - - tile_width // 2, tile_width, int) for miny in np.arange( - math.ceil(data.bounds[1]) + buffer + tile_height // 2, data.bounds[3] - tile_height - buffer - - tile_height // 2, tile_height, int)]) + grid_coords.extend([(int(minx), int(miny)) for minx in np.arange( + int(math.ceil(data.bounds[0])) + buffer + tile_width // 2, int(data.bounds[2] - tile_width - buffer - + tile_width // 2), tile_width) for miny in np.arange( + int(math.ceil(data.bounds[1])) + buffer + tile_height // 2, int(data.bounds[3] - tile_height - buffer - + tile_height // 2), tile_height)]) + coordinates = grid_coords elif tile_placement == "adaptive": if crowns is None: @@ -598,7 +599,6 @@ def _calculate_tile_placements( y_offset = (combined_tiles_height - area_height) / 2 logger.info("Starting Tile Placement Generation") - coordinates = [] for row in range(required_tiles_y): bar = gpd.GeoSeries([ box(crowns.total_bounds[0] - x_offset, crowns.total_bounds[1] - y_offset + row * tile_height, @@ -669,17 +669,17 @@ def calc_on_everything(): min_val, max_val = np.percentile(valid_data, [1, 99]) stats = { - "mean": np.mean(valid_data), - "min": min_val, - "max": max_val, - "std_dev": np.std(valid_data), + "mean": float(np.mean(valid_data)), + "min": float(min_val), + "max": float(max_val), + "std_dev": float(np.std(valid_data)), } else: stats = { - "mean": None, - "min": None, - "max": None, - "std_dev": None, + "mean": np.nan, + "min": np.nan, + "max": np.nan, + "std_dev": np.nan, } band_stats.append(stats) return band_stats @@ -743,17 +743,17 @@ def calc_on_everything(): if valid_data.size > 0: min_val, max_val = np.percentile(valid_data, [1, 99]) stats = { - "mean": np.mean(valid_data), - "min": min_val, - "max": max_val, - "std_dev": np.std(valid_data), + "mean": float(np.mean(valid_data)), + "min": float(min_val), + "max": float(max_val), + "std_dev": float(np.std(valid_data)), } else: stats = { - "mean": None, - "min": None, - "max": None, - "std_dev": None, + "mean": np.nan, + "min": np.nan, + "max": np.nan, + "std_dev": np.nan, } band_stats.append(stats) return band_stats @@ -1026,7 +1026,7 @@ def create_RGB_from_MS(tile_folder_path: Union[str, Path], # Write the PNG (we must convert shape to (H, W, 3) and then to uint8) output_png = out_path / f"{tif_file.stem}.png" - png_ready = np.moveaxis(transformed, 0, -1).astype(np.uint8) # (H, W, 3) + png_ready = np.moveaxis(transformed, 0, -1).astype(np.uint8) # type: ignore[attr-defined] # (H, W, 3) cv2.imwrite(str(output_png), cv2.cvtColor(png_ready, cv2.COLOR_RGB2BGR)) elif conversion == "first-three": @@ -1065,7 +1065,7 @@ def create_RGB_from_MS(tile_folder_path: Union[str, Path], # Write out the PNG (shape must be (H, W, 3)) output_png = out_path / f"{tif_file.stem}.png" # Move axis from (bands, H, W) -> (H, W, bands) - png_ready = np.moveaxis(data, 0, -1).astype(np.uint8) + png_ready = np.moveaxis(data, 0, -1).astype(np.uint8) # type: ignore[attr-defined] # We expect the order to be [band1, band2, band3], so interpret as R,G,B cv2.imwrite(str(output_png), cv2.cvtColor(png_ready, cv2.COLOR_RGB2BGR)) @@ -1315,7 +1315,7 @@ def to_traintest_folders( # noqa: C901 # random.shuffle(indices) num = list(range(0, len(file_roots))) random.shuffle(num) - ind_split = np.array_split(file_roots, folds) + ind_split = np.array_split(np.array(file_roots), folds) for i in range(0, folds): Path(out_dir / f"train/fold_{i + 1}").mkdir(parents=True, exist_ok=True) From 8a88b91527deea9f4714841724b5dfcffa189648 Mon Sep 17 00:00:00 2001 From: James Ball Date: Mon, 1 Sep 2025 15:04:07 +0100 Subject: [PATCH 04/17] author --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 0958d783..3f49275a 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,7 @@ A tutorial on how to prepare data, train models and make predictions is availabl Detectree2是一个基于Mask R-CNN的自动树冠检测与分割的Python包。您可以在[`model_garden`](https://github.com/PatBall1/detectree2/tree/master/model_garden)中选择预训练模型。[这里](https://patball1.github.io/detectree2/tutorial.html)提供了如何准备数据、训练模型和进行预测的教程。如果有任何问题,合作提案或者需要样例数据,可以邮件联系[James Ball](mailto:ball.jgc@gmail.com)。一些示例数据可以在[这里](https://doi.org/10.5281/zenodo.8136161)下载。 -| | Code developed by James Ball, Seb Hickman, Thomas Koay, Oscar Jiang, Luran Wang, Panagiotis Ioannou, James Hinton and Matthew Archer in the [Forest Ecology and Conservation Group](https://coomeslab.org/) at the University of Cambridge. The Forest Ecology and Conservation Group is led by Professor David Coomes and is part of the University of Cambridge [Conservation Research Institute](https://www.conservation.cam.ac.uk/). | +| | Code developed by James Ball, Seb Hickman, Christopher Kotthoff, Thomas Koay, Oscar Jiang, Luran Wang, Panagiotis Ioannou, James Hinton and Matthew Archer in the [Forest Ecology and Conservation Group](https://coomeslab.org/) at the University of Cambridge. The Forest Ecology and Conservation Group is led by Professor David Coomes and is part of the University of Cambridge [Conservation Research Institute](https://www.conservation.cam.ac.uk/). | | :---: | :--- | From ea7a4ef3c8d119d44810d07f70108d35579c9274 Mon Sep 17 00:00:00 2001 From: James Ball Date: Mon, 1 Sep 2025 17:12:29 +0100 Subject: [PATCH 05/17] set up fixes --- detectree2/data_loading/gdrive.py | 582 ------------------------ detectree2/data_loading/gdrivePull.sh | 15 - detectree2/data_loading/gdrive_load.py | 0 detectree2/data_loading/gee_download.py | 158 ------- detectree2/data_loading/quickstart.py | 50 -- detectree2/models/train.py | 29 +- setup.py | 34 +- 7 files changed, 46 insertions(+), 822 deletions(-) delete mode 100644 detectree2/data_loading/gdrive.py delete mode 100644 detectree2/data_loading/gdrivePull.sh delete mode 100644 detectree2/data_loading/gdrive_load.py delete mode 100644 detectree2/data_loading/gee_download.py delete mode 100644 detectree2/data_loading/quickstart.py diff --git a/detectree2/data_loading/gdrive.py b/detectree2/data_loading/gdrive.py deleted file mode 100644 index cb915218..00000000 --- a/detectree2/data_loading/gdrive.py +++ /dev/null @@ -1,582 +0,0 @@ -"""Wrapper around Google Drive API (v3) to download files from GDrive""" - -import io -import os.path -import pathlib -import pickle -import shutil -from collections import deque -from os import PathLike -from typing import Dict, List, Optional, Union - -from google.auth.transport.requests import Request -from google_auth_oauthlib.flow import InstalledAppFlow -from googleapiclient.discovery import build -from googleapiclient.http import MediaFileUpload, MediaIoBaseDownload -from src.constants import PROJECT_PATH -from src.utils.logging import get_logger -from tqdm.autonotebook import tqdm - -logger = get_logger(__file__) -SECRETS_PATH = PROJECT_PATH / "secrets" -DriveFileJson = Dict[str, str] - - -class DriveAPI: - """ - Python wrapper around the google drive v3 API. - Handels OAuth connection, file browsing, meta data retrieval, file download and - file upload. - """ - - # Define the scopes - SCOPES = [ - "https://www.googleapis.com/auth/drive", # For reading and writing - "https://www.googleapis.com/auth/drive.readonly", # For reading only (download) - ] - # Define GDrive types - GDRIVE_FOLDER = "application/vnd.google-apps.folder" - - def __init__( - self, - credentials_path: Union[str, PathLike] = SECRETS_PATH / "credentials.json", - ): - - # Variable self.creds will store the user access token. - # If no valid token found we will create one. - self.creds = None - self.credentials_path = pathlib.Path(credentials_path) - self._user_data = None - - # Authenticate - self._authenticate() - - # Connect to the API service - self.service = build("drive", "v3", credentials=self.creds) - - def _authenticate(self) -> None: - """ - Authenticate user with user token from google OAuth 2.0. - """ - # The file token.pickle stores the user's access and refresh tokens. It is - # created automatically when the authorization flow completes for the first - # time. - - # Check if file token.pickle exists - if os.path.exists(SECRETS_PATH / "token.pickle"): - - # Read the token from the file and - # store it in the variable self.creds - with open(SECRETS_PATH / "token.pickle", "rb") as token: - self.creds = pickle.load(token) - - # If no valid credentials are available, - # request the user to log in. - if not self.creds or not self.creds.valid: - - # If token is expired, it will be refreshed, - # else, we will request a new one. - if self.creds and self.creds.expired and self.creds.refresh_token: - self.creds.refresh(Request()) - else: - self._perform_oauth() - - # Save the access token in token.pickle - # file for future usage - with open(SECRETS_PATH / "token.pickle", "wb") as token: - pickle.dump(self.creds, token) - - def _perform_oauth(self) -> None: - """ - Perform google OAuth 2.0 flow to authenticate user. - """ - flow = InstalledAppFlow.from_client_secrets_file(self.credentials_path, DriveAPI.SCOPES) - self.creds = flow.run_local_server(port=0) - - @property - def user_data(self) -> dict: - """Returns metadata of currently logged in user""" - if self._user_data is None: - # fetch user data - about = self.service.about() - self._user_data = about.get(fields="user").execute()["user"] - return self._user_data - - @property - def user_email(self) -> str: - """Returns email address of currently logged in user""" - return self.user_data["emailAddress"] - - @property - def username(self) -> str: - """Returns user name of currently logged in user""" - return self.user_data["displayName"] - - def file_download( - self, - file_id: str, - save_path: str, - chunksize: int = 200 * 1024 * 1024, - verbose: bool = False, - ) -> bool: - """ - Download file with given `file_id` and save in `save_path`. - Raises an error if the download fails. - Args: - file_id (str): id of the file to download - save_path (str): path where the file will be saved - chunksize (int, optional): size of the chunks of data to request with - each http request. If the download is slow, try increasing the chunksize - as google limits the number of http requests we can pose per second. - Defaults to 200*1024*1024 (= 200 MB). - verbose (bool): Iff true, show download progress for each file. Defaults to - False. - Returns: - bool: True, iff the file was downloaded successfully. - """ - request = self.service.files().get_media(fileId=file_id) - file_handle = io.BytesIO() - - # Initialise a downloader object to download the file - downloader = MediaIoBaseDownload(file_handle, request, chunksize=chunksize) - done = False - - if verbose: - print("Starting file download") - progress_bar = tqdm(total=100, disable=not verbose) - while not done: - status, done = downloader.next_chunk() - if status and verbose: - progress_bar.update(n=status.progress() * 100) - progress_bar.close() - - file_handle.seek(0) - - # Write the received data to the file - with open(save_path, "wb") as f: - shutil.copyfileobj(file_handle, f) - - if verbose: - print("File Downloaded") - # Return True if file Downloaded successfully - return True - - def get_mimetype(self, file_id: str) -> str: - """ - Returns mime type of the given file - Args: - file_id (str): id of the file - Returns: - str: mime type of the given file - """ - - query = self.service.files().get( - fileId=file_id, - fields="mimeType", - supportsAllDrives=True, - ) - mime_type = query.execute()["mimeType"] - - return mime_type - - def is_mimetype(self, file_id: str, target_mime_type: str) -> bool: - """ - Check mime type of a given file against target mime type - Args: - file_id (str): id of the file - target_mime_type (str): target mime type to check against - Returns: - bool: True, iff the mime type of the given file matches the target mime - type - """ - - return self.get_mimetype(file_id) == target_mime_type - - def is_folder(self, file_id: str) -> bool: - """ - Checks if a given file is a gdrive folder - Args: - file_id (str): id of the file - Returns: - bool: True, iff file is a gdrive folder - """ - - return self.is_mimetype(file_id=file_id, target_mime_type=DriveAPI.GDRIVE_FOLDER) - - def is_tif(self, file_id: str) -> bool: - """ - Checks if a given file is a .tiff file. - Args: - file_id (str): id of the file - Returns: - bool: True, iff file is of type .tiff - """ - - return self.is_mimetype(file_id, target_mime_type="image/tiff") - - def is_kml(self, file_id: str) -> bool: - """ - Checks if a given file is a .kml file. - Args: - file_id (str): id of the file - Returns: - bool: True, iff file is of type .kml - """ - - return self.is_mimetype(file_id, target_mime_type="application/vnd.google-earth.kml+xml") - - def get_folder(self, folder_name: str, all_drives: bool = True, trashed_ok: bool = False) -> DriveFileJson: - """ - Return metadata of gdrive folder with the given `folder_name` - Raises an error if `folder_name` does not identify a unique folder (or does - not exist). - Args: - folder_name (str): The name of the folder for which to obtain metadata' - all_drives (bool): Whether to search TeamDrives. Defaults to True. - trashed_ok (bool): Whether to include bin in search. Defaults to False. - Returns: - DriveFileJson: The metadata of the requested folder as python dict. - """ - - file_browser = self.service.files() - file_metadata = {"name": folder_name, "mimeType": self.GDRIVE_FOLDER} - # Fomulate http request - query = file_browser.list( - q=self._metadata_to_query_string(file_metadata=file_metadata, trashed_ok=trashed_ok), - pageSize=1000, - supportsAllDrives=all_drives, - includeItemsFromAllDrives=all_drives, - ) - # Send http request - result = query.execute()["files"] - - if len(result) == 0: - raise UserWarning("No folder with this name exists") - elif len(result) > 1: - results = "\n".join([str(elem) for elem in result]) - raise UserWarning(f"Multiple folders with this name exist: \n\n{results}") - - return result[0] - - def get_folder_id(self, folder_name: str, all_drives: bool = True) -> str: - """ - Return id of a folder with the given name. - Raises an error if folder does not exist or if multiple folder share the - same name. - Args: - folder_name (str): The folder whose id should be returned - all_drives (bool): Whether to search TeamDrives. Defaults to True. - Returns: - str: id of the folder with the given foldername. - """ - - folder = self.get_folder(folder_name, all_drives=all_drives) - - return folder["id"] - - def get_file_name(self, file_id: str, all_drives: bool = True) -> str: - """ - Get name of a file by id - Args: - file_id (str): The id of the file whose name should be returned - all_drives (bool): Whether to search TeamDrives. Defaults to True. - Returns: - str: The filename - """ - - query = self.service.files().get(fileId=file_id, fields="name", supportsAllDrives=all_drives) - return query.execute()["name"] - - def list_all_files(self, all_drives=True) -> List[DriveFileJson]: - """ - Lists all files which are not folders in gdrive - Returns: - List[DriveFileJson]: A list of all files in the given gdrive account - """ - - file_browser = self.service.files() - query = file_browser.list( - q=f"mimeType!='{self.GDRIVE_FOLDER}'", - pageSize=1000, - supportsAllDrives=all_drives, - includeItemsFromAllDrives=all_drives, - ) - return query.execute()["files"] - - def list_all_folders(self, all_drives=True) -> List[DriveFileJson]: - """ - List all folders in gdrive - Returns: - List[dict]: List of all folders and their id's - """ - - file_browser = self.service.files() - query = file_browser.list( - q=f"mimeType='{self.GDRIVE_FOLDER}'", - supportsAllDrives=all_drives, - includeItemsFromAllDrives=all_drives, - ) - return query.execute()["files"] - - def list_all_drives(self) -> List[DriveFileJson]: - """ - List all drives - Returns: - List[dict]: List of all drives and their id's - """ - query = self.service.drives().list() - return query.execute()["drives"] - - def list_files_in_folder( - self, - folder_id: str, - fields: str = "files (id, name)", - all_drives: bool = True, - **kwargs, - ) -> List[DriveFileJson]: - """ - List all files in a gdrive folder with given `folder_id`. - Args: - folder_id (str): The id of the folder - fields (str, optional): The fields to list. Possible values can be taken - from the gdrive api v3 documentation. Defaults to "files (id, name)". - Returns: - List[dict]: A list of all the files in the given folder. - """ - - file_browser = self.service.files() - - assert self.is_folder(folder_id), "Selected file is not a folder" - - query = file_browser.list( - q=f"'{folder_id}' in parents", - fields=fields, - pageSize=1000, - supportsAllDrives=all_drives, - includeItemsFromAllDrives=all_drives, - **kwargs, - ) - return query.execute()["files"] - - @staticmethod - def _metadata_to_query_string(file_metadata: DriveFileJson, trashed_ok: bool = False) -> str: - """ - Turns file metadata into query string to be used in GDrive API file queries. - - Args: - file_metadata (DriveFileJson): The metadata to turn into a query - trashed_ok (bool, optional): Whether to allow documents in trash in query. - Defaults to False. - - Returns: - str: The query string to find the file with `file_metadata` on GDrive - """ - query_str = f"name='{file_metadata['name']}'" - if "parents" in file_metadata: - query_str += f" and '{file_metadata['parents'][0]}' in parents" - if "mimeType" in file_metadata: - query_str += f" and mimeType='{file_metadata['mimeType']}'" - query_str += f" and trashed={'true' if trashed_ok else 'false'}" - return query_str - - def get_file( - self, - file_metadata: DriveFileJson, - trashed_ok: bool = False, - all_drives: bool = True, - ) -> List[DriveFileJson]: - """ - Returns full metadata for all files which match the given `file_metadata`. - - Args: - file_metadata (DriveFileJson): The file metadata values to use as query - parameters. - trashed_ok (bool, optional): Whether to allow documents in trash in query. - Defaults to False. - all_drives (bool, optional): Whether to include all drives in the search, - (i.e. also TeamDrives). Defaults to True. - - Returns: - List[DriveFileJson]: A list with metadata of all files which match the - values in `file_metadata` - """ - query_str = self._metadata_to_query_string(file_metadata, trashed_ok) - print(query_str) - return (self.service.files().list( - q=query_str, - supportsAllDrives=all_drives, - includeItemsFromAllDrives=all_drives, - pageSize=1000, - ).execute()["files"]) - - def exists( - self, - file_metadata: DriveFileJson, - trashed_ok: bool = False, - all_drives: bool = True, - ) -> bool: - """Returns True, iff a file with the given file_metadata exists on GDrive.""" - return len(self.get_file(file_metadata, trashed_ok, all_drives)) > 0 - - @staticmethod - def _add_parent_to_metadata(file_metadata: DriveFileJson, parent: DriveFileJson) -> DriveFileJson: - """ - Adds parent information to a given metadata template. - - If parent is part of a team drive, team drive information will be passed on - as well. - - Args: - file_metadata (DriveFileJson): metadata to modify such that it will include - parent as its parent - parent (DriveFileJson): metadata of the parent - - Returns: - DriveFileJson: The modified file_metadata with `parent` as a parent. - """ - file_metadata["parents"] = [parent["id"]] - if "driveId" in parent.keys(): - file_metadata["driveId"] = parent["driveId"] - if "teamDriveId" in parent.keys(): - file_metadata["teamDriveId"] = parent["teamDriveId"] - return file_metadata - - def create_folder( - self, - folder_name: str, - parent: Optional[DriveFileJson] = None, - exists_ok: bool = True, - ) -> bool: - """ - Creates a new folder under parent on GDrive. - - Args: - folder_name (str): Name of the folder to - parent (Optional[DriveFileJson], optional): The created parent will be a - subfolder of parent. Defaults to None. - exists_ok (bool, optional): Whether to ignore existing files. - If False, existing folder may be overwritten. Defaults to True. - - Returns: - bool: True iff folder creation succeeded. - """ - # Write metadata for creating a folder - file_metadata = {"name": folder_name, "mimeType": self.GDRIVE_FOLDER} - # Check if folder already exists - if exists_ok and self.exists(file_metadata): - print("Folder exists already") - # Else, create - else: - if parent is not None: - self._add_parent_to_metadata(file_metadata, parent) - # Execute folder creation request - self.service.files().create(body=file_metadata, fields="id", supportsAllDrives=True).execute() - return True - - def upload_file( - self, - file_to_upload: pathlib.Path, - parent: Optional[DriveFileJson] = None, - exists_ok: bool = True, - chunksize: int = 5 * 1024 * 1024, - ) -> bool: - """ - Uploads the file `file_to_upload` to GDrive. - - Args: - file_to_upload (pathlib.Path): Path to the file to upload - parent (Optional[DriveFileJson], optional): GDrive folder under which to - store the uploaded file. Defaults to None. - exists_ok (bool, optional): If exists_ok, existing files in with same parent - and name will not be overwritten. Defaults to True. - chunksize (int, optional): The chunksize to use for uploads (default: 5MB). - Defaults to 5*1024*1024. - - Returns: - bool: True, iff upload was successful. - """ - # Define metadata and media to upload - assert file_to_upload.is_file() - file_metadata = {"name": file_to_upload.name} - if parent is not None: - self._add_parent_to_metadata(file_metadata, parent) - # Check if file already exists - if self.exists(file_metadata) and exists_ok: - print("File exists already.") - return True - # If not, upload - else: - media = MediaFileUpload(file_to_upload, chunksize=chunksize, resumable=True) - # Set up http request to upload file - file = self.service.files().create( - body=file_metadata, - media_body=media, - fields="id", - supportsAllDrives=True, - ) - # Upload file in chunks of the given size - response = None - progress_bar = tqdm(desc=file_metadata["name"], total=100) - while response is None: - status, response = file.next_chunk() - if status: - progress_bar.update(status.progress()) - logger.debug("Upload of %s complete!", file_to_upload.name) - return True - - def upload_folder( - self, - folder_to_upload: pathlib.Path, - parent: Optional[DriveFileJson] = None, - chunksize: int = 5 * 1024 * 1024, - ) -> bool: - """ - Uploads a folder incl. all subfolders and files to GDrive. - - Note: The folder structure is replicated on GDrive. - - Args: - folder_to_upload (pathlib.Path): Path to the folder to upload. - parent (Optional[DriveFileJson], optional): Parent folder under which to put - the uploaded folder. The uploaded folder will be a subfolder of parent. - Defaults to None. - chunksize (int, optional): The chunksize to use for uploads (default: 5 MB). - Defaults to 5*1024*1024. - - Raises: - RuntimeError: If a file is neither a directory nor a file - - Returns: - bool: True, iff the download succeeds - """ - - # Load folder into queue - assert folder_to_upload.is_dir() - queue = deque([folder_to_upload]) - - # Perform breadth first search traversal until queue is empty - while len(queue) > 0: - current_element = queue.popleft() - logger.debug("Visiting %s", (current_element)) - - # Check if current node is a file - if current_element.is_file(): - parent_folder_name = current_element.absolute().parent.name - logger.debug("Parent folder: %s", {parent_folder_name}) - parent = self.get_folder(folder_name=parent_folder_name) - self.upload_file(current_element, parent=parent, exists_ok=True, chunksize=chunksize) - # Else it is current node is a folder and we have to look at all children - elif current_element.is_dir(): - # Note: In first iteration, current_element will be folder_to_upload - # and thus parent is the parent given in the function signature in - # that case. - if current_element != folder_to_upload: - parent_folder_name = current_element.absolute().parent.name - parent = self.get_folder(folder_name=parent_folder_name) - self.create_folder(current_element.name, parent=parent, exists_ok=True) - folder_contents = list(current_element.glob("*")) - queue.extendleft(folder_contents) - else: - raise RuntimeError - - return True diff --git a/detectree2/data_loading/gdrivePull.sh b/detectree2/data_loading/gdrivePull.sh deleted file mode 100644 index b16508df..00000000 --- a/detectree2/data_loading/gdrivePull.sh +++ /dev/null @@ -1,15 +0,0 @@ -#!/bin/bash -#SBATCH --partition=high-mem -#SBATCH -o %j.out -#SBATCH -e %j.err -#SBATCH --time=48:00:00 -#SBATCH --mem=128GB - - -conda activate gediee -#pip install -e /gws/nopw/j04/forecol/jgcb3/gedi/gedi_ee/ - -DRIVE_FOLDER=Manuscripts/forecasting -DOWNLOAD_FOLDER=/home/users/patball/forecol/jgcb3/forecasting - -python -u /gws/nopw/j04/forecol/jgcb3/gedi/gedi_ee/src/data/gee_download.py $DRIVE_FOLDER $DOWNLOAD_FOLDER > dwnld_gedi_${SLURM_JOB_ID}.txt \ No newline at end of file diff --git a/detectree2/data_loading/gdrive_load.py b/detectree2/data_loading/gdrive_load.py deleted file mode 100644 index e69de29b..00000000 diff --git a/detectree2/data_loading/gee_download.py b/detectree2/data_loading/gee_download.py deleted file mode 100644 index e6e69557..00000000 --- a/detectree2/data_loading/gee_download.py +++ /dev/null @@ -1,158 +0,0 @@ -"""Download script for GEE .tif files from Google Drive""" - -import argparse -import logging -import pathlib - -from src.constants import DATA_PATH -from src.data.gdrive import DriveAPI -from tqdm.autonotebook import tqdm - -DEFAULT_SAVE_PATH = DATA_PATH / "gdrive" - - -def download_files( - folder_id: str, - save_path: str, - recursive: bool = True, - overwrite: bool = False, - logger: logging.Logger = logging.getLogger(), -) -> None: - """ - Download all .tif files in a given directory. Can work recursively. - Args: - folder_id (str): gdrive id of the folder to download - save_path (pathlib.Path): path to save the downloaded data at - recursive (bool, optional): If True, downloads all subfolders recursively. - Defaults to True. - overwrite (bool): Whether to overwrite existing files. Defaults to False. - logger (logging.Logger): Optional argument to pass a logger to the download - script for debugging. Defaults to logging.getLogger(). - """ - - elements = gdrive.list_files_in_folder(folder_id) - save_path.mkdir(exist_ok=True) - - progress_bar = tqdm( - elements, - position=0, - leave=True, - ) - for element in progress_bar: - file_id = element["id"] - file_name = element["name"] - progress_bar.set_description(f"Working on {file_name}") - - # If element is a directory, recursively clone: - if recursive and gdrive.is_folder(file_id): - subdir_path = save_path / file_name - logger.debug("Creating path at %s", subdir_path) - - # Create subfolder and start recursive download - subdir_path.mkdir(mode=0o777, exist_ok=True) - download_files(file_id, subdir_path) - - # Download tifs - elif gdrive.is_tif(file_id): - file_path = save_path / file_name - - if file_path.exists() and not overwrite: - logger.info("File %s already exists. Not overwriting.", file_path) - else: - logger.debug("Saving file at %s", file_path) - # Save file - gdrive.file_download( - file_id, - save_path=file_path, - chunksize=800 * 1024 * 1024, # 800 MB - verbose=False, - ) - file_path.chmod(0o664) - # Download tifs - elif gdrive.is_kml(file_id): - file_path = save_path / file_name - - if file_path.exists() and not overwrite: - logger.info("File %s already exists. Not overwriting.", file_path) - else: - logger.debug("Saving file at %s", file_path) - # Save file - gdrive.file_download( - file_id, - save_path=file_path, - chunksize=800 * 1024 * 1024, # 800 MB - verbose=False, - ) - file_path.chmod(0o664) - - # Print warnings for other files - else: - # logger.warning("Unknown file of type %s", gdrive.get_mime_type(file_id)) - logger.warning("Unknown file of type") - file_path = save_path / file_name - - if file_path.exists() and not overwrite: - logger.info("File %s already exists. Not overwriting.", file_path) - else: - logger.debug("Saving file at %s", file_path) - # Save file - gdrive.file_download( - file_id, - save_path=file_path, - chunksize=800 * 1024 * 1024, # 800 MB - verbose=False, - ) - file_path.chmod(0o664) - - -if __name__ == "__main__": - - # Parse command line arguments - parser = argparse.ArgumentParser(description="Google Drive TIF download script") - parser.add_argument( - "gdrive_folder_name", - help="The name of the gdrive folder from which to download the TIF files from.", - type=str, - ) - parser.add_argument( - "save_path", - help="The folder in which to save the downloaded files.", - type=str, - default=DEFAULT_SAVE_PATH, - nargs="?", # Argument is optional - ) - parser.add_argument( - "-r", - "--recursive", - help="If True, copy content of gdrive folder recursively. Defaults to True.", - type=bool, - default=True, - nargs="?", # Argument is optional - ) - parser.add_argument( - "-o", - "--overwrite", - help=("If True, existing files are downloaded again and overwritten. " - "Defaults to False."), - type=bool, - default=False, - nargs="?", # Argument is optional - ) - - args = parser.parse_args() - - # Connect to google drive and set folders to download and save path - gdrive = DriveAPI() - print("Signing in to Google") - print(f"Signed in as {gdrive.username} ({gdrive.user_email})") - gdrive_folder_id = gdrive.get_folder_id(args.gdrive_folder_name) - local_save_path = pathlib.Path(args.save_path) - - # Start download - print(f"Starting download of {args.gdrive_folder_name}. Saving to {local_save_path}") - download_files( - folder_id=gdrive_folder_id, - save_path=local_save_path, - recursive=args.recursive, - overwrite=args.overwrite, - ) diff --git a/detectree2/data_loading/quickstart.py b/detectree2/data_loading/quickstart.py deleted file mode 100644 index 9ecb63ce..00000000 --- a/detectree2/data_loading/quickstart.py +++ /dev/null @@ -1,50 +0,0 @@ -from __future__ import print_function - -import os.path - -from google.auth.transport.requests import Request -from google.oauth2.credentials import Credentials -from google_auth_oauthlib.flow import InstalledAppFlow -from googleapiclient.discovery import build - -# If modifying these scopes, delete the file token.json. -SCOPES = ["https://www.googleapis.com/auth/drive.metadata.readonly"] - - -def main(): - """Shows basic usage of the Drive v3 API. - Prints the names and ids of the first 10 files the user has access to. - """ - creds = None - # The file token.json stores the user's access and refresh tokens, and is - # created automatically when the authorization flow completes for the first - # time. - if os.path.exists("token.json"): - creds = Credentials.from_authorized_user_file("token.json", SCOPES) - # If there are no (valid) credentials available, let the user log in. - if not creds or not creds.valid: - if creds and creds.expired and creds.refresh_token: - creds.refresh(Request()) - else: - flow = InstalledAppFlow.from_client_secrets_file("credentials.json", SCOPES) - creds = flow.run_local_server(port=0) - # Save the credentials for the next run - with open("token.json", "w") as token: - token.write(creds.to_json()) - - service = build("drive", "v3", credentials=creds) - - # Call the Drive v3 API - results = service.files().list(pageSize=10, fields="nextPageToken, files(id, name)").execute() - items = results.get("files", []) - - if not items: - print("No files found.") - else: - print("Files:") - for item in items: - print("{0} ({1})".format(item["name"], item["id"])) - - -if __name__ == "__main__": - main() diff --git a/detectree2/models/train.py b/detectree2/models/train.py index c39c9bf8..c5dc6535 100644 --- a/detectree2/models/train.py +++ b/detectree2/models/train.py @@ -120,6 +120,14 @@ def __call__(self, dataset_dict): # Transpose image dimensions to match expected format (H, W, C) img = np.transpose(img, (1, 2, 0)).astype("float32") + # Basic band-count guard for multispectral inputs + expected_bands = int(getattr(self.cfg.INPUT, "NUM_IN_CHANNELS", img.shape[2])) + if img.shape[2] != expected_bands: + self.logger.warning( + f"Loaded image has {img.shape[2]} bands but cfg expects {expected_bands}. " + "This may cause a model input mismatch." + ) + # Size check similar to utils.check_image_size if img.shape[:2] != (dataset_dict.get("height"), dataset_dict.get("width")): self.logger.warning( @@ -492,6 +500,21 @@ def resume_or_load(self, resume=True): # at the next iteration self.start_iter = self.iter + 1 + # Early guard for MS: expected channels vs model conv1 + try: + desired_channels = int(getattr(self.cfg.INPUT, "NUM_IN_CHANNELS", 3)) + model_in_channels = int(self.model.backbone.bottom_up.stem.conv1.weight.shape[1]) + except Exception: + desired_channels = 3 + model_in_channels = 3 + + if self.cfg.IMGMODE == "ms" and desired_channels != model_in_channels: + raise RuntimeError( + f"Input channel mismatch: cfg expects {desired_channels} bands for multispectral input, " + f"but model conv1 expects {model_in_channels}. Please adapt the backbone's first conv to " + f"{desired_channels} channels or use 3-band inputs." + ) + if self.cfg.MODEL.WEIGHTS: device = self.model.backbone.bottom_up.stem.conv1.weight.device req_grad = self.model.backbone.bottom_up.stem.conv1.weight.requires_grad @@ -528,7 +551,7 @@ def resume_or_load(self, resume=True): ) with torch.no_grad(): self.model.backbone.bottom_up.stem.conv1.weight[:, :3] = checkpoint[:, :3] - multiply_conv1_weights(self.model) + # Do not silently expand conv1 here; rely on explicit model adaptation for MS self.model.backbone.bottom_up.stem.conv1.weight.to(device) self.model.backbone.bottom_up.stem.conv1.weight.requires_grad = req_grad @@ -1042,13 +1065,15 @@ def setup_cfg( cfg.RESIZE = resize cfg.INPUT.MIN_SIZE_TRAIN = 1000 cfg.IMGMODE = imgmode # "rgb" or "ms" (multispectral) + # Track intended input channels for early validation + cfg.INPUT.NUM_IN_CHANNELS = num_bands if num_bands > 3: # Adjust PIXEL_MEAN and PIXEL_STD for the number of bands default_pixel_mean = cfg.MODEL.PIXEL_MEAN default_pixel_std = cfg.MODEL.PIXEL_STD # Extend or truncate the PIXEL_MEAN and PIXEL_STD based on num_bands cfg.MODEL.PIXEL_MEAN = (default_pixel_mean * (num_bands // len(default_pixel_mean)) + - default_pixel_mean[:num_bands % len(default_pixel_mean)]) + default_pixel_mean[:num_bands % len(default_pixel_mean)]) cfg.MODEL.PIXEL_STD = (default_pixel_std * (num_bands // len(default_pixel_std)) + default_pixel_std[:num_bands % len(default_pixel_std)]) if visualize_training: diff --git a/setup.py b/setup.py index eace5294..8d797cb7 100644 --- a/setup.py +++ b/setup.py @@ -7,24 +7,28 @@ author_email="ball.jgc@gmail.com", description="Detectree packaging", url="https://github.com/PatBall1/detectree2", - # package_dir={"": "detectree2"}, packages=find_packages(), test_suite="detectree2.tests.test_all.suite", + python_requires=">=3.8", install_requires=[ + # Core "pyyaml>=5.1", - "GDAL>=1.11", - "numpy", - "rtree", - "proj", - "geos", - "pypng", - "pygeos", - "shapely", - "geopandas", - "rasterio==1.3a3", - "fiona==1.9.6", - "pycrs", - "descartes", - "detectron2@git+https://github.com/facebookresearch/detectron2.git", + "numpy>=1.20", + "pandas>=1.3", + "tqdm>=4.60", + "opencv-python>=4.5", + # Geospatial — prefer conda for GDAL stack; avoid alpha pins + "shapely>=1.8,<2.0", + "geopandas>=0.12", + "rasterio>=1.2,<1.4", + "fiona>=1.8,<1.10", + "rtree>=0.9", + # Evaluation utils + "pycocotools>=2.0.4", + ], + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", ], ) From 8e9a198de251aa5e7f303dccd8c377ef6d50457a Mon Sep 17 00:00:00 2001 From: James Ball Date: Mon, 1 Sep 2025 17:20:17 +0100 Subject: [PATCH 06/17] sphinx fix --- docs/source/api.rst | 16 ++++++++++++++++ docs/source/conf.py | 43 +++++++++++++++++++++++++++++++++++++++---- docs/source/index.rst | 1 + 3 files changed, 56 insertions(+), 4 deletions(-) create mode 100644 docs/source/api.rst diff --git a/docs/source/api.rst b/docs/source/api.rst new file mode 100644 index 00000000..13624dbb --- /dev/null +++ b/docs/source/api.rst @@ -0,0 +1,16 @@ +API Reference +============= + +This section provides the API reference for the main modules in ``detectree2``. + +.. autosummary:: + :toctree: api/ + :recursive: + + detectree2.preprocessing.tiling + detectree2.models.train + detectree2.models.predict + detectree2.models.outputs + detectree2.models.evaluation + detectree2.data_loading.custom + diff --git a/docs/source/conf.py b/docs/source/conf.py index 0b20e113..07013847 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -13,7 +13,8 @@ import os import sys -sys.path.insert(0, os.path.abspath("../detectree2")) +# Add repository root to sys.path so autodoc can import the package +sys.path.insert(0, os.path.abspath("../../")) # -- Project information ----------------------------------------------------- @@ -31,9 +32,43 @@ # ones. extensions = [ "sphinx.ext.autodoc", - "sphinx.ext.autosectionlabel", # cannot get this to work - "sphinx.ext.todo", # see contributing guide - "nbsphinx", # enables *.ipynb to be rendered in the docs as pages / notebooks + "sphinx.ext.autosummary", + "sphinx.ext.napoleon", + "sphinx.ext.viewcode", + "sphinx.ext.autosectionlabel", # section refs + "sphinx.ext.todo", # see contributing guide + "nbsphinx", # render notebooks +] + +# Autodoc / autosummary settings +autosummary_generate = True +autodoc_default_options = { + "members": True, + "undoc-members": False, + "inherited-members": False, + "show-inheritance": True, +} + +# Mock heavy optional dependencies to avoid import failures on docs build +autodoc_mock_imports = [ + "torch", + "detectron2", + "detectron2.engine", + "detectron2.config", + "detectron2.data", + "detectron2.layers", + "detectron2.structures", + "detectron2.utils", + "detectron2.evaluation", + "detectron2.checkpoint", + "detectron2.model_zoo", + "cv2", + "rasterio", + "geopandas", + "shapely", + "fiona", + "rtree", + "pycocotools", ] autosectionlabel_prefix_document = True todo_include_todos = True diff --git a/docs/source/index.rst b/docs/source/index.rst index 0cfa9d1b..ce30c26a 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -46,6 +46,7 @@ Accurate delineation of individual tree crowns in tropical forests from aerial R cluster contributing using-git + api .. _notebooks/contributing_guide .. _notebooks/trainingJB From f94481cdd12d0185e9810a398e359a8e49c7bcd8 Mon Sep 17 00:00:00 2001 From: James Ball Date: Mon, 1 Sep 2025 17:29:08 +0100 Subject: [PATCH 07/17] shapely version --- requirements/test-requirements.txt | 4 +++- setup.py | 6 +++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/requirements/test-requirements.txt b/requirements/test-requirements.txt index 434092fd..aa598df0 100644 --- a/requirements/test-requirements.txt +++ b/requirements/test-requirements.txt @@ -6,8 +6,10 @@ virtualenv -e . # external requirements +pytest +pytest-dependency click # BSD-3-Clause sphinx # BSD-3-Clause coverage # MIT flake8 # MIT -python-dotenv # BSD-3-Clause \ No newline at end of file +python-dotenv # BSD-3-Clause diff --git a/setup.py b/setup.py index 8d797cb7..07cbf34f 100644 --- a/setup.py +++ b/setup.py @@ -17,9 +17,9 @@ "pandas>=1.3", "tqdm>=4.60", "opencv-python>=4.5", - # Geospatial — prefer conda for GDAL stack; avoid alpha pins - "shapely>=1.8,<2.0", - "geopandas>=0.12", + # Geospatial — shapely 2.x required by evaluation module (make_valid) + "shapely>=2.0", + "geopandas>=0.13", "rasterio>=1.2,<1.4", "fiona>=1.8,<1.10", "rtree>=0.9", From 03c95538ec3cf0950fa289fd54cc6db832fa449b Mon Sep 17 00:00:00 2001 From: James Ball Date: Mon, 1 Sep 2025 17:41:22 +0100 Subject: [PATCH 08/17] mypy fixes --- detectree2/models/predict.py | 2 +- detectree2/models/train.py | 9 ++-- detectree2/preprocessing/tiling.py | 82 ++++++++++++++++-------------- 3 files changed, 50 insertions(+), 43 deletions(-) diff --git a/detectree2/models/predict.py b/detectree2/models/predict.py index 18bd01ad..8194bc2c 100644 --- a/detectree2/models/predict.py +++ b/detectree2/models/predict.py @@ -75,7 +75,7 @@ def predict_on_data( with rasterio.open(file_name) as src: img = src.read() # Transpose to match expected format (H, W, C) - img = np.transpose(img, (1, 2, 0)) + img = img.transpose(1, 2, 0) else: print(f"Unsupported file extension {file_ext} for file {file_name}") continue diff --git a/detectree2/models/train.py b/detectree2/models/train.py index c5dc6535..4ab8add9 100644 --- a/detectree2/models/train.py +++ b/detectree2/models/train.py @@ -748,7 +748,10 @@ def get_tree_dicts(directory: str, class_mapping: Optional[Dict[str, int]] = Non # Make sure we have the correct height and width # If image path ends in .png use cv2 to get height and width else if image path ends in .tif use rasterio if filename.endswith(".png"): - height, width = cv2.imread(filename).shape[:2] + img_arr = cv2.imread(filename) + if img_arr is None: + raise FileNotFoundError(f"Failed to read image at path: {filename}") + height, width = img_arr.shape[:2] elif filename.endswith(".tif"): with rasterio.open(filename) as src: height, width = src.shape @@ -767,7 +770,7 @@ def get_tree_dicts(directory: str, class_mapping: Optional[Dict[str, int]] = Non print("Skipping annotation of type", anno["type"], "in file", filename) continue px = [a[0] for a in anno["coordinates"][0]] - py = [np.array(height) - a[1] for a in anno["coordinates"][0]] + py = [height - a[1] for a in anno["coordinates"][0]] poly = [(x, y) for x, y in zip(px, py)] poly = [p for x in poly for p in x] @@ -778,7 +781,7 @@ def get_tree_dicts(directory: str, class_mapping: Optional[Dict[str, int]] = Non category_id = 0 # Default to "tree" if no class mapping is provided obj = { - "bbox": [np.min(px), np.min(py), np.max(px), np.max(py)], + "bbox": [min(px), min(py), max(px), max(py)], "bbox_mode": BoxMode.XYXY_ABS, "segmentation": [poly], "category_id": category_id, diff --git a/detectree2/preprocessing/tiling.py b/detectree2/preprocessing/tiling.py index 54ee7202..0be30d45 100644 --- a/detectree2/preprocessing/tiling.py +++ b/detectree2/preprocessing/tiling.py @@ -224,8 +224,9 @@ def process_tile(img_path: str, ) return None - # Apply nan mask - out_img[np.broadcast_to((nan_mask == 1)[None, :, :], out_img.shape)] = 0 + # Apply nan mask across all bands (3D mask) without using np.broadcast_to (typing-friendly) + band_mask = np.stack([nan_mask == 1] * out_img.shape[0], axis=0) + out_img[band_mask] = 0 dtype, nodata = dtype_map.get(out_img.dtype, (None, None)) if dtype is None: @@ -246,8 +247,9 @@ def process_tile(img_path: str, with rasterio.open(out_tif, "w", **out_meta) as dest: dest.write(out_img) - r, g, b = out_img[0], out_img[1], out_img[2] - rgb = np.dstack((b, g, r)) # Reorder for cv2 (BGRA) + r, g, b = out_img[0], out_img[1], out_img[2] + # Reorder channels to B, G, R for OpenCV + rgb = np.stack((b, g, r), axis=2) # Rescale to 0-255 if necessary if np.nanmax(g) > 255: @@ -255,7 +257,7 @@ def process_tile(img_path: str, else: rgb_rescaled = rgb - np.clip(rgb_rescaled, 0, 255, out=rgb_rescaled) + rgb_rescaled = np.clip(rgb_rescaled, 0, 255) cv2.imwrite(str(out_path_root.with_suffix(".png").resolve()), rgb_rescaled.astype(np.uint8)) @@ -393,8 +395,9 @@ def process_tile_ms(img_path: str, # additional clip to make sure out_img = np.clip(out_img.astype(np.float32), 1.0, 255.0) - # Apply nan mask - out_img[np.broadcast_to((nan_mask == 1)[None, :, :], out_img.shape)] = 0.0 + # Apply nan mask across all bands (3D mask) without using np.broadcast_to + band_mask = np.stack([nan_mask == 1] * out_img.shape[0], axis=0) + out_img[band_mask] = 0.0 dtype, nodata = dtype_map.get(out_img.dtype, (None, None)) if dtype is None: @@ -548,18 +551,18 @@ def _calculate_tile_placements( if tile_placement == "grid": with rasterio.open(img_path) as data: - coordinates = [ - (minx, miny) for minx in np.arange( - math.ceil(data.bounds[0]) + buffer, data.bounds[2] - tile_width - buffer, tile_width, int) - for miny in np.arange( - math.ceil(data.bounds[1]) + buffer, data.bounds[3] - tile_height - buffer, tile_height, int) - ] - if overlapping_tiles: - coordinates.extend([(minx, miny) for minx in np.arange( - math.ceil(data.bounds[0]) + buffer + tile_width // 2, data.bounds[2] - tile_width - buffer - - tile_width // 2, tile_width, int) for miny in np.arange( - math.ceil(data.bounds[1]) + buffer + tile_height // 2, data.bounds[3] - tile_height - buffer - - tile_height // 2, tile_height, int)]) + coordinates = [ + (minx, miny) for minx in np.arange( + math.ceil(data.bounds[0]) + buffer, data.bounds[2] - tile_width - buffer, tile_width, dtype=int) + for miny in np.arange( + math.ceil(data.bounds[1]) + buffer, data.bounds[3] - tile_height - buffer, tile_height, dtype=int) + ] + if overlapping_tiles: + coordinates.extend([(minx, miny) for minx in np.arange( + math.ceil(data.bounds[0]) + buffer + tile_width // 2, data.bounds[2] - tile_width - buffer - + tile_width // 2, tile_width, dtype=int) for miny in np.arange( + math.ceil(data.bounds[1]) + buffer + tile_height // 2, data.bounds[3] - tile_height - buffer - + tile_height // 2, tile_height, dtype=int)]) elif tile_placement == "adaptive": if crowns is None: @@ -725,23 +728,23 @@ def calc_on_everything(): # Compute statistics for each band band_stats = [] - for band_idx in range(1, len(band_aggregates) + 1) if mode == "ms" else range(1, 4): - valid_data = np.array(band_aggregates[band_idx]) - if valid_data.size > 0: - min_val, max_val = np.percentile(valid_data, [1, 99]) - stats = { - "mean": np.mean(valid_data), - "min": min_val, - "max": max_val, - "std_dev": np.std(valid_data), - } - else: - stats = { - "mean": None, - "min": None, - "max": None, - "std_dev": None, - } + for band_idx in range(1, len(band_aggregates) + 1) if mode == "ms" else range(1, 4): + valid_data = np.array(band_aggregates[band_idx]) + if valid_data.size > 0: + min_val, max_val = np.percentile(valid_data, [1, 99]) + stats = { + "mean": np.mean(valid_data), + "min": min_val, + "max": max_val, + "std_dev": np.std(valid_data), + } + else: + stats = { + "mean": float("nan"), + "min": float("nan"), + "max": float("nan"), + "std_dev": float("nan"), + } band_stats.append(stats) return band_stats @@ -1012,7 +1015,8 @@ def create_RGB_from_MS(tile_folder_path: Union[str, Path], # Write the PNG (we must convert shape to (H, W, 3) and then to uint8) output_png = out_path / f"{tif_file.stem}.png" - png_ready = np.moveaxis(transformed, 0, -1).astype(np.uint8) # (H, W, 3) + # Move axis from (bands, H, W) -> (H, W, bands) + png_ready = transformed.transpose(1, 2, 0).astype(np.uint8) cv2.imwrite(str(output_png), cv2.cvtColor(png_ready, cv2.COLOR_RGB2BGR)) elif conversion == "first-three": @@ -1051,7 +1055,7 @@ def create_RGB_from_MS(tile_folder_path: Union[str, Path], # Write out the PNG (shape must be (H, W, 3)) output_png = out_path / f"{tif_file.stem}.png" # Move axis from (bands, H, W) -> (H, W, bands) - png_ready = np.moveaxis(data, 0, -1).astype(np.uint8) + png_ready = data.transpose(1, 2, 0).astype(np.uint8) # We expect the order to be [band1, band2, band3], so interpret as R,G,B cv2.imwrite(str(output_png), cv2.cvtColor(png_ready, cv2.COLOR_RGB2BGR)) @@ -1301,7 +1305,7 @@ def to_traintest_folders( # noqa: C901 # random.shuffle(indices) num = list(range(0, len(file_roots))) random.shuffle(num) - ind_split = np.array_split(file_roots, folds) + ind_split = np.array_split(np.array(file_roots), folds) for i in range(0, folds): Path(out_dir / f"train/fold_{i + 1}").mkdir(parents=True, exist_ok=True) From d482f8d34cbe32a14d9101297bb09cfb0bb58bce Mon Sep 17 00:00:00 2001 From: James Ball Date: Mon, 1 Sep 2025 17:49:07 +0100 Subject: [PATCH 09/17] mypy fixes --- detectree2/preprocessing/tiling.py | 39 ++++++++++++++++++------------ 1 file changed, 23 insertions(+), 16 deletions(-) diff --git a/detectree2/preprocessing/tiling.py b/detectree2/preprocessing/tiling.py index 0be30d45..115f3a17 100644 --- a/detectree2/preprocessing/tiling.py +++ b/detectree2/preprocessing/tiling.py @@ -248,8 +248,8 @@ def process_tile(img_path: str, dest.write(out_img) r, g, b = out_img[0], out_img[1], out_img[2] - # Reorder channels to B, G, R for OpenCV - rgb = np.stack((b, g, r), axis=2) + # Reorder channels to B, G, R for OpenCV (use list for mypy-friendly typing) + rgb = np.stack([b, g, r], axis=2) # Rescale to 0-255 if necessary if np.nanmax(g) > 255: @@ -257,7 +257,7 @@ def process_tile(img_path: str, else: rgb_rescaled = rgb - rgb_rescaled = np.clip(rgb_rescaled, 0, 255) + rgb_rescaled = np.clip(rgb_rescaled.astype(np.float32), 0.0, 255.0) cv2.imwrite(str(out_path_root.with_suffix(".png").resolve()), rgb_rescaled.astype(np.uint8)) @@ -551,18 +551,25 @@ def _calculate_tile_placements( if tile_placement == "grid": with rasterio.open(img_path) as data: + start_x = int(math.ceil(data.bounds[0]) + buffer) + stop_x = int(data.bounds[2] - tile_width - buffer) + start_y = int(math.ceil(data.bounds[1]) + buffer) + stop_y = int(data.bounds[3] - tile_height - buffer) coordinates = [ - (minx, miny) for minx in np.arange( - math.ceil(data.bounds[0]) + buffer, data.bounds[2] - tile_width - buffer, tile_width, dtype=int) - for miny in np.arange( - math.ceil(data.bounds[1]) + buffer, data.bounds[3] - tile_height - buffer, tile_height, dtype=int) + (minx, miny) + for minx in range(start_x, stop_x, tile_width) + for miny in range(start_y, stop_y, tile_height) ] if overlapping_tiles: - coordinates.extend([(minx, miny) for minx in np.arange( - math.ceil(data.bounds[0]) + buffer + tile_width // 2, data.bounds[2] - tile_width - buffer - - tile_width // 2, tile_width, dtype=int) for miny in np.arange( - math.ceil(data.bounds[1]) + buffer + tile_height // 2, data.bounds[3] - tile_height - buffer - - tile_height // 2, tile_height, dtype=int)]) + start_x2 = int(math.ceil(data.bounds[0]) + buffer + tile_width // 2) + stop_x2 = int(data.bounds[2] - tile_width - buffer - tile_width // 2) + start_y2 = int(math.ceil(data.bounds[1]) + buffer + tile_height // 2) + stop_y2 = int(data.bounds[3] - tile_height - buffer - tile_height // 2) + coordinates.extend([ + (minx, miny) + for minx in range(start_x2, stop_x2, tile_width) + for miny in range(start_y2, stop_y2, tile_height) + ]) elif tile_placement == "adaptive": if crowns is None: @@ -733,10 +740,10 @@ def calc_on_everything(): if valid_data.size > 0: min_val, max_val = np.percentile(valid_data, [1, 99]) stats = { - "mean": np.mean(valid_data), - "min": min_val, - "max": max_val, - "std_dev": np.std(valid_data), + "mean": float(np.mean(valid_data)), + "min": float(min_val), + "max": float(max_val), + "std_dev": float(np.std(valid_data)), } else: stats = { From 575cdaef730ef2cd5e4a394671d498f3ed73e54f Mon Sep 17 00:00:00 2001 From: James Ball Date: Mon, 1 Sep 2025 17:55:46 +0100 Subject: [PATCH 10/17] mypy fixes --- detectree2/preprocessing/tiling.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/detectree2/preprocessing/tiling.py b/detectree2/preprocessing/tiling.py index 115f3a17..41edcbae 100644 --- a/detectree2/preprocessing/tiling.py +++ b/detectree2/preprocessing/tiling.py @@ -257,7 +257,9 @@ def process_tile(img_path: str, else: rgb_rescaled = rgb - rgb_rescaled = np.clip(rgb_rescaled.astype(np.float32), 0.0, 255.0) + rgb_rescaled = np.clip( + rgb_rescaled.astype(np.float32), np.float32(0.0), np.float32(255.0) + ) cv2.imwrite(str(out_path_root.with_suffix(".png").resolve()), rgb_rescaled.astype(np.uint8)) From 3e011e5fe37dc294440d5f5b1d7a6c9eb8bf0b10 Mon Sep 17 00:00:00 2001 From: James Ball Date: Mon, 1 Sep 2025 21:12:17 +0100 Subject: [PATCH 11/17] flake8 --- detectree2/models/outputs.py | 28 +++++++++----- detectree2/models/predict.py | 1 - detectree2/preprocessing/tiling.py | 62 +++++++++++++++++++----------- 3 files changed, 57 insertions(+), 34 deletions(-) diff --git a/detectree2/models/outputs.py b/detectree2/models/outputs.py index 017fdb24..d81d6e35 100644 --- a/detectree2/models/outputs.py +++ b/detectree2/models/outputs.py @@ -19,7 +19,7 @@ import rasterio from rasterio.crs import CRS from shapely.affinity import scale -from shapely.geometry import Polygon, box, shape +from shapely.geometry import Polygon, box from shapely.ops import orient from tqdm import tqdm @@ -344,12 +344,14 @@ def calc_iou(shape1, shape2): return iou -def clean_crowns(crowns, - iou_threshold= 0.7, - confidence= 0.2, - area_threshold = 2, - field= "Confidence_score", - verbose= True) -> gpd.GeoDataFrame: +def clean_crowns( + crowns, + iou_threshold=0.7, + confidence=0.2, + area_threshold=2, + field="Confidence_score", + verbose=True, +) -> gpd.GeoDataFrame: """ Clean overlapping crowns by first identifying all candidate overlapping pairs via a spatial join, then clustering crowns into connected components (where an edge is added if two crowns have IoU @@ -400,9 +402,15 @@ def union(x, y): parent[ry] = rx # 4. For each candidate pair, compute IoU and, if it exceeds the threshold, merge the groups. - for idx, row in tqdm(join.iterrows(), total=len(join), desc="clean_crowns: Processing candidate pairs", smoothing=0, disable=not verbose): - i = row.name # index from left table (crowns) - j = row["index_right"] # index from right table (crowns) + for idx, row in tqdm( + join.iterrows(), + total=len(join), + desc="clean_crowns: Processing candidate pairs", + smoothing=0, + disable=not verbose, + ): + i = row.name # index from left table (crowns) + j = row["index_right"] # index from right table (crowns) # To avoid duplicate work, skip if i and j are already in the same group. if find(i) == find(j): continue diff --git a/detectree2/models/predict.py b/detectree2/models/predict.py index 8194bc2c..2e389995 100644 --- a/detectree2/models/predict.py +++ b/detectree2/models/predict.py @@ -7,7 +7,6 @@ from pathlib import Path import cv2 -import numpy as np import rasterio from detectron2.engine import DefaultPredictor from detectron2.evaluation.coco_evaluation import instances_to_coco_json diff --git a/detectree2/preprocessing/tiling.py b/detectree2/preprocessing/tiling.py index 41edcbae..7994b332 100644 --- a/detectree2/preprocessing/tiling.py +++ b/detectree2/preprocessing/tiling.py @@ -183,9 +183,9 @@ def process_tile(img_path: str, out_img, out_transform = mask(data, shapes=coords, nodata=nodata, crop=True, indexes=[1, 2, 3]) mask_tif = None - if mask_gdf is not None: - #if mask_gdf.crs != data.crs: - # mask_gdf = mask_gdf.to_crs(data.crs) #TODO is this necessary? + if mask_gdf is not None: + # if mask_gdf.crs != data.crs: + # mask_gdf = mask_gdf.to_crs(data.crs) # TODO is this necessary? mask_tif = rasterio.features.geometry_mask([geom for geom in mask_gdf.geometry], transform=out_transform, @@ -218,10 +218,17 @@ def process_tile(img_path: str, # If the tile is mostly empty or mostly nan, don't save it invalid = (zero_mask | nan_mask).sum() - if invalid > nan_threshold * totalpix: - logger.warning( - f"Skipping tile at ({minx}, {miny}) due to being over nodata threshold. Threshold: {nan_threshold}, nodata ration: {invalid / totalpix}" - ) + if invalid > nan_threshold * totalpix: + logger.warning( + "Skipping tile at (%s, %s) due to being over nodata threshold.", + minx, + miny, + ) + logger.warning( + "Threshold: %s, nodata ratio: %s", + nan_threshold, + invalid / totalpix, + ) return None # Apply nan mask across all bands (3D mask) without using np.broadcast_to (typing-friendly) @@ -343,9 +350,9 @@ def process_tile_ms(img_path: str, out_img, out_transform = mask(data, shapes=coords, nodata=nodata, crop=True, indexes=bands_to_read) mask_tif = None - if mask_gdf is not None: - #if mask_gdf.crs != data.crs: - # mask_gdf = mask_gdf.to_crs(data.crs) #TODO is this necessary? + if mask_gdf is not None: + # if mask_gdf.crs != data.crs: + # mask_gdf = mask_gdf.to_crs(data.crs) # TODO is this necessary? mask_tif = rasterio.features.geometry_mask([geom for geom in mask_gdf.geometry], transform=out_transform, @@ -377,16 +384,23 @@ def process_tile_ms(img_path: str, # If the tile is mostly empty or mostly nan, don't save it invalid = (zero_mask | nan_mask).sum() - if invalid > nan_threshold * totalpix: - logger.warning( - f"Skipping tile at ({minx}, {miny}) due to being over nodata threshold. Threshold: {nan_threshold}, nodata ration: {invalid / totalpix}" - ) + if invalid > nan_threshold * totalpix: + logger.warning( + "Skipping tile at (%s, %s) due to being over nodata threshold.", + minx, + miny, + ) + logger.warning( + "Threshold: %s, nodata ratio: %s", + nan_threshold, + invalid / totalpix, + ) return None - # rescale image to 1-255 (0 is reserved for nodata) - assert image_statistics is not None, "image_statistics must be provided for multispectral data" - min_vals = np.array([stats['min'] for stats in image_statistics]).reshape(-1, 1, 1) - max_vals = np.array([stats['max'] for stats in image_statistics]).reshape(-1, 1, 1) + # rescale image to 1-255 (0 is reserved for nodata) + assert image_statistics is not None, "image_statistics must be provided for multispectral data" + min_vals = np.array([stats['min'] for stats in image_statistics]).reshape(-1, 1, 1) + max_vals = np.array([stats['max'] for stats in image_statistics]).reshape(-1, 1, 1) # making it a bit safer for small numbers if max_vals.min() > 1: @@ -394,12 +408,14 @@ def process_tile_ms(img_path: str, else: out_img = (out_img - min_vals) * 254 / (max_vals - min_vals) + 1 - # additional clip to make sure - out_img = np.clip(out_img.astype(np.float32), 1.0, 255.0) + # additional clip to make sure (use float32 bounds for mypy compatibility) + out_img = np.clip( + out_img.astype(np.float32), np.float32(1.0), np.float32(255.0) + ) # Apply nan mask across all bands (3D mask) without using np.broadcast_to band_mask = np.stack([nan_mask == 1] * out_img.shape[0], axis=0) - out_img[band_mask] = 0.0 + out_img[band_mask] = np.float32(0.0) dtype, nodata = dtype_map.get(out_img.dtype, (None, None)) if dtype is None: @@ -585,7 +601,7 @@ def _calculate_tile_placements( unioned_crowns = crowns.union_all() else: unioned_crowns = crowns.unary_union - logger.info(f"Finished Union of Crowns") + logger.info("Finished Union of Crowns") area_width = crowns.total_bounds[2] - crowns.total_bounds[0] area_height = crowns.total_bounds[3] - crowns.total_bounds[1] @@ -621,7 +637,7 @@ def _calculate_tile_placements( coordinates.append( (int(intersection.total_bounds[0] - x_intersection_offset) + col * tile_width + tile_width // 2, int(crowns.total_bounds[1] - y_offset) + row * tile_height + tile_height // 2)) - logger.info(f"Finished Tile Placement Generation") + logger.info("Finished Tile Placement Generation") else: raise ValueError('Unsupported tile_placement method. Must be "grid" or "adaptive"') From 4eb401ffb1848c59ce194cde4ac9e90c9df79b53 Mon Sep 17 00:00:00 2001 From: James Ball Date: Mon, 1 Sep 2025 21:39:35 +0100 Subject: [PATCH 12/17] sphinx requirements --- docs/requirements.txt | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/requirements.txt b/docs/requirements.txt index 28c3b7cb..baaee949 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1 +1,5 @@ jinja2==3.0.3 +# Core Sphinx and extensions used by the docs build +sphinx +sphinx_rtd_theme +nbsphinx From 5724652920f5c9055eeb9e13edac7ad2308b491c Mon Sep 17 00:00:00 2001 From: James Ball Date: Mon, 1 Sep 2025 22:05:32 +0100 Subject: [PATCH 13/17] updated tutorial --- docs/source/tutorial.rst | 43 +++++++++++++++++++++++++++++----- docs/source/tutorial_multi.rst | 6 ++--- 2 files changed, 40 insertions(+), 9 deletions(-) diff --git a/docs/source/tutorial.rst b/docs/source/tutorial.rst index 3699d950..52eadf92 100644 --- a/docs/source/tutorial.rst +++ b/docs/source/tutorial.rst @@ -22,8 +22,8 @@ Before getting started ensure ``detectree2`` is installed through (.venv) $pip install git+https://github.com/PatBall1/detectree2.git -To train a model you will need an orthomosaic (as ``.tif``) and -corresponding tree crown polgons that are readable by Geopandas +To train a model you will need an orthomosaic (as ``.tif``) and +corresponding tree crown polygons that are readable by Geopandas (e.g. ``.gpkg``, ``.shp``). For the best results, manual crowns should be supplied as dense clusters rather than sparsely scattered across in the landscape. The method is designed to make @@ -141,7 +141,7 @@ type systems or urban environments). tile_data(img_path, out_dir, buffer, tile_width, tile_height, crowns, threshold, mode="rgb") .. warning:: - If tiles are outputing as blank images set ``dtype_bool = True`` in the ``tile_data`` function. This is a bug + If tiles are outputting as blank images set ``dtype_bool = True`` in the ``tile_data`` function. This is a bug and we are working on fixing it. Supplying crown polygons will cause the function to tile for training (as opposed to landscape prediction which is described below). @@ -151,6 +151,32 @@ type systems or urban environments). closed canopy forests so some of the default assumptions will reflect that and parameters will need to be adjusted for different systems. +Advanced tiling options +----------------------- + +The ``tile_data`` function exposes a few knobs to better control how tiles are created, especially helpful for large +rasters and multispectral data: + +- ``tile_placement``: choose how tile origins are generated. + - ``"grid"`` (default): lays tiles on a fixed grid across the image bounds. Fast and predictable. + - ``"adaptive"``: concentrates tiles where crowns exist by scanning rows that intersect the union of crowns. Requires + supplying ``crowns``; if ``crowns`` is ``None``, it falls back to ``"grid"`` with a warning. +- ``overlapping_tiles``: when ``True``, adds a second set of tiles shifted by half a tile in X and Y (checkerboard + offset). Useful to reduce edge artifacts in predictions, or to capture crowns straddling tile boundaries. It + increases the number of tiles roughly 2x. +- ``ignore_bands_indices``: zero-based indices of bands to skip (multispectral only). These bands are ignored both when + computing image statistics and when writing the output tiles. For example, to exclude band 0 and band 4 in a 5-band + raster, pass ``ignore_bands_indices=[0, 4]``. + +Practical tips: + +- For training with ``crowns``, ``tile_placement="adaptive"`` can reduce I/O by avoiding empty regions while keeping + good coverage. For full-image prediction, stick with ``"grid"``. +- When running prediction, consider ``overlapping_tiles=True`` to reduce seam artifacts; you can later post-process + overlaps (e.g., discard detections near tile borders). +- ``ignore_bands_indices`` is zero-based; Rasterio band numbering is one-based internally, but the function accounts for + this. RGB mode ignores this parameter. + Send geojsons to train folder (with sub-folders for k-fold cross validation) and a test folder. .. code-block:: python @@ -159,7 +185,7 @@ Send geojsons to train folder (with sub-folders for k-fold cross validation) and to_traintest_folders(data_folder, out_dir, test_frac=0.15, strict=False, folds=5) .. note:: - If ``strict=True``, the ``to_traintest_folders`` function will automatically removes training/validation geojsons + If ``strict=True``, the ``to_traintest_folders`` function will automatically remove training/validation geojsons that have any overlap with test tiles (including the buffers), ensuring strict spatial separation of the test data. However, this can remove a significant proportion of the data available to train on so if validation accuracy is a sufficient test of model performance ``test_frac`` can be set to ``0`` or set ``strict=False`` (which allows for @@ -291,7 +317,7 @@ steps to expose the model to the full range of available training data. Register register_train_data(train_location, "Paracou", val_fold=5) The data will be registered as ``_train`` and ``_val`` (or ``Paracou_train`` and ``Paracou_val`` in the -above example). It will be necessary to supply these registation names below... +above example). It will be necessary to supply these registration names below... We must supply a ``base_model`` from Detectron2's ``model_zoo``. This loads a backbone that has been pre-trained which saves us the pain of training a model from scratch. We are effectively transferring this model and (re)training it on @@ -311,6 +337,11 @@ datasets should be tuples containing strings. If just a single site is being use cfg = setup_cfg(base_model, trains, tests, workers = 4, eval_period=100, max_iter=3000, out_dir=out_dir) # update_model arg can be used to load in trained model +.. note:: + ``tile_data`` also supports ``tile_placement`` ("grid" or "adaptive") and options such as + ``overlapping_tiles`` and ``ignore_bands_indices``. The defaults match prior behavior, so existing + examples continue to work, but you can use these parameters to better control tiling when needed. + Alternatively, it is possible to train from one of ``detectree2``'s pre-trained models. This is normally recommended and especially useful if you only have limited training data available. To retrieve the model from the repo's @@ -691,7 +722,7 @@ can discard partial the crowns predicted at the edge of tiles. tile_data(img_path, tiles_path, buffer, tile_width, tile_height, dtype_bool = True) .. warning:: - If tiles are outputing as blank images set ``dtype_bool = True`` in the ``tile_data`` function. This is a bug + If tiles are outputting as blank images set ``dtype_bool = True`` in the ``tile_data`` function. This is a bug and we are working on fixing it. Avoid supplying crown polygons otherwise the function will run as if it is tiling for training. diff --git a/docs/source/tutorial_multi.rst b/docs/source/tutorial_multi.rst index 91548842..9bc30d12 100644 --- a/docs/source/tutorial_multi.rst +++ b/docs/source/tutorial_multi.rst @@ -6,8 +6,8 @@ delineation (e.g. species mapping, disease mapping). A guide to single class prediction is available `here `_ - this covers more detail on the fundamentals of training and should be reviewed before this -tutorial. The multiclassprocess is slightly more intricate than single class -prediction as the classes need to be correctly encoded and caried throughout the pipeline. +tutorial. The multi-class process is slightly more intricate than single class +prediction as the classes need to be correctly encoded and carried throughout the pipeline. The key steps are: @@ -173,4 +173,4 @@ the model. By passing the class mapping file to the configuration set up, the Landscape predictions --------------------- -COMING SOON \ No newline at end of file +COMING SOON From 454ff4e89168b58c93c00cf6c2bd5d9a01c0949d Mon Sep 17 00:00:00 2001 From: Christopher Kotthoff Date: Wed, 16 Jul 2025 18:10:16 +0100 Subject: [PATCH 14/17] rgb contrast and unsafe access fix --- detectree2/preprocessing/tiling.py | 229 +++++++++++++++-------------- 1 file changed, 120 insertions(+), 109 deletions(-) diff --git a/detectree2/preprocessing/tiling.py b/detectree2/preprocessing/tiling.py index 7994b332..0f059d80 100644 --- a/detectree2/preprocessing/tiling.py +++ b/detectree2/preprocessing/tiling.py @@ -134,7 +134,8 @@ def process_tile(img_path: str, additional_nodata: List[Any] = [], image_statistics: List[Dict[str, float]] = None, ignore_bands_indices: List[int] = [], - use_convex_mask: bool = True): + use_convex_mask: bool = True, + enhance_rgb_contrast: bool = True): """Process a single tile for making predictions. Args: @@ -183,9 +184,9 @@ def process_tile(img_path: str, out_img, out_transform = mask(data, shapes=coords, nodata=nodata, crop=True, indexes=[1, 2, 3]) mask_tif = None - if mask_gdf is not None: - # if mask_gdf.crs != data.crs: - # mask_gdf = mask_gdf.to_crs(data.crs) # TODO is this necessary? + if mask_gdf is not None: + # if mask_gdf.crs != data.crs: + # mask_gdf = mask_gdf.to_crs(data.crs) # TODO is this necessary? mask_tif = rasterio.features.geometry_mask([geom for geom in mask_gdf.geometry], transform=out_transform, @@ -199,7 +200,7 @@ def process_tile(img_path: str, unioned_crowns = overlapping_crowns.union_all() else: unioned_crowns = overlapping_crowns.unary_union - convex_mask_tif = rasterio.features.geometry_mask([unioned_crowns.convex_hull.buffer(5)], + convex_mask_tif = rasterio.features.geometry_mask([unioned_crowns.convex_hull.buffer(3)], transform=out_transform, invert=True, out_shape=(out_img.shape[1], out_img.shape[2])) @@ -218,22 +219,31 @@ def process_tile(img_path: str, # If the tile is mostly empty or mostly nan, don't save it invalid = (zero_mask | nan_mask).sum() - if invalid > nan_threshold * totalpix: - logger.warning( - "Skipping tile at (%s, %s) due to being over nodata threshold.", - minx, - miny, - ) - logger.warning( - "Threshold: %s, nodata ratio: %s", - nan_threshold, - invalid / totalpix, - ) + if invalid > nan_threshold * totalpix: + logger.warning( + "Skipping tile at (%s, %s) due to being over nodata threshold.", + minx, + miny, + ) + logger.warning( + "Threshold: %s, nodata ratio: %s", + nan_threshold, + invalid / totalpix, + ) return None - # Apply nan mask across all bands (3D mask) without using np.broadcast_to (typing-friendly) - band_mask = np.stack([nan_mask == 1] * out_img.shape[0], axis=0) - out_img[band_mask] = 0 + if enhance_rgb_contrast: + # rescale image to 1-255 (0 is reserved for nodata) + min_vals, max_vals = np.percentile( + out_img.reshape(3, -1)[:, ~nan_mask.reshape(-1).astype(bool)], [0.2, 99.8]) + + out_img = (out_img - min_vals) / (max_vals - min_vals) * 254 + 1 + + # Apply nan mask + out_img[np.broadcast_to((nan_mask == 1)[None, :, :], out_img.shape)] = 0 + + if enhance_rgb_contrast: + out_img = np.clip(out_img, 0, 255) dtype, nodata = dtype_map.get(out_img.dtype, (None, None)) if dtype is None: @@ -254,26 +264,24 @@ def process_tile(img_path: str, with rasterio.open(out_tif, "w", **out_meta) as dest: dest.write(out_img) - r, g, b = out_img[0], out_img[1], out_img[2] - # Reorder channels to B, G, R for OpenCV (use list for mypy-friendly typing) - rgb = np.stack([b, g, r], axis=2) - - # Rescale to 0-255 if necessary - if np.nanmax(g) > 255: - rgb_rescaled = rgb / 65535 * 255 - else: - rgb_rescaled = rgb + r, g, b = out_img[0], out_img[1], out_img[2] + # Reorder channels to B, G, R for OpenCV (use list for mypy-friendly typing) + rgb = np.stack([b, g, r], axis=2) - rgb_rescaled = np.clip( - rgb_rescaled.astype(np.float32), np.float32(0.0), np.float32(255.0) - ) + if not enhance_rgb_contrast: + # If not enhancing contrast, ensure the dtype is uint8 + if dtype_bool: + rgb = rgb.astype(np.uint8) + else: + rgb = rgb.astype(np.float32) + np.clip(rgb, 0, 255, out=rgb) - cv2.imwrite(str(out_path_root.with_suffix(".png").resolve()), rgb_rescaled.astype(np.uint8)) + cv2.imwrite(str(out_path_root.with_suffix(".png").resolve()), rgb.astype(np.uint8)) if overlapping_crowns is not None: - return data, out_path_root, overlapping_crowns, minx, miny, buffer + return out_transform, out_path_root, overlapping_crowns, minx, miny, buffer - return data, out_path_root, None, minx, miny, buffer + return out_transform, out_path_root, None, minx, miny, buffer except RasterioIOError as e: logger.error(f"RasterioIOError while applying mask {coords}: {e}") @@ -350,9 +358,9 @@ def process_tile_ms(img_path: str, out_img, out_transform = mask(data, shapes=coords, nodata=nodata, crop=True, indexes=bands_to_read) mask_tif = None - if mask_gdf is not None: - # if mask_gdf.crs != data.crs: - # mask_gdf = mask_gdf.to_crs(data.crs) # TODO is this necessary? + if mask_gdf is not None: + # if mask_gdf.crs != data.crs: + # mask_gdf = mask_gdf.to_crs(data.crs) # TODO is this necessary? mask_tif = rasterio.features.geometry_mask([geom for geom in mask_gdf.geometry], transform=out_transform, @@ -384,23 +392,23 @@ def process_tile_ms(img_path: str, # If the tile is mostly empty or mostly nan, don't save it invalid = (zero_mask | nan_mask).sum() - if invalid > nan_threshold * totalpix: - logger.warning( - "Skipping tile at (%s, %s) due to being over nodata threshold.", - minx, - miny, - ) - logger.warning( - "Threshold: %s, nodata ratio: %s", - nan_threshold, - invalid / totalpix, - ) + if invalid > nan_threshold * totalpix: + logger.warning( + "Skipping tile at (%s, %s) due to being over nodata threshold.", + minx, + miny, + ) + logger.warning( + "Threshold: %s, nodata ratio: %s", + nan_threshold, + invalid / totalpix, + ) return None - # rescale image to 1-255 (0 is reserved for nodata) - assert image_statistics is not None, "image_statistics must be provided for multispectral data" - min_vals = np.array([stats['min'] for stats in image_statistics]).reshape(-1, 1, 1) - max_vals = np.array([stats['max'] for stats in image_statistics]).reshape(-1, 1, 1) + # rescale image to 1-255 (0 is reserved for nodata) + assert image_statistics is not None, "image_statistics must be provided for multispectral data" + min_vals = np.array([stats['min'] for stats in image_statistics]).reshape(-1, 1, 1) + max_vals = np.array([stats['max'] for stats in image_statistics]).reshape(-1, 1, 1) # making it a bit safer for small numbers if max_vals.min() > 1: @@ -408,14 +416,14 @@ def process_tile_ms(img_path: str, else: out_img = (out_img - min_vals) * 254 / (max_vals - min_vals) + 1 - # additional clip to make sure (use float32 bounds for mypy compatibility) - out_img = np.clip( - out_img.astype(np.float32), np.float32(1.0), np.float32(255.0) - ) + # additional clip to make sure (use float32 bounds for mypy compatibility) + out_img = np.clip( + out_img.astype(np.float32), np.float32(1.0), np.float32(255.0) + ) - # Apply nan mask across all bands (3D mask) without using np.broadcast_to - band_mask = np.stack([nan_mask == 1] * out_img.shape[0], axis=0) - out_img[band_mask] = np.float32(0.0) + # Apply nan mask across all bands (3D mask) without using np.broadcast_to + band_mask = np.stack([nan_mask == 1] * out_img.shape[0], axis=0) + out_img[band_mask] = np.float32(0.0) dtype, nodata = dtype_map.get(out_img.dtype, (None, None)) if dtype is None: @@ -442,9 +450,9 @@ def process_tile_ms(img_path: str, # cv2.imwrite(str(out_path_root.with_suffix(".png").resolve()), rgb) if overlapping_crowns is not None: - return data, out_path_root, overlapping_crowns, minx, miny, buffer + return out_transform, out_path_root, overlapping_crowns, minx, miny, buffer - return data, out_path_root, None, minx, miny, buffer + return out_transform, out_path_root, None, minx, miny, buffer except RasterioIOError as e: logger.error(f"RasterioIOError while applying mask {coords}: {e}") @@ -474,7 +482,9 @@ def process_tile_train( additional_nodata: List[Any] = [], image_statistics: List[Dict[str, float]] = None, ignore_bands_indices: List[int] = [], - use_convex_mask: bool = True) -> None: + use_convex_mask: bool = True, + enhance_rgb_contrast: bool = True + ) -> None: """Process a single tile for training data. Args: @@ -498,7 +508,7 @@ def process_tile_train( if mode == "rgb": result = process_tile(img_path, out_dir, buffer, tile_width, tile_height, dtype_bool, minx, miny, crs, tilename, crowns, threshold, nan_threshold, mask_gdf, additional_nodata, image_statistics, - ignore_bands_indices, use_convex_mask) + ignore_bands_indices, use_convex_mask, enhance_rgb_contrast) elif mode == "ms": result = process_tile_ms(img_path, out_dir, buffer, tile_width, tile_height, dtype_bool, minx, miny, crs, tilename, crowns, threshold, nan_threshold, mask_gdf, additional_nodata, @@ -508,13 +518,13 @@ def process_tile_train( # logger.warning(f"Skipping tile at ({minx}, {miny}) due to insufficient data.") return - data, out_path_root, overlapping_crowns, minx, miny, buffer = result + out_transform, out_path_root, overlapping_crowns, minx, miny, buffer = result if overlapping_crowns is not None and not overlapping_crowns.empty: overlapping_crowns = overlapping_crowns.explode(index_parts=True) moved = overlapping_crowns.translate(-minx + buffer, -miny + buffer) - scalingx = 1 / (data.transform[0]) - scalingy = -1 / (data.transform[4]) + scalingx = 1 / (out_transform[0]) + scalingy = -1 / (out_transform[4]) moved_scaled = moved.scale(scalingx, scalingy, origin=(0, 0)) if mode == "rgb": @@ -569,25 +579,25 @@ def _calculate_tile_placements( if tile_placement == "grid": with rasterio.open(img_path) as data: - start_x = int(math.ceil(data.bounds[0]) + buffer) - stop_x = int(data.bounds[2] - tile_width - buffer) - start_y = int(math.ceil(data.bounds[1]) + buffer) - stop_y = int(data.bounds[3] - tile_height - buffer) - coordinates = [ - (minx, miny) - for minx in range(start_x, stop_x, tile_width) - for miny in range(start_y, stop_y, tile_height) - ] - if overlapping_tiles: - start_x2 = int(math.ceil(data.bounds[0]) + buffer + tile_width // 2) - stop_x2 = int(data.bounds[2] - tile_width - buffer - tile_width // 2) - start_y2 = int(math.ceil(data.bounds[1]) + buffer + tile_height // 2) - stop_y2 = int(data.bounds[3] - tile_height - buffer - tile_height // 2) - coordinates.extend([ - (minx, miny) - for minx in range(start_x2, stop_x2, tile_width) - for miny in range(start_y2, stop_y2, tile_height) - ]) + start_x = int(math.ceil(data.bounds[0]) + buffer) + stop_x = int(data.bounds[2] - tile_width - buffer) + start_y = int(math.ceil(data.bounds[1]) + buffer) + stop_y = int(data.bounds[3] - tile_height - buffer) + coordinates = [ + (minx, miny) + for minx in range(start_x, stop_x, tile_width) + for miny in range(start_y, stop_y, tile_height) + ] + if overlapping_tiles: + start_x2 = int(math.ceil(data.bounds[0]) + buffer + tile_width // 2) + stop_x2 = int(data.bounds[2] - tile_width - buffer - tile_width // 2) + start_y2 = int(math.ceil(data.bounds[1]) + buffer + tile_height // 2) + stop_y2 = int(data.bounds[3] - tile_height - buffer - tile_height // 2) + coordinates.extend([ + (minx, miny) + for minx in range(start_x2, stop_x2, tile_width) + for miny in range(start_y2, stop_y2, tile_height) + ]) elif tile_placement == "adaptive": if crowns is None: @@ -601,7 +611,7 @@ def _calculate_tile_placements( unioned_crowns = crowns.union_all() else: unioned_crowns = crowns.unary_union - logger.info("Finished Union of Crowns") + logger.info("Finished Union of Crowns") area_width = crowns.total_bounds[2] - crowns.total_bounds[0] area_height = crowns.total_bounds[3] - crowns.total_bounds[1] @@ -637,7 +647,7 @@ def _calculate_tile_placements( coordinates.append( (int(intersection.total_bounds[0] - x_intersection_offset) + col * tile_width + tile_width // 2, int(crowns.total_bounds[1] - y_offset) + row * tile_height + tile_height // 2)) - logger.info("Finished Tile Placement Generation") + logger.info("Finished Tile Placement Generation") else: raise ValueError('Unsupported tile_placement method. Must be "grid" or "adaptive"') @@ -753,23 +763,23 @@ def calc_on_everything(): # Compute statistics for each band band_stats = [] - for band_idx in range(1, len(band_aggregates) + 1) if mode == "ms" else range(1, 4): - valid_data = np.array(band_aggregates[band_idx]) - if valid_data.size > 0: - min_val, max_val = np.percentile(valid_data, [1, 99]) - stats = { - "mean": float(np.mean(valid_data)), - "min": float(min_val), - "max": float(max_val), - "std_dev": float(np.std(valid_data)), - } - else: - stats = { - "mean": float("nan"), - "min": float("nan"), - "max": float("nan"), - "std_dev": float("nan"), - } + for band_idx in range(1, len(band_aggregates) + 1) if mode == "ms" else range(1, 4): + valid_data = np.array(band_aggregates[band_idx]) + if valid_data.size > 0: + min_val, max_val = np.percentile(valid_data, [1, 99]) + stats = { + "mean": float(np.mean(valid_data)), + "min": float(min_val), + "max": float(max_val), + "std_dev": float(np.std(valid_data)), + } + else: + stats = { + "mean": float("nan"), + "min": float("nan"), + "max": float("nan"), + "std_dev": float("nan"), + } band_stats.append(stats) return band_stats @@ -794,6 +804,7 @@ def tile_data( overlapping_tiles: bool = False, ignore_bands_indices: List[int] = [], use_convex_mask: bool = True, + enhance_rgb_contrast: bool = True, ) -> None: """Tiles up orthomosaic and corresponding crowns (if supplied) into training/prediction tiles. @@ -841,7 +852,7 @@ def tile_data( tile_args = [ (img_path, out_dir, buffer, tile_width, tile_height, dtype_bool, minx, miny, crs, tilename, crowns, threshold, nan_threshold, mode, class_column, mask_gdf, additional_nodata, image_statistics, ignore_bands_indices, - use_convex_mask) for minx, miny in tile_coordinates + use_convex_mask, enhance_rgb_contrast) for minx, miny in tile_coordinates if mask_path is None or (mask_path is not None and mask_gdf.intersects( box(minx, miny, minx + tile_width, miny + tile_height) #TODO maybe add to_crs here ).any()) @@ -1040,8 +1051,8 @@ def create_RGB_from_MS(tile_folder_path: Union[str, Path], # Write the PNG (we must convert shape to (H, W, 3) and then to uint8) output_png = out_path / f"{tif_file.stem}.png" - # Move axis from (bands, H, W) -> (H, W, bands) - png_ready = transformed.transpose(1, 2, 0).astype(np.uint8) + # Move axis from (bands, H, W) -> (H, W, bands) + png_ready = transformed.transpose(1, 2, 0).astype(np.uint8) cv2.imwrite(str(output_png), cv2.cvtColor(png_ready, cv2.COLOR_RGB2BGR)) elif conversion == "first-three": @@ -1080,7 +1091,7 @@ def create_RGB_from_MS(tile_folder_path: Union[str, Path], # Write out the PNG (shape must be (H, W, 3)) output_png = out_path / f"{tif_file.stem}.png" # Move axis from (bands, H, W) -> (H, W, bands) - png_ready = data.transpose(1, 2, 0).astype(np.uint8) + png_ready = data.transpose(1, 2, 0).astype(np.uint8) # We expect the order to be [band1, band2, band3], so interpret as R,G,B cv2.imwrite(str(output_png), cv2.cvtColor(png_ready, cv2.COLOR_RGB2BGR)) @@ -1330,7 +1341,7 @@ def to_traintest_folders( # noqa: C901 # random.shuffle(indices) num = list(range(0, len(file_roots))) random.shuffle(num) - ind_split = np.array_split(np.array(file_roots), folds) + ind_split = np.array_split(np.array(file_roots), folds) for i in range(0, folds): Path(out_dir / f"train/fold_{i + 1}").mkdir(parents=True, exist_ok=True) From ebfd0362d7989eb4da3f9da43f795fd3e50e1374 Mon Sep 17 00:00:00 2001 From: Christopher Kotthoff Date: Sun, 20 Jul 2025 19:22:18 +0100 Subject: [PATCH 15/17] experimental numpy fix --- .github/workflows/python-ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/python-ci.yml b/.github/workflows/python-ci.yml index a8fec3ca..3f5f7f4d 100644 --- a/.github/workflows/python-ci.yml +++ b/.github/workflows/python-ci.yml @@ -59,6 +59,7 @@ jobs: flake8 detectree2 --count --exit-zero --max-complexity=10 --statistics - name: pytest checks run: | + pip install --upgrade --force-reinstall "numpy>=1.20,<2.0" pip install pytest-order pytest - name: mypy checks From ab957ba4f8ff50c2dda6c4abc3a583f9fb7859fd Mon Sep 17 00:00:00 2001 From: Christopher Kotthoff Date: Sun, 20 Jul 2025 19:45:18 +0100 Subject: [PATCH 16/17] fixes a lot of accumulated mypy problem --- detectree2/models/predict.py | 5 +- detectree2/models/train.py | 15 ++++-- detectree2/preprocessing/tiling.py | 76 ++++++++++++++---------------- 3 files changed, 48 insertions(+), 48 deletions(-) diff --git a/detectree2/models/predict.py b/detectree2/models/predict.py index 2e389995..db9dd154 100644 --- a/detectree2/models/predict.py +++ b/detectree2/models/predict.py @@ -65,10 +65,11 @@ def predict_on_data( file_ext = os.path.splitext(file_name)[1].lower() if file_ext == ".png": # RGB image, read with cv2 - img = cv2.imread(file_name) - if img is None: + cv_img = cv2.imread(file_name) + if cv_img is None: print(f"Failed to read image {file_name} with cv2.") continue + img = np.array(cv_img) # Explicitly convert to numpy array elif file_ext == ".tif": # Multispectral image, read with rasterio with rasterio.open(file_name) as src: diff --git a/detectree2/models/train.py b/detectree2/models/train.py index 4ab8add9..f2edda17 100644 --- a/detectree2/models/train.py +++ b/detectree2/models/train.py @@ -748,10 +748,10 @@ def get_tree_dicts(directory: str, class_mapping: Optional[Dict[str, int]] = Non # Make sure we have the correct height and width # If image path ends in .png use cv2 to get height and width else if image path ends in .tif use rasterio if filename.endswith(".png"): - img_arr = cv2.imread(filename) - if img_arr is None: - raise FileNotFoundError(f"Failed to read image at path: {filename}") - height, width = img_arr.shape[:2] + img = cv2.imread(filename) + if img is None: + continue + height, width = img.shape[:2] elif filename.endswith(".tif"): with rasterio.open(filename) as src: height, width = src.shape @@ -781,7 +781,12 @@ def get_tree_dicts(directory: str, class_mapping: Optional[Dict[str, int]] = Non category_id = 0 # Default to "tree" if no class mapping is provided obj = { - "bbox": [min(px), min(py), max(px), max(py)], + "bbox": [ + np.min(np.array(px)), + np.min(np.array(py)), + np.max(np.array(px)), + np.max(np.array(py)), + ], "bbox_mode": BoxMode.XYXY_ABS, "segmentation": [poly], "category_id": category_id, diff --git a/detectree2/preprocessing/tiling.py b/detectree2/preprocessing/tiling.py index 0f059d80..f1316b3c 100644 --- a/detectree2/preprocessing/tiling.py +++ b/detectree2/preprocessing/tiling.py @@ -240,7 +240,7 @@ def process_tile(img_path: str, out_img = (out_img - min_vals) / (max_vals - min_vals) * 254 + 1 # Apply nan mask - out_img[np.broadcast_to((nan_mask == 1)[None, :, :], out_img.shape)] = 0 + out_img[np.broadcast_to((nan_mask == 1)[None, :, :], out_img.shape)] = 0 # type: ignore[attr-defined] if enhance_rgb_contrast: out_img = np.clip(out_img, 0, 255) @@ -265,8 +265,7 @@ def process_tile(img_path: str, dest.write(out_img) r, g, b = out_img[0], out_img[1], out_img[2] - # Reorder channels to B, G, R for OpenCV (use list for mypy-friendly typing) - rgb = np.stack([b, g, r], axis=2) + rgb = np.dstack((b, g, r)) # type: ignore[attr-defined] # Reorder for cv2 (BGRA) if not enhance_rgb_contrast: # If not enhancing contrast, ensure the dtype is uint8 @@ -274,7 +273,7 @@ def process_tile(img_path: str, rgb = rgb.astype(np.uint8) else: rgb = rgb.astype(np.float32) - np.clip(rgb, 0, 255, out=rgb) + np.clip(rgb, 0, 255, out=rgb) # type: ignore[call-arg] cv2.imwrite(str(out_path_root.with_suffix(".png").resolve()), rgb.astype(np.uint8)) @@ -421,9 +420,8 @@ def process_tile_ms(img_path: str, out_img.astype(np.float32), np.float32(1.0), np.float32(255.0) ) - # Apply nan mask across all bands (3D mask) without using np.broadcast_to - band_mask = np.stack([nan_mask == 1] * out_img.shape[0], axis=0) - out_img[band_mask] = np.float32(0.0) + # Apply nan mask + out_img[np.broadcast_to((nan_mask == 1)[None, :, :], out_img.shape)] = 0.0 # type: ignore[attr-defined] dtype, nodata = dtype_map.get(out_img.dtype, (None, None)) if dtype is None: @@ -576,28 +574,22 @@ def _calculate_tile_placements( overlapping_tiles: bool = False, ) -> List[Tuple[int, int]]: """Internal method for calculating the placement of tiles""" - + coordinates: List[Tuple[int, int]] = [] if tile_placement == "grid": with rasterio.open(img_path) as data: - start_x = int(math.ceil(data.bounds[0]) + buffer) - stop_x = int(data.bounds[2] - tile_width - buffer) - start_y = int(math.ceil(data.bounds[1]) + buffer) - stop_y = int(data.bounds[3] - tile_height - buffer) - coordinates = [ - (minx, miny) - for minx in range(start_x, stop_x, tile_width) - for miny in range(start_y, stop_y, tile_height) + grid_coords = [ + (int(minx), int(miny)) for minx in np.arange( + int(math.ceil(data.bounds[0])) + buffer, int(data.bounds[2] - tile_width - buffer), tile_width) + for miny in np.arange( + int(math.ceil(data.bounds[1])) + buffer, int(data.bounds[3] - tile_height - buffer), tile_height) ] if overlapping_tiles: - start_x2 = int(math.ceil(data.bounds[0]) + buffer + tile_width // 2) - stop_x2 = int(data.bounds[2] - tile_width - buffer - tile_width // 2) - start_y2 = int(math.ceil(data.bounds[1]) + buffer + tile_height // 2) - stop_y2 = int(data.bounds[3] - tile_height - buffer - tile_height // 2) - coordinates.extend([ - (minx, miny) - for minx in range(start_x2, stop_x2, tile_width) - for miny in range(start_y2, stop_y2, tile_height) - ]) + grid_coords.extend([(int(minx), int(miny)) for minx in np.arange( + int(math.ceil(data.bounds[0])) + buffer + tile_width // 2, int(data.bounds[2] - tile_width - buffer - + tile_width // 2), tile_width) for miny in np.arange( + int(math.ceil(data.bounds[1])) + buffer + tile_height // 2, int(data.bounds[3] - tile_height - buffer - + tile_height // 2), tile_height)]) + coordinates = grid_coords elif tile_placement == "adaptive": if crowns is None: @@ -623,7 +615,6 @@ def _calculate_tile_placements( y_offset = (combined_tiles_height - area_height) / 2 logger.info("Starting Tile Placement Generation") - coordinates = [] for row in range(required_tiles_y): bar = gpd.GeoSeries([ box(crowns.total_bounds[0] - x_offset, crowns.total_bounds[1] - y_offset + row * tile_height, @@ -694,17 +685,17 @@ def calc_on_everything(): min_val, max_val = np.percentile(valid_data, [1, 99]) stats = { - "mean": np.mean(valid_data), - "min": min_val, - "max": max_val, - "std_dev": np.std(valid_data), + "mean": float(np.mean(valid_data)), + "min": float(min_val), + "max": float(max_val), + "std_dev": float(np.std(valid_data)), } else: stats = { - "mean": None, - "min": None, - "max": None, - "std_dev": None, + "mean": np.nan, + "min": np.nan, + "max": np.nan, + "std_dev": np.nan, } band_stats.append(stats) return band_stats @@ -772,13 +763,17 @@ def calc_on_everything(): "min": float(min_val), "max": float(max_val), "std_dev": float(np.std(valid_data)), + "mean": float(np.mean(valid_data)), + "min": float(min_val), + "max": float(max_val), + "std_dev": float(np.std(valid_data)), } else: stats = { - "mean": float("nan"), - "min": float("nan"), - "max": float("nan"), - "std_dev": float("nan"), + "mean": np.nan, + "min": np.nan, + "max": np.nan, + "std_dev": np.nan, } band_stats.append(stats) return band_stats @@ -1051,8 +1046,7 @@ def create_RGB_from_MS(tile_folder_path: Union[str, Path], # Write the PNG (we must convert shape to (H, W, 3) and then to uint8) output_png = out_path / f"{tif_file.stem}.png" - # Move axis from (bands, H, W) -> (H, W, bands) - png_ready = transformed.transpose(1, 2, 0).astype(np.uint8) + png_ready = np.moveaxis(transformed, 0, -1).astype(np.uint8) # type: ignore[attr-defined] # (H, W, 3) cv2.imwrite(str(output_png), cv2.cvtColor(png_ready, cv2.COLOR_RGB2BGR)) elif conversion == "first-three": @@ -1091,7 +1085,7 @@ def create_RGB_from_MS(tile_folder_path: Union[str, Path], # Write out the PNG (shape must be (H, W, 3)) output_png = out_path / f"{tif_file.stem}.png" # Move axis from (bands, H, W) -> (H, W, bands) - png_ready = data.transpose(1, 2, 0).astype(np.uint8) + png_ready = np.moveaxis(data, 0, -1).astype(np.uint8) # type: ignore[attr-defined] # We expect the order to be [band1, band2, band3], so interpret as R,G,B cv2.imwrite(str(output_png), cv2.cvtColor(png_ready, cv2.COLOR_RGB2BGR)) From 6d20c779afded69d7dcd1c7cfcf774f9dbe9155b Mon Sep 17 00:00:00 2001 From: ChristopherKotthoff Date: Mon, 6 Oct 2025 15:02:32 +0200 Subject: [PATCH 17/17] cleanup fix --- detectree2/models/predict.py | 1 + detectree2/preprocessing/tiling.py | 2 -- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/detectree2/models/predict.py b/detectree2/models/predict.py index db9dd154..b87e4422 100644 --- a/detectree2/models/predict.py +++ b/detectree2/models/predict.py @@ -7,6 +7,7 @@ from pathlib import Path import cv2 +import numpy as np import rasterio from detectron2.engine import DefaultPredictor from detectron2.evaluation.coco_evaluation import instances_to_coco_json diff --git a/detectree2/preprocessing/tiling.py b/detectree2/preprocessing/tiling.py index ee282ed3..46df1982 100644 --- a/detectree2/preprocessing/tiling.py +++ b/detectree2/preprocessing/tiling.py @@ -136,8 +136,6 @@ def process_tile(img_path: str, ignore_bands_indices: List[int] = [], use_convex_mask: bool = True, enhance_rgb_contrast: bool = True): - use_convex_mask: bool = True, - enhance_rgb_contrast: bool = True): """Process a single tile for making predictions. Args: